libgav1: update snapshot to cl/263885228

+cl/264669395
significantly improves performance, this version now outperforms the
libaom based decoder.

Bug: 130249450
Test: aosp_arm-eng aosp_arm64-eng aosp_x86-eng aosp_x86_64-eng build
Test: cts-tradefed run commandAndExit cts-dev --include-filter CtsMediaTestCases --module-arg "CtsMediaTestCases:include-filter:android.media.cts.DecoderTest#testAV1*" --module-arg "CtsMediaTestCases:include-filter:android.media.cts.AdaptivePlaybackTest#testAV1*" (bonito-userdebug)
Test: (externally) libaom/Argon/Allegro test vectors, cpu/memory performance, fuzzing

Change-Id: Ie23ebe0d761bf1896a7acab112b7f2574d7b1065
(cherry picked from commit 3ef2dfa5357309adbaa94c82511a3e9003926da2)
diff --git a/Android.bp b/Android.bp
index 946aa99..30eca18 100644
--- a/Android.bp
+++ b/Android.bp
@@ -40,13 +40,21 @@
         "libgav1/src/buffer_pool.cc",
         "libgav1/src/decoder.cc",
         "libgav1/src/decoder_impl.cc",
+        "libgav1/src/dsp/arm/average_blend_neon.cc",
+        "libgav1/src/dsp/arm/convolve_neon.cc",
+        "libgav1/src/dsp/arm/distance_weighted_blend_neon.cc",
+        "libgav1/src/dsp/arm/intra_edge_neon.cc",
         "libgav1/src/dsp/arm/intrapred_cfl_neon.cc",
         "libgav1/src/dsp/arm/intrapred_directional_neon.cc",
         "libgav1/src/dsp/arm/intrapred_filter_intra_neon.cc",
         "libgav1/src/dsp/arm/intrapred_neon.cc",
         "libgav1/src/dsp/arm/intrapred_smooth_neon.cc",
+        "libgav1/src/dsp/arm/inverse_transform_neon.cc",
         "libgav1/src/dsp/arm/loop_filter_neon.cc",
         "libgav1/src/dsp/arm/loop_restoration_neon.cc",
+        "libgav1/src/dsp/arm/mask_blend_neon.cc",
+        "libgav1/src/dsp/arm/obmc_neon.cc",
+        "libgav1/src/dsp/arm/warp_neon.cc",
         "libgav1/src/dsp/average_blend.cc",
         "libgav1/src/dsp/cdef.cc",
         "libgav1/src/dsp/constants.cc",
@@ -60,9 +68,12 @@
         "libgav1/src/dsp/inverse_transform.cc",
         "libgav1/src/dsp/loop_filter.cc",
         "libgav1/src/dsp/loop_restoration.cc",
-        "libgav1/src/dsp/mask_blending.cc",
+        "libgav1/src/dsp/mask_blend.cc",
         "libgav1/src/dsp/obmc.cc",
         "libgav1/src/dsp/warp.cc",
+        "libgav1/src/dsp/x86/average_blend_sse4.cc",
+        "libgav1/src/dsp/x86/convolve_sse4.cc",
+        "libgav1/src/dsp/x86/distance_weighted_blend_sse4.cc",
         "libgav1/src/dsp/x86/intra_edge_sse4.cc",
         "libgav1/src/dsp/x86/intrapred_cfl_sse4.cc",
         "libgav1/src/dsp/x86/intrapred_smooth_sse4.cc",
@@ -70,6 +81,7 @@
         "libgav1/src/dsp/x86/inverse_transform_sse4.cc",
         "libgav1/src/dsp/x86/loop_filter_sse4.cc",
         "libgav1/src/dsp/x86/loop_restoration_sse4.cc",
+        "libgav1/src/dsp/x86/obmc_sse4.cc",
         "libgav1/src/internal_frame_buffer_list.cc",
         "libgav1/src/loop_filter_mask.cc",
         "libgav1/src/loop_restoration_info.cc",
@@ -88,7 +100,6 @@
         "libgav1/src/tile/bitstream/transform_size.cc",
         "libgav1/src/tile/prediction.cc",
         "libgav1/src/tile/tile.cc",
-        "libgav1/src/utils/allocator.cc",
         "libgav1/src/utils/bit_reader.cc",
         "libgav1/src/utils/block_parameters_holder.cc",
         "libgav1/src/utils/constants.cc",
@@ -97,7 +108,6 @@
         "libgav1/src/utils/logging.cc",
         "libgav1/src/utils/parameter_tree.cc",
         "libgav1/src/utils/raw_bit_reader.cc",
-        "libgav1/src/utils/scan.cc",
         "libgav1/src/utils/segmentation.cc",
         "libgav1/src/utils/segmentation_map.cc",
         "libgav1/src/utils/threadpool.cc",
diff --git a/libgav1/src/decoder.cc b/libgav1/src/decoder.cc
index b4251f2..8767887 100644
--- a/libgav1/src/decoder.cc
+++ b/libgav1/src/decoder.cc
@@ -1,9 +1,6 @@
 #include "src/decoder.h"
 
-#include <new>
-
 #include "src/decoder_impl.h"
-#include "src/utils/logging.h"
 
 namespace libgav1 {
 
@@ -13,15 +10,9 @@
 
 StatusCode Decoder::Init(const DecoderSettings* const settings) {
   if (initialized_) return kLibgav1StatusAlready;
-  if (settings != nullptr) {
-    if (settings->threads < 0) return kLibgav1StatusInvalidArgument;
-    settings_ = *settings;
-  }
-  impl_.reset(new (std::nothrow) DecoderImpl(&settings_));
-  if (impl_ == nullptr) {
-    LIBGAV1_DLOG(ERROR, "Failed to allocate DecoderImpl.");
-    return kLibgav1StatusOutOfMemory;
-  }
+  if (settings != nullptr) settings_ = *settings;
+  const StatusCode status = DecoderImpl::Create(&settings_, &impl_);
+  if (status != kLibgav1StatusOk) return status;
   initialized_ = true;
   return kLibgav1StatusOk;
 }
@@ -41,4 +32,7 @@
   return settings_.frame_parallel ? settings_.threads : 1;
 }
 
+// static.
+int Decoder::GetMaxBitdepth() { return DecoderImpl::GetMaxBitdepth(); }
+
 }  // namespace libgav1
diff --git a/libgav1/src/decoder.h b/libgav1/src/decoder.h
index 274b41c..b01107d 100644
--- a/libgav1/src/decoder.h
+++ b/libgav1/src/decoder.h
@@ -8,13 +8,14 @@
 #include "src/decoder_buffer.h"
 #include "src/decoder_settings.h"
 #include "src/status_code.h"
+#include "src/symbol_visibility.h"
 
 namespace libgav1 {
 
 // Forward declaration.
 class DecoderImpl;
 
-class Decoder {
+class LIBGAV1_PUBLIC Decoder {
  public:
   Decoder();
   ~Decoder();
@@ -55,6 +56,9 @@
   // is false, then this function will always return 1.
   int GetMaxAllowedFrames() const;
 
+  // Returns the maximum bitdepth that is supported by this decoder.
+  static int GetMaxBitdepth();
+
  private:
   bool initialized_ = false;
   DecoderSettings settings_;
diff --git a/libgav1/src/decoder_buffer.h b/libgav1/src/decoder_buffer.h
index 477acb7..6806603 100644
--- a/libgav1/src/decoder_buffer.h
+++ b/libgav1/src/decoder_buffer.h
@@ -4,6 +4,7 @@
 #include <cstdint>
 
 #include "src/frame_buffer.h"
+#include "src/symbol_visibility.h"
 
 // All the declarations in this file are part of the public ABI.
 
@@ -23,7 +24,7 @@
   kImageFormatMonochrome400
 };
 
-struct DecoderBuffer {
+struct LIBGAV1_PUBLIC DecoderBuffer {
   int NumPlanes() const {
     return (image_format == kImageFormatMonochrome400) ? 1 : 3;
   }
diff --git a/libgav1/src/decoder_impl.cc b/libgav1/src/decoder_impl.cc
index a01d2ee..dbfaf4c 100644
--- a/libgav1/src/decoder_impl.cc
+++ b/libgav1/src/decoder_impl.cc
@@ -1,10 +1,9 @@
 #include "src/decoder_impl.h"
 
 #include <algorithm>
+#include <atomic>
 #include <cassert>
-#include <condition_variable>  // NOLINT (unapproved c++11 header)
 #include <iterator>
-#include <mutex>  // NOLINT (unapproved c++11 header)
 #include <new>
 #include <utility>
 
@@ -16,6 +15,7 @@
 #include "src/post_filter.h"
 #include "src/prediction_mask.h"
 #include "src/quantizer.h"
+#include "src/utils/blocking_counter.h"
 #include "src/utils/common.h"
 #include "src/utils/logging.h"
 #include "src/utils/parameter_tree.h"
@@ -50,6 +50,24 @@
 
 }  // namespace
 
+// static
+StatusCode DecoderImpl::Create(const DecoderSettings* settings,
+                               std::unique_ptr<DecoderImpl>* output) {
+  if (settings->threads <= 0) {
+    LIBGAV1_DLOG(ERROR, "Invalid settings->threads: %d.", settings->threads);
+    return kLibgav1StatusInvalidArgument;
+  }
+  std::unique_ptr<DecoderImpl> impl(new (std::nothrow) DecoderImpl(settings));
+  if (impl == nullptr) {
+    LIBGAV1_DLOG(ERROR, "Failed to allocate DecoderImpl.");
+    return kLibgav1StatusOutOfMemory;
+  }
+  const StatusCode status = impl->Init();
+  if (status != kLibgav1StatusOk) return status;
+  *output = std::move(impl);
+  return kLibgav1StatusOk;
+}
+
 DecoderImpl::DecoderImpl(const DecoderSettings* settings)
     : buffer_pool_(*settings), settings_(*settings) {
   dsp::DspInit();
@@ -66,17 +84,27 @@
   }
 }
 
+StatusCode DecoderImpl::Init() {
+  const int max_allowed_frames =
+      settings_.frame_parallel ? settings_.threads : 1;
+  assert(max_allowed_frames > 0);
+  if (!encoded_frames_.Init(max_allowed_frames)) {
+    LIBGAV1_DLOG(ERROR, "encoded_frames_.Init() failed.");
+    return kLibgav1StatusOutOfMemory;
+  }
+  return kLibgav1StatusOk;
+}
+
 StatusCode DecoderImpl::EnqueueFrame(const uint8_t* data, size_t size,
                                      int64_t user_private_data) {
   if (data == nullptr) {
     // This has to actually flush the decoder.
     return kLibgav1StatusOk;
   }
-  int max_allowed_frames = settings_.frame_parallel ? settings_.threads : 1;
-  if (encoded_frames_.size() >= static_cast<size_t>(max_allowed_frames)) {
+  if (encoded_frames_.Full()) {
     return kLibgav1StatusResourceExhausted;
   }
-  encoded_frames_.emplace_back(data, size, user_private_data);
+  encoded_frames_.Push(EncodedFrame(data, size, user_private_data));
   return kLibgav1StatusOk;
 }
 
@@ -96,19 +124,21 @@
   // We assume a call to DequeueFrame() indicates that the caller is no longer
   // using the previous output frame, so we can release it.
   ReleaseOutputFrame();
-  if (encoded_frames_.empty()) {
+  if (encoded_frames_.Empty()) {
     // No encoded frame to decode. Not an error.
     *out_ptr = nullptr;
     return kLibgav1StatusOk;
   }
-  const EncodedFrame encoded_frame = encoded_frames_[0];
-  encoded_frames_.erase(encoded_frames_.begin());
+  const EncodedFrame encoded_frame = encoded_frames_.Pop();
   std::unique_ptr<ObuParser> obu(new (std::nothrow) ObuParser(
       encoded_frame.data, encoded_frame.size, &state_));
   if (obu == nullptr) {
     LIBGAV1_DLOG(ERROR, "Failed to initialize OBU parser.");
     return kLibgav1StatusOutOfMemory;
   }
+  if (state_.has_sequence_header) {
+    obu->set_sequence_header(state_.sequence_header);
+  }
   RefCountedBufferPtrCleanup current_frame_cleanup(&state_.current_frame);
   RefCountedBufferPtr displayable_frame;
   StatusCode status;
@@ -119,14 +149,16 @@
       return kLibgav1StatusResourceExhausted;
     }
 
-    obu->set_sequence_header(state_.sequence_header);
     if (!obu->ParseOneFrame()) {
       LIBGAV1_DLOG(ERROR, "Failed to parse OBU.");
       return kLibgav1StatusUnknownError;
     }
-    if (std::find(obu->types().begin(), obu->types().end(),
-                  kObuSequenceHeader) != obu->types().end()) {
+    if (std::find_if(obu->obu_headers().begin(), obu->obu_headers().end(),
+                     [](const ObuHeader& obu_header) {
+                       return obu_header.type == kObuSequenceHeader;
+                     }) != obu->obu_headers().end()) {
       state_.sequence_header = obu->sequence_header();
+      state_.has_sequence_header = true;
     }
     if (!obu->frame_header().show_existing_frame) {
       if (obu->tile_groups().empty()) {
@@ -355,6 +387,9 @@
       obu->frame_header().rows4x4 + kMaxBlockHeight4x4,
       obu->frame_header().columns4x4 + kMaxBlockWidth4x4,
       obu->sequence_header().use_128x128_superblock);
+  if (!block_parameters_holder.Init()) {
+    return kLibgav1StatusOutOfMemory;
+  }
   const dsp::Dsp* const dsp =
       dsp::GetDspTable(obu->sequence_header().color_config.bitdepth);
   if (dsp == nullptr) {
@@ -384,8 +419,11 @@
   const uint8_t tile_size_bytes = obu->frame_header().tile_info.tile_size_bytes;
   const int tile_count = obu->tile_groups().back().end + 1;
   assert(tile_count >= 1);
-  std::vector<std::unique_ptr<Tile>> tiles;
-  tiles.reserve(tile_count);
+  Vector<std::unique_ptr<Tile>> tiles;
+  if (!tiles.reserve(tile_count)) {
+    LIBGAV1_DLOG(ERROR, "tiles.reserve(%d) failed.\n", tile_count);
+    return kLibgav1StatusOutOfMemory;
+  }
   if (!threading_strategy_.Reset(obu->frame_header(), settings_.threads)) {
     return kLibgav1StatusOutOfMemory;
   }
@@ -412,31 +450,60 @@
     }
   }
 
+  const bool do_cdef =
+      PostFilter::DoCdef(obu->frame_header(), settings_.post_filter_mask);
+  const int num_planes = obu->sequence_header().color_config.is_monochrome
+                             ? kMaxPlanesMonochrome
+                             : kMaxPlanes;
+  const bool do_restoration =
+      PostFilter::DoRestoration(obu->frame_header().loop_restoration,
+                                settings_.post_filter_mask, num_planes);
   if (threading_strategy_.post_filter_thread_pool() != nullptr &&
-      PostFilter::DoRestoration(
-          obu->frame_header().loop_restoration, settings_.post_filter_mask,
-          obu->sequence_header().color_config.is_monochrome
-              ? kMaxPlanesMonochrome
-              : kMaxPlanes)) {
-    const size_t threaded_loop_restoration_buffer_size =
-        PostFilter::kRestorationWindowWidth *
-        PostFilter::GetRestorationWindowHeight(
+      (do_cdef || do_restoration)) {
+    const int window_buffer_width = PostFilter::GetWindowBufferWidth(
+        threading_strategy_.post_filter_thread_pool(), obu->frame_header());
+    size_t threaded_window_buffer_size =
+        window_buffer_width *
+        PostFilter::GetWindowBufferHeight(
             threading_strategy_.post_filter_thread_pool(),
             obu->frame_header()) *
         (obu->sequence_header().color_config.bitdepth == 8 ? sizeof(uint8_t)
                                                            : sizeof(uint16_t));
-    if (threaded_loop_restoration_buffer_size_ <
-        threaded_loop_restoration_buffer_size) {
-      threaded_loop_restoration_buffer_.reset(
-          new (std::nothrow) uint8_t[threaded_loop_restoration_buffer_size]);
-      if (threaded_loop_restoration_buffer_ == nullptr) {
+    if (do_cdef && !do_restoration) {
+      // TODO(chengchen): for cdef U, V planes, if there's subsampling, we can
+      // use smaller buffer.
+      threaded_window_buffer_size *= num_planes;
+    }
+    if (threaded_window_buffer_size_ < threaded_window_buffer_size) {
+      // threaded_window_buffer_ will be subdivided by PostFilter into windows
+      // of width 512 pixels. Each row in the window is filtered by a worker
+      // thread. To avoid false sharing, each 512-pixel row processed by one
+      // thread should not share a cache line with a row processed by another
+      // thread. So we align threaded_window_buffer_ to the cache line size.
+      // In addition, it is faster to memcpy from an aligned buffer.
+      //
+      // On Linux, the cache line size can be looked up with the command:
+      //   getconf LEVEL1_DCACHE_LINESIZE
+      //
+      // The cache line size should ideally be queried at run time. 64 is a
+      // common cache line size of x86 CPUs. Web searches showed the cache line
+      // size of ARM CPUs is 32 or 64 bytes. So aligning to 64-byte boundary
+      // will work for all CPUs that we care about, even though it is excessive
+      // for some ARM CPUs.
+      constexpr size_t kCacheLineSize = 64;
+      // To avoid false sharing, PostFilter's window width in bytes should also
+      // be a multiple of the cache line size. For simplicity, we check the
+      // window width in pixels.
+      assert(window_buffer_width % kCacheLineSize == 0);
+      threaded_window_buffer_ = MakeAlignedUniquePtr<uint8_t>(
+          kCacheLineSize, threaded_window_buffer_size);
+      if (threaded_window_buffer_ == nullptr) {
         LIBGAV1_DLOG(ERROR,
                      "Failed to allocate threaded loop restoration buffer.\n");
-        threaded_loop_restoration_buffer_size_ = 0;
+        threaded_window_buffer_size_ = 0;
         return kLibgav1StatusOutOfMemory;
       }
-      threaded_loop_restoration_buffer_size_ =
-          threaded_loop_restoration_buffer_size;
+      threaded_window_buffer_size_ = threaded_window_buffer_size;
     }
   }
 
@@ -445,9 +512,10 @@
       cdef_index, &loop_restoration_info, &block_parameters_holder,
       state_.current_frame->buffer(), dsp,
       threading_strategy_.post_filter_thread_pool(),
-      threaded_loop_restoration_buffer_.get(), settings_.post_filter_mask);
+      threaded_window_buffer_.get(), settings_.post_filter_mask);
   SymbolDecoderContext saved_symbol_decoder_context;
   int tile_index = 0;
+  BlockingCounterWithStatus pending_tiles(tile_count);
   for (const auto& tile_group : obu->tile_groups()) {
     size_t bytes_left = tile_group.data_size;
     size_t byte_offset = 0;
@@ -484,86 +552,91 @@
           prev_segment_ids, &post_filter, &block_parameters_holder, &cdef_index,
           &inter_transform_sizes_, dsp,
           threading_strategy_.row_thread_pool(tile_index++),
-          residual_buffer_pool_.get()));
+          residual_buffer_pool_.get(), &pending_tiles));
       if (tile == nullptr) {
         LIBGAV1_DLOG(ERROR, "Failed to allocate tile.");
         return kLibgav1StatusOutOfMemory;
       }
-      tiles.push_back(std::move(tile));
+      tiles.push_back_unchecked(std::move(tile));
 
       byte_offset += tile_size;
       bytes_left -= tile_size;
     }
   }
   assert(tiles.size() == static_cast<size_t>(tile_count));
+  bool tile_decoding_failed = false;
   if (threading_strategy_.tile_thread_pool() == nullptr) {
     for (const auto& tile_ptr : tiles) {
-      if (!tile_ptr->Decode()) {
-        LIBGAV1_DLOG(ERROR, "Error decoding tile #%d", tile_ptr->number());
-        return kLibgav1StatusUnknownError;
+      if (!tile_decoding_failed) {
+        if (!tile_ptr->Decode(/*is_main_thread=*/true)) {
+          LIBGAV1_DLOG(ERROR, "Error decoding tile #%d", tile_ptr->number());
+          tile_decoding_failed = true;
+        }
+      } else {
+        pending_tiles.Decrement(false);
       }
     }
   } else {
-    const int num_workers =
-        threading_strategy_.tile_thread_pool()->num_threads();
-    const int tiles_for_thread_pool =
-        tile_count * num_workers / (num_workers + 1);
-    // Make sure the current thread does some work.
-    assert(tiles_for_thread_pool < tile_count);
-    std::mutex mutex;
-    bool tile_decoding_failed = false;          // Guarded by |mutex|.
-    int pending_tiles = tiles_for_thread_pool;  // Guarded by |mutex|.
-    std::condition_variable pending_tiles_zero_condvar;
-    // Submit some tiles to the thread pool for decoding.
-    int i;
-    for (i = 0; i < tiles_for_thread_pool; ++i) {
-      auto* const tile = tiles[i].get();
+    const int num_workers = threading_strategy_.tile_thread_count();
+    BlockingCounterWithStatus pending_workers(num_workers);
+    std::atomic<int> tile_counter(0);
+    // Submit tile decoding jobs to the thread pool.
+    for (int i = 0; i < num_workers; ++i) {
       threading_strategy_.tile_thread_pool()->Schedule(
-          [&mutex, &tile_decoding_failed, tile, &pending_tiles,
-           &pending_tiles_zero_condvar]() {
-            const bool failed = !tile->Decode();
-            if (failed) {
-              LIBGAV1_DLOG(ERROR, "Error decoding tile #%d", tile->number());
+          [&tiles, tile_count, &tile_counter, &pending_workers,
+           &pending_tiles]() {
+            bool failed = false;
+            int index;
+            while ((index = tile_counter.fetch_add(
+                        1, std::memory_order_relaxed)) < tile_count) {
+              if (!failed) {
+                const auto& tile_ptr = tiles[index];
+                if (!tile_ptr->Decode(/*is_main_thread=*/false)) {
+                  LIBGAV1_DLOG(ERROR, "Error decoding tile #%d",
+                               tile_ptr->number());
+                  failed = true;
+                }
+              } else {
+                pending_tiles.Decrement(false);
+              }
             }
-            std::lock_guard<std::mutex> lock(mutex);
-            tile_decoding_failed |= failed;
-            if (--pending_tiles == 0) {
-              // TODO(jzern): the mutex doesn't need to be locked to signal the
-              // condition.
-              pending_tiles_zero_condvar.notify_one();
-            }
+            pending_workers.Decrement(!failed);
           });
     }
-    // Decode the rest of the tiles on the current thread.
-    bool failed = false;
-    for (; i < tile_count; ++i) {
-      const auto& tile_ptr = tiles[i];
-      if (!tile_ptr->Decode()) {
-        LIBGAV1_DLOG(ERROR, "Error decoding tile #%d", tile_ptr->number());
-        failed = true;
-        break;
+    // Have the current thread partake in tile decoding.
+    int index;
+    while ((index = tile_counter.fetch_add(1, std::memory_order_relaxed)) <
+           tile_count) {
+      if (!tile_decoding_failed) {
+        const auto& tile_ptr = tiles[index];
+        if (!tile_ptr->Decode(/*is_main_thread=*/true)) {
+          LIBGAV1_DLOG(ERROR, "Error decoding tile #%d", tile_ptr->number());
+          tile_decoding_failed = true;
+        }
+      } else {
+        pending_tiles.Decrement(false);
       }
     }
-    // Wait for the thread pool to finish decoding the tiles.
-    std::unique_lock<std::mutex> lock(mutex);
-    tile_decoding_failed |= failed;
-    while (pending_tiles != 0) {
-      pending_tiles_zero_condvar.wait(lock);
-    }
-    if (tile_decoding_failed) return kLibgav1StatusUnknownError;
+    // Wait until all the workers are done. This ensures that all the tiles have
+    // been parsed.
+    tile_decoding_failed |= !pending_workers.Wait();
   }
+  // Wait until all the tiles have been decoded.
+  tile_decoding_failed |= !pending_tiles.Wait();
+
+  // At this point, all the tiles have been parsed and decoded and the
+  // threadpool will be empty.
+  if (tile_decoding_failed) return kLibgav1StatusUnknownError;
 
   if (obu->frame_header().enable_frame_end_update_cdf) {
     symbol_decoder_context_ = saved_symbol_decoder_context;
   }
   state_.current_frame->SetFrameContext(symbol_decoder_context_);
-  if (post_filter.DoDeblock() &&
-      !loop_filter_mask_.Build(
-          obu->sequence_header(), obu->frame_header(),
-          obu->tile_groups().front().start, obu->tile_groups().back().end,
-          &block_parameters_holder, inter_transform_sizes_)) {
-    LIBGAV1_DLOG(ERROR, "Error building deblocking filter masks.");
-    return kLibgav1StatusUnknownError;
+  if (post_filter.DoDeblock()) {
+    loop_filter_mask_.Build(obu->sequence_header(), obu->frame_header(),
+                            obu->tile_groups().front().start,
+                            obu->tile_groups().back().end,
+                            block_parameters_holder, inter_transform_sizes_);
   }
   if (!post_filter.ApplyFiltering()) {
     LIBGAV1_DLOG(ERROR, "Error applying in-loop filtering.");
diff --git a/libgav1/src/decoder_impl.h b/libgav1/src/decoder_impl.h
index 0920661..c4ea526 100644
--- a/libgav1/src/decoder_impl.h
+++ b/libgav1/src/decoder_impl.h
@@ -5,7 +5,6 @@
 #include <cstddef>
 #include <cstdint>
 #include <memory>
-#include <vector>
 
 #include "src/buffer_pool.h"
 #include "src/decoder_buffer.h"
@@ -22,12 +21,17 @@
 #include "src/utils/block_parameters_holder.h"
 #include "src/utils/constants.h"
 #include "src/utils/memory.h"
+#include "src/utils/queue.h"
 #include "src/utils/segmentation_map.h"
 #include "src/utils/types.h"
 
 namespace libgav1 {
 
-struct EncodedFrame {
+struct EncodedFrame : public Allocable {
+  // The default constructor is invoked by the Queue<EncodedFrame>::Init()
+  // method. Queue<> does not use the default-constructed elements, so it is
+  // safe for the default constructor to not initialize the members.
+  EncodedFrame() = default;
   EncodedFrame(const uint8_t* data, size_t size, int64_t user_private_data)
       : data(data), size(size), user_private_data(user_private_data) {}
 
@@ -38,6 +42,8 @@
 
 struct DecoderState {
   ObuSequenceHeader sequence_header = {};
+  // If true, sequence_header is valid.
+  bool has_sequence_header = false;
   // reference_valid and reference_frame_id are used only if
   // sequence_header_.frame_id_numbers_present is true.
   // The reference_valid array is indexed by a reference picture slot number.
@@ -82,14 +88,26 @@
 class DecoderImpl : public Allocable {
  public:
   // The constructor saves a const reference to |*settings|. Therefore
-  // |*settings| must outlive the DecoderImpl object.
-  explicit DecoderImpl(const DecoderSettings* settings);
+  // |*settings| must outlive the DecoderImpl object. On success, |*output|
+  // contains a pointer to the newly-created DecoderImpl object. On failure,
+  // |*output| is not modified.
+  static StatusCode Create(const DecoderSettings* settings,
+                           std::unique_ptr<DecoderImpl>* output);
   ~DecoderImpl();
   StatusCode EnqueueFrame(const uint8_t* data, size_t size,
                           int64_t user_private_data);
   StatusCode DequeueFrame(const DecoderBuffer** out_ptr);
+  static constexpr int GetMaxBitdepth() {
+#if LIBGAV1_MAX_BITDEPTH >= 10
+    return 10;
+#else
+    return 8;
+#endif
+  }
 
  private:
+  explicit DecoderImpl(const DecoderSettings* settings);
+  StatusCode Init();
   bool AllocateCurrentFrame(const ObuFrameHeader& frame_header);
   void ReleaseOutputFrame();
   // Populates buffer_ with values from |frame|. Adds a reference to |frame|
@@ -104,7 +122,7 @@
   // state_.current_frame, based on the refresh_frame_flags bitmask.
   void UpdateReferenceFrames(int refresh_frame_flags);
 
-  std::vector<EncodedFrame> encoded_frames_;
+  Queue<EncodedFrame> encoded_frames_;
   DecoderState state_;
   ThreadingStrategy threading_strategy_;
   SymbolDecoderContext symbol_decoder_context_;
@@ -117,8 +135,8 @@
 
   BufferPool buffer_pool_;
   std::unique_ptr<ResidualBufferPool> residual_buffer_pool_;
-  std::unique_ptr<uint8_t[]> threaded_loop_restoration_buffer_;
-  size_t threaded_loop_restoration_buffer_size_ = 0;
+  AlignedUniquePtr<uint8_t> threaded_window_buffer_;
+  size_t threaded_window_buffer_size_ = 0;
   Array2D<TransformSize> inter_transform_sizes_;
 
   LoopFilterMask loop_filter_mask_;
diff --git a/libgav1/src/decoder_settings.h b/libgav1/src/decoder_settings.h
index 4a3adce..051e5bb 100644
--- a/libgav1/src/decoder_settings.h
+++ b/libgav1/src/decoder_settings.h
@@ -11,7 +11,10 @@
 
 // Applications must populate this structure before creating a decoder instance.
 struct DecoderSettings {
-  // Number of threads to use when decoding. Defaults to 1.
+  // Number of threads to use when decoding. Must be greater than 0. The
+  // library will create at most |threads|-1 new threads, the calling thread is
+  // considered part of the library's thread count. Defaults to 1 (no new
+  // threads will be created).
   int threads = 1;
   // Do frame parallel decoding.
   bool frame_parallel = false;
diff --git a/libgav1/src/dsp/arm/average_blend_neon.cc b/libgav1/src/dsp/arm/average_blend_neon.cc
new file mode 100644
index 0000000..5f9c458
--- /dev/null
+++ b/libgav1/src/dsp/arm/average_blend_neon.cc
@@ -0,0 +1,149 @@
+#include "src/dsp/average_blend.h"
+#include "src/dsp/dsp.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+constexpr int kBitdepth8 = 8;
+constexpr int kInterPostRoundBit = 4;
+// An offset to cancel offsets used in compound predictor generation that
+// make intermediate computations non negative.
+const int16x8_t kCompoundRoundOffset =
+    vdupq_n_s16((2 << (kBitdepth8 + 4)) + (2 << (kBitdepth8 + 3)));
+
+inline void AverageBlend4Row(const uint16_t* prediction_0,
+                             const uint16_t* prediction_1, uint8_t* dest) {
+  const int16x4_t pred0 = vreinterpret_s16_u16(vld1_u16(prediction_0));
+  const int16x4_t pred1 = vreinterpret_s16_u16(vld1_u16(prediction_1));
+  int16x4_t res = vadd_s16(pred0, pred1);
+  res = vsub_s16(res, vget_low_s16(kCompoundRoundOffset));
+  StoreLo4(dest,
+           vqrshrun_n_s16(vcombine_s16(res, res), kInterPostRoundBit + 1));
+}
+
+inline void AverageBlend8Row(const uint16_t* prediction_0,
+                             const uint16_t* prediction_1, uint8_t* dest) {
+  const int16x8_t pred0 = vreinterpretq_s16_u16(vld1q_u16(prediction_0));
+  const int16x8_t pred1 = vreinterpretq_s16_u16(vld1q_u16(prediction_1));
+  int16x8_t res = vaddq_s16(pred0, pred1);
+  res = vsubq_s16(res, kCompoundRoundOffset);
+  vst1_u8(dest, vqrshrun_n_s16(res, kInterPostRoundBit + 1));
+}
+
+inline void AverageBlendLargeRow(const uint16_t* prediction_0,
+                                 const uint16_t* prediction_1, const int width,
+                                 uint8_t* dest) {
+  int x = 0;
+  do {
+    const int16x8_t pred_00 =
+        vreinterpretq_s16_u16(vld1q_u16(&prediction_0[x]));
+    const int16x8_t pred_01 =
+        vreinterpretq_s16_u16(vld1q_u16(&prediction_1[x]));
+    int16x8_t res0 = vaddq_s16(pred_00, pred_01);
+    res0 = vsubq_s16(res0, kCompoundRoundOffset);
+    const uint8x8_t res_out0 = vqrshrun_n_s16(res0, kInterPostRoundBit + 1);
+    const int16x8_t pred_10 =
+        vreinterpretq_s16_u16(vld1q_u16(&prediction_0[x + 8]));
+    const int16x8_t pred_11 =
+        vreinterpretq_s16_u16(vld1q_u16(&prediction_1[x + 8]));
+    int16x8_t res1 = vaddq_s16(pred_10, pred_11);
+    res1 = vsubq_s16(res1, kCompoundRoundOffset);
+    const uint8x8_t res_out1 = vqrshrun_n_s16(res1, kInterPostRoundBit + 1);
+    vst1q_u8(dest + x, vcombine_u8(res_out0, res_out1));
+    x += 16;
+  } while (x < width);
+}
+
+void AverageBlend_NEON(const uint16_t* prediction_0,
+                       const ptrdiff_t prediction_stride_0,
+                       const uint16_t* prediction_1,
+                       const ptrdiff_t prediction_stride_1, const int width,
+                       const int height, void* const dest,
+                       const ptrdiff_t dest_stride) {
+  auto* dst = static_cast<uint8_t*>(dest);
+  int y = height;
+
+  if (width == 4) {
+    do {
+      AverageBlend4Row(prediction_0, prediction_1, dst);
+      dst += dest_stride;
+      prediction_0 += prediction_stride_0;
+      prediction_1 += prediction_stride_1;
+
+      AverageBlend4Row(prediction_0, prediction_1, dst);
+      dst += dest_stride;
+      prediction_0 += prediction_stride_0;
+      prediction_1 += prediction_stride_1;
+
+      y -= 2;
+    } while (y != 0);
+    return;
+  }
+
+  if (width == 8) {
+    do {
+      AverageBlend8Row(prediction_0, prediction_1, dst);
+      dst += dest_stride;
+      prediction_0 += prediction_stride_0;
+      prediction_1 += prediction_stride_1;
+
+      AverageBlend8Row(prediction_0, prediction_1, dst);
+      dst += dest_stride;
+      prediction_0 += prediction_stride_0;
+      prediction_1 += prediction_stride_1;
+
+      y -= 2;
+    } while (y != 0);
+    return;
+  }
+
+  do {
+    AverageBlendLargeRow(prediction_0, prediction_1, width, dst);
+    dst += dest_stride;
+    prediction_0 += prediction_stride_0;
+    prediction_1 += prediction_stride_1;
+
+    AverageBlendLargeRow(prediction_0, prediction_1, width, dst);
+    dst += dest_stride;
+    prediction_0 += prediction_stride_0;
+    prediction_1 += prediction_stride_1;
+
+    y -= 2;
+  } while (y != 0);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+  dsp->average_blend = AverageBlend_NEON;
+}
+
+}  // namespace
+
+void AverageBlendInit_NEON() { Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else   // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void AverageBlendInit_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/average_blend_neon.h b/libgav1/src/dsp/arm/average_blend_neon.h
new file mode 100644
index 0000000..1225975
--- /dev/null
+++ b/libgav1/src/dsp/arm/average_blend_neon.h
@@ -0,0 +1,20 @@
+#ifndef LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_
+
+#include "src/dsp/cpu.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::average_blend. This function is not thread-safe.
+void AverageBlendInit_NEON();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_AverageBlend LIBGAV1_DSP_NEON
+#endif  // LIBGAV1_ENABLE_NEON
+
+#endif  // LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_
diff --git a/libgav1/src/dsp/arm/common_neon.h b/libgav1/src/dsp/arm/common_neon.h
index 6a28526..2ba35ed 100644
--- a/libgav1/src/dsp/arm/common_neon.h
+++ b/libgav1/src/dsp/arm/common_neon.h
@@ -10,6 +10,154 @@
 #include <cstdint>
 #include <cstring>
 
+#if 0
+#include <cstdio>
+
+constexpr bool kEnablePrintRegs = true;
+
+union DebugRegister {
+  int8_t i8[8];
+  int16_t i16[4];
+  int32_t i32[2];
+  uint8_t u8[8];
+  uint16_t u16[4];
+  uint32_t u32[2];
+};
+
+union DebugRegisterQ {
+  int8_t i8[16];
+  int16_t i16[8];
+  int32_t i32[4];
+  uint8_t u8[16];
+  uint16_t u16[8];
+  uint32_t u32[4];
+};
+
+// Quite useful macro for debugging. Left here for convenience.
+inline void PrintVect(const DebugRegister r, const char* const name, int size) {
+  int n;
+  if (kEnablePrintRegs) {
+    fprintf(stderr, "%s\t: ", name);
+    if (size == 8) {
+      for (n = 0; n < 8; ++n) fprintf(stderr, "%.2x ", r.u8[n]);
+    } else if (size == 16) {
+      for (n = 0; n < 4; ++n) fprintf(stderr, "%.4x ", r.u16[n]);
+    } else if (size == 32) {
+      for (n = 0; n < 2; ++n) fprintf(stderr, "%.8x ", r.u32[n]);
+    }
+    fprintf(stderr, "\n");
+  }
+}
+
+// Debugging macro for 128-bit types.
+inline void PrintVectQ(const DebugRegisterQ r, const char* const name,
+                       int size) {
+  int n;
+  if (kEnablePrintRegs) {
+    fprintf(stderr, "%s\t: ", name);
+    if (size == 8) {
+      for (n = 0; n < 16; ++n) fprintf(stderr, "%.2x ", r.u8[n]);
+    } else if (size == 16) {
+      for (n = 0; n < 8; ++n) fprintf(stderr, "%.4x ", r.u16[n]);
+    } else if (size == 32) {
+      for (n = 0; n < 4; ++n) fprintf(stderr, "%.8x ", r.u32[n]);
+    }
+    fprintf(stderr, "\n");
+  }
+}
+
+inline void PrintReg(const uint32x4_t val, const char* name) {
+  DebugRegisterQ r;
+  vst1q_u32(r.u32, val);
+  PrintVectQ(r, name, 32);
+}
+
+inline void PrintReg(const uint32x2_t val, const char* name) {
+  DebugRegister r;
+  vst1_u32(r.u32, val);
+  PrintVect(r, name, 32);
+}
+
+inline void PrintReg(const uint16x8_t val, const char* name) {
+  DebugRegisterQ r;
+  vst1q_u16(r.u16, val);
+  PrintVectQ(r, name, 16);
+}
+
+inline void PrintReg(const uint16x4_t val, const char* name) {
+  DebugRegister r;
+  vst1_u16(r.u16, val);
+  PrintVect(r, name, 16);
+}
+
+inline void PrintReg(const uint8x16_t val, const char* name) {
+  DebugRegisterQ r;
+  vst1q_u8(r.u8, val);
+  PrintVectQ(r, name, 8);
+}
+
+inline void PrintReg(const uint8x8_t val, const char* name) {
+  DebugRegister r;
+  vst1_u8(r.u8, val);
+  PrintVect(r, name, 8);
+}
+
+inline void PrintReg(const int32x4_t val, const char* name) {
+  DebugRegisterQ r;
+  vst1q_s32(r.i32, val);
+  PrintVectQ(r, name, 32);
+}
+
+inline void PrintReg(const int32x2_t val, const char* name) {
+  DebugRegister r;
+  vst1_s32(r.i32, val);
+  PrintVect(r, name, 32);
+}
+
+inline void PrintReg(const int16x8_t val, const char* name) {
+  DebugRegisterQ r;
+  vst1q_s16(r.i16, val);
+  PrintVectQ(r, name, 16);
+}
+
+inline void PrintReg(const int16x4_t val, const char* name) {
+  DebugRegister r;
+  vst1_s16(r.i16, val);
+  PrintVect(r, name, 16);
+}
+
+inline void PrintReg(const int8x16_t val, const char* name) {
+  DebugRegisterQ r;
+  vst1q_s8(r.i8, val);
+  PrintVectQ(r, name, 8);
+}
+
+inline void PrintReg(const int8x8_t val, const char* name) {
+  DebugRegister r;
+  vst1_s8(r.i8, val);
+  PrintVect(r, name, 8);
+}
+
+// Print an individual (non-vector) value in decimal format.
+inline void PrintReg(const int x, const char* name) {
+  if (kEnablePrintRegs) {
+    printf("%s: %d\n", name, x);
+  }
+}
+
+// Print an individual (non-vector) value in hexadecimal format.
+inline void PrintHex(const int x, const char* name) {
+  if (kEnablePrintRegs) {
+    printf("%s: %x\n", name, x);
+  }
+}
+
+#define PR(x) PrintReg(x, #x)
+#define PD(x) PrintReg(x, #x)
+#define PX(x) PrintHex(x, #x)
+
+#endif  // 0
+
 namespace libgav1 {
 namespace dsp {
 
@@ -64,9 +212,97 @@
   return vreinterpret_s8_u64(vshr_n_u64(vreinterpret_u64_s8(vector), shift));
 }
 
+// Shim vqtbl1_u8 for armv7.
+inline uint8x8_t VQTbl1U8(const uint8x16_t a, const uint8x8_t index) {
+#if defined(__aarch64__)
+  return vqtbl1_u8(a, index);
+#else
+  const uint8x8x2_t b = {vget_low_u8(a), vget_high_u8(a)};
+  return vtbl2_u8(b, index);
+#endif
+}
+
+//------------------------------------------------------------------------------
+// Interleave.
+
+// vzipN is exclusive to A64.
+inline uint8x8_t InterleaveLow8(const uint8x8_t a, const uint8x8_t b) {
+#if defined(__aarch64__)
+  return vzip1_u8(a, b);
+#else
+  // Discard |.val[1]|
+  return vzip_u8(a, b).val[0];
+#endif
+}
+
+inline uint8x8_t InterleaveLow32(const uint8x8_t a, const uint8x8_t b) {
+#if defined(__aarch64__)
+  return vreinterpret_u8_u32(
+      vzip1_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)));
+#else
+  // Discard |.val[1]|
+  return vreinterpret_u8_u32(
+      vzip_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)).val[0]);
+#endif
+}
+
+inline int8x8_t InterleaveLow32(const int8x8_t a, const int8x8_t b) {
+#if defined(__aarch64__)
+  return vreinterpret_s8_u32(
+      vzip1_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)));
+#else
+  // Discard |.val[1]|
+  return vreinterpret_s8_u32(
+      vzip_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)).val[0]);
+#endif
+}
+
+inline uint8x8_t InterleaveHigh32(const uint8x8_t a, const uint8x8_t b) {
+#if defined(__aarch64__)
+  return vreinterpret_u8_u32(
+      vzip2_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)));
+#else
+  // Discard |.val[0]|
+  return vreinterpret_u8_u32(
+      vzip_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)).val[1]);
+#endif
+}
+
+inline int8x8_t InterleaveHigh32(const int8x8_t a, const int8x8_t b) {
+#if defined(__aarch64__)
+  return vreinterpret_s8_u32(
+      vzip2_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)));
+#else
+  // Discard |.val[0]|
+  return vreinterpret_s8_u32(
+      vzip_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)).val[1]);
+#endif
+}
+
 //------------------------------------------------------------------------------
 // Transpose.
 
+// Transpose 32 bit elements such that:
+// a: 00 01
+// b: 02 03
+// returns
+// val[0]: 00 02
+// val[1]: 01 03
+inline uint8x8x2_t Interleave32(const uint8x8_t a, const uint8x8_t b) {
+  const uint32x2_t a_32 = vreinterpret_u32_u8(a);
+  const uint32x2_t b_32 = vreinterpret_u32_u8(b);
+  const uint32x2x2_t c = vtrn_u32(a_32, b_32);
+  const uint8x8x2_t d = {vreinterpret_u8_u32(c.val[0]),
+                         vreinterpret_u8_u32(c.val[1])};
+  return d;
+}
+
+// Swap high and low 32 bit elements.
+inline uint8x8_t Transpose32(const uint8x8_t a) {
+  const uint32x2_t b = vrev64_u32(vreinterpret_u32_u8(a));
+  return vreinterpret_u8_u32(b);
+}
+
 // Implement vtrnq_s64().
 // Input:
 // a0: 00 01 02 03 04 05 06 07
@@ -83,6 +319,15 @@
   return b0;
 }
 
+inline uint16x8x2_t VtrnqU64(uint32x4_t a0, uint32x4_t a1) {
+  uint16x8x2_t b0;
+  b0.val[0] = vcombine_u16(vreinterpret_u16_u32(vget_low_u32(a0)),
+                           vreinterpret_u16_u32(vget_low_u32(a1)));
+  b0.val[1] = vcombine_u16(vreinterpret_u16_u32(vget_high_u32(a0)),
+                           vreinterpret_u16_u32(vget_high_u32(a1)));
+  return b0;
+}
+
 // Input:
 // a: 00 01 02 03 10 11 12 13
 // b: 20 21 22 23 30 31 32 33
@@ -129,6 +374,102 @@
 }
 
 // Input:
+// a[0]: 00 01 02 03 04 05 06 07
+// a[1]: 10 11 12 13 14 15 16 17
+// a[2]: 20 21 22 23 24 25 26 27
+// a[3]: 30 31 32 33 34 35 36 37
+// a[4]: 40 41 42 43 44 45 46 47
+// a[5]: 50 51 52 53 54 55 56 57
+// a[6]: 60 61 62 63 64 65 66 67
+// a[7]: 70 71 72 73 74 75 76 77
+
+// Output:
+// a[0]: 00 10 20 30 40 50 60 70
+// a[1]: 01 11 21 31 41 51 61 71
+// a[2]: 02 12 22 32 42 52 62 72
+// a[3]: 03 13 23 33 43 53 63 73
+// a[4]: 04 14 24 34 44 54 64 74
+// a[5]: 05 15 25 35 45 55 65 75
+// a[6]: 06 16 26 36 46 56 66 76
+// a[7]: 07 17 27 37 47 57 67 77
+inline void Transpose8x8(int8x8_t a[8]) {
+  // Swap 8 bit elements. Goes from:
+  // a[0]: 00 01 02 03 04 05 06 07
+  // a[1]: 10 11 12 13 14 15 16 17
+  // a[2]: 20 21 22 23 24 25 26 27
+  // a[3]: 30 31 32 33 34 35 36 37
+  // a[4]: 40 41 42 43 44 45 46 47
+  // a[5]: 50 51 52 53 54 55 56 57
+  // a[6]: 60 61 62 63 64 65 66 67
+  // a[7]: 70 71 72 73 74 75 76 77
+  // to:
+  // b0.val[0]: 00 10 02 12 04 14 06 16  40 50 42 52 44 54 46 56
+  // b0.val[1]: 01 11 03 13 05 15 07 17  41 51 43 53 45 55 47 57
+  // b1.val[0]: 20 30 22 32 24 34 26 36  60 70 62 72 64 74 66 76
+  // b1.val[1]: 21 31 23 33 25 35 27 37  61 71 63 73 65 75 67 77
+  const int8x16x2_t b0 =
+      vtrnq_s8(vcombine_s8(a[0], a[4]), vcombine_s8(a[1], a[5]));
+  const int8x16x2_t b1 =
+      vtrnq_s8(vcombine_s8(a[2], a[6]), vcombine_s8(a[3], a[7]));
+
+  // Swap 16 bit elements resulting in:
+  // c0.val[0]: 00 10 20 30 04 14 24 34  40 50 60 70 44 54 64 74
+  // c0.val[1]: 02 12 22 32 06 16 26 36  42 52 62 72 46 56 66 76
+  // c1.val[0]: 01 11 21 31 05 15 25 35  41 51 61 71 45 55 65 75
+  // c1.val[1]: 03 13 23 33 07 17 27 37  43 53 63 73 47 57 67 77
+  const int16x8x2_t c0 = vtrnq_s16(vreinterpretq_s16_s8(b0.val[0]),
+                                   vreinterpretq_s16_s8(b1.val[0]));
+  const int16x8x2_t c1 = vtrnq_s16(vreinterpretq_s16_s8(b0.val[1]),
+                                   vreinterpretq_s16_s8(b1.val[1]));
+
+  // Unzip 32 bit elements resulting in:
+  // d0.val[0]: 00 10 20 30 40 50 60 70  01 11 21 31 41 51 61 71
+  // d0.val[1]: 04 14 24 34 44 54 64 74  05 15 25 35 45 55 65 75
+  // d1.val[0]: 02 12 22 32 42 52 62 72  03 13 23 33 43 53 63 73
+  // d1.val[1]: 06 16 26 36 46 56 66 76  07 17 27 37 47 57 67 77
+  const int32x4x2_t d0 = vuzpq_s32(vreinterpretq_s32_s16(c0.val[0]),
+                                   vreinterpretq_s32_s16(c1.val[0]));
+  const int32x4x2_t d1 = vuzpq_s32(vreinterpretq_s32_s16(c0.val[1]),
+                                   vreinterpretq_s32_s16(c1.val[1]));
+
+  a[0] = vreinterpret_s8_s32(vget_low_s32(d0.val[0]));
+  a[1] = vreinterpret_s8_s32(vget_high_s32(d0.val[0]));
+  a[2] = vreinterpret_s8_s32(vget_low_s32(d1.val[0]));
+  a[3] = vreinterpret_s8_s32(vget_high_s32(d1.val[0]));
+  a[4] = vreinterpret_s8_s32(vget_low_s32(d0.val[1]));
+  a[5] = vreinterpret_s8_s32(vget_high_s32(d0.val[1]));
+  a[6] = vreinterpret_s8_s32(vget_low_s32(d1.val[1]));
+  a[7] = vreinterpret_s8_s32(vget_high_s32(d1.val[1]));
+}
+
+// Unsigned.
+inline void Transpose8x8(uint8x8_t a[8]) {
+  const uint8x16x2_t b0 =
+      vtrnq_u8(vcombine_u8(a[0], a[4]), vcombine_u8(a[1], a[5]));
+  const uint8x16x2_t b1 =
+      vtrnq_u8(vcombine_u8(a[2], a[6]), vcombine_u8(a[3], a[7]));
+
+  const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]),
+                                    vreinterpretq_u16_u8(b1.val[0]));
+  const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]),
+                                    vreinterpretq_u16_u8(b1.val[1]));
+
+  const uint32x4x2_t d0 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[0]),
+                                    vreinterpretq_u32_u16(c1.val[0]));
+  const uint32x4x2_t d1 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[1]),
+                                    vreinterpretq_u32_u16(c1.val[1]));
+
+  a[0] = vreinterpret_u8_u32(vget_low_u32(d0.val[0]));
+  a[1] = vreinterpret_u8_u32(vget_high_u32(d0.val[0]));
+  a[2] = vreinterpret_u8_u32(vget_low_u32(d1.val[0]));
+  a[3] = vreinterpret_u8_u32(vget_high_u32(d1.val[0]));
+  a[4] = vreinterpret_u8_u32(vget_low_u32(d0.val[1]));
+  a[5] = vreinterpret_u8_u32(vget_high_u32(d0.val[1]));
+  a[6] = vreinterpret_u8_u32(vget_low_u32(d1.val[1]));
+  a[7] = vreinterpret_u8_u32(vget_high_u32(d1.val[1]));
+}
+
+// Input:
 // a0: 00 01 02 03 04 05 06 07
 // a1: 10 11 12 13 14 15 16 17
 // a2: 20 21 22 23 24 25 26 27
@@ -180,6 +521,87 @@
 }
 
 // Input:
+// a0: 00 01 02 03 04 05 06 07
+// a1: 10 11 12 13 14 15 16 17
+// a2: 20 21 22 23 24 25 26 27
+// a3: 30 31 32 33 34 35 36 37
+// a4: 40 41 42 43 44 45 46 47
+// a5: 50 51 52 53 54 55 56 57
+// a6: 60 61 62 63 64 65 66 67
+// a7: 70 71 72 73 74 75 76 77
+
+// Output:
+// a0: 00 10 20 30 40 50 60 70
+// a1: 01 11 21 31 41 51 61 71
+// a2: 02 12 22 32 42 52 62 72
+// a3: 03 13 23 33 43 53 63 73
+// a4: 04 14 24 34 44 54 64 74
+// a5: 05 15 25 35 45 55 65 75
+// a6: 06 16 26 36 46 56 66 76
+// a7: 07 17 27 37 47 57 67 77
+// TODO(johannkoenig): Switch users of the above transpose to this one.
+inline void Transpose8x8(int16x8_t a[8]) {
+  const int16x8x2_t b0 = vtrnq_s16(a[0], a[1]);
+  const int16x8x2_t b1 = vtrnq_s16(a[2], a[3]);
+  const int16x8x2_t b2 = vtrnq_s16(a[4], a[5]);
+  const int16x8x2_t b3 = vtrnq_s16(a[6], a[7]);
+
+  const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]),
+                                   vreinterpretq_s32_s16(b1.val[0]));
+  const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]),
+                                   vreinterpretq_s32_s16(b1.val[1]));
+  const int32x4x2_t c2 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[0]),
+                                   vreinterpretq_s32_s16(b3.val[0]));
+  const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]),
+                                   vreinterpretq_s32_s16(b3.val[1]));
+
+  const int16x8x2_t d0 = VtrnqS64(c0.val[0], c2.val[0]);
+  const int16x8x2_t d1 = VtrnqS64(c1.val[0], c3.val[0]);
+  const int16x8x2_t d2 = VtrnqS64(c0.val[1], c2.val[1]);
+  const int16x8x2_t d3 = VtrnqS64(c1.val[1], c3.val[1]);
+
+  a[0] = d0.val[0];
+  a[1] = d1.val[0];
+  a[2] = d2.val[0];
+  a[3] = d3.val[0];
+  a[4] = d0.val[1];
+  a[5] = d1.val[1];
+  a[6] = d2.val[1];
+  a[7] = d3.val[1];
+}
+
+// Unsigned.
+inline void Transpose8x8(uint16x8_t a[8]) {
+  const uint16x8x2_t b0 = vtrnq_u16(a[0], a[1]);
+  const uint16x8x2_t b1 = vtrnq_u16(a[2], a[3]);
+  const uint16x8x2_t b2 = vtrnq_u16(a[4], a[5]);
+  const uint16x8x2_t b3 = vtrnq_u16(a[6], a[7]);
+
+  const uint32x4x2_t c0 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[0]),
+                                    vreinterpretq_u32_u16(b1.val[0]));
+  const uint32x4x2_t c1 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[1]),
+                                    vreinterpretq_u32_u16(b1.val[1]));
+  const uint32x4x2_t c2 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[0]),
+                                    vreinterpretq_u32_u16(b3.val[0]));
+  const uint32x4x2_t c3 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[1]),
+                                    vreinterpretq_u32_u16(b3.val[1]));
+
+  const uint16x8x2_t d0 = VtrnqU64(c0.val[0], c2.val[0]);
+  const uint16x8x2_t d1 = VtrnqU64(c1.val[0], c3.val[0]);
+  const uint16x8x2_t d2 = VtrnqU64(c0.val[1], c2.val[1]);
+  const uint16x8x2_t d3 = VtrnqU64(c1.val[1], c3.val[1]);
+
+  a[0] = d0.val[0];
+  a[1] = d1.val[0];
+  a[2] = d2.val[0];
+  a[3] = d3.val[0];
+  a[4] = d0.val[1];
+  a[5] = d1.val[1];
+  a[6] = d2.val[1];
+  a[7] = d3.val[1];
+}
+
+// Input:
 // i0: 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f
 // i1: 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f
 // i2: 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f
@@ -257,6 +679,49 @@
   *o15 = vget_high_u8(vreinterpretq_u8_u32(d3.val[1]));
 }
 
+// TODO(johannkoenig): Replace usage of the above transpose with this one.
+inline void Transpose16x8(const uint8x16_t input[8], uint8x8_t output[16]) {
+  const uint8x16x2_t b0 = vtrnq_u8(input[0], input[1]);
+  const uint8x16x2_t b1 = vtrnq_u8(input[2], input[3]);
+  const uint8x16x2_t b2 = vtrnq_u8(input[4], input[5]);
+  const uint8x16x2_t b3 = vtrnq_u8(input[6], input[7]);
+
+  const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]),
+                                    vreinterpretq_u16_u8(b1.val[0]));
+  const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]),
+                                    vreinterpretq_u16_u8(b1.val[1]));
+  const uint16x8x2_t c2 = vtrnq_u16(vreinterpretq_u16_u8(b2.val[0]),
+                                    vreinterpretq_u16_u8(b3.val[0]));
+  const uint16x8x2_t c3 = vtrnq_u16(vreinterpretq_u16_u8(b2.val[1]),
+                                    vreinterpretq_u16_u8(b3.val[1]));
+
+  const uint32x4x2_t d0 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[0]),
+                                    vreinterpretq_u32_u16(c2.val[0]));
+  const uint32x4x2_t d1 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[1]),
+                                    vreinterpretq_u32_u16(c2.val[1]));
+  const uint32x4x2_t d2 = vtrnq_u32(vreinterpretq_u32_u16(c1.val[0]),
+                                    vreinterpretq_u32_u16(c3.val[0]));
+  const uint32x4x2_t d3 = vtrnq_u32(vreinterpretq_u32_u16(c1.val[1]),
+                                    vreinterpretq_u32_u16(c3.val[1]));
+
+  output[0] = vget_low_u8(vreinterpretq_u8_u32(d0.val[0]));
+  output[1] = vget_low_u8(vreinterpretq_u8_u32(d2.val[0]));
+  output[2] = vget_low_u8(vreinterpretq_u8_u32(d1.val[0]));
+  output[3] = vget_low_u8(vreinterpretq_u8_u32(d3.val[0]));
+  output[4] = vget_low_u8(vreinterpretq_u8_u32(d0.val[1]));
+  output[5] = vget_low_u8(vreinterpretq_u8_u32(d2.val[1]));
+  output[6] = vget_low_u8(vreinterpretq_u8_u32(d1.val[1]));
+  output[7] = vget_low_u8(vreinterpretq_u8_u32(d3.val[1]));
+  output[8] = vget_high_u8(vreinterpretq_u8_u32(d0.val[0]));
+  output[9] = vget_high_u8(vreinterpretq_u8_u32(d2.val[0]));
+  output[10] = vget_high_u8(vreinterpretq_u8_u32(d1.val[0]));
+  output[11] = vget_high_u8(vreinterpretq_u8_u32(d3.val[0]));
+  output[12] = vget_high_u8(vreinterpretq_u8_u32(d0.val[1]));
+  output[13] = vget_high_u8(vreinterpretq_u8_u32(d2.val[1]));
+  output[14] = vget_high_u8(vreinterpretq_u8_u32(d1.val[1]));
+  output[15] = vget_high_u8(vreinterpretq_u8_u32(d3.val[1]));
+}
+
 }  // namespace dsp
 }  // namespace libgav1
 
diff --git a/libgav1/src/dsp/arm/convolve_neon.cc b/libgav1/src/dsp/arm/convolve_neon.cc
new file mode 100644
index 0000000..4126c41
--- /dev/null
+++ b/libgav1/src/dsp/arm/convolve_neon.cc
@@ -0,0 +1,1694 @@
+#include "src/dsp/convolve.h"
+#include "src/dsp/dsp.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+constexpr int kBitdepth8 = 8;
+constexpr int kIntermediateStride = kMaxSuperBlockSizeInPixels;
+constexpr int kSubPixelMask = (1 << kSubPixelBits) - 1;
+constexpr int kHorizontalOffset = 3;
+constexpr int kVerticalOffset = 3;
+constexpr int kInterRoundBitsVertical = 11;
+
+int GetFilterIndex(const int filter_index, const int length) {
+  if (length <= 4) {
+    if (filter_index == kInterpolationFilterEightTap ||
+        filter_index == kInterpolationFilterEightTapSharp) {
+      return 4;
+    }
+    if (filter_index == kInterpolationFilterEightTapSmooth) {
+      return 5;
+    }
+  }
+  return filter_index;
+}
+
+inline int16x8_t ZeroExtend(const uint8x8_t in) {
+  return vreinterpretq_s16_u16(vmovl_u8(in));
+}
+
+inline void Load8x8(const uint8_t* s, const ptrdiff_t p, int16x8_t* dst) {
+  dst[0] = ZeroExtend(vld1_u8(s));
+  s += p;
+  dst[1] = ZeroExtend(vld1_u8(s));
+  s += p;
+  dst[2] = ZeroExtend(vld1_u8(s));
+  s += p;
+  dst[3] = ZeroExtend(vld1_u8(s));
+  s += p;
+  dst[4] = ZeroExtend(vld1_u8(s));
+  s += p;
+  dst[5] = ZeroExtend(vld1_u8(s));
+  s += p;
+  dst[6] = ZeroExtend(vld1_u8(s));
+  s += p;
+  dst[7] = ZeroExtend(vld1_u8(s));
+}
+
+// Multiply every entry in |src[]| by the corresponding lane in |taps| and sum.
+// The sum of the entries in |taps| is always 128. In some situations negative
+// values are used. This creates a situation where the positive taps sum to more
+// than 128. An example is:
+// {-4, 10, -24, 100, 60, -20, 8, -2}
+// The negative taps never sum to < -128
+// The center taps are always positive. The remaining positive taps never sum
+// to > 128.
+// Summing these naively can overflow int16_t. This can be avoided by adding the
+// center taps last and saturating the result.
+// We do not need to expand to int32_t because later in the function the value
+// is shifted by |kFilterBits| (7) and saturated to uint8_t. This means any
+// value over 255 << 7 (32576 because of rounding) is clamped.
+template <int num_taps>
+int16x8_t SumTaps(const int16x8_t* const src, const int16x8_t taps) {
+  int16x8_t sum;
+  if (num_taps == 8) {
+    const int16x4_t taps_lo = vget_low_s16(taps);
+    const int16x4_t taps_hi = vget_high_s16(taps);
+    sum = vmulq_lane_s16(src[0], taps_lo, 0);
+    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 1);
+    sum = vmlaq_lane_s16(sum, src[2], taps_lo, 2);
+    sum = vmlaq_lane_s16(sum, src[5], taps_hi, 1);
+    sum = vmlaq_lane_s16(sum, src[6], taps_hi, 2);
+    sum = vmlaq_lane_s16(sum, src[7], taps_hi, 3);
+
+    // Center taps.
+    sum = vqaddq_s16(sum, vmulq_lane_s16(src[3], taps_lo, 3));
+    sum = vqaddq_s16(sum, vmulq_lane_s16(src[4], taps_hi, 0));
+  } else if (num_taps == 6) {
+    const int16x4_t taps_lo = vget_low_s16(taps);
+    const int16x4_t taps_hi = vget_high_s16(taps);
+    sum = vmulq_lane_s16(src[0], taps_lo, 1);
+    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 2);
+    sum = vmlaq_lane_s16(sum, src[4], taps_hi, 1);
+    sum = vmlaq_lane_s16(sum, src[5], taps_hi, 2);
+
+    // Center taps.
+    sum = vqaddq_s16(sum, vmulq_lane_s16(src[2], taps_lo, 3));
+    sum = vqaddq_s16(sum, vmulq_lane_s16(src[3], taps_hi, 0));
+  } else if (num_taps == 4) {
+    const int16x4_t taps_lo = vget_low_s16(taps);
+    sum = vmulq_lane_s16(src[0], taps_lo, 0);
+    sum = vmlaq_lane_s16(sum, src[3], taps_lo, 3);
+
+    // Center taps.
+    sum = vqaddq_s16(sum, vmulq_lane_s16(src[1], taps_lo, 1));
+    sum = vqaddq_s16(sum, vmulq_lane_s16(src[2], taps_lo, 2));
+  } else {
+    assert(num_taps == 2);
+    // All the taps are positive so there is no concern regarding saturation.
+    const int16x4_t taps_lo = vget_low_s16(taps);
+    sum = vmulq_lane_s16(src[0], taps_lo, 1);
+    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 2);
+  }
+
+  return sum;
+}
+
+// Add an offset to ensure the sum is positive and it fits within uint16_t.
+template <int num_taps>
+uint16x8_t SumTaps8To16(const int16x8_t* const src, const int16x8_t taps) {
+  // The worst case sum of negative taps is -56. The worst case sum of positive
+  // taps is 184. With the single pass versions of the Convolve we could safely
+  // saturate to int16_t because it outranged the final shift and narrow to
+  // uint8_t. For the 2D Convolve the intermediate values are 16 bits so we
+  // don't have that option.
+  // 184 * 255 = 46920 which is greater than int16_t can hold, but not uint16_t.
+  // The minimum value we need to handle is -56 * 255 = -14280.
+  // By offsetting the sum with 1 << 14 = 16384 we ensure that the sum is never
+  // negative and that 46920 + 16384 = 63304 fits comfortably in uint16_t. This
+  // allows us to use 16 bit registers instead of 32 bit registers.
+  // When considering the bit operations it is safe to ignore signedness. Due to
+  // the magic of 2's complement and well defined rollover rules the bit
+  // representations are equivalent.
+  const int16x4_t taps_lo = vget_low_s16(taps);
+  const int16x4_t taps_hi = vget_high_s16(taps);
+  // |offset| == 1 << (bitdepth + kFilterBits - 1);
+  int16x8_t sum = vdupq_n_s16(1 << 14);
+  if (num_taps == 8) {
+    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 0);
+    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 1);
+    sum = vmlaq_lane_s16(sum, src[2], taps_lo, 2);
+    sum = vmlaq_lane_s16(sum, src[3], taps_lo, 3);
+    sum = vmlaq_lane_s16(sum, src[4], taps_hi, 0);
+    sum = vmlaq_lane_s16(sum, src[5], taps_hi, 1);
+    sum = vmlaq_lane_s16(sum, src[6], taps_hi, 2);
+    sum = vmlaq_lane_s16(sum, src[7], taps_hi, 3);
+  } else if (num_taps == 6) {
+    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 1);
+    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 2);
+    sum = vmlaq_lane_s16(sum, src[2], taps_lo, 3);
+    sum = vmlaq_lane_s16(sum, src[3], taps_hi, 0);
+    sum = vmlaq_lane_s16(sum, src[4], taps_hi, 1);
+    sum = vmlaq_lane_s16(sum, src[5], taps_hi, 2);
+  } else if (num_taps == 4) {
+    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 2);
+    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 3);
+    sum = vmlaq_lane_s16(sum, src[2], taps_hi, 0);
+    sum = vmlaq_lane_s16(sum, src[3], taps_hi, 1);
+  } else if (num_taps == 2) {
+    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 3);
+    sum = vmlaq_lane_s16(sum, src[1], taps_hi, 0);
+  }
+
+  // This is guaranteed to be positive. Convert it for the final shift.
+  return vreinterpretq_u16_s16(sum);
+}
+
+// Process 16 bit inputs and output 32 bits.
+template <int num_taps>
+uint32x4x2_t Sum2DVerticalTaps(const int16x8_t* const src,
+                               const int16x8_t taps) {
+  // In order to get the rollover correct with the lengthening instruction we
+  // need to treat these as signed so that they sign extend properly.
+  const int16x4_t taps_lo = vget_low_s16(taps);
+  const int16x4_t taps_hi = vget_high_s16(taps);
+  // An offset to guarantee the sum is non negative. Captures 56 * -4590 =
+  // 257040 (worst case negative value from horizontal pass). It should be
+  // possible to use 1 << 18 (262144) instead of 1 << 19 but there probably
+  // isn't any benefit.
+  // |offset_bits| = bitdepth + 2 * kFilterBits - kInterRoundBitsHorizontal
+  // == 19.
+  int32x4_t sum_lo = vdupq_n_s32(1 << 19);
+  int32x4_t sum_hi = sum_lo;
+  if (num_taps == 8) {
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[0]), taps_lo, 0);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[0]), taps_lo, 0);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 1);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 1);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 2);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_lo, 2);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_lo, 3);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_lo, 3);
+
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[4]), taps_hi, 0);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[4]), taps_hi, 0);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 1);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 1);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[6]), taps_hi, 2);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[6]), taps_hi, 2);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[7]), taps_hi, 3);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[7]), taps_hi, 3);
+  } else if (num_taps == 6) {
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[0]), taps_lo, 1);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[0]), taps_lo, 1);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 2);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 2);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 3);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_lo, 3);
+
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 0);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 0);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[4]), taps_hi, 1);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[4]), taps_hi, 1);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 2);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 2);
+  } else if (num_taps == 4) {
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[0]), taps_lo, 2);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[0]), taps_lo, 2);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 3);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 3);
+
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_hi, 0);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_hi, 0);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 1);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 1);
+  } else if (num_taps == 2) {
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[0]), taps_lo, 3);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[0]), taps_lo, 3);
+
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_hi, 0);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_hi, 0);
+  }
+
+  // This is guaranteed to be positive. Convert it for the final shift.
+  const uint32x4x2_t return_val = {vreinterpretq_u32_s32(sum_lo),
+                                   vreinterpretq_u32_s32(sum_hi)};
+  return return_val;
+}
+
+template <int num_taps>
+void Filter2DVertical(const uint16_t* src, const ptrdiff_t src_stride,
+                      uint8_t* dst, const ptrdiff_t dst_stride, const int width,
+                      const int height, const int16x8_t taps) {
+  constexpr int next_row = num_taps - 1;
+
+  int x = 0;
+  do {
+    int16x8_t srcs[8];
+    srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src + x));
+    if (num_taps >= 4) {
+      srcs[1] = vreinterpretq_s16_u16(vld1q_u16(src + x + src_stride));
+      srcs[2] = vreinterpretq_s16_u16(vld1q_u16(src + x + 2 * src_stride));
+      if (num_taps >= 6) {
+        srcs[3] = vreinterpretq_s16_u16(vld1q_u16(src + x + 3 * src_stride));
+        srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src + x + 4 * src_stride));
+        if (num_taps == 8) {
+          srcs[5] = vreinterpretq_s16_u16(vld1q_u16(src + x + 5 * src_stride));
+          srcs[6] = vreinterpretq_s16_u16(vld1q_u16(src + x + 6 * src_stride));
+        }
+      }
+    }
+
+    int y = 0;
+    do {
+      srcs[next_row] = vreinterpretq_s16_u16(
+          vld1q_u16(src + x + (y + next_row) * src_stride));
+
+      const uint32x4x2_t sums = Sum2DVerticalTaps<num_taps>(srcs, taps);
+      const uint16x8_t first_shift =
+          vcombine_u16(vqrshrn_n_u32(sums.val[0], kInterRoundBitsVertical),
+                       vqrshrn_n_u32(sums.val[1], kInterRoundBitsVertical));
+      // |single_round_offset| == (1 << bitdepth) + (1 << (bitdepth - 1)) ==
+      // 384
+      const uint8x8_t results =
+          vqmovn_u16(vqsubq_u16(first_shift, vdupq_n_u16(384)));
+
+      vst1_u8(dst + x + y * dst_stride, results);
+
+      srcs[0] = srcs[1];
+      if (num_taps >= 4) {
+        srcs[1] = srcs[2];
+        srcs[2] = srcs[3];
+        if (num_taps >= 6) {
+          srcs[3] = srcs[4];
+          srcs[4] = srcs[5];
+          if (num_taps == 8) {
+            srcs[5] = srcs[6];
+            srcs[6] = srcs[7];
+          }
+        }
+      }
+    } while (++y < height);
+    x += 8;
+  } while (x < width);
+}
+
+void Convolve2D_NEON(const void* const reference,
+                     const ptrdiff_t reference_stride,
+                     const int horizontal_filter_index,
+                     const int vertical_filter_index,
+                     const uint8_t /*inter_round_bits_vertical*/,
+                     const int subpixel_x, const int subpixel_y,
+                     const int /*step_x*/, const int /*step_y*/,
+                     const int width, const int height, void* prediction,
+                     const ptrdiff_t pred_stride) {
+  const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
+  const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
+  int horizontal_taps, horizontal_taps_start, vertical_taps,
+      vertical_taps_start;
+
+  if (horiz_filter_index < 2) {
+    horizontal_taps = 6;
+    horizontal_taps_start = 1;
+  } else if (horiz_filter_index == 2) {
+    horizontal_taps = 8;
+    horizontal_taps_start = 0;
+  } else if (horiz_filter_index == 3) {
+    horizontal_taps = 2;
+    horizontal_taps_start = 3;
+  } else /* if (horiz_filter_index > 3) */ {
+    horizontal_taps = 4;
+    horizontal_taps_start = 2;
+  }
+
+  if (vert_filter_index < 2) {
+    vertical_taps = 6;
+    vertical_taps_start = 1;
+  } else if (vert_filter_index == 2) {
+    vertical_taps = 8;
+    vertical_taps_start = 0;
+  } else if (vert_filter_index == 3) {
+    vertical_taps = 2;
+    vertical_taps_start = 3;
+  } else /* if (vert_filter_index > 3) */ {
+    vertical_taps = 4;
+    vertical_taps_start = 2;
+  }
+
+  // Neon processes blocks of 8x8 for context during the horizontal pass so it
+  // still does a few more than it needs.
+  const int intermediate_height = height + vertical_taps - 1;
+  // The output of the horizontal filter is guaranteed to fit in 16 bits.
+  uint16_t intermediate_result[kMaxSuperBlockSizeInPixels *
+                               (kMaxSuperBlockSizeInPixels + kSubPixelTaps)];
+  const int intermediate_stride = width;
+  const int max_pixel_value = 255;
+
+  if (width > 4) {
+    // Horizontal filter.
+    const int horiz_filter_id = (subpixel_x >> 6) & kSubPixelMask;
+    const int16x8_t horiz_taps =
+        vld1q_s16(kSubPixelFilters[horiz_filter_index][horiz_filter_id]);
+
+    uint16_t* intermediate = intermediate_result;
+    const ptrdiff_t src_stride = reference_stride;
+    // Offset for 8 tap horizontal filter and |vertical_taps|.
+    const auto* src = static_cast<const uint8_t*>(reference) -
+                      ((vertical_taps / 2) - 1) * src_stride -
+                      kHorizontalOffset;
+    int y = 0;
+    do {
+      int x = 0;
+      do {
+        uint8x16_t temp[8];
+        uint8x8_t input[16];
+        for (int i = 0; i < 8; ++i) {
+          temp[i] = vld1q_u8(src + 0 + x + i * src_stride);
+        }
+        // TODO(johannkoenig): It should be possible to get the transpose
+        // started with vld2().
+        Transpose16x8(temp, input);
+        int16x8_t input16[16];
+        for (int i = 0; i < 16; ++i) {
+          input16[i] = ZeroExtend(input[i]);
+        }
+
+        // TODO(johannkoenig): Explore moving the branch outside the main loop.
+        uint16x8_t output[8];
+        if (horizontal_taps == 8) {
+          for (int i = 0; i < 8; ++i) {
+            const uint16x8_t neon_sums =
+                SumTaps8To16<8>(input16 + i, horiz_taps);
+            output[i] = vrshrq_n_u16(neon_sums, kInterRoundBitsHorizontal);
+          }
+        } else if (horizontal_taps == 6) {
+          for (int i = 0; i < 8; ++i) {
+            const uint16x8_t neon_sums =
+                SumTaps8To16<6>(input16 + i + 1, horiz_taps);
+            output[i] = vrshrq_n_u16(neon_sums, kInterRoundBitsHorizontal);
+          }
+        } else {  // |horizontal_taps| == 2
+          for (int i = 0; i < 8; ++i) {
+            const uint16x8_t neon_sums =
+                SumTaps8To16<2>(input16 + i + 3, horiz_taps);
+            output[i] = vrshrq_n_u16(neon_sums, kInterRoundBitsHorizontal);
+          }
+        }
+
+        Transpose8x8(output);
+        for (int i = 0; i < 8; ++i) {
+          vst1q_u16(intermediate + x + i * intermediate_stride, output[i]);
+        }
+        x += 8;
+      } while (x < width);
+      src += src_stride << 3;
+      intermediate += intermediate_stride << 3;
+      y += 8;
+    } while (y < intermediate_height);
+
+    // Vertical filter.
+    auto* dest = static_cast<uint8_t*>(prediction);
+    const ptrdiff_t dest_stride = pred_stride;
+    const int filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
+    const int16x8_t taps =
+        vld1q_s16(kSubPixelFilters[vert_filter_index][filter_id]);
+
+    if (vertical_taps == 8) {
+      Filter2DVertical<8>(intermediate_result, intermediate_stride, dest,
+                          dest_stride, width, height, taps);
+    } else if (vertical_taps == 6) {
+      Filter2DVertical<6>(intermediate_result, intermediate_stride, dest,
+                          dest_stride, width, height, taps);
+    } else if (vertical_taps == 4) {
+      Filter2DVertical<4>(intermediate_result, intermediate_stride, dest,
+                          dest_stride, width, height, taps);
+    } else {  // |vertical_taps| == 2
+      Filter2DVertical<2>(intermediate_result, intermediate_stride, dest,
+                          dest_stride, width, height, taps);
+    }
+  } else {
+    // Horizontal filter.
+    // Filter types used for width <= 4 are different from those for width > 4.
+    // When width > 4, the valid filter index range is always [0, 3].
+    // When width <= 4, the valid filter index range is always [4, 5].
+    // Similarly for height.
+    uint16_t* intermediate = intermediate_result;
+    const ptrdiff_t src_stride = reference_stride;
+    const auto* src = static_cast<const uint8_t*>(reference) -
+                      ((vertical_taps / 2) - 1) * src_stride -
+                      ((horizontal_taps / 2) - 1);
+    auto* dest = static_cast<uint8_t*>(prediction);
+    const ptrdiff_t dest_stride = pred_stride;
+    int filter_id = (subpixel_x >> 6) & kSubPixelMask;
+    for (int y = 0; y < intermediate_height; ++y) {
+      for (int x = 0; x < width; ++x) {
+        // An offset to guarantee the sum is non negative.
+        int sum = 1 << 14;
+        for (int k = 0; k < horizontal_taps; ++k) {
+          const int tap = k + horizontal_taps_start;
+          sum +=
+              kSubPixelFilters[horiz_filter_index][filter_id][tap] * src[x + k];
+        }
+        intermediate[x] = static_cast<int16_t>(RightShiftWithRounding(sum, 3));
+      }
+      src += src_stride;
+      intermediate += intermediate_stride;
+    }
+    // Vertical filter.
+    intermediate = intermediate_result;
+    filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
+    for (int y = 0; y < height; ++y) {
+      for (int x = 0; x < width; ++x) {
+        // An offset to guarantee the sum is non negative.
+        int sum = 1 << 19;
+        for (int k = 0; k < vertical_taps; ++k) {
+          const int tap = k + vertical_taps_start;
+          sum += kSubPixelFilters[vert_filter_index][filter_id][tap] *
+                 intermediate[k * intermediate_stride + x];
+        }
+        dest[x] = static_cast<uint8_t>(
+            Clip3(RightShiftWithRounding(sum, 11) - 384, 0, max_pixel_value));
+      }
+      dest += dest_stride;
+      intermediate += intermediate_stride;
+    }
+  }
+}
+
+template <int tap_lane0, int tap_lane1>
+inline int16x8_t CombineFilterTapsLong(const int16x8_t sum,
+                                       const int16x8_t src0, int16x8_t src1,
+                                       int16x4_t taps0, int16x4_t taps1) {
+  int32x4_t sum_lo = vmovl_s16(vget_low_s16(sum));
+  int32x4_t sum_hi = vmovl_s16(vget_high_s16(sum));
+  const int16x8_t product0 = vmulq_lane_s16(src0, taps0, tap_lane0);
+  const int16x8_t product1 = vmulq_lane_s16(src1, taps1, tap_lane1);
+  const int32x4_t center_vals_lo =
+      vaddl_s16(vget_low_s16(product0), vget_low_s16(product1));
+  const int32x4_t center_vals_hi =
+      vaddl_s16(vget_high_s16(product0), vget_high_s16(product1));
+
+  sum_lo = vaddq_s32(sum_lo, center_vals_lo);
+  sum_hi = vaddq_s32(sum_hi, center_vals_hi);
+  return vcombine_s16(vrshrn_n_s32(sum_lo, 3), vrshrn_n_s32(sum_hi, 3));
+}
+
+// TODO(b/133525024): Replace usage of this function with version that uses
+// unsigned trick, once cl/263050071 is submitted.
+template <int num_taps>
+inline int16x8_t SumTapsCompound(const int16x8_t* const src,
+                                 const int16x8_t taps) {
+  int16x8_t sum = vdupq_n_s16(1 << (kBitdepth8 + kFilterBits - 1));
+  if (num_taps == 8) {
+    const int16x4_t taps_lo = vget_low_s16(taps);
+    const int16x4_t taps_hi = vget_high_s16(taps);
+    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 0);
+    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 1);
+    sum = vmlaq_lane_s16(sum, src[2], taps_lo, 2);
+    sum = vmlaq_lane_s16(sum, src[5], taps_hi, 1);
+    sum = vmlaq_lane_s16(sum, src[6], taps_hi, 2);
+    sum = vmlaq_lane_s16(sum, src[7], taps_hi, 3);
+
+    // Center taps may sum to as much as 160, which pollutes the sign bit in
+    // int16 types.
+    sum = CombineFilterTapsLong<3, 0>(sum, src[3], src[4], taps_lo, taps_hi);
+  } else if (num_taps == 6) {
+    const int16x4_t taps_lo = vget_low_s16(taps);
+    const int16x4_t taps_hi = vget_high_s16(taps);
+    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 0);
+    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 1);
+    sum = vmlaq_lane_s16(sum, src[4], taps_hi, 0);
+    sum = vmlaq_lane_s16(sum, src[5], taps_hi, 1);
+
+    // Center taps in filter 0 may sum to as much as 148, which pollutes the
+    // sign bit in int16 types. This is not true of filter 1.
+    sum = CombineFilterTapsLong<2, 3>(sum, src[2], src[3], taps_lo, taps_lo);
+  } else if (num_taps == 4) {
+    const int16x4_t taps_lo = vget_low_s16(taps);
+    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 0);
+    sum = vmlaq_lane_s16(sum, src[3], taps_lo, 3);
+
+    // Center taps.
+    sum = vqaddq_s16(sum, vmulq_lane_s16(src[1], taps_lo, 1));
+    sum = vrshrq_n_s16(vqaddq_s16(sum, vmulq_lane_s16(src[2], taps_lo, 2)),
+                       kInterRoundBitsHorizontal);
+  } else {
+    assert(num_taps == 2);
+    // All the taps are positive so there is no concern regarding saturation.
+    const int16x4_t taps_lo = vget_low_s16(taps);
+    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 0);
+    sum = vrshrq_n_s16(vmlaq_lane_s16(sum, src[1], taps_lo, 1),
+                       kInterRoundBitsHorizontal);
+  }
+  return sum;
+}
+
+// |grade_x| determines an upper limit on how many whole-pixel steps will be
+// realized with 8 |step_x| increments.
+template <int filter_index, int num_taps, int grade_x>
+inline void ConvolveHorizontalScaled_NEON(const uint8_t* src,
+                                          const ptrdiff_t src_stride,
+                                          const int width, const int subpixel_x,
+                                          const int step_x,
+                                          const int intermediate_height,
+                                          int16_t* dst) {
+  const int dst_stride = kMaxSuperBlockSizeInPixels;
+  const int kernel_offset = (8 - num_taps) / 2;
+  const int ref_x = subpixel_x >> kScaleSubPixelBits;
+  int y = intermediate_height;
+  do {  // y > 0
+    int p = subpixel_x;
+    int prev_p = p;
+    int x = 0;
+    int16x8_t s[(grade_x + 1) * 8];
+    const uint8_t* src_x =
+        &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+    Load8x8(src_x, src_stride, s);
+    Transpose8x8(s);
+    if (grade_x > 1) {
+      Load8x8(src_x + 8, src_stride, &s[8]);
+      Transpose8x8(&s[8]);
+    }
+
+    do {  // x < width
+      int16x8_t result[8];
+      src_x = &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+      // process 8 src_x steps
+      Load8x8(src_x + 8, src_stride, &s[8]);
+      Transpose8x8(&s[8]);
+      if (grade_x > 1) {
+        Load8x8(src_x + 16, src_stride, &s[16]);
+        Transpose8x8(&s[16]);
+      }
+      // Remainder after whole index increments.
+      int pixel_offset = p & ((1 << kScaleSubPixelBits) - 1);
+      for (int z = 0; z < 8; ++z) {
+        const int16x8_t filter = vld1q_s16(
+            &kSubPixelFilters[filter_index][(p >> 6) & 0xF][kernel_offset]);
+        result[z] = SumTapsCompound<num_taps>(
+            &s[pixel_offset >> kScaleSubPixelBits], filter);
+        pixel_offset += step_x;
+        p += step_x;
+      }
+
+      // Transpose the 8x8 filtered values back to dst.
+      Transpose8x8(result);
+
+      vst1q_s16(&dst[x + 0 * dst_stride], result[0]);
+      vst1q_s16(&dst[x + 1 * dst_stride], result[1]);
+      vst1q_s16(&dst[x + 2 * dst_stride], result[2]);
+      vst1q_s16(&dst[x + 3 * dst_stride], result[3]);
+      vst1q_s16(&dst[x + 4 * dst_stride], result[4]);
+      vst1q_s16(&dst[x + 5 * dst_stride], result[5]);
+      vst1q_s16(&dst[x + 6 * dst_stride], result[6]);
+      vst1q_s16(&dst[x + 7 * dst_stride], result[7]);
+
+      for (int i = 0; i < 8; ++i) {
+        s[i] =
+            s[(p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits) + i];
+        if (grade_x > 1) {
+          s[i + 8] = s[(p >> kScaleSubPixelBits) -
+                       (prev_p >> kScaleSubPixelBits) + i + 8];
+        }
+      }
+
+      prev_p = p;
+      x += 8;
+    } while (x < width);
+
+    src += src_stride * 8;
+    dst += dst_stride * 8;
+    y -= 8;
+  } while (y > 0);
+}
+
+inline uint8x16_t GetPositive2TapFilter(const int tap_index) {
+  assert(tap_index < 2);
+  constexpr uint8_t kSubPixel2TapFilterColumns[2][16] = {
+      {128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8},
+      {0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120}};
+
+  return vld1q_u8(kSubPixel2TapFilterColumns[tap_index]);
+}
+
+inline void ConvolveKernelHorizontal2Tap(const uint8_t* src,
+                                         const ptrdiff_t src_stride,
+                                         const int width, const int subpixel_x,
+                                         const int step_x,
+                                         const int intermediate_height,
+                                         int16_t* intermediate) {
+  const int kIntermediateStride = kMaxSuperBlockSizeInPixels;
+  // Account for the 0-taps that precede the 2 nonzero taps.
+  const int kernel_offset = 3;
+  const int ref_x = subpixel_x >> kScaleSubPixelBits;
+  const int step_x8 = step_x << 3;
+  const uint8x16_t filter_taps0 = GetPositive2TapFilter(0);
+  const uint8x16_t filter_taps1 = GetPositive2TapFilter(1);
+  const uint16x8_t sum = vdupq_n_u16(1 << (kBitdepth8 + kFilterBits - 1));
+  uint16x8_t index_steps = vmulq_n_u16(vmovl_u8(vcreate_u8(0x0706050403020100)),
+                                       static_cast<uint16_t>(step_x));
+
+  const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+  for (int x = 0, p = subpixel_x; x < width; x += 8, p += step_x8) {
+    const uint8_t* src_x =
+        &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+    int16_t* intermediate_x = intermediate + x;
+    // Only add steps to the 10-bit truncated p to avoid overflow.
+    const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+    const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+    const uint8x8_t filter_indices =
+        vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask);
+    // This is a special case. The 2-tap filter has no negative taps, so we
+    // can use unsigned values.
+    // For each x, a lane of tapsK has
+    // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
+    // on x.
+    const uint8x8_t taps0 = VQTbl1U8(filter_taps0, filter_indices);
+    const uint8x8_t taps1 = VQTbl1U8(filter_taps1, filter_indices);
+    for (int y = 0; y < intermediate_height; ++y) {
+      // Load a pool of samples to select from using stepped indices.
+      uint8x16_t src_vals = vld1q_u8(src_x);
+      const uint8x8_t src_indices =
+          vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
+
+      // For each x, a lane of srcK contains src_x[k].
+      const uint8x8_t src0 = VQTbl1U8(src_vals, src_indices);
+      const uint8x8_t src1 =
+          VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)));
+
+      const uint16x8_t product0 = vmlal_u8(sum, taps0, src0);
+      // product0 + product1
+      const uint16x8_t result = vmlal_u8(product0, taps1, src1);
+
+      vst1q_s16(intermediate_x, vreinterpretq_s16_u16(vrshrq_n_u16(result, 3)));
+      src_x += src_stride;
+      intermediate_x += kIntermediateStride;
+    }
+  }
+}
+
+inline uint8x16_t GetPositive4TapFilter(const int tap_index) {
+  assert(tap_index < 4);
+  constexpr uint8_t kSubPixel4TapPositiveFilterColumns[4][16] = {
+      {0, 30, 26, 22, 20, 18, 16, 14, 12, 12, 10, 8, 6, 4, 4, 2},
+      {128, 62, 62, 62, 60, 58, 56, 54, 52, 48, 46, 44, 42, 40, 36, 34},
+      {0, 34, 36, 40, 42, 44, 46, 48, 52, 54, 56, 58, 60, 62, 62, 62},
+      {0, 2, 4, 4, 6, 8, 10, 12, 12, 14, 16, 18, 20, 22, 26, 30}};
+
+  uint8x16_t filter_taps =
+      vld1q_u8(kSubPixel4TapPositiveFilterColumns[tap_index]);
+  return filter_taps;
+}
+
+// This filter is only possible when width <= 4.
+inline void ConvolveKernelHorizontalPositive4Tap(
+    const uint8_t* src, const ptrdiff_t src_stride, const int subpixel_x,
+    const int step_x, const int intermediate_height, int16_t* intermediate) {
+  const int kernel_offset = 2;
+  const int ref_x = subpixel_x >> kScaleSubPixelBits;
+  const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+  const uint8x16_t filter_taps0 = GetPositive4TapFilter(0);
+  const uint8x16_t filter_taps1 = GetPositive4TapFilter(1);
+  const uint8x16_t filter_taps2 = GetPositive4TapFilter(2);
+  const uint8x16_t filter_taps3 = GetPositive4TapFilter(3);
+  uint16x8_t index_steps = vmulq_n_u16(vmovl_u8(vcreate_u8(0x0706050403020100)),
+                                       static_cast<uint16_t>(step_x));
+  int p = subpixel_x;
+  const uint16x8_t base = vdupq_n_u16(1 << (kBitdepth8 + kFilterBits - 1));
+  // First filter is special, just a 128 tap on the center.
+  const uint8_t* src_x =
+      &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+  // Only add steps to the 10-bit truncated p to avoid overflow.
+  const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+  const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+  const uint8x8_t filter_indices =
+      vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask);
+  // Note that filter_id depends on x.
+  // For each x, tapsK has kSubPixelFilters[filter_index][filter_id][k].
+  const uint8x8_t taps0 = VQTbl1U8(filter_taps0, filter_indices);
+  const uint8x8_t taps1 = VQTbl1U8(filter_taps1, filter_indices);
+  const uint8x8_t taps2 = VQTbl1U8(filter_taps2, filter_indices);
+  const uint8x8_t taps3 = VQTbl1U8(filter_taps3, filter_indices);
+
+  const uint8x8_t src_indices =
+      vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
+  for (int y = 0; y < intermediate_height; ++y) {
+    // Load a pool of samples to select from using stepped index vectors.
+    uint8x16_t src_vals = vld1q_u8(src_x);
+
+    // For each x, srcK contains src_x[k] where k=1.
+    // Whereas taps come from different arrays, src pixels are drawn from the
+    // same contiguous line.
+    const uint8x8_t src0 = VQTbl1U8(src_vals, src_indices);
+    const uint8x8_t src1 =
+        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)));
+    const uint8x8_t src2 =
+        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(2)));
+    const uint8x8_t src3 =
+        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(3)));
+
+    uint16x8_t sum = vmlal_u8(base, taps0, src0);
+    sum = vmlal_u8(sum, taps1, src1);
+    sum = vmlal_u8(sum, taps2, src2);
+    sum = vmlal_u8(sum, taps3, src3);
+
+    vst1_s16(intermediate,
+             vreinterpret_s16_u16(vrshr_n_u16(vget_low_u16(sum), 3)));
+
+    src_x += src_stride;
+    intermediate += kIntermediateStride;
+  }
+}
+
+inline uint8x16_t GetSigned4TapFilter(const int tap_index) {
+  assert(tap_index < 4);
+  // The first and fourth taps of each filter are negative. However
+  // 128 does not fit in an 8-bit signed integer. Thus we use subtraction to
+  // keep everything unsigned.
+  constexpr uint8_t kSubPixel4TapSignedFilterColumns[4][16] = {
+      {0, 4, 8, 10, 12, 12, 14, 12, 12, 10, 10, 10, 8, 6, 4, 2},
+      {128, 126, 122, 116, 110, 102, 94, 84, 76, 66, 58, 48, 38, 28, 18, 8},
+      {0, 8, 18, 28, 38, 48, 58, 66, 76, 84, 94, 102, 110, 116, 122, 126},
+      {0, 2, 4, 6, 8, 10, 10, 10, 12, 12, 14, 12, 12, 10, 8, 4}};
+
+  uint8x16_t filter_taps =
+      vld1q_u8(kSubPixel4TapSignedFilterColumns[tap_index]);
+  return filter_taps;
+}
+
+// This filter is only possible when width <= 4.
+inline void ConvolveKernelHorizontalSigned4Tap(
+    const uint8_t* src, const ptrdiff_t src_stride, const int subpixel_x,
+    const int step_x, const int intermediate_height, int16_t* intermediate) {
+  const int kernel_offset = 2;
+  const int ref_x = subpixel_x >> kScaleSubPixelBits;
+  const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+  const uint8x16_t filter_taps0 = GetSigned4TapFilter(0);
+  const uint8x16_t filter_taps1 = GetSigned4TapFilter(1);
+  const uint8x16_t filter_taps2 = GetSigned4TapFilter(2);
+  const uint8x16_t filter_taps3 = GetSigned4TapFilter(3);
+  const uint16x8_t index_steps = vmulq_n_u16(vmovl_u8(vcreate_u8(0x03020100)),
+                                             static_cast<uint16_t>(step_x));
+
+  const uint16x8_t base = vdupq_n_u16(1 << (kBitdepth8 + kFilterBits - 1));
+  int p = subpixel_x;
+  const uint8_t* src_x =
+      &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+  // Only add steps to the 10-bit truncated p to avoid overflow.
+  const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+  const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+  const uint8x8_t filter_indices =
+      vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask);
+  // Note that filter_id depends on x.
+  // For each x, tapsK has kSubPixelFilters[filter_index][filter_id][k].
+  const uint8x8_t taps0 = VQTbl1U8(filter_taps0, filter_indices);
+  const uint8x8_t taps1 = VQTbl1U8(filter_taps1, filter_indices);
+  const uint8x8_t taps2 = VQTbl1U8(filter_taps2, filter_indices);
+  const uint8x8_t taps3 = VQTbl1U8(filter_taps3, filter_indices);
+  for (int y = 0; y < intermediate_height; ++y) {
+    // Load a pool of samples to select from using stepped indices.
+    uint8x16_t src_vals = vld1q_u8(src_x);
+    const uint8x8_t src_indices =
+        vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
+
+    // For each x, srcK contains src_x[k] where k=1.
+    // Whereas taps come from different arrays, src pixels are drawn from the
+    // same contiguous line.
+    const uint8x8_t src0 = VQTbl1U8(src_vals, src_indices);
+    const uint8x8_t src1 =
+        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)));
+    const uint8x8_t src2 =
+        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(2)));
+    const uint8x8_t src3 =
+        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(3)));
+
+    // Offsetting by base permits a guaranteed positive.
+    uint16x8_t sum = vmlsl_u8(base, taps0, src0);
+    sum = vmlal_u8(sum, taps1, src1);
+    sum = vmlal_u8(sum, taps2, src2);
+    sum = vmlsl_u8(sum, taps3, src3);
+
+    vst1_s16(intermediate,
+             vreinterpret_s16_u16(vrshr_n_u16(vget_low_u16(sum), 3)));
+    src_x += src_stride;
+    intermediate += kIntermediateStride;
+  }
+}
+
+void ConvolveCompoundScale2D_NEON(
+    const void* const reference, const ptrdiff_t reference_stride,
+    const int horizontal_filter_index, const int vertical_filter_index,
+    const uint8_t inter_round_bits_vertical, const int subpixel_x,
+    const int subpixel_y, const int step_x, const int step_y, const int width,
+    const int height, void* prediction, const ptrdiff_t pred_stride) {
+  const int intermediate_height =
+      (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
+       kScaleSubPixelBits) +
+      kSubPixelTaps;
+  // TODO(b/133525024): Decide whether it's worth branching to a special case
+  // when step_x or step_y is 1024.
+  assert(step_x <= 2048);
+  // The output of the horizontal filter, i.e. the intermediate_result, is
+  // guaranteed to fit in int16_t.
+  int16_t intermediate_result[kMaxSuperBlockSizeInPixels *
+                              (2 * kMaxSuperBlockSizeInPixels + 8)];
+
+  // Horizontal filter.
+  // Filter types used for width <= 4 are different from those for width > 4.
+  // When width > 4, the valid filter index range is always [0, 3].
+  // When width <= 4, the valid filter index range is always [3, 5].
+  // Similarly for height.
+  const int kIntermediateStride = kMaxSuperBlockSizeInPixels;
+  int filter_index = GetFilterIndex(horizontal_filter_index, width);
+  int16_t* intermediate = intermediate_result;
+  const auto* src = static_cast<const uint8_t*>(reference);
+  const ptrdiff_t src_stride = reference_stride;
+  auto* dest = static_cast<uint16_t*>(prediction);
+  switch (filter_index) {
+    case 0:
+      if (step_x < 1024) {
+        ConvolveHorizontalScaled_NEON<0, 6, 1>(
+            src, src_stride, width, subpixel_x, step_x, intermediate_height,
+            intermediate);
+      } else {
+        ConvolveHorizontalScaled_NEON<0, 6, 2>(
+            src, src_stride, width, subpixel_x, step_x, intermediate_height,
+            intermediate);
+      }
+      break;
+    case 1:
+      if (step_x < 1024) {
+        ConvolveHorizontalScaled_NEON<1, 6, 1>(
+            src, src_stride, width, subpixel_x, step_x, intermediate_height,
+            intermediate);
+      } else {
+        ConvolveHorizontalScaled_NEON<1, 6, 2>(
+            src, src_stride, width, subpixel_x, step_x, intermediate_height,
+            intermediate);
+      }
+      break;
+    case 2:
+      if (step_x <= 1024) {
+        ConvolveHorizontalScaled_NEON<2, 8, 1>(
+            src, src_stride, width, subpixel_x, step_x, intermediate_height,
+            intermediate);
+      } else {
+        ConvolveHorizontalScaled_NEON<2, 8, 2>(
+            src, src_stride, width, subpixel_x, step_x, intermediate_height,
+            intermediate);
+      }
+      break;
+    case 3:
+      ConvolveKernelHorizontal2Tap(src, src_stride, width, subpixel_x, step_x,
+                                   intermediate_height, intermediate);
+      break;
+    case 4:
+      assert(width <= 4);
+      ConvolveKernelHorizontalSigned4Tap(src, src_stride, subpixel_x, step_x,
+                                         intermediate_height, intermediate);
+      break;
+    default:
+      assert(filter_index == 5);
+      ConvolveKernelHorizontalPositive4Tap(src, src_stride, subpixel_x, step_x,
+                                           intermediate_height, intermediate);
+  }
+  // Vertical filter.
+  filter_index = GetFilterIndex(vertical_filter_index, height);
+  intermediate = intermediate_result;
+  const int offset_bits = kBitdepth8 + 2 * kFilterBits - 3;
+  for (int y = 0, p = subpixel_y & 1023; y < height; ++y, p += step_y) {
+    const int filter_id = (p >> 6) & kSubPixelMask;
+    for (int x = 0; x < width; ++x) {
+      // An offset to guarantee the sum is non negative.
+      int sum = 1 << offset_bits;
+      for (int k = 0; k < kSubPixelTaps; ++k) {
+        sum +=
+            kSubPixelFilters[filter_index][filter_id][k] *
+            intermediate[((p >> kScaleSubPixelBits) + k) * kIntermediateStride +
+                         x];
+      }
+      assert(sum >= 0 && sum < (1 << (offset_bits + 2)));
+      dest[x] = static_cast<uint16_t>(
+          RightShiftWithRounding(sum, inter_round_bits_vertical));
+    }
+    dest += pred_stride;
+  }
+}
+
+void ConvolveHorizontal_NEON(const void* const reference,
+                             const ptrdiff_t reference_stride,
+                             const int horizontal_filter_index,
+                             const int /*vertical_filter_index*/,
+                             const uint8_t /*inter_round_bits_vertical*/,
+                             const int subpixel_x, const int /*subpixel_y*/,
+                             const int /*step_x*/, const int /*step_y*/,
+                             const int width, const int height,
+                             void* prediction, const ptrdiff_t pred_stride) {
+  // For 8 (and 10) bit calculations |inter_round_bits_horizontal| is 3.
+  const int filter_index = GetFilterIndex(horizontal_filter_index, width);
+  // Set |src| to the outermost tap.
+  const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
+  const ptrdiff_t src_stride = reference_stride;
+  auto* dest = static_cast<uint8_t*>(prediction);
+  const ptrdiff_t dest_stride = pred_stride;
+  const int filter_id = (subpixel_x >> 6) & kSubPixelMask;
+  const int block_output_height = std::min(height, 8);
+  const int16x8_t four = vdupq_n_s16(4);
+
+  int16x8_t taps;
+  if (filter_index < 3) {
+    // 6 and 8 tap filters.
+    taps = vld1q_s16(kSubPixelFilters[filter_index][filter_id]);
+  } else {
+    // The 2 tap filter only uses the lower half of |taps|.
+    taps = vld1q_s16(kSubPixelFilters[filter_index][filter_id] + 2);
+  }
+
+  // TODO(johannkoenig): specialize small |height| variants so we don't
+  // overread |reference|.
+  if (width > 4 && height > 4) {
+    int y = 0;
+    do {
+      // This was intended to load and transpose 16 values before the |width|
+      // loop. At the end of the loop it would keep 8 of those values and only
+      // load and transpose 8 additional values. Unfortunately the approach did
+      // not appear to provide any benefit.
+      int x = 0;
+      do {
+        uint8x16_t temp[8];
+        uint8x8_t input[16];
+        for (int i = 0; i < 8; ++i) {
+          temp[i] = vld1q_u8(src + x + i * src_stride);
+        }
+        // TODO(johannkoenig): It should be possible to get the transpose
+        // started with vld4().
+        Transpose16x8(temp, input);
+        int16x8_t input16[16];
+        for (int i = 0; i < 16; ++i) {
+          input16[i] = ZeroExtend(input[i]);
+        }
+
+        // This does not handle |filter_index| > 3 because those 4 tap filters
+        // are only used when |width| <= 4.
+        // TODO(johannkoenig): Explore moving the branch outside the main loop.
+        uint8x8_t output[8];
+        if (filter_index == 2) {  // 8 taps.
+          for (int i = 0; i < 8; ++i) {
+            const int16x8_t neon_sums = SumTaps<8>(input16 + i, taps);
+            output[i] =
+                vqrshrun_n_s16(vqaddq_s16(neon_sums, four), kFilterBits);
+          }
+        } else if (filter_index < 2) {  // 6 taps.
+          for (int i = 0; i < 8; ++i) {
+            const int16x8_t neon_sums = SumTaps<6>(input16 + i + 1, taps);
+            output[i] =
+                vqrshrun_n_s16(vqaddq_s16(neon_sums, four), kFilterBits);
+          }
+        } else {  // |filter_index| == 3. 2 taps.
+          for (int i = 0; i < 8; ++i) {
+            const int16x8_t neon_sums = SumTaps<2>(input16 + i + 3, taps);
+            output[i] =
+                vqrshrun_n_s16(vqaddq_s16(neon_sums, four), kFilterBits);
+          }
+        }
+
+        Transpose8x8(output);
+
+        int i = 0;
+        do {
+          vst1_u8(dest + x + i * dest_stride, output[i]);
+        } while (++i < block_output_height);
+        x += 8;
+      } while (x < width);
+      y += 8;
+      src += 8 * src_stride;
+      dest += 8 * dest_stride;
+    } while (y < height);
+  } else {
+    // TODO(johannkoenig): Investigate 2xH and 4xH. During the original
+    // implementation 4x2 was slower than C, 4x4 reached parity, and 4x8
+    // was < 20% faster.
+    for (int y = 0; y < height; ++y) {
+      for (int x = 0; x < width; ++x) {
+        int sum = 0;
+        for (int k = 0; k < kSubPixelTaps; ++k) {
+          sum += kSubPixelFilters[filter_index][filter_id][k] * src[x + k];
+        }
+        // We can combine the shifts if we compensate for the skipped rounding.
+        // ((sum + 4 >> 3) + 8) >> 4 == (sum + 64 + 4) >> 7;
+        dest[x] = static_cast<uint8_t>(
+            Clip3(RightShiftWithRounding(sum + 4, kFilterBits), 0, 255));
+      }
+      src += src_stride;
+      dest += dest_stride;
+    }
+  }
+}
+
+template <int min_width, int num_taps>
+void FilterVertical(const uint8_t* src, const ptrdiff_t src_stride,
+                    uint8_t* dst, const ptrdiff_t dst_stride, const int width,
+                    const int height, const int16x8_t taps) {
+  constexpr int next_row = num_taps - 1;
+  // |src| points to the outermost tap of the first value. When doing fewer than
+  // 8 taps it needs to be adjusted.
+  if (num_taps == 6) {
+    src += src_stride;
+  } else if (num_taps == 4) {
+    src += 2 * src_stride;
+  } else if (num_taps == 2) {
+    src += 3 * src_stride;
+  }
+
+  int x = 0;
+  do {
+    int16x8_t srcs[8];
+    srcs[0] = ZeroExtend(vld1_u8(src + x));
+    if (num_taps >= 4) {
+      srcs[1] = ZeroExtend(vld1_u8(src + x + src_stride));
+      srcs[2] = ZeroExtend(vld1_u8(src + x + 2 * src_stride));
+      if (num_taps >= 6) {
+        srcs[3] = ZeroExtend(vld1_u8(src + x + 3 * src_stride));
+        srcs[4] = ZeroExtend(vld1_u8(src + x + 4 * src_stride));
+        if (num_taps == 8) {
+          srcs[5] = ZeroExtend(vld1_u8(src + x + 5 * src_stride));
+          srcs[6] = ZeroExtend(vld1_u8(src + x + 6 * src_stride));
+        }
+      }
+    }
+
+    int y = 0;
+    do {
+      srcs[next_row] =
+          ZeroExtend(vld1_u8(src + x + (y + next_row) * src_stride));
+
+      const int16x8_t sums = SumTaps<num_taps>(srcs, taps);
+      const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits);
+
+      if (min_width == 4) {
+        StoreLo4(dst + x + y * dst_stride, results);
+      } else {
+        vst1_u8(dst + x + y * dst_stride, results);
+      }
+
+      srcs[0] = srcs[1];
+      if (num_taps >= 4) {
+        srcs[1] = srcs[2];
+        srcs[2] = srcs[3];
+        if (num_taps >= 6) {
+          srcs[3] = srcs[4];
+          srcs[4] = srcs[5];
+          if (num_taps == 8) {
+            srcs[5] = srcs[6];
+            srcs[6] = srcs[7];
+          }
+        }
+      }
+    } while (++y < height);
+    x += 8;
+  } while (x < width);
+}
+
+// This function is a simplified version of Convolve2D_C.
+// It is called when it is single prediction mode, where only vertical
+// filtering is required.
+// The output is the single prediction of the block, clipped to valid pixel
+// range.
+void ConvolveVertical_NEON(const void* const reference,
+                           const ptrdiff_t reference_stride,
+                           const int /*horizontal_filter_index*/,
+                           const int vertical_filter_index,
+                           const uint8_t /*inter_round_bits_vertical*/,
+                           const int /*subpixel_x*/, const int subpixel_y,
+                           const int /*step_x*/, const int /*step_y*/,
+                           const int width, const int height, void* prediction,
+                           const ptrdiff_t pred_stride) {
+  const int filter_index = GetFilterIndex(vertical_filter_index, height);
+  const ptrdiff_t src_stride = reference_stride;
+  const auto* src =
+      static_cast<const uint8_t*>(reference) - kVerticalOffset * src_stride;
+  auto* dest = static_cast<uint8_t*>(prediction);
+  const ptrdiff_t dest_stride = pred_stride;
+  const int filter_id = (subpixel_y >> 6) & kSubPixelMask;
+  // First filter is always a copy.
+  if (filter_id == 0) {
+    // Move |src| down the actual values and not the start of the context.
+    src = static_cast<const uint8_t*>(reference);
+    int y = 0;
+    do {
+      memcpy(dest, src, width * sizeof(src[0]));
+      src += src_stride;
+      dest += dest_stride;
+    } while (++y < height);
+    return;
+  }
+
+  // Break up by # of taps
+  // |filter_index| taps  enum InterpolationFilter
+  //        0       6     kInterpolationFilterEightTap
+  //        1       6     kInterpolationFilterEightTapSmooth
+  //        2       8     kInterpolationFilterEightTapSharp
+  //        3       2     kInterpolationFilterBilinear
+  //        4       4     kInterpolationFilterSwitchable
+  //        5       4     !!! SECRET FILTER !!! only for Wx4.
+  if (width >= 4) {
+    if (filter_index == 2) {  // 8 tap.
+      const int16x8_t taps =
+          vld1q_s16(kSubPixelFilters[filter_index][filter_id]);
+      if (width == 4) {
+        FilterVertical<4, 8>(src, src_stride, dest, dest_stride, width, height,
+                             taps);
+      } else {
+        FilterVertical<8, 8>(src, src_stride, dest, dest_stride, width, height,
+                             taps);
+      }
+    } else if (filter_index < 2) {  // 6 tap.
+      const int16x8_t taps =
+          vld1q_s16(kSubPixelFilters[filter_index][filter_id]);
+      if (width == 4) {
+        FilterVertical<4, 6>(src, src_stride, dest, dest_stride, width, height,
+                             taps);
+      } else {
+        FilterVertical<8, 6>(src, src_stride, dest, dest_stride, width, height,
+                             taps);
+      }
+    } else if (filter_index > 3) {  // 4 tap.
+      // Store taps in vget_low_s16(taps).
+      const int16x8_t taps =
+          vld1q_s16(kSubPixelFilters[filter_index][filter_id] + 2);
+      if (width == 4) {
+        FilterVertical<4, 4>(src, src_stride, dest, dest_stride, width, height,
+                             taps);
+      } else {
+        FilterVertical<8, 4>(src, src_stride, dest, dest_stride, width, height,
+                             taps);
+      }
+    } else {  // 2 tap.
+      // Store taps in vget_low_s16(taps).
+      const int16x8_t taps =
+          vld1q_s16(kSubPixelFilters[filter_index][filter_id] + 2);
+      if (width == 4) {
+        FilterVertical<4, 2>(src, src_stride, dest, dest_stride, width, height,
+                             taps);
+      } else {
+        FilterVertical<8, 2>(src, src_stride, dest, dest_stride, width, height,
+                             taps);
+      }
+    }
+  } else {
+    // TODO(johannkoenig): Determine if it is worth writing a 2xH
+    // implementation.
+    assert(width == 2);
+    const int max_pixel_value = 255;
+    int y = 0;
+    do {
+      for (int x = 0; x < 2; ++x) {
+        int sum = 0;
+        for (int k = 0; k < kSubPixelTaps; ++k) {
+          sum += kSubPixelFilters[filter_index][filter_id][k] *
+                 src[k * src_stride + x];
+        }
+        dest[x] = static_cast<uint8_t>(Clip3(
+            RightShiftWithRounding(sum, kFilterBits), 0, max_pixel_value));
+      }
+      src += src_stride;
+      dest += pred_stride;
+    } while (++y < height);
+  }
+}
+
+void ConvolveCompoundCopy_NEON(
+    const void* const reference, const ptrdiff_t reference_stride,
+    const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
+    const uint8_t /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
+    const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
+    const int width, const int height, void* prediction,
+    const ptrdiff_t pred_stride) {
+  const auto* src = static_cast<const uint8_t*>(reference);
+  const ptrdiff_t src_stride = reference_stride;
+  auto* dest = static_cast<uint16_t*>(prediction);
+  const int bitdepth = 8;
+  const int compound_round_offset =
+      (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
+  const uint16x8_t v_compound_round_offset = vdupq_n_u16(compound_round_offset);
+
+  if (width >= 16) {
+    int y = 0;
+    do {
+      int x = 0;
+      do {
+        const uint8x16_t v_src = vld1q_u8(&src[x]);
+        const uint16x8_t v_src_x16_lo = vshll_n_u8(vget_low_u8(v_src), 4);
+        const uint16x8_t v_src_x16_hi = vshll_n_u8(vget_high_u8(v_src), 4);
+        const uint16x8_t v_dest_lo =
+            vaddq_u16(v_src_x16_lo, v_compound_round_offset);
+        const uint16x8_t v_dest_hi =
+            vaddq_u16(v_src_x16_hi, v_compound_round_offset);
+        vst1q_u16(&dest[x], v_dest_lo);
+        x += 8;
+        vst1q_u16(&dest[x], v_dest_hi);
+        x += 8;
+      } while (x < width);
+      src += src_stride;
+      dest += pred_stride;
+    } while (++y < height);
+  } else if (width == 8) {
+    int y = 0;
+    do {
+      const uint8x8_t v_src = vld1_u8(&src[0]);
+      const uint16x8_t v_src_x16 = vshll_n_u8(v_src, 4);
+      vst1q_u16(&dest[0], vaddq_u16(v_src_x16, v_compound_round_offset));
+      src += src_stride;
+      dest += pred_stride;
+    } while (++y < height);
+  } else if (width == 4) {
+    const uint8x8_t zero = vdup_n_u8(0);
+    int y = 0;
+    do {
+      const uint8x8_t v_src = LoadLo4(&src[0], zero);
+      const uint16x8_t v_src_x16 = vshll_n_u8(v_src, 4);
+      const uint16x8_t v_dest = vaddq_u16(v_src_x16, v_compound_round_offset);
+      vst1_u16(&dest[0], vget_low_u16(v_dest));
+      src += src_stride;
+      dest += pred_stride;
+    } while (++y < height);
+  } else {  // width == 2
+    assert(width == 2);
+    int y = 0;
+    do {
+      dest[0] = (src[0] << 4) + compound_round_offset;
+      dest[1] = (src[1] << 4) + compound_round_offset;
+      src += src_stride;
+      dest += pred_stride;
+    } while (++y < height);
+  }
+}
+
+// Input 8 bits and output 16 bits.
+template <int min_width, int num_taps>
+void FilterCompoundVertical(const uint8_t* src, const ptrdiff_t src_stride,
+                            uint16_t* dst, const ptrdiff_t dst_stride,
+                            const int width, const int height,
+                            const int16x8_t taps) {
+  constexpr int next_row = num_taps - 1;
+  // |src| points to the outermost tap of the first value. When doing fewer than
+  // 8 taps it needs to be adjusted.
+  if (num_taps == 6) {
+    src += src_stride;
+  } else if (num_taps == 4) {
+    src += 2 * src_stride;
+  } else if (num_taps == 2) {
+    src += 3 * src_stride;
+  }
+
+  const uint16x8_t compound_round_offset = vdupq_n_u16(1 << 12);
+
+  int x = 0;
+  do {
+    int16x8_t srcs[8];
+    srcs[0] = ZeroExtend(vld1_u8(src + x));
+    if (num_taps >= 4) {
+      srcs[1] = ZeroExtend(vld1_u8(src + x + src_stride));
+      srcs[2] = ZeroExtend(vld1_u8(src + x + 2 * src_stride));
+      if (num_taps >= 6) {
+        srcs[3] = ZeroExtend(vld1_u8(src + x + 3 * src_stride));
+        srcs[4] = ZeroExtend(vld1_u8(src + x + 4 * src_stride));
+        if (num_taps == 8) {
+          srcs[5] = ZeroExtend(vld1_u8(src + x + 5 * src_stride));
+          srcs[6] = ZeroExtend(vld1_u8(src + x + 6 * src_stride));
+        }
+      }
+    }
+
+    int y = 0;
+    do {
+      srcs[next_row] =
+          ZeroExtend(vld1_u8(src + x + (y + next_row) * src_stride));
+
+      const uint16x8_t sums = SumTaps8To16<num_taps>(srcs, taps);
+      const uint16x8_t shifted = vrshrq_n_u16(sums, 3);
+      // In order to keep the sum in 16 bits we add an offset to the sum
+      // (1 << (bitdepth + kFilterBits - 1) == 1 << 14). This ensures that the
+      // results will never be negative.
+      // Normally ConvolveCompoundVertical would add |compound_round_offset| at
+      // the end. Instead we use that to compensate for the initial offset.
+      // (1 << (bitdepth + 4)) + (1 << (bitdepth + 3)) == (1 << 12) + (1 << 11)
+      // After taking into account the shift above:
+      // RightShiftWithRounding(LeftShift(sum, bits_shift), inter_round_bits[1])
+      // where bits_shift == kFilterBits - inter_round_bits[0] == 4
+      // and inter_round_bits[1] == 7
+      // and simplifying it to RightShiftWithRounding(sum, 3)
+      // we see that the initial offset of 1 << 14 >> 3 == 1 << 11 and
+      // |compound_round_offset| can be simplified to 1 << 12.
+      const uint16x8_t offset = vaddq_u16(shifted, compound_round_offset);
+
+      if (min_width == 4) {
+        vst1_u16(dst + x + y * dst_stride, vget_low_u16(offset));
+      } else {
+        vst1q_u16(dst + x + y * dst_stride, offset);
+      }
+
+      srcs[0] = srcs[1];
+      if (num_taps >= 4) {
+        srcs[1] = srcs[2];
+        srcs[2] = srcs[3];
+        if (num_taps >= 6) {
+          srcs[3] = srcs[4];
+          srcs[4] = srcs[5];
+          if (num_taps == 8) {
+            srcs[5] = srcs[6];
+            srcs[6] = srcs[7];
+          }
+        }
+      }
+    } while (++y < height);
+    x += 8;
+  } while (x < width);
+}
+
+void ConvolveCompoundVertical_NEON(
+    const void* const reference, const ptrdiff_t reference_stride,
+    const int /*horizontal_filter_index*/, const int vertical_filter_index,
+    const uint8_t /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
+    const int subpixel_y, const int /*step_x*/, const int /*step_y*/,
+    const int width, const int height, void* prediction,
+    const ptrdiff_t pred_stride) {
+  const int filter_index = GetFilterIndex(vertical_filter_index, height);
+  const ptrdiff_t src_stride = reference_stride;
+  const auto* src =
+      static_cast<const uint8_t*>(reference) - kVerticalOffset * src_stride;
+  auto* dest = static_cast<uint16_t*>(prediction);
+  const int filter_id = (subpixel_y >> 6) & kSubPixelMask;
+  const int compound_round_offset = 1 << 12;  // Leave off + 1 << 11.
+
+  if (width >= 4) {
+    const int16x8_t taps = vld1q_s16(kSubPixelFilters[filter_index][filter_id]);
+
+    if (filter_index == 2) {  // 8 tap.
+      if (width == 4) {
+        FilterCompoundVertical<4, 8>(src, src_stride, dest, pred_stride, width,
+                                     height, taps);
+      } else {
+        FilterCompoundVertical<8, 8>(src, src_stride, dest, pred_stride, width,
+                                     height, taps);
+      }
+    } else if (filter_index < 2) {  // 6 tap.
+      if (width == 4) {
+        FilterCompoundVertical<4, 6>(src, src_stride, dest, pred_stride, width,
+                                     height, taps);
+      } else {
+        FilterCompoundVertical<8, 6>(src, src_stride, dest, pred_stride, width,
+                                     height, taps);
+      }
+    } else if (filter_index == 3) {  // 2 tap.
+      if (width == 4) {
+        FilterCompoundVertical<4, 2>(src, src_stride, dest, pred_stride, width,
+                                     height, taps);
+      } else {
+        FilterCompoundVertical<8, 2>(src, src_stride, dest, pred_stride, width,
+                                     height, taps);
+      }
+    } else if (filter_index > 3) {  // 4 tap.
+      if (width == 4) {
+        FilterCompoundVertical<4, 4>(src, src_stride, dest, pred_stride, width,
+                                     height, taps);
+      } else {
+        FilterCompoundVertical<8, 4>(src, src_stride, dest, pred_stride, width,
+                                     height, taps);
+      }
+    }
+  } else {
+    assert(width == 2);
+    for (int y = 0; y < height; ++y) {
+      for (int x = 0; x < 2; ++x) {
+        // Use an offset to avoid 32 bits.
+        int sum = 1 << 14;
+        for (int k = 0; k < kSubPixelTaps; ++k) {
+          sum += kSubPixelFilters[filter_index][filter_id][k] *
+                 src[k * src_stride + x];
+        }
+        // |compound_round_offset| has been modified to take into account the
+        // offset used above. The 1 << 11 term cancels out with 1 << 14 >> 3.
+        dest[x] = RightShiftWithRounding(sum, 3) + compound_round_offset;
+      }
+      src += src_stride;
+      dest += pred_stride;
+    }
+  }
+}
+
+template <int num_taps, int filter_index, bool negative_outside_taps = true>
+uint16x8_t SumCompoundHorizontalTaps(const uint8_t* const src,
+                                     uint8x8_t* v_tap) {
+  // Start with an offset to guarantee the sum is non negative.
+  uint16x8_t v_sum = vdupq_n_u16(1 << 14);
+  uint8x16_t v_src[8];
+  v_src[0] = vld1q_u8(&src[0]);
+  if (num_taps == 8) {
+    v_src[1] = vextq_u8(v_src[0], v_src[0], 1);
+    v_src[2] = vextq_u8(v_src[0], v_src[0], 2);
+    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
+    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
+    v_src[5] = vextq_u8(v_src[0], v_src[0], 5);
+    v_src[6] = vextq_u8(v_src[0], v_src[0], 6);
+    v_src[7] = vextq_u8(v_src[0], v_src[0], 7);
+
+    // tap signs : - + - + + - + -
+    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[0]), v_tap[0]);
+    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
+    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
+    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
+    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
+    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
+    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
+    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[7]), v_tap[7]);
+  } else if (num_taps == 6) {
+    v_src[1] = vextq_u8(v_src[0], v_src[0], 1);
+    v_src[2] = vextq_u8(v_src[0], v_src[0], 2);
+    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
+    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
+    v_src[5] = vextq_u8(v_src[0], v_src[0], 5);
+    v_src[6] = vextq_u8(v_src[0], v_src[0], 6);
+    if (filter_index == 0) {
+      // tap signs : + - + + - +
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
+      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
+      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
+    } else {
+      if (negative_outside_taps) {
+        // tap signs : - + + + + -
+        v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
+        v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
+      } else {
+        // tap signs : + + + + + +
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
+        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
+      }
+    }
+  } else if (num_taps == 4) {
+    v_src[2] = vextq_u8(v_src[0], v_src[0], 2);
+    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
+    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
+    v_src[5] = vextq_u8(v_src[0], v_src[0], 5);
+    if (filter_index == 4) {
+      // tap signs : - + + -
+      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
+      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
+    } else {
+      // tap signs : + + + +
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
+      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
+    }
+  } else {
+    assert(num_taps == 2);
+    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
+    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
+    // tap signs : + +
+    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
+    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
+  }
+
+  return v_sum;
+}
+
+template <int num_taps, int step, int filter_index,
+          bool negative_outside_taps = true>
+void ConvolveCompoundHorizontalBlock(const uint8_t* src, ptrdiff_t src_stride,
+                                     uint16_t* dest, ptrdiff_t pred_stride,
+                                     const int width, const int height,
+                                     uint8x8_t* v_tap,
+                                     int16x8_t v_inter_round_bits_0,
+                                     int16x8_t v_bits_shift,
+                                     uint16x8_t v_compound_round_offset) {
+  if (width > 4) {
+    int y = 0;
+    do {
+      int x = 0;
+      do {
+        uint16x8_t v_sum =
+            SumCompoundHorizontalTaps<num_taps, filter_index,
+                                      negative_outside_taps>(&src[x], v_tap);
+        v_sum = vrshlq_u16(v_sum, v_inter_round_bits_0);
+        v_sum = vshlq_u16(v_sum, v_bits_shift);
+        v_sum = vaddq_u16(v_sum, v_compound_round_offset);
+        vst1q_u16(&dest[x], v_sum);
+        x += step;
+      } while (x < width);
+      src += src_stride;
+      dest += pred_stride;
+    } while (++y < height);
+  } else {
+    int y = 0;
+    do {
+      uint16x8_t v_sum =
+          SumCompoundHorizontalTaps<num_taps, filter_index,
+                                    negative_outside_taps>(&src[0], v_tap);
+      v_sum = vrshlq_u16(v_sum, v_inter_round_bits_0);
+      v_sum = vshlq_u16(v_sum, v_bits_shift);
+      v_sum = vaddq_u16(v_sum, v_compound_round_offset);
+      vst1_u16(&dest[0], vget_low_u16(v_sum));
+      src += src_stride;
+      dest += pred_stride;
+    } while (++y < height);
+  }
+}
+
+void ConvolveCompoundHorizontal_NEON(
+    const void* const reference, const ptrdiff_t reference_stride,
+    const int horizontal_filter_index, const int /*vertical_filter_index*/,
+    const uint8_t inter_round_bits_vertical, const int subpixel_x,
+    const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
+    const int width, const int height, void* prediction,
+    const ptrdiff_t pred_stride) {
+  const int filter_index = GetFilterIndex(horizontal_filter_index, width);
+  const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
+  const ptrdiff_t src_stride = reference_stride;
+  auto* dest = static_cast<uint16_t*>(prediction);
+  const int filter_id = (subpixel_x >> 6) & kSubPixelMask;
+  const int bits_shift = kFilterBits - inter_round_bits_vertical;
+
+  const int compound_round_offset =
+      (1 << (kBitdepth8 + 4)) + (1 << (kBitdepth8 + 3));
+
+  if (width >= 4) {
+    // Duplicate the absolute value for each tap.  Negative taps are corrected
+    // by using the vmlsl_u8 instruction.  Positive taps use vmlal_u8.
+    uint8x8_t v_tap[kSubPixelTaps];
+    for (int k = 0; k < kSubPixelTaps; ++k) {
+      v_tap[k] = vreinterpret_u8_s8(
+          vabs_s8(vdup_n_s8(kSubPixelFilters[filter_index][filter_id][k])));
+    }
+
+    const int16x8_t v_inter_round_bits_0 =
+        vdupq_n_s16(-kInterRoundBitsHorizontal);
+    const int16x8_t v_bits_shift = vdupq_n_s16(bits_shift);
+
+    const uint16x8_t v_compound_round_offset =
+        vdupq_n_u16(compound_round_offset - (1 << (kBitdepth8 + 3)));
+
+    if (filter_index == 2) {  // 8 tap.
+      ConvolveCompoundHorizontalBlock<8, 8, 2>(
+          src, src_stride, dest, pred_stride, width, height, v_tap,
+          v_inter_round_bits_0, v_bits_shift, v_compound_round_offset);
+    } else if (filter_index == 1) {  // 6 tap.
+      // Check if outside taps are positive.
+      if ((filter_id == 1) | (filter_id == 15)) {
+        ConvolveCompoundHorizontalBlock<6, 8, 1, false>(
+            src, src_stride, dest, pred_stride, width, height, v_tap,
+            v_inter_round_bits_0, v_bits_shift, v_compound_round_offset);
+      } else {
+        ConvolveCompoundHorizontalBlock<6, 8, 1>(
+            src, src_stride, dest, pred_stride, width, height, v_tap,
+            v_inter_round_bits_0, v_bits_shift, v_compound_round_offset);
+      }
+    } else if (filter_index == 0) {  // 6 tap.
+      ConvolveCompoundHorizontalBlock<6, 8, 0>(
+          src, src_stride, dest, pred_stride, width, height, v_tap,
+          v_inter_round_bits_0, v_bits_shift, v_compound_round_offset);
+    } else if (filter_index == 4) {  // 4 tap.
+      ConvolveCompoundHorizontalBlock<4, 8, 4>(
+          src, src_stride, dest, pred_stride, width, height, v_tap,
+          v_inter_round_bits_0, v_bits_shift, v_compound_round_offset);
+    } else if (filter_index == 5) {  // 4 tap.
+      ConvolveCompoundHorizontalBlock<4, 8, 5>(
+          src, src_stride, dest, pred_stride, width, height, v_tap,
+          v_inter_round_bits_0, v_bits_shift, v_compound_round_offset);
+    } else {  // 2 tap.
+      ConvolveCompoundHorizontalBlock<2, 8, 3>(
+          src, src_stride, dest, pred_stride, width, height, v_tap,
+          v_inter_round_bits_0, v_bits_shift, v_compound_round_offset);
+    }
+  } else {
+    // 2xH
+    int y = 0;
+    do {
+      for (int x = 0; x < 2; ++x) {
+        int sum = 0;
+        for (int k = 0; k < kSubPixelTaps; ++k) {
+          sum += kSubPixelFilters[filter_index][filter_id][k] * src[x + k];
+        }
+        sum = RightShiftWithRounding(sum, kInterRoundBitsHorizontal)
+              << bits_shift;
+        dest[x] = sum + compound_round_offset;
+      }
+      src += src_stride;
+      dest += pred_stride;
+    } while (++y < height);
+  }
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+  dsp->convolve[0][0][0][1] = ConvolveHorizontal_NEON;
+  dsp->convolve[0][0][1][0] = ConvolveVertical_NEON;
+  // TODO(b/139707209): reenable after segfault on android is fixed.
+  // dsp->convolve[0][0][1][1] = Convolve2D_NEON;
+  static_cast<void>(Convolve2D_NEON);
+
+  dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_NEON;
+  dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_NEON;
+  dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_NEON;
+
+  // dsp->convolve_scale[1] = ConvolveCompoundScale2D_NEON;
+  static_cast<void>(ConvolveCompoundScale2D_NEON);
+}
+
+}  // namespace
+}  // namespace low_bitdepth
+
+void ConvolveInit_NEON() { low_bitdepth::Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else   // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void ConvolveInit_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/convolve_neon.h b/libgav1/src/dsp/arm/convolve_neon.h
new file mode 100644
index 0000000..6b5873c
--- /dev/null
+++ b/libgav1/src/dsp/arm/convolve_neon.h
@@ -0,0 +1,29 @@
+#ifndef LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_
+
+#include "src/dsp/cpu.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::convolve. This function is not thread-safe.
+void ConvolveInit_NEON();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveHorizontal LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveVertical LIBGAV1_DSP_NEON
+// TODO(b/139707209): reenable after segfault on android is fixed.
+// #define LIBGAV1_Dsp8bpp_Convolve2D LIBGAV1_DSP_NEON
+
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundCopy LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundVertical LIBGAV1_DSP_NEON
+
+// #define LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D LIBGAV1_DSP_NEON
+#endif  // LIBGAV1_ENABLE_NEON
+
+#endif  // LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_
diff --git a/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc b/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc
new file mode 100644
index 0000000..d0a95a2
--- /dev/null
+++ b/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc
@@ -0,0 +1,237 @@
+#include "src/dsp/distance_weighted_blend.h"
+#include "src/dsp/dsp.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+constexpr int kBitdepth8 = 8;
+constexpr int kInterPostRoundBit = 4;
+
+const int16x8_t kCompoundRoundOffset =
+    vdupq_n_s16((1 << kBitdepth8) + (1 << (kBitdepth8 - 1)));
+
+inline int16x8_t ComputeWeightedAverage8(const uint16x8_t pred0,
+                                         const uint16x8_t pred1,
+                                         const uint16x4_t weights[2]) {
+  const uint32x4_t wpred0_lo = vmull_u16(weights[0], vget_low_u16(pred0));
+  const uint32x4_t wpred0_hi = vmull_u16(weights[0], vget_high_u16(pred0));
+  const uint32x4_t blended_lo =
+      vmlal_u16(wpred0_lo, weights[1], vget_low_u16(pred1));
+  const uint32x4_t blended_hi =
+      vmlal_u16(wpred0_hi, weights[1], vget_high_u16(pred1));
+
+  const uint16x4_t result_lo =
+      vqrshrn_n_u32(blended_lo, kInterPostRoundBit + 4);
+  const uint16x4_t result_hi =
+      vqrshrn_n_u32(blended_hi, kInterPostRoundBit + 4);
+  return vsubq_s16(vreinterpretq_s16_u16(vcombine_u16(result_lo, result_hi)),
+                   kCompoundRoundOffset);
+}
+
+template <int height>
+inline void DistanceWeightedBlend4xH_NEON(const uint16_t* prediction_0,
+                                          const ptrdiff_t prediction_stride_0,
+                                          const uint16_t* prediction_1,
+                                          const ptrdiff_t prediction_stride_1,
+                                          const uint16x4_t weights[2],
+                                          void* const dest,
+                                          const ptrdiff_t dest_stride) {
+  auto* dst = static_cast<uint8_t*>(dest);
+  const uint16_t* pred_0 = prediction_0;
+  const uint16_t* pred_1 = prediction_1;
+
+  for (int y = 0; y < height; y += 4) {
+    const uint16x4_t src_00 = vld1_u16(pred_0);
+    const uint16x4_t src_10 = vld1_u16(pred_1);
+    pred_0 += prediction_stride_0;
+    pred_1 += prediction_stride_1;
+    const uint16x4_t src_01 = vld1_u16(pred_0);
+    const uint16x4_t src_11 = vld1_u16(pred_1);
+    pred_0 += prediction_stride_0;
+    pred_1 += prediction_stride_1;
+    const int16x8_t res01 = ComputeWeightedAverage8(
+        vcombine_u16(src_00, src_01), vcombine_u16(src_10, src_11), weights);
+
+    const uint16x4_t src_02 = vld1_u16(pred_0);
+    const uint16x4_t src_12 = vld1_u16(pred_1);
+    pred_0 += prediction_stride_0;
+    pred_1 += prediction_stride_1;
+    const uint16x4_t src_03 = vld1_u16(pred_0);
+    const uint16x4_t src_13 = vld1_u16(pred_1);
+    pred_0 += prediction_stride_0;
+    pred_1 += prediction_stride_1;
+    const int16x8_t res23 = ComputeWeightedAverage8(
+        vcombine_u16(src_02, src_03), vcombine_u16(src_12, src_13), weights);
+
+    const uint8x8_t result_01 = vqmovun_s16(res01);
+    const uint8x8_t result_23 = vqmovun_s16(res23);
+    StoreLo4(dst, result_01);
+    dst += dest_stride;
+    StoreHi4(dst, result_01);
+    dst += dest_stride;
+    StoreLo4(dst, result_23);
+    dst += dest_stride;
+    StoreHi4(dst, result_23);
+    dst += dest_stride;
+  }
+}
+
+template <int height>
+inline void DistanceWeightedBlend8xH_NEON(const uint16_t* prediction_0,
+                                          const ptrdiff_t prediction_stride_0,
+                                          const uint16_t* prediction_1,
+                                          const ptrdiff_t prediction_stride_1,
+                                          const uint16x4_t weights[2],
+                                          void* const dest,
+                                          const ptrdiff_t dest_stride) {
+  auto* dst = static_cast<uint8_t*>(dest);
+  const uint16_t* pred_0 = prediction_0;
+  const uint16_t* pred_1 = prediction_1;
+
+  for (int y = 0; y < height; y += 2) {
+    const uint16x8_t src_00 = vld1q_u16(pred_0);
+    const uint16x8_t src_10 = vld1q_u16(pred_1);
+    pred_0 += prediction_stride_0;
+    pred_1 += prediction_stride_1;
+    const int16x8_t res0 = ComputeWeightedAverage8(src_00, src_10, weights);
+
+    const uint16x8_t src_01 = vld1q_u16(pred_0);
+    const uint16x8_t src_11 = vld1q_u16(pred_1);
+    pred_0 += prediction_stride_0;
+    pred_1 += prediction_stride_1;
+    const int16x8_t res1 = ComputeWeightedAverage8(src_01, src_11, weights);
+
+    const uint8x8_t result0 = vqmovun_s16(res0);
+    const uint8x8_t result1 = vqmovun_s16(res1);
+    vst1_u8(dst, result0);
+    dst += dest_stride;
+    vst1_u8(dst, result1);
+    dst += dest_stride;
+  }
+}
+
+inline void DistanceWeightedBlendLarge_NEON(
+    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
+    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
+    const uint16x4_t weights[2], const int width, const int height,
+    void* const dest, const ptrdiff_t dest_stride) {
+  auto* dst = static_cast<uint8_t*>(dest);
+  const uint16_t* pred_0 = prediction_0;
+  const uint16_t* pred_1 = prediction_1;
+
+  int y = height;
+  do {
+    int x = 0;
+    do {
+      const uint16x8_t src0_lo = vld1q_u16(pred_0 + x);
+      const uint16x8_t src1_lo = vld1q_u16(pred_1 + x);
+      const int16x8_t res_lo =
+          ComputeWeightedAverage8(src0_lo, src1_lo, weights);
+
+      const uint16x8_t src0_hi = vld1q_u16(pred_0 + x + 8);
+      const uint16x8_t src1_hi = vld1q_u16(pred_1 + x + 8);
+      const int16x8_t res_hi =
+          ComputeWeightedAverage8(src0_hi, src1_hi, weights);
+
+      const uint8x16_t result =
+          vcombine_u8(vqmovun_s16(res_lo), vqmovun_s16(res_hi));
+      vst1q_u8(dst + x, result);
+      x += 16;
+    } while (x < width);
+    dst += dest_stride;
+    pred_0 += prediction_stride_0;
+    pred_1 += prediction_stride_1;
+  } while (--y != 0);
+}
+
+inline void DistanceWeightedBlend_NEON(
+    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
+    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
+    const uint8_t weight_0, const uint8_t weight_1, const int width,
+    const int height, void* const dest, const ptrdiff_t dest_stride) {
+  uint16x4_t weights[2] = {vdup_n_u16(weight_0), vdup_n_u16(weight_1)};
+  if (width == 4) {
+    if (height == 4) {
+      DistanceWeightedBlend4xH_NEON<4>(prediction_0, prediction_stride_0,
+                                       prediction_1, prediction_stride_1,
+                                       weights, dest, dest_stride);
+    } else if (height == 8) {
+      DistanceWeightedBlend4xH_NEON<8>(prediction_0, prediction_stride_0,
+                                       prediction_1, prediction_stride_1,
+                                       weights, dest, dest_stride);
+    } else {
+      assert(height == 16);
+      DistanceWeightedBlend4xH_NEON<16>(prediction_0, prediction_stride_0,
+                                        prediction_1, prediction_stride_1,
+                                        weights, dest, dest_stride);
+    }
+    return;
+  }
+
+  if (width == 8) {
+    switch (height) {
+      case 4:
+        DistanceWeightedBlend8xH_NEON<4>(prediction_0, prediction_stride_0,
+                                         prediction_1, prediction_stride_1,
+                                         weights, dest, dest_stride);
+        return;
+      case 8:
+        DistanceWeightedBlend8xH_NEON<8>(prediction_0, prediction_stride_0,
+                                         prediction_1, prediction_stride_1,
+                                         weights, dest, dest_stride);
+        return;
+      case 16:
+        DistanceWeightedBlend8xH_NEON<16>(prediction_0, prediction_stride_0,
+                                          prediction_1, prediction_stride_1,
+                                          weights, dest, dest_stride);
+        return;
+      default:
+        assert(height == 32);
+        DistanceWeightedBlend8xH_NEON<32>(prediction_0, prediction_stride_0,
+                                          prediction_1, prediction_stride_1,
+                                          weights, dest, dest_stride);
+
+        return;
+    }
+  }
+
+  DistanceWeightedBlendLarge_NEON(prediction_0, prediction_stride_0,
+                                  prediction_1, prediction_stride_1, weights,
+                                  width, height, dest, dest_stride);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+  dsp->distance_weighted_blend = DistanceWeightedBlend_NEON;
+}
+
+}  // namespace
+
+void DistanceWeightedBlendInit_NEON() { Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else   // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void DistanceWeightedBlendInit_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/distance_weighted_blend_neon.h b/libgav1/src/dsp/arm/distance_weighted_blend_neon.h
new file mode 100644
index 0000000..846b7e1
--- /dev/null
+++ b/libgav1/src/dsp/arm/distance_weighted_blend_neon.h
@@ -0,0 +1,23 @@
+#ifndef LIBGAV1_SRC_DSP_ARM_DISTANCE_WEIGHTED_BLEND_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_DISTANCE_WEIGHTED_BLEND_NEON_H_
+
+#include "src/dsp/cpu.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::distance_weighted_blend. This function is not thread-safe.
+void DistanceWeightedBlendInit_NEON();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+// If NEON is enabled signal the NEON implementation should be used instead of
+// normal C.
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_DistanceWeightedBlend LIBGAV1_DSP_NEON
+
+#endif  // LIBGAV1_ENABLE_NEON
+
+#endif  // LIBGAV1_SRC_DSP_ARM_DISTANCE_WEIGHTED_BLEND_NEON_H_
diff --git a/libgav1/src/dsp/arm/intra_edge_neon.cc b/libgav1/src/dsp/arm/intra_edge_neon.cc
new file mode 100644
index 0000000..3cc70b3
--- /dev/null
+++ b/libgav1/src/dsp/arm/intra_edge_neon.cc
@@ -0,0 +1,285 @@
+#include "src/dsp/dsp.h"
+#include "src/dsp/intra_edge.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/utils/common.h"  // RightShiftWithRounding()
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+// Simplified version of intra_edge.cc:kKernels[][]. Only |strength| 1 and 2 are
+// required.
+constexpr int kKernelsNEON[3][2] = {{4, 8}, {5, 6}};
+
+void IntraEdgeFilter_NEON(void* buffer, const int size, const int strength) {
+  assert(strength == 1 || strength == 2 || strength == 3);
+  const int kernel_index = strength - 1;
+  auto* const dst_buffer = static_cast<uint8_t*>(buffer);
+
+  // The first element is not written out (but it is input) so the number of
+  // elements written is |size| - 1.
+  if (size == 1) return;
+
+  // |strength| 1 and 2 use a 3 tap filter.
+  if (strength < 3) {
+    // The last value requires extending the buffer (duplicating
+    // |dst_buffer[size - 1]). Calculate it here to avoid extra processing in
+    // neon.
+    const uint8_t last_val = RightShiftWithRounding(
+        kKernelsNEON[kernel_index][0] * dst_buffer[size - 2] +
+            kKernelsNEON[kernel_index][1] * dst_buffer[size - 1] +
+            kKernelsNEON[kernel_index][0] * dst_buffer[size - 1],
+        4);
+
+    const uint8x8_t krn1 = vdup_n_u8(kKernelsNEON[kernel_index][1]);
+
+    // The first value we need gets overwritten by the output from the
+    // previous iteration.
+    uint8x16_t src_0 = vld1q_u8(dst_buffer);
+    int i = 1;
+
+    // Process blocks until there are less than 16 values remaining.
+    for (; i < size - 15; i += 16) {
+      // Loading these at the end of the block with |src_0| will read past the
+      // end of |top_row_data[160]|, the source of |buffer|.
+      const uint8x16_t src_1 = vld1q_u8(dst_buffer + i);
+      const uint8x16_t src_2 = vld1q_u8(dst_buffer + i + 1);
+      uint16x8_t sum_lo = vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_2));
+      sum_lo = vmulq_n_u16(sum_lo, kKernelsNEON[kernel_index][0]);
+      sum_lo = vmlal_u8(sum_lo, vget_low_u8(src_1), krn1);
+      uint16x8_t sum_hi = vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_2));
+      sum_hi = vmulq_n_u16(sum_hi, kKernelsNEON[kernel_index][0]);
+      sum_hi = vmlal_u8(sum_hi, vget_high_u8(src_1), krn1);
+
+      const uint8x16_t result =
+          vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4));
+
+      // Load the next row before overwriting. This loads an extra 15 values
+      // past |size| on the trailing iteration.
+      src_0 = vld1q_u8(dst_buffer + i + 15);
+
+      vst1q_u8(dst_buffer + i, result);
+    }
+
+    // The last output value |last_val| was already calculated so if
+    // |remainder| == 1 then we don't have to do anything.
+    const int remainder = (size - 1) & 0xf;
+    if (remainder > 1) {
+      uint8_t temp[16];
+      const uint8x16_t src_1 = vld1q_u8(dst_buffer + i);
+      const uint8x16_t src_2 = vld1q_u8(dst_buffer + i + 1);
+
+      uint16x8_t sum_lo = vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_2));
+      sum_lo = vmulq_n_u16(sum_lo, kKernelsNEON[kernel_index][0]);
+      sum_lo = vmlal_u8(sum_lo, vget_low_u8(src_1), krn1);
+      uint16x8_t sum_hi = vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_2));
+      sum_hi = vmulq_n_u16(sum_hi, kKernelsNEON[kernel_index][0]);
+      sum_hi = vmlal_u8(sum_hi, vget_high_u8(src_1), krn1);
+
+      const uint8x16_t result =
+          vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4));
+
+      vst1q_u8(temp, result);
+      memcpy(dst_buffer + i, temp, remainder);
+    }
+
+    dst_buffer[size - 1] = last_val;
+    return;
+  }
+
+  assert(strength == 3);
+  // 5 tap filter. The first element requires duplicating |buffer[0]| and the
+  // last two elements require duplicating |buffer[size - 1]|.
+  uint8_t special_vals[3];
+  special_vals[0] = RightShiftWithRounding(
+      (dst_buffer[0] << 1) + (dst_buffer[0] << 2) + (dst_buffer[1] << 2) +
+          (dst_buffer[2] << 2) + (dst_buffer[3] << 1),
+      4);
+  // Clamp index for very small |size| values.
+  const int first_index_min = std::max(size - 4, 0);
+  const int second_index_min = std::max(size - 3, 0);
+  const int third_index_min = std::max(size - 2, 0);
+  special_vals[1] = RightShiftWithRounding(
+      (dst_buffer[first_index_min] << 1) + (dst_buffer[second_index_min] << 2) +
+          (dst_buffer[third_index_min] << 2) + (dst_buffer[size - 1] << 2) +
+          (dst_buffer[size - 1] << 1),
+      4);
+  special_vals[2] = RightShiftWithRounding(
+      (dst_buffer[second_index_min] << 1) + (dst_buffer[third_index_min] << 2) +
+          // x << 2 + x << 2 == x << 3
+          (dst_buffer[size - 1] << 3) + (dst_buffer[size - 1] << 1),
+      4);
+
+  // The first two values we need get overwritten by the output from the
+  // previous iteration.
+  uint8x16_t src_0 = vld1q_u8(dst_buffer - 1);
+  uint8x16_t src_1 = vld1q_u8(dst_buffer);
+  int i = 1;
+
+  for (; i < size - 15; i += 16) {
+    // Loading these at the end of the block with |src_[01]| will read past
+    // the end of |top_row_data[160]|, the source of |buffer|.
+    const uint8x16_t src_2 = vld1q_u8(dst_buffer + i);
+    const uint8x16_t src_3 = vld1q_u8(dst_buffer + i + 1);
+    const uint8x16_t src_4 = vld1q_u8(dst_buffer + i + 2);
+
+    uint16x8_t sum_lo =
+        vshlq_n_u16(vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_4)), 1);
+    const uint16x8_t sum_123_lo = vaddw_u8(
+        vaddl_u8(vget_low_u8(src_1), vget_low_u8(src_2)), vget_low_u8(src_3));
+    sum_lo = vaddq_u16(sum_lo, vshlq_n_u16(sum_123_lo, 2));
+
+    uint16x8_t sum_hi =
+        vshlq_n_u16(vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_4)), 1);
+    const uint16x8_t sum_123_hi =
+        vaddw_u8(vaddl_u8(vget_high_u8(src_1), vget_high_u8(src_2)),
+                 vget_high_u8(src_3));
+    sum_hi = vaddq_u16(sum_hi, vshlq_n_u16(sum_123_hi, 2));
+
+    const uint8x16_t result =
+        vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4));
+
+    src_0 = vld1q_u8(dst_buffer + i + 14);
+    src_1 = vld1q_u8(dst_buffer + i + 15);
+
+    vst1q_u8(dst_buffer + i, result);
+  }
+
+  const int remainder = (size - 1) & 0xf;
+  // Like the 3 tap but if there are two remaining values we have already
+  // calculated them.
+  if (remainder > 2) {
+    uint8_t temp[16];
+    const uint8x16_t src_2 = vld1q_u8(dst_buffer + i);
+    const uint8x16_t src_3 = vld1q_u8(dst_buffer + i + 1);
+    const uint8x16_t src_4 = vld1q_u8(dst_buffer + i + 2);
+
+    uint16x8_t sum_lo =
+        vshlq_n_u16(vaddl_u8(vget_low_u8(src_0), vget_low_u8(src_4)), 1);
+    const uint16x8_t sum_123_lo = vaddw_u8(
+        vaddl_u8(vget_low_u8(src_1), vget_low_u8(src_2)), vget_low_u8(src_3));
+    sum_lo = vaddq_u16(sum_lo, vshlq_n_u16(sum_123_lo, 2));
+
+    uint16x8_t sum_hi =
+        vshlq_n_u16(vaddl_u8(vget_high_u8(src_0), vget_high_u8(src_4)), 1);
+    const uint16x8_t sum_123_hi =
+        vaddw_u8(vaddl_u8(vget_high_u8(src_1), vget_high_u8(src_2)),
+                 vget_high_u8(src_3));
+    sum_hi = vaddq_u16(sum_hi, vshlq_n_u16(sum_123_hi, 2));
+
+    const uint8x16_t result =
+        vcombine_u8(vrshrn_n_u16(sum_lo, 4), vrshrn_n_u16(sum_hi, 4));
+
+    vst1q_u8(temp, result);
+    memcpy(dst_buffer + i, temp, remainder);
+  }
+
+  dst_buffer[1] = special_vals[0];
+  // Avoid overwriting |dst_buffer[0]|.
+  if (size > 2) dst_buffer[size - 2] = special_vals[1];
+  dst_buffer[size - 1] = special_vals[2];
+}
+
+// (-|src0| + |src1| * 9 + |src2| * 9 - |src3|) >> 4
+uint8x8_t Upsample(const uint8x8_t src0, const uint8x8_t src1,
+                   const uint8x8_t src2, const uint8x8_t src3) {
+  const uint16x8_t middle = vmulq_n_u16(vaddl_u8(src1, src2), 9);
+  const uint16x8_t ends = vaddl_u8(src0, src3);
+  const int16x8_t sum =
+      vsubq_s16(vreinterpretq_s16_u16(middle), vreinterpretq_s16_u16(ends));
+  return vqrshrun_n_s16(sum, 4);
+}
+
+void IntraEdgeUpsampler_NEON(void* buffer, const int size) {
+  assert(size % 4 == 0 && size <= 16);
+  auto* const pixel_buffer = static_cast<uint8_t*>(buffer);
+  // This is OK because we don't read this value for |size| 4 or 8 but if we
+  // write |pixel_buffer[size]| and then vld() it, that seems to introduce
+  // some latency.
+  pixel_buffer[-2] = pixel_buffer[-1];
+  if (size == 4) {
+    // This uses one load and two vtbl() which is better than 4x Load{Lo,Hi}4().
+    const uint8x8_t src = vld1_u8(pixel_buffer - 1);
+    // The outside values are negated so put those in the same vector.
+    const uint8x8_t src03 = vtbl1_u8(src, vcreate_u8(0x0404030202010000));
+    // Reverse |src1| and |src2| so we can use |src2| for the interleave at the
+    // end.
+    const uint8x8_t src21 = vtbl1_u8(src, vcreate_u8(0x0302010004030201));
+
+    const uint16x8_t middle = vmull_u8(src21, vdup_n_u8(9));
+    const int16x8_t half_sum = vsubq_s16(
+        vreinterpretq_s16_u16(middle), vreinterpretq_s16_u16(vmovl_u8(src03)));
+    const int16x4_t sum =
+        vadd_s16(vget_low_s16(half_sum), vget_high_s16(half_sum));
+    const uint8x8_t result = vqrshrun_n_s16(vcombine_s16(sum, sum), 4);
+
+    vst1_u8(pixel_buffer - 1, InterleaveLow8(result, src21));
+    return;
+  } else if (size == 8) {
+    // Likewise, one load + multiple vtbls seems preferred to multiple loads.
+    const uint8x16_t src = vld1q_u8(pixel_buffer - 1);
+    const uint8x8_t src0 = VQTbl1U8(src, vcreate_u8(0x0605040302010000));
+    const uint8x8_t src1 = vget_low_u8(src);
+    const uint8x8_t src2 = VQTbl1U8(src, vcreate_u8(0x0807060504030201));
+    const uint8x8_t src3 = VQTbl1U8(src, vcreate_u8(0x0808070605040302));
+
+    const uint8x8x2_t output = {Upsample(src0, src1, src2, src3), src2};
+    vst2_u8(pixel_buffer - 1, output);
+    return;
+  }
+  assert(size == 12 || size == 16);
+  // Extend the input borders to avoid branching later.
+  pixel_buffer[size] = pixel_buffer[size - 1];
+  const uint8x16_t src0 = vld1q_u8(pixel_buffer - 2);
+  const uint8x16_t src1 = vld1q_u8(pixel_buffer - 1);
+  const uint8x16_t src2 = vld1q_u8(pixel_buffer);
+  const uint8x16_t src3 = vld1q_u8(pixel_buffer + 1);
+
+  const uint8x8_t result_lo = Upsample(vget_low_u8(src0), vget_low_u8(src1),
+                                       vget_low_u8(src2), vget_low_u8(src3));
+
+  const uint8x8x2_t output_lo = {result_lo, vget_low_u8(src2)};
+  vst2_u8(pixel_buffer - 1, output_lo);
+
+  const uint8x8_t result_hi = Upsample(vget_high_u8(src0), vget_high_u8(src1),
+                                       vget_high_u8(src2), vget_high_u8(src3));
+
+  if (size == 12) {
+    vst1_u8(pixel_buffer + 15, InterleaveLow8(result_hi, vget_high_u8(src2)));
+  } else /* size == 16 */ {
+    const uint8x8x2_t output_hi = {result_hi, vget_high_u8(src2)};
+    vst2_u8(pixel_buffer + 15, output_hi);
+  }
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+  dsp->intra_edge_filter = IntraEdgeFilter_NEON;
+  dsp->intra_edge_upsampler = IntraEdgeUpsampler_NEON;
+}
+
+}  // namespace
+
+void IntraEdgeInit_NEON() { Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else   // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void IntraEdgeInit_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/intra_edge_neon.h b/libgav1/src/dsp/arm/intra_edge_neon.h
new file mode 100644
index 0000000..5da3c24
--- /dev/null
+++ b/libgav1/src/dsp/arm/intra_edge_neon.h
@@ -0,0 +1,23 @@
+#ifndef LIBGAV1_SRC_DSP_ARM_INTRA_EDGE_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_INTRA_EDGE_NEON_H_
+
+#include "src/dsp/cpu.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::intra_edge_filter and Dsp::intra_edge_upsampler. This
+// function is not thread-safe.
+void IntraEdgeInit_NEON();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_IntraEdgeFilter LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_IntraEdgeUpsampler LIBGAV1_DSP_NEON
+
+#endif  // LIBGAV1_ENABLE_NEON
+
+#endif  // LIBGAV1_SRC_DSP_ARM_INTRA_EDGE_NEON_H_
diff --git a/libgav1/src/dsp/arm/intrapred_cfl_neon.cc b/libgav1/src/dsp/arm/intrapred_cfl_neon.cc
index cad9f7b..7947fef 100644
--- a/libgav1/src/dsp/arm/intrapred_cfl_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_cfl_neon.cc
@@ -1,5 +1,5 @@
-#include "src/dsp/arm/intrapred_neon.h"
 #include "src/dsp/dsp.h"
+#include "src/dsp/intrapred.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -10,12 +10,258 @@
 #include <cstdint>
 
 #include "src/dsp/arm/common_neon.h"
+#include "src/utils/common.h"
 
 namespace libgav1 {
 namespace dsp {
 namespace low_bitdepth {
 namespace {
 
+uint8x16_t Set2ValuesQ(const uint8_t* a) {
+  uint16_t combined_values = a[0] | a[1] << 8;
+  return vreinterpretq_u8_u16(vdupq_n_u16(combined_values));
+}
+
+int SumVector(uint32x2_t a) {
+#if defined(__aarch64__)
+  return vaddv_u32(a);
+#else
+  const uint64x1_t b = vpaddl_u32(a);
+  return vget_lane_u64(b, 0);
+#endif  // defined(__aarch64__)
+}
+
+int SumVector(uint32x4_t a) {
+#if defined(__aarch64__)
+  return vaddvq_u32(a);
+#else
+  const uint64x2_t b = vpaddlq_u32(a);
+  const uint64x1_t c = vadd_u64(vget_low_u64(b), vget_high_u64(b));
+  return vget_lane_u64(c, 0);
+#endif  // defined(__aarch64__)
+}
+
+// Divide by the number of elements.
+int Average(const int sum, const int width, const int height) {
+  return RightShiftWithRounding(sum, FloorLog2(width) + FloorLog2(height));
+}
+
+// Subtract |val| from every element in |a|.
+void BlockSubtract(const int val,
+                   int16_t a[kCflLumaBufferStride][kCflLumaBufferStride],
+                   const int width, const int height) {
+  const int16x8_t val_v = vdupq_n_s16(val);
+
+  for (int y = 0; y < height; ++y) {
+    if (width == 4) {
+      const int16x4_t b = vld1_s16(a[y]);
+      vst1_s16(a[y], vsub_s16(b, vget_low_s16(val_v)));
+    } else if (width == 8) {
+      const int16x8_t b = vld1q_s16(a[y]);
+      vst1q_s16(a[y], vsubq_s16(b, val_v));
+    } else if (width == 16) {
+      const int16x8_t b = vld1q_s16(a[y]);
+      const int16x8_t c = vld1q_s16(a[y] + 8);
+      vst1q_s16(a[y], vsubq_s16(b, val_v));
+      vst1q_s16(a[y] + 8, vsubq_s16(c, val_v));
+    } else /* block_width == 32 */ {
+      const int16x8_t b = vld1q_s16(a[y]);
+      const int16x8_t c = vld1q_s16(a[y] + 8);
+      const int16x8_t d = vld1q_s16(a[y] + 16);
+      const int16x8_t e = vld1q_s16(a[y] + 24);
+      vst1q_s16(a[y], vsubq_s16(b, val_v));
+      vst1q_s16(a[y] + 8, vsubq_s16(c, val_v));
+      vst1q_s16(a[y] + 16, vsubq_s16(d, val_v));
+      vst1q_s16(a[y] + 24, vsubq_s16(e, val_v));
+    }
+  }
+}
+
+template <int block_width, int block_height>
+void CflSubsampler420_NEON(
+    int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+    const int max_luma_width, const int max_luma_height,
+    const void* const source, const ptrdiff_t stride) {
+  const auto* src = static_cast<const uint8_t*>(source);
+  int sum;
+  if (block_width == 4) {
+    assert(max_luma_width >= 8);
+    uint32x2_t running_sum = vdup_n_u32(0);
+
+    for (int y = 0; y < block_height; ++y) {
+      const uint8x8_t row0 = vld1_u8(src);
+      const uint8x8_t row1 = vld1_u8(src + stride);
+
+      uint16x4_t sum_row = vpadal_u8(vpaddl_u8(row0), row1);
+      sum_row = vshl_n_u16(sum_row, 1);
+      running_sum = vpadal_u16(running_sum, sum_row);
+      vst1_s16(luma[y], vreinterpret_s16_u16(sum_row));
+
+      if (y << 1 < max_luma_height - 2) {
+        // Once this threshold is reached the loop could be simplified.
+        src += stride << 1;
+      }
+    }
+
+    sum = SumVector(running_sum);
+  } else if (block_width == 8) {
+    const uint8x16_t x_index = {0, 0, 2,  2,  4,  4,  6,  6,
+                                8, 8, 10, 10, 12, 12, 14, 14};
+    const uint8x16_t x_max_index = vdupq_n_u8(max_luma_width - 2);
+    const uint8x16_t x_mask = vcltq_u8(x_index, x_max_index);
+
+    uint32x4_t running_sum = vdupq_n_u32(0);
+
+    for (int y = 0; y < block_height; ++y) {
+      const uint8x16_t x_max0 = Set2ValuesQ(src + max_luma_width - 2);
+      const uint8x16_t x_max1 = Set2ValuesQ(src + max_luma_width - 2 + stride);
+
+      uint8x16_t row0 = vld1q_u8(src);
+      row0 = vbslq_u8(x_mask, row0, x_max0);
+      uint8x16_t row1 = vld1q_u8(src + stride);
+      row1 = vbslq_u8(x_mask, row1, x_max1);
+
+      uint16x8_t sum_row = vpadalq_u8(vpaddlq_u8(row0), row1);
+      sum_row = vshlq_n_u16(sum_row, 1);
+      running_sum = vpadalq_u16(running_sum, sum_row);
+      vst1q_s16(luma[y], vreinterpretq_s16_u16(sum_row));
+
+      if (y << 1 < max_luma_height - 2) {
+        src += stride << 1;
+      }
+    }
+
+    sum = SumVector(running_sum);
+  } else /* block_width >= 16 */ {
+    const uint8x16_t x_max_index = vdupq_n_u8(max_luma_width - 2);
+    uint32x4_t running_sum = vdupq_n_u32(0);
+
+    for (int y = 0; y < block_height; ++y) {
+      uint8x16_t x_index = {0,  2,  4,  6,  8,  10, 12, 14,
+                            16, 18, 20, 22, 24, 26, 28, 30};
+      const uint8x16_t x_max00 = vdupq_n_u8(src[max_luma_width - 2]);
+      const uint8x16_t x_max01 = vdupq_n_u8(src[max_luma_width - 2 + 1]);
+      const uint8x16_t x_max10 = vdupq_n_u8(src[stride + max_luma_width - 2]);
+      const uint8x16_t x_max11 =
+          vdupq_n_u8(src[stride + max_luma_width - 2 + 1]);
+      for (int x = 0; x < block_width; x += 16) {
+        const ptrdiff_t src_x_offset = x << 1;
+        const uint8x16_t x_mask = vcltq_u8(x_index, x_max_index);
+        const uint8x16x2_t row0 = vld2q_u8(src + src_x_offset);
+        const uint8x16x2_t row1 = vld2q_u8(src + src_x_offset + stride);
+        const uint8x16_t row_masked_00 = vbslq_u8(x_mask, row0.val[0], x_max00);
+        const uint8x16_t row_masked_01 = vbslq_u8(x_mask, row0.val[1], x_max01);
+        const uint8x16_t row_masked_10 = vbslq_u8(x_mask, row1.val[0], x_max10);
+        const uint8x16_t row_masked_11 = vbslq_u8(x_mask, row1.val[1], x_max11);
+
+        uint16x8_t sum_row_lo =
+            vaddl_u8(vget_low_u8(row_masked_00), vget_low_u8(row_masked_01));
+        sum_row_lo = vaddw_u8(sum_row_lo, vget_low_u8(row_masked_10));
+        sum_row_lo = vaddw_u8(sum_row_lo, vget_low_u8(row_masked_11));
+        sum_row_lo = vshlq_n_u16(sum_row_lo, 1);
+        running_sum = vpadalq_u16(running_sum, sum_row_lo);
+        vst1q_s16(luma[y] + x, vreinterpretq_s16_u16(sum_row_lo));
+
+        uint16x8_t sum_row_hi =
+            vaddl_u8(vget_high_u8(row_masked_00), vget_high_u8(row_masked_01));
+        sum_row_hi = vaddw_u8(sum_row_hi, vget_high_u8(row_masked_10));
+        sum_row_hi = vaddw_u8(sum_row_hi, vget_high_u8(row_masked_11));
+        sum_row_hi = vshlq_n_u16(sum_row_hi, 1);
+        running_sum = vpadalq_u16(running_sum, sum_row_hi);
+        vst1q_s16(luma[y] + x + 8, vreinterpretq_s16_u16(sum_row_hi));
+
+        x_index = vaddq_u8(x_index, vdupq_n_u8(32));
+      }
+      if (y << 1 < max_luma_height - 2) {
+        src += stride << 1;
+      }
+    }
+    sum = SumVector(running_sum);
+  }
+
+  const int average = Average(sum, block_width, block_height);
+  BlockSubtract(average, luma, block_width, block_height);
+}
+
+template <int block_width, int block_height>
+void CflSubsampler444_NEON(
+    int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+    const int max_luma_width, const int max_luma_height,
+    const void* const source, const ptrdiff_t stride) {
+  const auto* src = static_cast<const uint8_t*>(source);
+  int sum;
+  if (block_width == 4) {
+    assert(max_luma_width >= 4);
+    uint32x4_t running_sum = vdupq_n_u32(0);
+
+    for (int y = 0; y < block_height; y += 2) {
+      uint8x8_t row = vdup_n_u8(0);
+      row = LoadLo4(src, row);
+      row = LoadHi4(src + stride, row);
+      if (y < (max_luma_height - 1)) {
+        src += stride << 1;
+      }
+
+      const uint16x8_t row_shifted = vshll_n_u8(row, 3);
+      running_sum = vpadalq_u16(running_sum, row_shifted);
+      vst1_s16(luma[y], vreinterpret_s16_u16(vget_low_u16(row_shifted)));
+      vst1_s16(luma[y + 1], vreinterpret_s16_u16(vget_high_u16(row_shifted)));
+    }
+
+    sum = SumVector(running_sum);
+  } else if (block_width == 8) {
+    const uint8x8_t x_index = {0, 1, 2, 3, 4, 5, 6, 7};
+    const uint8x8_t x_max_index = vdup_n_u8(max_luma_width - 1);
+    const uint8x8_t x_mask = vclt_u8(x_index, x_max_index);
+
+    uint32x4_t running_sum = vdupq_n_u32(0);
+
+    for (int y = 0; y < block_height; ++y) {
+      const uint8x8_t x_max = vdup_n_u8(src[max_luma_width - 1]);
+      const uint8x8_t row = vbsl_u8(x_mask, vld1_u8(src), x_max);
+
+      const uint16x8_t row_shifted = vshll_n_u8(row, 3);
+      running_sum = vpadalq_u16(running_sum, row_shifted);
+      vst1q_s16(luma[y], vreinterpretq_s16_u16(row_shifted));
+
+      if (y < max_luma_height - 1) {
+        src += stride;
+      }
+    }
+
+    sum = SumVector(running_sum);
+  } else /* block_width >= 16 */ {
+    const uint8x16_t x_max_index = vdupq_n_u8(max_luma_width - 1);
+    uint32x4_t running_sum = vdupq_n_u32(0);
+
+    for (int y = 0; y < block_height; ++y) {
+      uint8x16_t x_index = {0, 1, 2,  3,  4,  5,  6,  7,
+                            8, 9, 10, 11, 12, 13, 14, 15};
+      const uint8x16_t x_max = vdupq_n_u8(src[max_luma_width - 1]);
+      for (int x = 0; x < block_width; x += 16) {
+        const uint8x16_t x_mask = vcltq_u8(x_index, x_max_index);
+        const uint8x16_t row = vbslq_u8(x_mask, vld1q_u8(src + x), x_max);
+
+        const uint16x8_t row_shifted_low = vshll_n_u8(vget_low_u8(row), 3);
+        const uint16x8_t row_shifted_high = vshll_n_u8(vget_high_u8(row), 3);
+        running_sum = vpadalq_u16(running_sum, row_shifted_low);
+        running_sum = vpadalq_u16(running_sum, row_shifted_high);
+        vst1q_s16(luma[y] + x, vreinterpretq_s16_u16(row_shifted_low));
+        vst1q_s16(luma[y] + x + 8, vreinterpretq_s16_u16(row_shifted_high));
+
+        x_index = vaddq_u8(x_index, vdupq_n_u8(16));
+      }
+      if (y < max_luma_height - 1) {
+        src += stride;
+      }
+    }
+    sum = SumVector(running_sum);
+  }
+
+  const int average = Average(sum, block_width, block_height);
+  BlockSubtract(average, luma, block_width, block_height);
+}
+
 // Saturate |dc + ((alpha * luma) >> 6))| to uint8_t.
 inline uint8x8_t Combine8(const int16x8_t luma, const int alpha,
                           const int16x8_t dc) {
@@ -110,6 +356,70 @@
   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
   assert(dsp != nullptr);
 
+  dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType420] =
+      CflSubsampler420_NEON<4, 4>;
+  dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType420] =
+      CflSubsampler420_NEON<4, 8>;
+  dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType420] =
+      CflSubsampler420_NEON<4, 16>;
+
+  dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType420] =
+      CflSubsampler420_NEON<8, 4>;
+  dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType420] =
+      CflSubsampler420_NEON<8, 8>;
+  dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType420] =
+      CflSubsampler420_NEON<8, 16>;
+  dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType420] =
+      CflSubsampler420_NEON<8, 32>;
+
+  dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType420] =
+      CflSubsampler420_NEON<16, 4>;
+  dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType420] =
+      CflSubsampler420_NEON<16, 8>;
+  dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType420] =
+      CflSubsampler420_NEON<16, 16>;
+  dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType420] =
+      CflSubsampler420_NEON<16, 32>;
+
+  dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType420] =
+      CflSubsampler420_NEON<32, 8>;
+  dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType420] =
+      CflSubsampler420_NEON<32, 16>;
+  dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType420] =
+      CflSubsampler420_NEON<32, 32>;
+
+  dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType444] =
+      CflSubsampler444_NEON<4, 4>;
+  dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType444] =
+      CflSubsampler444_NEON<4, 8>;
+  dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType444] =
+      CflSubsampler444_NEON<4, 16>;
+
+  dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType444] =
+      CflSubsampler444_NEON<8, 4>;
+  dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType444] =
+      CflSubsampler444_NEON<8, 8>;
+  dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType444] =
+      CflSubsampler444_NEON<8, 16>;
+  dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType444] =
+      CflSubsampler444_NEON<8, 32>;
+
+  dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType444] =
+      CflSubsampler444_NEON<16, 4>;
+  dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType444] =
+      CflSubsampler444_NEON<16, 8>;
+  dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType444] =
+      CflSubsampler444_NEON<16, 16>;
+  dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType444] =
+      CflSubsampler444_NEON<16, 32>;
+
+  dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType444] =
+      CflSubsampler444_NEON<32, 8>;
+  dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType444] =
+      CflSubsampler444_NEON<32, 16>;
+  dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType444] =
+      CflSubsampler444_NEON<32, 32>;
+
   dsp->cfl_intra_predictors[kTransformSize4x4] = CflIntraPredictor4xN_NEON<4>;
   dsp->cfl_intra_predictors[kTransformSize4x8] = CflIntraPredictor4xN_NEON<8>;
   dsp->cfl_intra_predictors[kTransformSize4x16] = CflIntraPredictor4xN_NEON<16>;
diff --git a/libgav1/src/dsp/arm/intrapred_directional_neon.cc b/libgav1/src/dsp/arm/intrapred_directional_neon.cc
index e59d8ed..751ba89 100644
--- a/libgav1/src/dsp/arm/intrapred_directional_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_directional_neon.cc
@@ -1,10 +1,11 @@
-#include "src/dsp/arm/intrapred_neon.h"
 #include "src/dsp/dsp.h"
+#include "src/dsp/intrapred.h"
 
 #if LIBGAV1_ENABLE_NEON
 
 #include <arm_neon.h>
 
+#include <algorithm>  // std::min
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
@@ -31,10 +32,159 @@
 // For vertical operations the weights are one constant value.
 inline uint8x8_t WeightedBlend(const uint8x8_t a, const uint8x8_t b,
                                const uint8_t weight) {
-  const uint16x8_t a_product = vmull_u8(a, vdup_n_u8(32 - weight));
-  const uint16x8_t b_product = vmull_u8(b, vdup_n_u8(weight));
+  return WeightedBlend(a, b, vdup_n_u8(32 - weight), vdup_n_u8(weight));
+}
 
-  return vrshrn_n_u16(vaddq_u16(a_product, b_product), 5);
+// Fill |left| and |right| with the appropriate values for a given |base_step|.
+inline void LoadStepwise(const uint8_t* const source, const uint8x8_t left_step,
+                         const uint8x8_t right_step, uint8x8_t* left,
+                         uint8x8_t* right) {
+  const uint8x16_t mixed = vld1q_u8(source);
+  *left = VQTbl1U8(mixed, left_step);
+  *right = VQTbl1U8(mixed, right_step);
+}
+
+// Handle signed step arguments by ignoring the sign. Negative values are
+// considered out of range and overwritten later.
+inline void LoadStepwise(const uint8_t* const source, const int8x8_t left_step,
+                         const int8x8_t right_step, uint8x8_t* left,
+                         uint8x8_t* right) {
+  LoadStepwise(source, vreinterpret_u8_s8(left_step),
+               vreinterpret_u8_s8(right_step), left, right);
+}
+
+// Process 4 or 8 |width| by any |height|.
+template <int width>
+inline void DirectionalZone1_WxH(uint8_t* dst, const ptrdiff_t stride,
+                                 const int height, const uint8_t* const top,
+                                 const int xstep, const bool upsampled) {
+  assert(width == 4 || width == 8);
+
+  const int upsample_shift = static_cast<int>(upsampled);
+  const int scale_bits = 6 - upsample_shift;
+
+  const int max_base_x = (width + height - 1) << upsample_shift;
+  const int8x8_t max_base = vdup_n_s8(max_base_x);
+  const uint8x8_t top_max_base = vdup_n_u8(top[max_base_x]);
+
+  const int8x8_t all = vcreate_s8(0x0706050403020100);
+  const int8x8_t even = vcreate_s8(0x0e0c0a0806040200);
+  const int8x8_t base_step = upsampled ? even : all;
+  const int8x8_t right_step = vadd_s8(base_step, vdup_n_s8(1));
+
+  int top_x = xstep;
+  int y = 0;
+  do {
+    const int top_base_x = top_x >> scale_bits;
+
+    if (top_base_x >= max_base_x) {
+      for (int i = y; i < height; ++i) {
+        memset(dst, top[max_base_x], 4 /* width */);
+        dst += stride;
+      }
+      return;
+    }
+
+    const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1;
+
+    // Zone2 uses negative values for xstep. Use signed values to compare
+    // |top_base_x| to |max_base_x|.
+    const int8x8_t base_v = vadd_s8(vdup_n_s8(top_base_x), base_step);
+
+    const uint8x8_t max_base_mask = vclt_s8(base_v, max_base);
+
+    // 4 wide subsamples the output. 8 wide subsamples the input.
+    if (width == 4) {
+      const uint8x8_t left_values = vld1_u8(top + top_base_x);
+      const uint8x8_t right_values = RightShift<8>(left_values);
+      const uint8x8_t value = WeightedBlend(left_values, right_values, shift);
+
+      // If |upsampled| is true then extract every other value for output.
+      const uint8x8_t value_stepped =
+          vtbl1_u8(value, vreinterpret_u8_s8(base_step));
+      const uint8x8_t masked_value =
+          vbsl_u8(max_base_mask, value_stepped, top_max_base);
+
+      StoreLo4(dst, masked_value);
+    } else /* width == 8 */ {
+      uint8x8_t left_values, right_values;
+      // WeightedBlend() steps up to Q registers. Downsample the input to avoid
+      // doing extra calculations.
+      LoadStepwise(top + top_base_x, base_step, right_step, &left_values,
+                   &right_values);
+
+      const uint8x8_t value = WeightedBlend(left_values, right_values, shift);
+      const uint8x8_t masked_value =
+          vbsl_u8(max_base_mask, value, top_max_base);
+
+      vst1_u8(dst, masked_value);
+    }
+    dst += stride;
+    top_x += xstep;
+  } while (++y < height);
+}
+
+// Process a multiple of 8 |width| by any |height|. Processes horizontally
+// before vertically in the hopes of being a little more cache friendly.
+inline void DirectionalZone1_WxH(uint8_t* dst, const ptrdiff_t stride,
+                                 const int width, const int height,
+                                 const uint8_t* const top, const int xstep,
+                                 const bool upsampled) {
+  assert(width % 8 == 0);
+  const int upsample_shift = static_cast<int>(upsampled);
+  const int scale_bits = 6 - upsample_shift;
+
+  const int max_base_x = (width + height - 1) << upsample_shift;
+  const int8x8_t max_base = vdup_n_s8(max_base_x);
+  const uint8x8_t top_max_base = vdup_n_u8(top[max_base_x]);
+
+  const int8x8_t all = vcreate_s8(0x0706050403020100);
+  const int8x8_t even = vcreate_s8(0x0e0c0a0806040200);
+  const int8x8_t base_step = upsampled ? even : all;
+  const int8x8_t right_step = vadd_s8(base_step, vdup_n_s8(1));
+  const int8x8_t block_step = vdup_n_s8(8 << upsample_shift);
+
+  int top_x = xstep;
+  int y = 0;
+  do {
+    const int top_base_x = top_x >> scale_bits;
+
+    if (top_base_x >= max_base_x) {
+      for (int i = y; i < height; ++i) {
+        memset(dst, top[max_base_x], 4 /* width */);
+        dst += stride;
+      }
+      return;
+    }
+
+    const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1;
+
+    // Zone2 uses negative values for xstep. Use signed values to compare
+    // |top_base_x| to |max_base_x|.
+    int8x8_t base_v = vadd_s8(vdup_n_s8(top_base_x), base_step);
+
+    int x = 0;
+    do {
+      const uint8x8_t max_base_mask = vclt_s8(base_v, max_base);
+
+      // Extract the input values based on |upsampled| here to avoid doing twice
+      // as many calculations.
+      uint8x8_t left_values, right_values;
+      LoadStepwise(top + top_base_x + x, base_step, right_step, &left_values,
+                   &right_values);
+
+      const uint8x8_t value = WeightedBlend(left_values, right_values, shift);
+      const uint8x8_t masked_value =
+          vbsl_u8(max_base_mask, value, top_max_base);
+
+      vst1_u8(dst + x, masked_value);
+
+      base_v = vadd_s8(base_v, block_step);
+      x += 8;
+    } while (x < width);
+    top_x += xstep;
+    dst += stride;
+  } while (++y < height);
 }
 
 void DirectionalIntraPredictorZone1_NEON(void* const dest,
@@ -49,64 +199,24 @@
   assert(xstep > 0);
 
   const int upsample_shift = static_cast<int>(upsampled_top);
-  const int scale_bits = 6 - upsample_shift;
 
   const uint8x8_t all = vcreate_u8(0x0706050403020100);
-  const uint8x8_t even = vcreate_u8(0x0e0c0a0806040200);
-  const uint8x8_t base_step = upsampled_top ? even : all;
 
   if (xstep == 64) {
     assert(!upsampled_top);
     const uint8_t* top_ptr = top + 1;
-    for (int y = 0; y < height; y += 4) {
+    int y = 0;
+    do {
       memcpy(dst, top_ptr, width);
       memcpy(dst + stride, top_ptr + 1, width);
       memcpy(dst + 2 * stride, top_ptr + 2, width);
       memcpy(dst + 3 * stride, top_ptr + 3, width);
       dst += 4 * stride;
       top_ptr += 4;
-    }
+      y += 4;
+    } while (y < height);
   } else if (width == 4) {
-    const int max_base_x = ((width + height) - 1) << upsample_shift;
-    const uint8x8_t max_base = vdup_n_u8(max_base_x);
-    const uint8x8_t top_max_base = vdup_n_u8(top[max_base_x]);
-
-    for (int y = 0, top_x = xstep; y < height;
-         ++y, dst += stride, top_x += xstep) {
-      const int top_base_x = top_x >> scale_bits;
-
-      if (top_base_x >= max_base_x) {
-        for (int i = y; i < height; ++i) {
-          memset(dst, top[max_base_x], width);
-          dst += stride;
-        }
-        return;
-      }
-
-      const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1;
-      const uint8x8_t shift_mul = vdup_n_u8(shift);
-      const uint8x8_t inv_shift_mul = vdup_n_u8(32 - shift);
-
-      const uint8x8_t base_v = vadd_u8(vdup_n_u8(top_base_x), base_step);
-
-      const uint8x8_t max_base_mask = vclt_u8(base_v, max_base);
-
-      // Load 8 values because we will extract the output values based on
-      // |upsampled_top| at the end.
-      const uint8x8_t left_values = vld1_u8(top + top_base_x);
-      const uint8x8_t right_values = RightShift<8>(left_values);
-
-      const uint8x8_t value =
-          WeightedBlend(left_values, right_values, inv_shift_mul, shift_mul);
-
-      // If |upsampled_top| is true then extract every other value for output.
-      const uint8x8_t value_stepped = vtbl1_u8(value, base_step);
-
-      const uint8x8_t masked_value =
-          vbsl_u8(max_base_mask, value_stepped, top_max_base);
-
-      StoreLo4(dst, masked_value);
-    }
+    DirectionalZone1_WxH<4>(dst, stride, height, top, xstep, upsampled_top);
   } else if (xstep > 51) {
     // 7.11.2.10. Intra edge upsample selection process
     // if ( d <= 0 || d >= 40 ) useUpsample = 0
@@ -120,8 +230,9 @@
     const uint8x8_t top_max_base = vdup_n_u8(top[max_base_x]);
     const uint8x8_t block_step = vdup_n_u8(8);
 
-    for (int y = 0, top_x = xstep; y < height;
-         ++y, dst += stride, top_x += xstep) {
+    int top_x = xstep;
+    int y = 0;
+    do {
       const int top_base_x = top_x >> 6;
 
       if (top_base_x >= max_base_x) {
@@ -133,88 +244,545 @@
       }
 
       const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1;
-      const uint8x8_t shift_mul = vdup_n_u8(shift);
-      const uint8x8_t inv_shift_mul = vdup_n_u8(32 - shift);
 
-      uint8x8_t base_v = vadd_u8(vdup_n_u8(top_base_x), base_step);
+      uint8x8_t base_v = vadd_u8(vdup_n_u8(top_base_x), all);
 
-      for (int x = 0; x < width; x += 8) {
+      int x = 0;
+      do {
         const uint8x8_t max_base_mask = vclt_u8(base_v, max_base);
 
         // Since these |xstep| values can not be upsampled the load is
         // simplified.
         const uint8x8_t left_values = vld1_u8(top + top_base_x + x);
         const uint8x8_t right_values = vld1_u8(top + top_base_x + x + 1);
-
-        const uint8x8_t value =
-            WeightedBlend(left_values, right_values, inv_shift_mul, shift_mul);
-
+        const uint8x8_t value = WeightedBlend(left_values, right_values, shift);
         const uint8x8_t masked_value =
             vbsl_u8(max_base_mask, value, top_max_base);
 
         vst1_u8(dst + x, masked_value);
 
         base_v = vadd_u8(base_v, block_step);
-      }
-    }
+        x += 8;
+      } while (x < width);
+      dst += stride;
+      top_x += xstep;
+    } while (++y < height);
   } else {
-    const int max_base_x = ((width + height) - 1) << upsample_shift;
-    const uint8x8_t max_base = vdup_n_u8(max_base_x);
-    const uint8x8_t top_max_base = vdup_n_u8(top[max_base_x]);
-    const uint8x8_t right_step = vadd_u8(base_step, vdup_n_u8(1));
-    const uint8x8_t block_step = vdup_n_u8(8 << upsample_shift);
-
-    for (int y = 0, top_x = xstep; y < height;
-         ++y, dst += stride, top_x += xstep) {
-      const int top_base_x = top_x >> scale_bits;
-
-      if (top_base_x >= max_base_x) {
-        for (int i = y; i < height; ++i) {
-          memset(dst, top[max_base_x], width);
-          dst += stride;
-        }
-        return;
-      }
-
-      const uint8_t shift = ((top_x << upsample_shift) & 0x3F) >> 1;
-      const uint8x8_t shift_mul = vdup_n_u8(shift);
-      const uint8x8_t inv_shift_mul = vdup_n_u8(32 - shift);
-
-      uint8x8_t base_v = vadd_u8(vdup_n_u8(top_base_x), base_step);
-
-      for (int x = 0; x < width; x += 8) {
-        const uint8x8_t max_base_mask = vclt_u8(base_v, max_base);
-
-        // Extract the input values based on |upsampled_top| here to avoid doing
-        // twice as many calculations.
-        const uint8x16_t mixed_values = vld1q_u8(top + top_base_x + x);
-        const uint8x8_t left_values = vtbl2_u8(
-            {vget_low_u8(mixed_values), vget_high_u8(mixed_values)}, base_step);
-        const uint8x8_t right_values =
-            vtbl2_u8({vget_low_u8(mixed_values), vget_high_u8(mixed_values)},
-                     right_step);
-
-        const uint8x8_t value =
-            WeightedBlend(left_values, right_values, inv_shift_mul, shift_mul);
-
-        const uint8x8_t masked_value =
-            vbsl_u8(max_base_mask, value, top_max_base);
-
-        vst1_u8(dst + x, masked_value);
-
-        base_v = vadd_u8(base_v, block_step);
-      }
-    }
+    DirectionalZone1_WxH(dst, stride, width, height, top, xstep, upsampled_top);
   }
 }
 
-// Fill |left| and |right| with the appropriate values for a given |base_step|.
-inline void LoadStepwise(const uint8_t* const source, const uint8x8_t left_step,
-                         const uint8x8_t right_step, uint8x8_t* left,
-                         uint8x8_t* right) {
-  const uint8x16_t mixed = vld1q_u8(source);
-  *left = vtbl2_u8({vget_low_u8(mixed), vget_high_u8(mixed)}, left_step);
-  *right = vtbl2_u8({vget_low_u8(mixed), vget_high_u8(mixed)}, right_step);
+// Process 4 or 8 |width| by 4 or 8 |height|.
+template <int width>
+inline void DirectionalZone3_WxH(uint8_t* dest, const ptrdiff_t stride,
+                                 const int height,
+                                 const uint8_t* const left_column,
+                                 const int base_left_y, const int ystep,
+                                 const int upsample_shift) {
+  assert(width == 4 || width == 8);
+  assert(height == 4 || height == 8);
+  const int scale_bits = 6 - upsample_shift;
+
+  // Zone3 never runs out of left_column values.
+  assert((width + height - 1) << upsample_shift >  // max_base_y
+         ((ystep * width) >> scale_bits) +
+             (/* base_step */ 1 << upsample_shift) *
+                 (height - 1));  // left_base_y
+
+  // Limited improvement for 8x8. ~20% faster for 64x64.
+  const uint8x8_t all = vcreate_u8(0x0706050403020100);
+  const uint8x8_t even = vcreate_u8(0x0e0c0a0806040200);
+  const uint8x8_t base_step = upsample_shift ? even : all;
+  const uint8x8_t right_step = vadd_u8(base_step, vdup_n_u8(1));
+
+  uint8_t* dst = dest;
+  uint8x8_t left_v[8], right_v[8], value_v[8];
+  const uint8_t* const left = left_column;
+
+  const int index_0 = base_left_y;
+  LoadStepwise(left + (index_0 >> scale_bits), base_step, right_step,
+               &left_v[0], &right_v[0]);
+  value_v[0] = WeightedBlend(left_v[0], right_v[0],
+                             ((index_0 << upsample_shift) & 0x3F) >> 1);
+
+  const int index_1 = base_left_y + ystep;
+  LoadStepwise(left + (index_1 >> scale_bits), base_step, right_step,
+               &left_v[1], &right_v[1]);
+  value_v[1] = WeightedBlend(left_v[1], right_v[1],
+                             ((index_1 << upsample_shift) & 0x3F) >> 1);
+
+  const int index_2 = base_left_y + ystep * 2;
+  LoadStepwise(left + (index_2 >> scale_bits), base_step, right_step,
+               &left_v[2], &right_v[2]);
+  value_v[2] = WeightedBlend(left_v[2], right_v[2],
+                             ((index_2 << upsample_shift) & 0x3F) >> 1);
+
+  const int index_3 = base_left_y + ystep * 3;
+  LoadStepwise(left + (index_3 >> scale_bits), base_step, right_step,
+               &left_v[3], &right_v[3]);
+  value_v[3] = WeightedBlend(left_v[3], right_v[3],
+                             ((index_3 << upsample_shift) & 0x3F) >> 1);
+
+  const int index_4 = base_left_y + ystep * 4;
+  LoadStepwise(left + (index_4 >> scale_bits), base_step, right_step,
+               &left_v[4], &right_v[4]);
+  value_v[4] = WeightedBlend(left_v[4], right_v[4],
+                             ((index_4 << upsample_shift) & 0x3F) >> 1);
+
+  const int index_5 = base_left_y + ystep * 5;
+  LoadStepwise(left + (index_5 >> scale_bits), base_step, right_step,
+               &left_v[5], &right_v[5]);
+  value_v[5] = WeightedBlend(left_v[5], right_v[5],
+                             ((index_5 << upsample_shift) & 0x3F) >> 1);
+
+  const int index_6 = base_left_y + ystep * 6;
+  LoadStepwise(left + (index_6 >> scale_bits), base_step, right_step,
+               &left_v[6], &right_v[6]);
+  value_v[6] = WeightedBlend(left_v[6], right_v[6],
+                             ((index_6 << upsample_shift) & 0x3F) >> 1);
+
+  const int index_7 = base_left_y + ystep * 7;
+  LoadStepwise(left + (index_7 >> scale_bits), base_step, right_step,
+               &left_v[7], &right_v[7]);
+  value_v[7] = WeightedBlend(left_v[7], right_v[7],
+                             ((index_7 << upsample_shift) & 0x3F) >> 1);
+
+  // 8x8 transpose.
+  const uint8x16x2_t b0 = vtrnq_u8(vcombine_u8(value_v[0], value_v[4]),
+                                   vcombine_u8(value_v[1], value_v[5]));
+  const uint8x16x2_t b1 = vtrnq_u8(vcombine_u8(value_v[2], value_v[6]),
+                                   vcombine_u8(value_v[3], value_v[7]));
+
+  const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]),
+                                    vreinterpretq_u16_u8(b1.val[0]));
+  const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]),
+                                    vreinterpretq_u16_u8(b1.val[1]));
+
+  const uint32x4x2_t d0 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[0]),
+                                    vreinterpretq_u32_u16(c1.val[0]));
+  const uint32x4x2_t d1 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[1]),
+                                    vreinterpretq_u32_u16(c1.val[1]));
+
+  if (width == 4) {
+    StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[0])));
+    dst += stride;
+    StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[0])));
+    dst += stride;
+    StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[0])));
+    dst += stride;
+    StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[0])));
+    if (height == 4) return;
+    dst += stride;
+    StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[1])));
+    dst += stride;
+    StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[1])));
+    dst += stride;
+    StoreLo4(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[1])));
+    dst += stride;
+    StoreLo4(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[1])));
+  } else {
+    vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[0])));
+    dst += stride;
+    vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[0])));
+    dst += stride;
+    vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[0])));
+    dst += stride;
+    vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[0])));
+    if (height == 4) return;
+    dst += stride;
+    vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[1])));
+    dst += stride;
+    vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[1])));
+    dst += stride;
+    vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[1])));
+    dst += stride;
+    vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[1])));
+  }
+}
+
+// Because the source values "move backwards" as the row index increases, the
+// indices derived from ystep are generally negative. This is accommodated by
+// making sure the relative indices are within [-15, 0] when the function is
+// called, and sliding them into the inclusive range [0, 15], relative to a
+// lower base address.
+constexpr int kPositiveIndexOffset = 15;
+
+// Process 4 or 8 |width| by any |height|.
+template <int width>
+inline void DirectionalZone2FromLeftCol_WxH(uint8_t* dst,
+                                            const ptrdiff_t stride,
+                                            const int height,
+                                            const uint8_t* const left_column,
+                                            const int16x8_t left_y,
+                                            const int upsample_shift) {
+  assert(width == 4 || width == 8);
+
+  // The shift argument must be a constant.
+  int16x8_t offset_y, shift_upsampled = left_y;
+  if (upsample_shift) {
+    offset_y = vshrq_n_s16(left_y, 5);
+    shift_upsampled = vshlq_n_s16(shift_upsampled, 1);
+  } else {
+    offset_y = vshrq_n_s16(left_y, 6);
+  }
+
+  // Select values to the left of the starting point.
+  // The 15th element (and 16th) will be all the way at the end, to the right.
+  // With a negative ystep everything else will be "left" of them.
+  // This supports cumulative steps up to 15. We could support up to 16 by doing
+  // separate loads for |left_values| and |right_values|. vtbl supports 2 Q
+  // registers as input which would allow for cumulative offsets of 32.
+  const int16x8_t sampler =
+      vaddq_s16(offset_y, vdupq_n_s16(kPositiveIndexOffset));
+  const uint8x8_t left_values = vqmovun_s16(sampler);
+  const uint8x8_t right_values = vadd_u8(left_values, vdup_n_u8(1));
+
+  const int16x8_t shift_masked = vandq_s16(shift_upsampled, vdupq_n_s16(0x3f));
+  const uint8x8_t shift_mul = vreinterpret_u8_s8(vshrn_n_s16(shift_masked, 1));
+  const uint8x8_t inv_shift_mul = vsub_u8(vdup_n_u8(32), shift_mul);
+
+  int y = 0;
+  do {
+    uint8x8_t src_left, src_right;
+    LoadStepwise(left_column - kPositiveIndexOffset + (y << upsample_shift),
+                 left_values, right_values, &src_left, &src_right);
+    const uint8x8_t val =
+        WeightedBlend(src_left, src_right, inv_shift_mul, shift_mul);
+
+    if (width == 4) {
+      StoreLo4(dst, val);
+    } else {
+      vst1_u8(dst, val);
+    }
+    dst += stride;
+  } while (++y < height);
+}
+
+// Process 4 or 8 |width| by any |height|.
+template <int width>
+inline void DirectionalZone1Blend_WxH(uint8_t* dest, const ptrdiff_t stride,
+                                      const int height,
+                                      const uint8_t* const top_row,
+                                      int zone_bounds, int top_x,
+                                      const int xstep,
+                                      const int upsample_shift) {
+  assert(width == 4 || width == 8);
+
+  const int scale_bits_x = 6 - upsample_shift;
+
+  const uint8x8_t all = vcreate_u8(0x0706050403020100);
+  const uint8x8_t even = vcreate_u8(0x0e0c0a0806040200);
+  const uint8x8_t base_step = upsample_shift ? even : all;
+  const uint8x8_t right_step = vadd_u8(base_step, vdup_n_u8(1));
+
+  int y = 0;
+  do {
+    const uint8_t* const src = top_row + (top_x >> scale_bits_x);
+    uint8x8_t left, right;
+    LoadStepwise(src, base_step, right_step, &left, &right);
+
+    const uint8_t shift = ((top_x << upsample_shift) & 0x3f) >> 1;
+    const uint8x8_t val = WeightedBlend(left, right, shift);
+
+    uint8x8_t dst_blend = vld1_u8(dest);
+    // |zone_bounds| values can be negative.
+    uint8x8_t blend =
+        vcge_s8(vreinterpret_s8_u8(all), vdup_n_s8((zone_bounds >> 6)));
+    uint8x8_t output = vbsl_u8(blend, val, dst_blend);
+
+    if (width == 4) {
+      StoreLo4(dest, output);
+    } else {
+      vst1_u8(dest, output);
+    }
+    dest += stride;
+    zone_bounds += xstep;
+    top_x -= xstep;
+  } while (++y < height);
+}
+
+// The height at which a load of 16 bytes will not contain enough source pixels
+// from |left_column| to supply an accurate row when computing 8 pixels at a
+// time. The values are found by inspection. By coincidence, all angles that
+// satisfy (ystep >> 6) == 2 map to the same value, so it is enough to look up
+// by ystep >> 6. The largest index for this lookup is 1023 >> 6 == 15.
+constexpr int kDirectionalZone2ShuffleInvalidHeight[16] = {
+    1024, 1024, 16, 16, 16, 16, 0, 0, 18, 0, 0, 0, 0, 0, 0, 40};
+
+// 7.11.2.4 (8) 90 < angle > 180
+// The strategy for these functions (4xH and 8+xH) is to know how many blocks
+// can be processed with just pixels from |top_ptr|, then handle mixed blocks,
+// then handle only blocks that take from |left_ptr|. Additionally, a fast
+// index-shuffle approach is used for pred values from |left_column| in sections
+// that permit it.
+inline void DirectionalZone2_4xH(uint8_t* dst, const ptrdiff_t stride,
+                                 const uint8_t* const top_row,
+                                 const uint8_t* const left_column,
+                                 const int height, const int xstep,
+                                 const int ystep, const bool upsampled_top,
+                                 const bool upsampled_left) {
+  const int upsample_left_shift = static_cast<int>(upsampled_left);
+  const int upsample_top_shift = static_cast<int>(upsampled_top);
+
+  // Helper vector.
+  const int16x8_t zero_to_seven = {0, 1, 2, 3, 4, 5, 6, 7};
+
+  // Loop incrementers for moving by block (4xN). Vertical still steps by 8. If
+  // it's only 4, it will be finished in the first iteration.
+  const ptrdiff_t stride8 = stride << 3;
+  const int xstep8 = xstep << 3;
+
+  const int min_height = (height == 4) ? 4 : 8;
+
+  // All columns from |min_top_only_x| to the right will only need |top_row| to
+  // compute and can therefore call the Zone1 functions. This assumes |xstep| is
+  // at least 3.
+  assert(xstep >= 3);
+  const int min_top_only_x = std::min((height * xstep) >> 6, /* width */ 4);
+
+  // For steep angles, the source pixels from |left_column| may not fit in a
+  // 16-byte load for shuffling.
+  // TODO(petersonab): Find a more precise formula for this subject to x.
+  // TODO(johannkoenig): Revisit this for |width| == 4.
+  const int max_shuffle_height =
+      std::min(kDirectionalZone2ShuffleInvalidHeight[ystep >> 6], height);
+
+  // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1
+  int xstep_bounds_base = (xstep == 64) ? 0 : xstep - 1;
+
+  const int left_base_increment = ystep >> 6;
+  const int ystep_remainder = ystep & 0x3F;
+
+  // If the 64 scaling is regarded as a decimal point, the first value of the
+  // left_y vector omits the portion which is covered under the left_column
+  // offset. The following values need the full ystep as a relative offset.
+  int16x8_t left_y = vmulq_n_s16(zero_to_seven, -ystep);
+  left_y = vaddq_s16(left_y, vdupq_n_s16(-ystep_remainder));
+
+  // This loop treats each set of 4 columns in 3 stages with y-value boundaries.
+  // The first stage, before the first y-loop, covers blocks that are only
+  // computed from the top row. The second stage, comprising two y-loops, covers
+  // blocks that have a mixture of values computed from top or left. The final
+  // stage covers blocks that are only computed from the left.
+  if (min_top_only_x > 0) {
+    // Round down to the nearest multiple of 8.
+    // TODO(johannkoenig): This never hits for Wx4 blocks but maybe it should.
+    const int max_top_only_y = std::min((1 << 6) / xstep, height) & ~7;
+    DirectionalZone1_WxH<4>(dst, stride, max_top_only_y, top_row, -xstep,
+                            upsampled_top);
+
+    if (max_top_only_y == height) return;
+
+    int y = max_top_only_y;
+    dst += stride * y;
+    const int xstep_y = xstep * y;
+
+    // All rows from |min_left_only_y| down for this set of columns only need
+    // |left_column| to compute.
+    const int min_left_only_y = std::min((4 << 6) / xstep, height);
+    // At high angles such that min_left_only_y < 8, ystep is low and xstep is
+    // high. This means that max_shuffle_height is unbounded and xstep_bounds
+    // will overflow in 16 bits. This is prevented by stopping the first
+    // blending loop at min_left_only_y for such cases, which means we skip over
+    // the second blending loop as well.
+    const int left_shuffle_stop_y =
+        std::min(max_shuffle_height, min_left_only_y);
+    int xstep_bounds = xstep_bounds_base + xstep_y;
+    int top_x = -xstep - xstep_y;
+
+    // +8 increment is OK because if height is 4 this only goes once.
+    for (; y < left_shuffle_stop_y;
+         y += 8, dst += stride8, xstep_bounds += xstep8, top_x -= xstep8) {
+      DirectionalZone2FromLeftCol_WxH<4>(
+          dst, stride, min_height,
+          left_column + ((y - left_base_increment) << upsample_left_shift),
+          left_y, upsample_left_shift);
+
+      DirectionalZone1Blend_WxH<4>(dst, stride, min_height, top_row,
+                                   xstep_bounds, top_x, xstep,
+                                   upsample_top_shift);
+    }
+
+    // Pick up from the last y-value, using the slower but secure method for
+    // left prediction.
+    const int16_t base_left_y = vgetq_lane_s16(left_y, 0);
+    for (; y < min_left_only_y;
+         y += 8, dst += stride8, xstep_bounds += xstep8, top_x -= xstep8) {
+      DirectionalZone3_WxH<4>(
+          dst, stride, min_height,
+          left_column + ((y - left_base_increment) << upsample_left_shift),
+          base_left_y, -ystep, upsample_left_shift);
+
+      DirectionalZone1Blend_WxH<4>(dst, stride, min_height, top_row,
+                                   xstep_bounds, top_x, xstep,
+                                   upsample_top_shift);
+    }
+    // Loop over y for left_only rows.
+    for (; y < height; y += 8, dst += stride8) {
+      DirectionalZone3_WxH<4>(
+          dst, stride, min_height,
+          left_column + ((y - left_base_increment) << upsample_left_shift),
+          base_left_y, -ystep, upsample_left_shift);
+    }
+  } else {
+    DirectionalZone1_WxH<4>(dst, stride, height, top_row, -xstep,
+                            upsampled_top);
+  }
+}
+
+// Process a multiple of 8 |width|.
+inline void DirectionalZone2_8(uint8_t* const dst, const ptrdiff_t stride,
+                               const uint8_t* const top_row,
+                               const uint8_t* const left_column,
+                               const int width, const int height,
+                               const int xstep, const int ystep,
+                               const bool upsampled_top,
+                               const bool upsampled_left) {
+  const int upsample_left_shift = static_cast<int>(upsampled_left);
+  const int upsample_top_shift = static_cast<int>(upsampled_top);
+
+  // Helper vector.
+  const int16x8_t zero_to_seven = {0, 1, 2, 3, 4, 5, 6, 7};
+
+  // Loop incrementers for moving by block (8x8). This function handles blocks
+  // with height 4 as well. They are calculated in one pass so these variables
+  // do not get used.
+  const ptrdiff_t stride8 = stride << 3;
+  const int xstep8 = xstep << 3;
+  const int ystep8 = ystep << 3;
+
+  // Process Wx4 blocks.
+  const int min_height = (height == 4) ? 4 : 8;
+
+  // All columns from |min_top_only_x| to the right will only need |top_row| to
+  // compute and can therefore call the Zone1 functions. This assumes |xstep| is
+  // at least 3.
+  assert(xstep >= 3);
+  const int min_top_only_x = std::min((height * xstep) >> 6, width);
+
+  // For steep angles, the source pixels from |left_column| may not fit in a
+  // 16-byte load for shuffling.
+  // TODO(petersonab): Find a more precise formula for this subject to x.
+  const int max_shuffle_height =
+      std::min(kDirectionalZone2ShuffleInvalidHeight[ystep >> 6], height);
+
+  // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1
+  int xstep_bounds_base = (xstep == 64) ? 0 : xstep - 1;
+
+  const int left_base_increment = ystep >> 6;
+  const int ystep_remainder = ystep & 0x3F;
+
+  const int left_base_increment8 = ystep8 >> 6;
+  const int ystep_remainder8 = ystep8 & 0x3F;
+  const int16x8_t increment_left8 = vdupq_n_s16(ystep_remainder8);
+
+  // If the 64 scaling is regarded as a decimal point, the first value of the
+  // left_y vector omits the portion which is covered under the left_column
+  // offset. Following values need the full ystep as a relative offset.
+  int16x8_t left_y = vmulq_n_s16(zero_to_seven, -ystep);
+  left_y = vaddq_s16(left_y, vdupq_n_s16(-ystep_remainder));
+
+  // This loop treats each set of 4 columns in 3 stages with y-value boundaries.
+  // The first stage, before the first y-loop, covers blocks that are only
+  // computed from the top row. The second stage, comprising two y-loops, covers
+  // blocks that have a mixture of values computed from top or left. The final
+  // stage covers blocks that are only computed from the left.
+  int x = 0;
+  for (int left_offset = -left_base_increment; x < min_top_only_x; x += 8,
+           xstep_bounds_base -= (8 << 6),
+           left_y = vsubq_s16(left_y, increment_left8),
+           left_offset -= left_base_increment8) {
+    uint8_t* dst_x = dst + x;
+
+    // Round down to the nearest multiple of 8.
+    const int max_top_only_y = std::min(((x + 1) << 6) / xstep, height) & ~7;
+    DirectionalZone1_WxH<8>(dst_x, stride, max_top_only_y,
+                            top_row + (x << upsample_top_shift), -xstep,
+                            upsampled_top);
+
+    if (max_top_only_y == height) continue;
+
+    int y = max_top_only_y;
+    dst_x += stride * y;
+    const int xstep_y = xstep * y;
+
+    // All rows from |min_left_only_y| down for this set of columns only need
+    // |left_column| to compute.
+    const int min_left_only_y = std::min(((x + 8) << 6) / xstep, height);
+    // At high angles such that min_left_only_y < 8, ystep is low and xstep is
+    // high. This means that max_shuffle_height is unbounded and xstep_bounds
+    // will overflow in 16 bits. This is prevented by stopping the first
+    // blending loop at min_left_only_y for such cases, which means we skip over
+    // the second blending loop as well.
+    const int left_shuffle_stop_y =
+        std::min(max_shuffle_height, min_left_only_y);
+    int xstep_bounds = xstep_bounds_base + xstep_y;
+    int top_x = -xstep - xstep_y;
+
+    for (; y < left_shuffle_stop_y;
+         y += 8, dst_x += stride8, xstep_bounds += xstep8, top_x -= xstep8) {
+      DirectionalZone2FromLeftCol_WxH<8>(
+          dst_x, stride, min_height,
+          left_column + ((left_offset + y) << upsample_left_shift), left_y,
+          upsample_left_shift);
+
+      DirectionalZone1Blend_WxH<8>(
+          dst_x, stride, min_height, top_row + (x << upsample_top_shift),
+          xstep_bounds, top_x, xstep, upsample_top_shift);
+    }
+
+    // Pick up from the last y-value, using the slower but secure method for
+    // left prediction.
+    const int16_t base_left_y = vgetq_lane_s16(left_y, 0);
+    for (; y < min_left_only_y;
+         y += 8, dst_x += stride8, xstep_bounds += xstep8, top_x -= xstep8) {
+      DirectionalZone3_WxH<8>(
+          dst_x, stride, min_height,
+          left_column + ((left_offset + y) << upsample_left_shift), base_left_y,
+          -ystep, upsample_left_shift);
+
+      DirectionalZone1Blend_WxH<8>(
+          dst_x, stride, min_height, top_row + (x << upsample_top_shift),
+          xstep_bounds, top_x, xstep, upsample_top_shift);
+    }
+    // Loop over y for left_only rows.
+    for (; y < height; y += 8, dst_x += stride8) {
+      DirectionalZone3_WxH<8>(
+          dst_x, stride, min_height,
+          left_column + ((left_offset + y) << upsample_left_shift), base_left_y,
+          -ystep, upsample_left_shift);
+    }
+  }
+  // TODO(johannkoenig): May be able to remove this branch.
+  if (x < width) {
+    DirectionalZone1_WxH(dst + x, stride, width - x, height,
+                         top_row + (x << upsample_top_shift), -xstep,
+                         upsampled_top);
+  }
+}
+
+void DirectionalIntraPredictorZone2_NEON(
+    void* const dest, const ptrdiff_t stride, const void* const top_row,
+    const void* const left_column, const int width, const int height,
+    const int xstep, const int ystep, const bool upsampled_top,
+    const bool upsampled_left) {
+  // Increasing the negative buffer for this function allows more rows to be
+  // processed at a time without branching in an inner loop to check the base.
+  uint8_t top_buffer[288];
+  uint8_t left_buffer[288];
+  memcpy(top_buffer + 128, static_cast<const uint8_t*>(top_row) - 16, 160);
+  memcpy(left_buffer + 128, static_cast<const uint8_t*>(left_column) - 16, 160);
+  const uint8_t* top_ptr = top_buffer + 144;
+  const uint8_t* left_ptr = left_buffer + 144;
+  auto* dst = static_cast<uint8_t*>(dest);
+
+  if (width == 4) {
+    DirectionalZone2_4xH(dst, stride, top_ptr, left_ptr, height, xstep, ystep,
+                         upsampled_top, upsampled_left);
+  } else {
+    DirectionalZone2_8(dst, stride, top_ptr, left_ptr, width, height, xstep,
+                       ystep, upsampled_top, upsampled_left);
+  }
 }
 
 void DirectionalIntraPredictorZone3_NEON(void* const dest,
@@ -239,8 +807,10 @@
     const uint8x8_t base_step_v = upsampled_left ? even : all;
     const uint8x8_t right_step = vadd_u8(base_step_v, vdup_n_u8(1));
 
-    for (int y = 0; y < height; y += 8) {
-      for (int x = 0; x < width; x += 4) {
+    int y = 0;
+    do {
+      int x = 0;
+      do {
         uint8_t* dst = static_cast<uint8_t*>(dest);
         dst += y * stride + x;
         uint8x8_t left_v[4], right_v[4], value_v[4];
@@ -298,104 +868,26 @@
           dst += stride;
           StoreHi4(dst, vreinterpret_u8_u16(c1.val[1]));
         }
-      }
-    }
+        x += 4;
+      } while (x < width);
+      y += 8;
+    } while (y < height);
   } else {  // 8x8 at a time.
     // Limited improvement for 8x8. ~20% faster for 64x64.
-    const uint8x8_t all = vcreate_u8(0x0706050403020100);
-    const uint8x8_t even = vcreate_u8(0x0e0c0a0806040200);
-    const uint8x8_t base_step_v = upsampled_left ? even : all;
-    const uint8x8_t right_step = vadd_u8(base_step_v, vdup_n_u8(1));
-
-    for (int y = 0; y < height; y += 8) {
-      for (int x = 0; x < width; x += 8) {
+    int y = 0;
+    do {
+      int x = 0;
+      do {
         uint8_t* dst = static_cast<uint8_t*>(dest);
         dst += y * stride + x;
-        uint8x8_t left_v[8], right_v[8], value_v[8];
-        const int ystep_base = ystep * x;
-        const int offset = y * base_step;
+        const int ystep_base = ystep * (x + 1);
 
-        const int index_0 = ystep_base + ystep * 1;
-        LoadStepwise(left + offset + (index_0 >> scale_bits), base_step_v,
-                     right_step, &left_v[0], &right_v[0]);
-        value_v[0] = WeightedBlend(left_v[0], right_v[0],
-                                   ((index_0 << upsample_shift) & 0x3F) >> 1);
-
-        const int index_1 = ystep_base + ystep * 2;
-        LoadStepwise(left + offset + (index_1 >> scale_bits), base_step_v,
-                     right_step, &left_v[1], &right_v[1]);
-        value_v[1] = WeightedBlend(left_v[1], right_v[1],
-                                   ((index_1 << upsample_shift) & 0x3F) >> 1);
-
-        const int index_2 = ystep_base + ystep * 3;
-        LoadStepwise(left + offset + (index_2 >> scale_bits), base_step_v,
-                     right_step, &left_v[2], &right_v[2]);
-        value_v[2] = WeightedBlend(left_v[2], right_v[2],
-                                   ((index_2 << upsample_shift) & 0x3F) >> 1);
-
-        const int index_3 = ystep_base + ystep * 4;
-        LoadStepwise(left + offset + (index_3 >> scale_bits), base_step_v,
-                     right_step, &left_v[3], &right_v[3]);
-        value_v[3] = WeightedBlend(left_v[3], right_v[3],
-                                   ((index_3 << upsample_shift) & 0x3F) >> 1);
-
-        const int index_4 = ystep_base + ystep * 5;
-        LoadStepwise(left + offset + (index_4 >> scale_bits), base_step_v,
-                     right_step, &left_v[4], &right_v[4]);
-        value_v[4] = WeightedBlend(left_v[4], right_v[4],
-                                   ((index_4 << upsample_shift) & 0x3F) >> 1);
-
-        const int index_5 = ystep_base + ystep * 6;
-        LoadStepwise(left + offset + (index_5 >> scale_bits), base_step_v,
-                     right_step, &left_v[5], &right_v[5]);
-        value_v[5] = WeightedBlend(left_v[5], right_v[5],
-                                   ((index_5 << upsample_shift) & 0x3F) >> 1);
-
-        const int index_6 = ystep_base + ystep * 7;
-        LoadStepwise(left + offset + (index_6 >> scale_bits), base_step_v,
-                     right_step, &left_v[6], &right_v[6]);
-        value_v[6] = WeightedBlend(left_v[6], right_v[6],
-                                   ((index_6 << upsample_shift) & 0x3F) >> 1);
-
-        const int index_7 = ystep_base + ystep * 8;
-        LoadStepwise(left + offset + (index_7 >> scale_bits), base_step_v,
-                     right_step, &left_v[7], &right_v[7]);
-        value_v[7] = WeightedBlend(left_v[7], right_v[7],
-                                   ((index_7 << upsample_shift) & 0x3F) >> 1);
-
-        // 8x8 transpose.
-        const uint8x16x2_t b0 = vtrnq_u8(vcombine_u8(value_v[0], value_v[4]),
-                                         vcombine_u8(value_v[1], value_v[5]));
-        const uint8x16x2_t b1 = vtrnq_u8(vcombine_u8(value_v[2], value_v[6]),
-                                         vcombine_u8(value_v[3], value_v[7]));
-
-        const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]),
-                                          vreinterpretq_u16_u8(b1.val[0]));
-        const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]),
-                                          vreinterpretq_u16_u8(b1.val[1]));
-
-        const uint32x4x2_t d0 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[0]),
-                                          vreinterpretq_u32_u16(c1.val[0]));
-        const uint32x4x2_t d1 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[1]),
-                                          vreinterpretq_u32_u16(c1.val[1]));
-
-        vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[0])));
-        dst += stride;
-        vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[0])));
-        dst += stride;
-        vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[0])));
-        dst += stride;
-        vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[0])));
-        dst += stride;
-        vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d0.val[1])));
-        dst += stride;
-        vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d0.val[1])));
-        dst += stride;
-        vst1_u8(dst, vreinterpret_u8_u32(vget_low_u32(d1.val[1])));
-        dst += stride;
-        vst1_u8(dst, vreinterpret_u8_u32(vget_high_u32(d1.val[1])));
-      }
-    }
+        DirectionalZone3_WxH<8>(dst, stride, 8, left + (y << upsample_shift),
+                                ystep_base, ystep, upsample_shift);
+        x += 8;
+      } while (x < width);
+      y += 8;
+    } while (y < height);
   }
 }
 
@@ -403,6 +895,7 @@
   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
   assert(dsp != nullptr);
   dsp->directional_intra_predictor_zone1 = DirectionalIntraPredictorZone1_NEON;
+  dsp->directional_intra_predictor_zone2 = DirectionalIntraPredictorZone2_NEON;
   dsp->directional_intra_predictor_zone3 = DirectionalIntraPredictorZone3_NEON;
 }
 
diff --git a/libgav1/src/dsp/arm/intrapred_filter_intra_neon.cc b/libgav1/src/dsp/arm/intrapred_filter_intra_neon.cc
index 9d023d3..ed7287f 100644
--- a/libgav1/src/dsp/arm/intrapred_filter_intra_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_filter_intra_neon.cc
@@ -1,5 +1,5 @@
-#include "src/dsp/arm/intrapred_neon.h"
 #include "src/dsp/dsp.h"
+#include "src/dsp/intrapred.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -90,9 +90,11 @@
   const uint8_t* relative_top = top;
   uint8_t relative_left[2] = {left[0], left[1]};
 
-  for (int y = 0; y < height; y += 2) {
+  int y = 0;
+  do {
     uint8_t* row_dst = dst;
-    for (int x = 0; x < width; x += 4) {
+    int x = 0;
+    do {
       uint16x8_t sum = vdupq_n_u16(0);
       const uint16x8_t subtrahend =
           vmull_u8(transposed_taps[0], vdup_n_u8(relative_top_left));
@@ -119,7 +121,8 @@
       relative_left[0] = row_dst[3];
       relative_left[1] = row_dst[3 + stride];
       row_dst += 4;
-    }
+      x += 4;
+    } while (x < width);
 
     // Progress down.
     relative_top_left = left[y + 1];
@@ -128,7 +131,8 @@
     relative_left[1] = left[y + 3];
 
     dst += 2 * stride;
-  }
+    y += 2;
+  } while (y < height);
 }
 
 void Init8bpp() {
diff --git a/libgav1/src/dsp/arm/intrapred_neon.cc b/libgav1/src/dsp/arm/intrapred_neon.cc
index d7cea3c..dad050b 100644
--- a/libgav1/src/dsp/arm/intrapred_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_neon.cc
@@ -1,6 +1,5 @@
-#include "src/dsp/arm/intrapred_neon.h"
-
 #include "src/dsp/dsp.h"
+#include "src/dsp/intrapred.h"
 
 #if LIBGAV1_ENABLE_NEON
 
diff --git a/libgav1/src/dsp/arm/intrapred_neon.h b/libgav1/src/dsp/arm/intrapred_neon.h
index 925587a..190f1a8 100644
--- a/libgav1/src/dsp/arm/intrapred_neon.h
+++ b/libgav1/src/dsp/arm/intrapred_neon.h
@@ -3,13 +3,14 @@
 
 #include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
-#include "src/dsp/intrapred.h"
 
 namespace libgav1 {
 namespace dsp {
 
-// Initializes Dsp::intra_predictors with neon implementations. This function is
-// not thread-safe.
+// Initializes Dsp::intra_predictors, Dsp::directional_intra_predictor_zone*,
+// Dsp::cfl_intra_predictors, Dsp::cfl_subsamplers and
+// Dsp::filter_intra_predictor, see the defines below for specifics. These
+// functions are not thread-safe.
 void IntraPredCflInit_NEON();
 void IntraPredDirectionalInit_NEON();
 void IntraPredFilterIntraInit_NEON();
@@ -24,6 +25,7 @@
 #define LIBGAV1_Dsp8bpp_FilterIntraPredictor LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 LIBGAV1_DSP_NEON
 #define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 LIBGAV1_DSP_NEON
 
 // 4x4
@@ -38,6 +40,8 @@
   LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 LIBGAV1_DSP_NEON
 
 // 4x8
 #define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_DSP_NEON
@@ -51,6 +55,8 @@
   LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 LIBGAV1_DSP_NEON
 
 // 4x16
 #define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_DSP_NEON
@@ -64,6 +70,8 @@
   LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 LIBGAV1_DSP_NEON
 
 // 8x4
 #define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_DSP_NEON
@@ -77,6 +85,8 @@
   LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 LIBGAV1_DSP_NEON
 
 // 8x8
 #define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_DSP_NEON
@@ -90,6 +100,8 @@
   LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 LIBGAV1_DSP_NEON
 
 // 8x16
 #define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_DSP_NEON
@@ -103,6 +115,8 @@
   LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 LIBGAV1_DSP_NEON
 
 // 8x32
 #define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_DSP_NEON
@@ -116,6 +130,8 @@
   LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 LIBGAV1_DSP_NEON
 
 // 16x4
 #define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_DSP_NEON
@@ -129,6 +145,8 @@
   LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 LIBGAV1_DSP_NEON
 
 // 16x8
 #define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_DSP_NEON
@@ -142,6 +160,8 @@
   LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 LIBGAV1_DSP_NEON
 
 // 16x16
 #define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcTop LIBGAV1_DSP_NEON
@@ -155,6 +175,8 @@
   LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 LIBGAV1_DSP_NEON
 
 // 16x32
 #define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcTop LIBGAV1_DSP_NEON
@@ -168,6 +190,8 @@
   LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 LIBGAV1_DSP_NEON
 
 // 16x64
 #define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcTop LIBGAV1_DSP_NEON
@@ -192,6 +216,8 @@
   LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 LIBGAV1_DSP_NEON
 
 // 32x16
 #define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcTop LIBGAV1_DSP_NEON
@@ -205,6 +231,8 @@
   LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 LIBGAV1_DSP_NEON
 
 // 32x32
 #define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcTop LIBGAV1_DSP_NEON
@@ -218,6 +246,8 @@
   LIBGAV1_DSP_NEON
 
 #define LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 LIBGAV1_DSP_NEON
 
 // 32x64
 #define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcTop LIBGAV1_DSP_NEON
diff --git a/libgav1/src/dsp/arm/intrapred_smooth_neon.cc b/libgav1/src/dsp/arm/intrapred_smooth_neon.cc
index e094174..b251fa5 100644
--- a/libgav1/src/dsp/arm/intrapred_smooth_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_smooth_neon.cc
@@ -1,5 +1,5 @@
-#include "src/dsp/arm/intrapred_neon.h"
 #include "src/dsp/dsp.h"
+#include "src/dsp/intrapred.h"
 
 #if LIBGAV1_ENABLE_NEON
 
diff --git a/libgav1/src/dsp/arm/inverse_transform_neon.cc b/libgav1/src/dsp/arm/inverse_transform_neon.cc
new file mode 100644
index 0000000..9e57365
--- /dev/null
+++ b/libgav1/src/dsp/arm/inverse_transform_neon.cc
@@ -0,0 +1,2627 @@
+#include "src/dsp/dsp.h"
+#include "src/dsp/inverse_transform.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/utils/array_2d.h"
+#include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+// Include the constants and utility functions inside the anonymous namespace.
+#include "src/dsp/inverse_transform.inc"
+
+//------------------------------------------------------------------------------
+
+// TODO(slavarnway): Move transpose functions to transpose_neon.h or
+// common_neon.h.
+
+LIBGAV1_ALWAYS_INLINE void Transpose4x4(const int16x8_t in[4],
+                                        int16x8_t out[4]) {
+  // Swap 16 bit elements. Goes from:
+  // a0: 00 01 02 03
+  // a1: 10 11 12 13
+  // a2: 20 21 22 23
+  // a3: 30 31 32 33
+  // to:
+  // b0.val[0]: 00 10 02 12
+  // b0.val[1]: 01 11 03 13
+  // b1.val[0]: 20 30 22 32
+  // b1.val[1]: 21 31 23 33
+  const int16x4_t a0 = vget_low_s16(in[0]);
+  const int16x4_t a1 = vget_low_s16(in[1]);
+  const int16x4_t a2 = vget_low_s16(in[2]);
+  const int16x4_t a3 = vget_low_s16(in[3]);
+
+  const int16x4x2_t b0 = vtrn_s16(a0, a1);
+  const int16x4x2_t b1 = vtrn_s16(a2, a3);
+
+  // Swap 32 bit elements resulting in:
+  // c0.val[0]: 00 10 20 30 04 14 24 34
+  // c0.val[1]: 02 12 22 32 06 16 26 36
+  // c1.val[0]: 01 11 21 31 05 15 25 35
+  // c1.val[1]: 03 13 23 33 07 17 27 37
+  const int32x2x2_t c0 = vtrn_s32(vreinterpret_s32_s16(b0.val[0]),
+                                  vreinterpret_s32_s16(b1.val[0]));
+  const int32x2x2_t c1 = vtrn_s32(vreinterpret_s32_s16(b0.val[1]),
+                                  vreinterpret_s32_s16(b1.val[1]));
+
+  const int16x4_t d0 = vreinterpret_s16_s32(c0.val[0]);
+  const int16x4_t d1 = vreinterpret_s16_s32(c1.val[0]);
+  const int16x4_t d2 = vreinterpret_s16_s32(c0.val[1]);
+  const int16x4_t d3 = vreinterpret_s16_s32(c1.val[1]);
+
+  out[0] = vcombine_s16(d0, d0);
+  out[1] = vcombine_s16(d1, d1);
+  out[2] = vcombine_s16(d2, d2);
+  out[3] = vcombine_s16(d3, d3);
+}
+
+LIBGAV1_ALWAYS_INLINE void Transpose8x8(const int16x8_t in[8],
+                                        int16x8_t out[8]) {
+  // Swap 16 bit elements. Goes from:
+  // a0: 00 01 02 03 04 05 06 07
+  // a1: 10 11 12 13 14 15 16 17
+  // a2: 20 21 22 23 24 25 26 27
+  // a3: 30 31 32 33 34 35 36 37
+  // a4: 40 41 42 43 44 45 46 47
+  // a5: 50 51 52 53 54 55 56 57
+  // a6: 60 61 62 63 64 65 66 67
+  // a7: 70 71 72 73 74 75 76 77
+  // to:
+  // b0.val[0]: 00 10 02 12 04 14 06 16
+  // b0.val[1]: 01 11 03 13 05 15 07 17
+  // b1.val[0]: 20 30 22 32 24 34 26 36
+  // b1.val[1]: 21 31 23 33 25 35 27 37
+  // b2.val[0]: 40 50 42 52 44 54 46 56
+  // b2.val[1]: 41 51 43 53 45 55 47 57
+  // b3.val[0]: 60 70 62 72 64 74 66 76
+  // b3.val[1]: 61 71 63 73 65 75 67 77
+
+  const int16x8x2_t b0 = vtrnq_s16(in[0], in[1]);
+  const int16x8x2_t b1 = vtrnq_s16(in[2], in[3]);
+  const int16x8x2_t b2 = vtrnq_s16(in[4], in[5]);
+  const int16x8x2_t b3 = vtrnq_s16(in[6], in[7]);
+
+  // Swap 32 bit elements resulting in:
+  // c0.val[0]: 00 10 20 30 04 14 24 34
+  // c0.val[1]: 02 12 22 32 06 16 26 36
+  // c1.val[0]: 01 11 21 31 05 15 25 35
+  // c1.val[1]: 03 13 23 33 07 17 27 37
+  // c2.val[0]: 40 50 60 70 44 54 64 74
+  // c2.val[1]: 42 52 62 72 46 56 66 76
+  // c3.val[0]: 41 51 61 71 45 55 65 75
+  // c3.val[1]: 43 53 63 73 47 57 67 77
+
+  const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]),
+                                   vreinterpretq_s32_s16(b1.val[0]));
+  const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]),
+                                   vreinterpretq_s32_s16(b1.val[1]));
+  const int32x4x2_t c2 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[0]),
+                                   vreinterpretq_s32_s16(b3.val[0]));
+  const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]),
+                                   vreinterpretq_s32_s16(b3.val[1]));
+
+  // Swap 64 bit elements resulting in:
+  // d0.val[0]: 00 10 20 30 40 50 60 70
+  // d0.val[1]: 04 14 24 34 44 54 64 74
+  // d1.val[0]: 01 11 21 31 41 51 61 71
+  // d1.val[1]: 05 15 25 35 45 55 65 75
+  // d2.val[0]: 02 12 22 32 42 52 62 72
+  // d2.val[1]: 06 16 26 36 46 56 66 76
+  // d3.val[0]: 03 13 23 33 43 53 63 73
+  // d3.val[1]: 07 17 27 37 47 57 67 77
+  const int16x8x2_t d0 = VtrnqS64(c0.val[0], c2.val[0]);
+  const int16x8x2_t d1 = VtrnqS64(c1.val[0], c3.val[0]);
+  const int16x8x2_t d2 = VtrnqS64(c0.val[1], c2.val[1]);
+  const int16x8x2_t d3 = VtrnqS64(c1.val[1], c3.val[1]);
+
+  out[0] = d0.val[0];
+  out[1] = d1.val[0];
+  out[2] = d2.val[0];
+  out[3] = d3.val[0];
+  out[4] = d0.val[1];
+  out[5] = d1.val[1];
+  out[6] = d2.val[1];
+  out[7] = d3.val[1];
+}
+
+LIBGAV1_ALWAYS_INLINE void Transpose4x8To8x4(const uint16x8_t in[8],
+                                             uint16x8_t out[4]) {
+  // Swap 16 bit elements. Goes from:
+  // a0: 00 01 02 03
+  // a1: 10 11 12 13
+  // a2: 20 21 22 23
+  // a3: 30 31 32 33
+  // a4: 40 41 42 43
+  // a5: 50 51 52 53
+  // a6: 60 61 62 63
+  // a7: 70 71 72 73
+  // to:
+  // b0.val[0]: 00 10 02 12
+  // b0.val[1]: 01 11 03 13
+  // b1.val[0]: 20 30 22 32
+  // b1.val[1]: 21 31 23 33
+  // b2.val[0]: 40 50 42 52
+  // b2.val[1]: 41 51 43 53
+  // b3.val[0]: 60 70 62 72
+  // b3.val[1]: 61 71 63 73
+
+  uint16x4x2_t b0 = vtrn_u16(vget_low_u16(in[0]), vget_low_u16(in[1]));
+  uint16x4x2_t b1 = vtrn_u16(vget_low_u16(in[2]), vget_low_u16(in[3]));
+  uint16x4x2_t b2 = vtrn_u16(vget_low_u16(in[4]), vget_low_u16(in[5]));
+  uint16x4x2_t b3 = vtrn_u16(vget_low_u16(in[6]), vget_low_u16(in[7]));
+
+  // Swap 32 bit elements resulting in:
+  // c0.val[0]: 00 10 20 30
+  // c0.val[1]: 02 12 22 32
+  // c1.val[0]: 01 11 21 31
+  // c1.val[1]: 03 13 23 33
+  // c2.val[0]: 40 50 60 70
+  // c2.val[1]: 42 52 62 72
+  // c3.val[0]: 41 51 61 71
+  // c3.val[1]: 43 53 63 73
+
+  uint32x2x2_t c0 = vtrn_u32(vreinterpret_u32_u16(b0.val[0]),
+                             vreinterpret_u32_u16(b1.val[0]));
+  uint32x2x2_t c1 = vtrn_u32(vreinterpret_u32_u16(b0.val[1]),
+                             vreinterpret_u32_u16(b1.val[1]));
+  uint32x2x2_t c2 = vtrn_u32(vreinterpret_u32_u16(b2.val[0]),
+                             vreinterpret_u32_u16(b3.val[0]));
+  uint32x2x2_t c3 = vtrn_u32(vreinterpret_u32_u16(b2.val[1]),
+                             vreinterpret_u32_u16(b3.val[1]));
+
+  // Swap 64 bit elements resulting in:
+  // o0: 00 10 20 30 40 50 60 70
+  // o1: 01 11 21 31 41 51 61 71
+  // o2: 02 12 22 32 42 52 62 72
+  // o3: 03 13 23 33 43 53 63 73
+
+  out[0] = vcombine_u16(vreinterpret_u16_u32(c0.val[0]),
+                        vreinterpret_u16_u32(c2.val[0]));
+  out[1] = vcombine_u16(vreinterpret_u16_u32(c1.val[0]),
+                        vreinterpret_u16_u32(c3.val[0]));
+  out[2] = vcombine_u16(vreinterpret_u16_u32(c0.val[1]),
+                        vreinterpret_u16_u32(c2.val[1]));
+  out[3] = vcombine_u16(vreinterpret_u16_u32(c1.val[1]),
+                        vreinterpret_u16_u32(c3.val[1]));
+}
+
+LIBGAV1_ALWAYS_INLINE void Transpose4x8To8x4(const int16x8_t in[8],
+                                             int16x8_t out[4]) {
+  Transpose4x8To8x4(reinterpret_cast<const uint16x8_t*>(in),
+                    reinterpret_cast<uint16x8_t*>(out));
+}
+
+LIBGAV1_ALWAYS_INLINE void Transpose8x4To4x8(const int16x8_t in[4],
+                                             int16x8_t out[8]) {
+  // Swap 16 bit elements. Goes from:
+  // a0: 00 01 02 03 04 05 06 07
+  // a1: 10 11 12 13 14 15 16 17
+  // a2: 20 21 22 23 24 25 26 27
+  // a3: 30 31 32 33 34 35 36 37
+  // to:
+  // b0.val[0]: 00 10 02 12 04 14 06 16
+  // b0.val[1]: 01 11 03 13 05 15 07 17
+  // b1.val[0]: 20 30 22 32 24 34 26 36
+  // b1.val[1]: 21 31 23 33 25 35 27 37
+  const int16x8x2_t b0 = vtrnq_s16(in[0], in[1]);
+  const int16x8x2_t b1 = vtrnq_s16(in[2], in[3]);
+
+  // Swap 32 bit elements resulting in:
+  // c0.val[0]: 00 10 20 30 04 14 24 34
+  // c0.val[1]: 02 12 22 32 06 16 26 36
+  // c1.val[0]: 01 11 21 31 05 15 25 35
+  // c1.val[1]: 03 13 23 33 07 17 27 37
+  const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]),
+                                   vreinterpretq_s32_s16(b1.val[0]));
+  const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]),
+                                   vreinterpretq_s32_s16(b1.val[1]));
+
+  // The upper 8 bytes are don't cares.
+  // out[0]: 00 10 20 30 04 14 24 34
+  // out[1]: 01 11 21 31 05 15 25 35
+  // out[2]: 02 12 22 32 06 16 26 36
+  // out[3]: 03 13 23 33 07 17 27 37
+  // out[4]: 04 14 24 34 04 14 24 34
+  // out[5]: 05 15 25 35 05 15 25 35
+  // out[6]: 06 16 26 36 06 16 26 36
+  // out[7]: 07 17 27 37 07 17 27 37
+  out[0] = vreinterpretq_s16_s32(c0.val[0]);
+  out[1] = vreinterpretq_s16_s32(c1.val[0]);
+  out[2] = vreinterpretq_s16_s32(c0.val[1]);
+  out[3] = vreinterpretq_s16_s32(c1.val[1]);
+  out[4] = vreinterpretq_s16_s32(
+      vcombine_s32(vget_high_s32(c0.val[0]), vget_high_s32(c0.val[0])));
+  out[5] = vreinterpretq_s16_s32(
+      vcombine_s32(vget_high_s32(c1.val[0]), vget_high_s32(c1.val[0])));
+  out[6] = vreinterpretq_s16_s32(
+      vcombine_s32(vget_high_s32(c0.val[1]), vget_high_s32(c0.val[1])));
+  out[7] = vreinterpretq_s16_s32(
+      vcombine_s32(vget_high_s32(c1.val[1]), vget_high_s32(c1.val[1])));
+}
+
+//------------------------------------------------------------------------------
+template <int store_width, int store_count>
+LIBGAV1_ALWAYS_INLINE void StoreDst(int16_t* dst, int32_t stride, int32_t idx,
+                                    const int16x8_t* const s) {
+  assert(store_count % 4 == 0);
+  assert(store_width == 8 || store_width == 16);
+  // NOTE: It is expected that the compiler will unroll these loops.
+  if (store_width == 16) {
+    for (int i = 0; i < store_count; i += 4) {
+      vst1q_s16(&dst[i * stride + idx], (s[i]));
+      vst1q_s16(&dst[(i + 1) * stride + idx], (s[i + 1]));
+      vst1q_s16(&dst[(i + 2) * stride + idx], (s[i + 2]));
+      vst1q_s16(&dst[(i + 3) * stride + idx], (s[i + 3]));
+    }
+  } else {
+    // store_width == 8
+    for (int i = 0; i < store_count; i += 4) {
+      vst1_s16(&dst[i * stride + idx], vget_low_s16(s[i]));
+      vst1_s16(&dst[(i + 1) * stride + idx], vget_low_s16(s[i + 1]));
+      vst1_s16(&dst[(i + 2) * stride + idx], vget_low_s16(s[i + 2]));
+      vst1_s16(&dst[(i + 3) * stride + idx], vget_low_s16(s[i + 3]));
+    }
+  }
+}
+
+template <int load_width, int load_count>
+LIBGAV1_ALWAYS_INLINE void LoadSrc(const int16_t* src, int32_t stride,
+                                   int32_t idx, int16x8_t* x) {
+  assert(load_count % 4 == 0);
+  assert(load_width == 8 || load_width == 16);
+  // NOTE: It is expected that the compiler will unroll these loops.
+  if (load_width == 16) {
+    for (int i = 0; i < load_count; i += 4) {
+      x[i] = vld1q_s16(&src[i * stride + idx]);
+      x[i + 1] = vld1q_s16(&src[(i + 1) * stride + idx]);
+      x[i + 2] = vld1q_s16(&src[(i + 2) * stride + idx]);
+      x[i + 3] = vld1q_s16(&src[(i + 3) * stride + idx]);
+    }
+  } else {
+    // load_width == 8
+    const int64x2_t zero = vdupq_n_s64(0);
+    for (int i = 0; i < load_count; i += 4) {
+      // The src buffer is aligned to 32 bytes.  Each load will always be 8
+      // byte aligned.
+      x[i] = vreinterpretq_s16_s64(vld1q_lane_s64(
+          reinterpret_cast<const int64_t*>(&src[i * stride + idx]), zero, 0));
+      x[i + 1] = vreinterpretq_s16_s64(vld1q_lane_s64(
+          reinterpret_cast<const int64_t*>(&src[(i + 1) * stride + idx]), zero,
+          0));
+      x[i + 2] = vreinterpretq_s16_s64(vld1q_lane_s64(
+          reinterpret_cast<const int64_t*>(&src[(i + 2) * stride + idx]), zero,
+          0));
+      x[i + 3] = vreinterpretq_s16_s64(vld1q_lane_s64(
+          reinterpret_cast<const int64_t*>(&src[(i + 3) * stride + idx]), zero,
+          0));
+    }
+  }
+}
+
+// Butterfly rotate 4 values.
+LIBGAV1_ALWAYS_INLINE void ButterflyRotation_4(int16x8_t* a, int16x8_t* b,
+                                               const int angle,
+                                               const bool flip) {
+  const int16_t cos128 = Cos128(angle);
+  const int16_t sin128 = Sin128(angle);
+  const int32x4_t acc_x = vmull_n_s16(vget_low_s16(*a), cos128);
+  const int32x4_t acc_y = vmull_n_s16(vget_low_s16(*a), sin128);
+  const int32x4_t x0 = vmlsl_n_s16(acc_x, vget_low_s16(*b), sin128);
+  const int32x4_t y0 = vmlal_n_s16(acc_y, vget_low_s16(*b), cos128);
+  const int16x4_t x1 = vqrshrn_n_s32(x0, 12);
+  const int16x4_t y1 = vqrshrn_n_s32(y0, 12);
+  const int16x8_t x = vcombine_s16(x1, x1);
+  const int16x8_t y = vcombine_s16(y1, y1);
+  if (flip) {
+    *a = y;
+    *b = x;
+  } else {
+    *a = x;
+    *b = y;
+  }
+}
+
+// Butterfly rotate 8 values.
+LIBGAV1_ALWAYS_INLINE void ButterflyRotation_8(int16x8_t* a, int16x8_t* b,
+                                               const int angle,
+                                               const bool flip) {
+  const int16_t cos128 = Cos128(angle);
+  const int16_t sin128 = Sin128(angle);
+  const int32x4_t acc_x = vmull_n_s16(vget_low_s16(*a), cos128);
+  const int32x4_t acc_y = vmull_n_s16(vget_low_s16(*a), sin128);
+  const int32x4_t x0 = vmlsl_n_s16(acc_x, vget_low_s16(*b), sin128);
+  const int32x4_t y0 = vmlal_n_s16(acc_y, vget_low_s16(*b), cos128);
+  const int16x4_t x1 = vqrshrn_n_s32(x0, 12);
+  const int16x4_t y1 = vqrshrn_n_s32(y0, 12);
+
+  const int32x4_t acc_x_hi = vmull_n_s16(vget_high_s16(*a), cos128);
+  const int32x4_t acc_y_hi = vmull_n_s16(vget_high_s16(*a), sin128);
+  const int32x4_t x0_hi = vmlsl_n_s16(acc_x_hi, vget_high_s16(*b), sin128);
+  const int32x4_t y0_hi = vmlal_n_s16(acc_y_hi, vget_high_s16(*b), cos128);
+  const int16x4_t x1_hi = vqrshrn_n_s32(x0_hi, 12);
+  const int16x4_t y1_hi = vqrshrn_n_s32(y0_hi, 12);
+
+  const int16x8_t x = vcombine_s16(x1, x1_hi);
+  const int16x8_t y = vcombine_s16(y1, y1_hi);
+  if (flip) {
+    *a = y;
+    *b = x;
+  } else {
+    *a = x;
+    *b = y;
+  }
+}
+
+LIBGAV1_ALWAYS_INLINE void ButterflyRotation_FirstIsZero(int16x8_t* a,
+                                                         int16x8_t* b,
+                                                         const int angle,
+                                                         const bool flip) {
+  const int16_t cos128 = Cos128(angle);
+  const int16_t sin128 = Sin128(angle);
+  const int32x4_t x0 = vmlsl_n_s16(vdupq_n_s32(0), vget_low_s16(*b), sin128);
+  const int32x4_t x0_hi =
+      vmlsl_n_s16(vdupq_n_s32(0), vget_high_s16(*b), sin128);
+  const int16x4_t x1 = vqrshrn_n_s32(x0, 12);
+  const int16x4_t x1_hi = vqrshrn_n_s32(x0_hi, 12);
+  const int16x8_t x = vcombine_s16(x1, x1_hi);
+  const int16x8_t y = vqrdmulhq_s16(*b, vdupq_n_s16(cos128 << 3));
+  if (flip) {
+    *a = y;
+    *b = x;
+  } else {
+    *a = x;
+    *b = y;
+  }
+}
+
+LIBGAV1_ALWAYS_INLINE void ButterflyRotation_SecondIsZero(int16x8_t* a,
+                                                          int16x8_t* b,
+                                                          const int angle,
+                                                          const bool flip) {
+  const int16_t cos128 = Cos128(angle);
+  const int16_t sin128 = Sin128(angle);
+  const int16x8_t x = vqrdmulhq_s16(*a, vdupq_n_s16(cos128 << 3));
+  const int16x8_t y = vqrdmulhq_s16(*a, vdupq_n_s16(sin128 << 3));
+  if (flip) {
+    *a = y;
+    *b = x;
+  } else {
+    *a = x;
+    *b = y;
+  }
+}
+
+LIBGAV1_ALWAYS_INLINE void HadamardRotation(int16x8_t* a, int16x8_t* b,
+                                            bool flip) {
+  int16x8_t x, y;
+  if (flip) {
+    y = vqaddq_s16(*b, *a);
+    x = vqsubq_s16(*b, *a);
+  } else {
+    x = vqaddq_s16(*a, *b);
+    y = vqsubq_s16(*a, *b);
+  }
+  *a = x;
+  *b = y;
+}
+
+using ButterflyRotationFunc = void (*)(int16x8_t* a, int16x8_t* b, int angle,
+                                       bool flip);
+
+//------------------------------------------------------------------------------
+// Discrete Cosine Transforms (DCT).
+
+template <ButterflyRotationFunc bufferfly_rotation,
+          bool is_fast_bufferfly = false>
+LIBGAV1_ALWAYS_INLINE void Dct4Stages(int16x8_t* s) {
+  // stage 12.
+  if (is_fast_bufferfly) {
+    ButterflyRotation_SecondIsZero(&s[0], &s[1], 32, true);
+    ButterflyRotation_SecondIsZero(&s[2], &s[3], 48, false);
+  } else {
+    bufferfly_rotation(&s[0], &s[1], 32, true);
+    bufferfly_rotation(&s[2], &s[3], 48, false);
+  }
+
+  // stage 17.
+  HadamardRotation(&s[0], &s[3], false);
+  HadamardRotation(&s[1], &s[2], false);
+}
+
+template <ButterflyRotationFunc bufferfly_rotation, bool stage_is_rectangular>
+LIBGAV1_ALWAYS_INLINE void Dct4_NEON(void* dest, const void* source,
+                                     int32_t step, bool transpose) {
+  auto* const dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  int16x8_t s[4], x[4];
+
+  if (stage_is_rectangular) {
+    if (transpose) {
+      int16x8_t input[8];
+      LoadSrc<8, 8>(src, step, 0, input);
+      Transpose4x8To8x4(input, x);
+    } else {
+      LoadSrc<16, 4>(src, step, 0, x);
+    }
+  } else {
+    LoadSrc<8, 4>(src, step, 0, x);
+    if (transpose) {
+      Transpose4x4(x, x);
+    }
+  }
+
+  // stage 1.
+  // kBitReverseLookup 0, 2, 1, 3
+  s[0] = x[0];
+  s[1] = x[2];
+  s[2] = x[1];
+  s[3] = x[3];
+
+  Dct4Stages<bufferfly_rotation>(s);
+
+  if (stage_is_rectangular) {
+    if (transpose) {
+      int16x8_t output[8];
+      Transpose8x4To4x8(s, output);
+      StoreDst<8, 8>(dst, step, 0, output);
+    } else {
+      StoreDst<16, 4>(dst, step, 0, s);
+    }
+  } else {
+    if (transpose) {
+      Transpose4x4(s, s);
+    }
+    StoreDst<8, 4>(dst, step, 0, s);
+  }
+}
+
+template <ButterflyRotationFunc bufferfly_rotation,
+          bool is_fast_bufferfly = false>
+LIBGAV1_ALWAYS_INLINE void Dct8Stages(int16x8_t* s) {
+  // stage 8.
+  if (is_fast_bufferfly) {
+    ButterflyRotation_SecondIsZero(&s[4], &s[7], 56, false);
+    ButterflyRotation_FirstIsZero(&s[5], &s[6], 24, false);
+  } else {
+    bufferfly_rotation(&s[4], &s[7], 56, false);
+    bufferfly_rotation(&s[5], &s[6], 24, false);
+  }
+
+  // stage 13.
+  HadamardRotation(&s[4], &s[5], false);
+  HadamardRotation(&s[6], &s[7], true);
+
+  // stage 18.
+  bufferfly_rotation(&s[6], &s[5], 32, true);
+
+  // stage 22.
+  HadamardRotation(&s[0], &s[7], false);
+  HadamardRotation(&s[1], &s[6], false);
+  HadamardRotation(&s[2], &s[5], false);
+  HadamardRotation(&s[3], &s[4], false);
+}
+
+// Process dct8 rows or columns, depending on the transpose flag.
+template <ButterflyRotationFunc bufferfly_rotation, bool stage_is_rectangular>
+LIBGAV1_ALWAYS_INLINE void Dct8_NEON(void* dest, const void* source,
+                                     int32_t step, bool transpose) {
+  auto* const dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  int16x8_t s[8], x[8];
+
+  if (stage_is_rectangular) {
+    if (transpose) {
+      int16x8_t input[4];
+      LoadSrc<16, 4>(src, step, 0, input);
+      Transpose8x4To4x8(input, x);
+    } else {
+      LoadSrc<8, 8>(src, step, 0, x);
+    }
+  } else if (transpose) {
+    int16x8_t input[8];
+    LoadSrc<16, 8>(src, step, 0, input);
+    Transpose8x8(input, x);
+  } else {
+    LoadSrc<16, 8>(src, step, 0, x);
+  }
+
+  // stage 1.
+  // kBitReverseLookup 0, 4, 2, 6, 1, 5, 3, 7,
+  s[0] = x[0];
+  s[1] = x[4];
+  s[2] = x[2];
+  s[3] = x[6];
+  s[4] = x[1];
+  s[5] = x[5];
+  s[6] = x[3];
+  s[7] = x[7];
+
+  Dct4Stages<bufferfly_rotation>(s);
+  Dct8Stages<bufferfly_rotation>(s);
+
+  if (stage_is_rectangular) {
+    if (transpose) {
+      int16x8_t output[4];
+      Transpose4x8To8x4(s, output);
+      StoreDst<16, 4>(dst, step, 0, output);
+    } else {
+      StoreDst<8, 8>(dst, step, 0, s);
+    }
+  } else if (transpose) {
+    int16x8_t output[8];
+    Transpose8x8(s, output);
+    StoreDst<16, 8>(dst, step, 0, output);
+  } else {
+    StoreDst<16, 8>(dst, step, 0, s);
+  }
+}
+
+template <ButterflyRotationFunc bufferfly_rotation,
+          bool is_fast_bufferfly = false>
+LIBGAV1_ALWAYS_INLINE void Dct16Stages(int16x8_t* s) {
+  // stage 5.
+  if (is_fast_bufferfly) {
+    ButterflyRotation_SecondIsZero(&s[8], &s[15], 60, false);
+    ButterflyRotation_FirstIsZero(&s[9], &s[14], 28, false);
+    ButterflyRotation_SecondIsZero(&s[10], &s[13], 44, false);
+    ButterflyRotation_FirstIsZero(&s[11], &s[12], 12, false);
+  } else {
+    bufferfly_rotation(&s[8], &s[15], 60, false);
+    bufferfly_rotation(&s[9], &s[14], 28, false);
+    bufferfly_rotation(&s[10], &s[13], 44, false);
+    bufferfly_rotation(&s[11], &s[12], 12, false);
+  }
+
+  // stage 9.
+  HadamardRotation(&s[8], &s[9], false);
+  HadamardRotation(&s[10], &s[11], true);
+  HadamardRotation(&s[12], &s[13], false);
+  HadamardRotation(&s[14], &s[15], true);
+
+  // stage 14.
+  bufferfly_rotation(&s[14], &s[9], 48, true);
+  bufferfly_rotation(&s[13], &s[10], 112, true);
+
+  // stage 19.
+  HadamardRotation(&s[8], &s[11], false);
+  HadamardRotation(&s[9], &s[10], false);
+  HadamardRotation(&s[12], &s[15], true);
+  HadamardRotation(&s[13], &s[14], true);
+
+  // stage 23.
+  bufferfly_rotation(&s[13], &s[10], 32, true);
+  bufferfly_rotation(&s[12], &s[11], 32, true);
+
+  // stage 26.
+  HadamardRotation(&s[0], &s[15], false);
+  HadamardRotation(&s[1], &s[14], false);
+  HadamardRotation(&s[2], &s[13], false);
+  HadamardRotation(&s[3], &s[12], false);
+  HadamardRotation(&s[4], &s[11], false);
+  HadamardRotation(&s[5], &s[10], false);
+  HadamardRotation(&s[6], &s[9], false);
+  HadamardRotation(&s[7], &s[8], false);
+}
+
+// Process dct16 rows or columns, depending on the transpose flag.
+template <ButterflyRotationFunc bufferfly_rotation, bool stage_is_rectangular>
+LIBGAV1_ALWAYS_INLINE void Dct16_NEON(void* dest, const void* source,
+                                      int32_t step, bool transpose) {
+  auto* const dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  int16x8_t s[16], x[16];
+
+  if (stage_is_rectangular) {
+    if (transpose) {
+      int16x8_t input[4];
+      LoadSrc<16, 4>(src, step, 0, input);
+      Transpose8x4To4x8(input, x);
+      LoadSrc<16, 4>(src, step, 8, input);
+      Transpose8x4To4x8(input, &x[8]);
+    } else {
+      LoadSrc<8, 16>(src, step, 0, x);
+    }
+  } else if (transpose) {
+    for (int idx = 0; idx < 16; idx += 8) {
+      int16x8_t input[8];
+      LoadSrc<16, 8>(src, step, idx, input);
+      Transpose8x8(input, &x[idx]);
+    }
+  } else {
+    LoadSrc<16, 16>(src, step, 0, x);
+  }
+
+  // stage 1
+  // kBitReverseLookup 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15,
+  s[0] = x[0];
+  s[1] = x[8];
+  s[2] = x[4];
+  s[3] = x[12];
+  s[4] = x[2];
+  s[5] = x[10];
+  s[6] = x[6];
+  s[7] = x[14];
+  s[8] = x[1];
+  s[9] = x[9];
+  s[10] = x[5];
+  s[11] = x[13];
+  s[12] = x[3];
+  s[13] = x[11];
+  s[14] = x[7];
+  s[15] = x[15];
+
+  Dct4Stages<bufferfly_rotation>(s);
+  Dct8Stages<bufferfly_rotation>(s);
+  Dct16Stages<bufferfly_rotation>(s);
+
+  if (stage_is_rectangular) {
+    if (transpose) {
+      int16x8_t output[4];
+      Transpose4x8To8x4(s, output);
+      StoreDst<16, 4>(dst, step, 0, output);
+      Transpose4x8To8x4(&s[8], output);
+      StoreDst<16, 4>(dst, step, 8, output);
+    } else {
+      StoreDst<8, 16>(dst, step, 0, s);
+    }
+  } else if (transpose) {
+    for (int idx = 0; idx < 16; idx += 8) {
+      int16x8_t output[8];
+      Transpose8x8(&s[idx], output);
+      StoreDst<16, 8>(dst, step, idx, output);
+    }
+  } else {
+    StoreDst<16, 16>(dst, step, 0, s);
+  }
+}
+
+template <ButterflyRotationFunc bufferfly_rotation,
+          bool is_fast_butterfly = false>
+LIBGAV1_ALWAYS_INLINE void Dct32Stages(int16x8_t* s) {
+  // stage 3
+  if (is_fast_butterfly) {
+    ButterflyRotation_SecondIsZero(&s[16], &s[31], 62, false);
+    ButterflyRotation_FirstIsZero(&s[17], &s[30], 30, false);
+    ButterflyRotation_SecondIsZero(&s[18], &s[29], 46, false);
+    ButterflyRotation_FirstIsZero(&s[19], &s[28], 14, false);
+    ButterflyRotation_SecondIsZero(&s[20], &s[27], 54, false);
+    ButterflyRotation_FirstIsZero(&s[21], &s[26], 22, false);
+    ButterflyRotation_SecondIsZero(&s[22], &s[25], 38, false);
+    ButterflyRotation_FirstIsZero(&s[23], &s[24], 6, false);
+  } else {
+    bufferfly_rotation(&s[16], &s[31], 62, false);
+    bufferfly_rotation(&s[17], &s[30], 30, false);
+    bufferfly_rotation(&s[18], &s[29], 46, false);
+    bufferfly_rotation(&s[19], &s[28], 14, false);
+    bufferfly_rotation(&s[20], &s[27], 54, false);
+    bufferfly_rotation(&s[21], &s[26], 22, false);
+    bufferfly_rotation(&s[22], &s[25], 38, false);
+    bufferfly_rotation(&s[23], &s[24], 6, false);
+  }
+  // stage 6.
+  HadamardRotation(&s[16], &s[17], false);
+  HadamardRotation(&s[18], &s[19], true);
+  HadamardRotation(&s[20], &s[21], false);
+  HadamardRotation(&s[22], &s[23], true);
+  HadamardRotation(&s[24], &s[25], false);
+  HadamardRotation(&s[26], &s[27], true);
+  HadamardRotation(&s[28], &s[29], false);
+  HadamardRotation(&s[30], &s[31], true);
+
+  // stage 10.
+  bufferfly_rotation(&s[30], &s[17], 24 + 32, true);
+  bufferfly_rotation(&s[29], &s[18], 24 + 64 + 32, true);
+  bufferfly_rotation(&s[26], &s[21], 24, true);
+  bufferfly_rotation(&s[25], &s[22], 24 + 64, true);
+
+  // stage 15.
+  HadamardRotation(&s[16], &s[19], false);
+  HadamardRotation(&s[17], &s[18], false);
+  HadamardRotation(&s[20], &s[23], true);
+  HadamardRotation(&s[21], &s[22], true);
+  HadamardRotation(&s[24], &s[27], false);
+  HadamardRotation(&s[25], &s[26], false);
+  HadamardRotation(&s[28], &s[31], true);
+  HadamardRotation(&s[29], &s[30], true);
+
+  // stage 20.
+  bufferfly_rotation(&s[29], &s[18], 48, true);
+  bufferfly_rotation(&s[28], &s[19], 48, true);
+  bufferfly_rotation(&s[27], &s[20], 48 + 64, true);
+  bufferfly_rotation(&s[26], &s[21], 48 + 64, true);
+
+  // stage 24.
+  HadamardRotation(&s[16], &s[23], false);
+  HadamardRotation(&s[17], &s[22], false);
+  HadamardRotation(&s[18], &s[21], false);
+  HadamardRotation(&s[19], &s[20], false);
+  HadamardRotation(&s[24], &s[31], true);
+  HadamardRotation(&s[25], &s[30], true);
+  HadamardRotation(&s[26], &s[29], true);
+  HadamardRotation(&s[27], &s[28], true);
+
+  // stage 27.
+  bufferfly_rotation(&s[27], &s[20], 32, true);
+  bufferfly_rotation(&s[26], &s[21], 32, true);
+  bufferfly_rotation(&s[25], &s[22], 32, true);
+  bufferfly_rotation(&s[24], &s[23], 32, true);
+
+  // stage 29.
+  HadamardRotation(&s[0], &s[31], false);
+  HadamardRotation(&s[1], &s[30], false);
+  HadamardRotation(&s[2], &s[29], false);
+  HadamardRotation(&s[3], &s[28], false);
+  HadamardRotation(&s[4], &s[27], false);
+  HadamardRotation(&s[5], &s[26], false);
+  HadamardRotation(&s[6], &s[25], false);
+  HadamardRotation(&s[7], &s[24], false);
+  HadamardRotation(&s[8], &s[23], false);
+  HadamardRotation(&s[9], &s[22], false);
+  HadamardRotation(&s[10], &s[21], false);
+  HadamardRotation(&s[11], &s[20], false);
+  HadamardRotation(&s[12], &s[19], false);
+  HadamardRotation(&s[13], &s[18], false);
+  HadamardRotation(&s[14], &s[17], false);
+  HadamardRotation(&s[15], &s[16], false);
+}
+
+// Process dct32 rows or columns, depending on the transpose flag.
+LIBGAV1_ALWAYS_INLINE void Dct32_NEON(void* dest, const void* source,
+                                      const int32_t step,
+                                      const bool transpose) {
+  auto* const dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  int16x8_t s[32], x[32];
+
+  if (transpose) {
+    for (int idx = 0; idx < 32; idx += 8) {
+      int16x8_t input[8];
+      LoadSrc<16, 8>(src, step, idx, input);
+      Transpose8x8(input, &x[idx]);
+    }
+  } else {
+    LoadSrc<16, 32>(src, step, 0, x);
+  }
+
+  // stage 1
+  // kBitReverseLookup
+  // 0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30,
+  s[0] = x[0];
+  s[1] = x[16];
+  s[2] = x[8];
+  s[3] = x[24];
+  s[4] = x[4];
+  s[5] = x[20];
+  s[6] = x[12];
+  s[7] = x[28];
+  s[8] = x[2];
+  s[9] = x[18];
+  s[10] = x[10];
+  s[11] = x[26];
+  s[12] = x[6];
+  s[13] = x[22];
+  s[14] = x[14];
+  s[15] = x[30];
+
+  // 1, 17, 9, 25, 5, 21, 13, 29, 3, 19, 11, 27, 7, 23, 15, 31,
+  s[16] = x[1];
+  s[17] = x[17];
+  s[18] = x[9];
+  s[19] = x[25];
+  s[20] = x[5];
+  s[21] = x[21];
+  s[22] = x[13];
+  s[23] = x[29];
+  s[24] = x[3];
+  s[25] = x[19];
+  s[26] = x[11];
+  s[27] = x[27];
+  s[28] = x[7];
+  s[29] = x[23];
+  s[30] = x[15];
+  s[31] = x[31];
+
+  Dct4Stages<ButterflyRotation_8>(s);
+  Dct8Stages<ButterflyRotation_8>(s);
+  Dct16Stages<ButterflyRotation_8>(s);
+  Dct32Stages<ButterflyRotation_8>(s);
+
+  if (transpose) {
+    for (int idx = 0; idx < 32; idx += 8) {
+      int16x8_t output[8];
+      Transpose8x8(&s[idx], output);
+      StoreDst<16, 8>(dst, step, idx, output);
+    }
+  } else {
+    StoreDst<16, 32>(dst, step, 0, s);
+  }
+}
+
+// Allow the compiler to call this function instead of force inlining. Tests
+// show the performance is slightly faster.
+void Dct64_NEON(void* dest, const void* source, int32_t step, bool transpose) {
+  auto* const dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  int16x8_t s[64], x[32];
+
+  if (transpose) {
+    // The last 32 values of every row are always zero if the |tx_width| is
+    // 64.
+    for (int idx = 0; idx < 32; idx += 8) {
+      int16x8_t input[8];
+      LoadSrc<16, 8>(src, step, idx, input);
+      Transpose8x8(input, &x[idx]);
+    }
+  } else {
+    // The last 32 values of every column are always zero if the |tx_height| is
+    // 64.
+    LoadSrc<16, 32>(src, step, 0, x);
+  }
+
+  // stage 1
+  // kBitReverseLookup
+  // 0, 32, 16, 48, 8, 40, 24, 56, 4, 36, 20, 52, 12, 44, 28, 60,
+  s[0] = x[0];
+  s[2] = x[16];
+  s[4] = x[8];
+  s[6] = x[24];
+  s[8] = x[4];
+  s[10] = x[20];
+  s[12] = x[12];
+  s[14] = x[28];
+
+  // 2, 34, 18, 50, 10, 42, 26, 58, 6, 38, 22, 54, 14, 46, 30, 62,
+  s[16] = x[2];
+  s[18] = x[18];
+  s[20] = x[10];
+  s[22] = x[26];
+  s[24] = x[6];
+  s[26] = x[22];
+  s[28] = x[14];
+  s[30] = x[30];
+
+  // 1, 33, 17, 49, 9, 41, 25, 57, 5, 37, 21, 53, 13, 45, 29, 61,
+  s[32] = x[1];
+  s[34] = x[17];
+  s[36] = x[9];
+  s[38] = x[25];
+  s[40] = x[5];
+  s[42] = x[21];
+  s[44] = x[13];
+  s[46] = x[29];
+
+  // 3, 35, 19, 51, 11, 43, 27, 59, 7, 39, 23, 55, 15, 47, 31, 63
+  s[48] = x[3];
+  s[50] = x[19];
+  s[52] = x[11];
+  s[54] = x[27];
+  s[56] = x[7];
+  s[58] = x[23];
+  s[60] = x[15];
+  s[62] = x[31];
+
+  Dct4Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
+  Dct8Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
+  Dct16Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
+  Dct32Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
+
+  //-- start dct 64 stages
+  // stage 2.
+  ButterflyRotation_SecondIsZero(&s[32], &s[63], 63 - 0, false);
+  ButterflyRotation_FirstIsZero(&s[33], &s[62], 63 - 32, false);
+  ButterflyRotation_SecondIsZero(&s[34], &s[61], 63 - 16, false);
+  ButterflyRotation_FirstIsZero(&s[35], &s[60], 63 - 48, false);
+  ButterflyRotation_SecondIsZero(&s[36], &s[59], 63 - 8, false);
+  ButterflyRotation_FirstIsZero(&s[37], &s[58], 63 - 40, false);
+  ButterflyRotation_SecondIsZero(&s[38], &s[57], 63 - 24, false);
+  ButterflyRotation_FirstIsZero(&s[39], &s[56], 63 - 56, false);
+  ButterflyRotation_SecondIsZero(&s[40], &s[55], 63 - 4, false);
+  ButterflyRotation_FirstIsZero(&s[41], &s[54], 63 - 36, false);
+  ButterflyRotation_SecondIsZero(&s[42], &s[53], 63 - 20, false);
+  ButterflyRotation_FirstIsZero(&s[43], &s[52], 63 - 52, false);
+  ButterflyRotation_SecondIsZero(&s[44], &s[51], 63 - 12, false);
+  ButterflyRotation_FirstIsZero(&s[45], &s[50], 63 - 44, false);
+  ButterflyRotation_SecondIsZero(&s[46], &s[49], 63 - 28, false);
+  ButterflyRotation_FirstIsZero(&s[47], &s[48], 63 - 60, false);
+
+  // stage 4.
+  HadamardRotation(&s[32], &s[33], false);
+  HadamardRotation(&s[34], &s[35], true);
+  HadamardRotation(&s[36], &s[37], false);
+  HadamardRotation(&s[38], &s[39], true);
+  HadamardRotation(&s[40], &s[41], false);
+  HadamardRotation(&s[42], &s[43], true);
+  HadamardRotation(&s[44], &s[45], false);
+  HadamardRotation(&s[46], &s[47], true);
+  HadamardRotation(&s[48], &s[49], false);
+  HadamardRotation(&s[50], &s[51], true);
+  HadamardRotation(&s[52], &s[53], false);
+  HadamardRotation(&s[54], &s[55], true);
+  HadamardRotation(&s[56], &s[57], false);
+  HadamardRotation(&s[58], &s[59], true);
+  HadamardRotation(&s[60], &s[61], false);
+  HadamardRotation(&s[62], &s[63], true);
+
+  // stage 7.
+  ButterflyRotation_8(&s[62], &s[33], 60 - 0, true);
+  ButterflyRotation_8(&s[61], &s[34], 60 - 0 + 64, true);
+  ButterflyRotation_8(&s[58], &s[37], 60 - 32, true);
+  ButterflyRotation_8(&s[57], &s[38], 60 - 32 + 64, true);
+  ButterflyRotation_8(&s[54], &s[41], 60 - 16, true);
+  ButterflyRotation_8(&s[53], &s[42], 60 - 16 + 64, true);
+  ButterflyRotation_8(&s[50], &s[45], 60 - 48, true);
+  ButterflyRotation_8(&s[49], &s[46], 60 - 48 + 64, true);
+
+  // stage 11.
+  HadamardRotation(&s[32], &s[35], false);
+  HadamardRotation(&s[33], &s[34], false);
+  HadamardRotation(&s[36], &s[39], true);
+  HadamardRotation(&s[37], &s[38], true);
+  HadamardRotation(&s[40], &s[43], false);
+  HadamardRotation(&s[41], &s[42], false);
+  HadamardRotation(&s[44], &s[47], true);
+  HadamardRotation(&s[45], &s[46], true);
+  HadamardRotation(&s[48], &s[51], false);
+  HadamardRotation(&s[49], &s[50], false);
+  HadamardRotation(&s[52], &s[55], true);
+  HadamardRotation(&s[53], &s[54], true);
+  HadamardRotation(&s[56], &s[59], false);
+  HadamardRotation(&s[57], &s[58], false);
+  HadamardRotation(&s[60], &s[63], true);
+  HadamardRotation(&s[61], &s[62], true);
+
+  // stage 16.
+  ButterflyRotation_8(&s[61], &s[34], 56, true);
+  ButterflyRotation_8(&s[60], &s[35], 56, true);
+  ButterflyRotation_8(&s[59], &s[36], 56 + 64, true);
+  ButterflyRotation_8(&s[58], &s[37], 56 + 64, true);
+  ButterflyRotation_8(&s[53], &s[42], 56 - 32, true);
+  ButterflyRotation_8(&s[52], &s[43], 56 - 32, true);
+  ButterflyRotation_8(&s[51], &s[44], 56 - 32 + 64, true);
+  ButterflyRotation_8(&s[50], &s[45], 56 - 32 + 64, true);
+
+  // stage 21.
+  HadamardRotation(&s[32], &s[39], false);
+  HadamardRotation(&s[33], &s[38], false);
+  HadamardRotation(&s[34], &s[37], false);
+  HadamardRotation(&s[35], &s[36], false);
+  HadamardRotation(&s[40], &s[47], true);
+  HadamardRotation(&s[41], &s[46], true);
+  HadamardRotation(&s[42], &s[45], true);
+  HadamardRotation(&s[43], &s[44], true);
+  HadamardRotation(&s[48], &s[55], false);
+  HadamardRotation(&s[49], &s[54], false);
+  HadamardRotation(&s[50], &s[53], false);
+  HadamardRotation(&s[51], &s[52], false);
+  HadamardRotation(&s[56], &s[63], true);
+  HadamardRotation(&s[57], &s[62], true);
+  HadamardRotation(&s[58], &s[61], true);
+  HadamardRotation(&s[59], &s[60], true);
+
+  // stage 25.
+  ButterflyRotation_8(&s[59], &s[36], 48, true);
+  ButterflyRotation_8(&s[58], &s[37], 48, true);
+  ButterflyRotation_8(&s[57], &s[38], 48, true);
+  ButterflyRotation_8(&s[56], &s[39], 48, true);
+  ButterflyRotation_8(&s[55], &s[40], 112, true);
+  ButterflyRotation_8(&s[54], &s[41], 112, true);
+  ButterflyRotation_8(&s[53], &s[42], 112, true);
+  ButterflyRotation_8(&s[52], &s[43], 112, true);
+
+  // stage 28.
+  HadamardRotation(&s[32], &s[47], false);
+  HadamardRotation(&s[33], &s[46], false);
+  HadamardRotation(&s[34], &s[45], false);
+  HadamardRotation(&s[35], &s[44], false);
+  HadamardRotation(&s[36], &s[43], false);
+  HadamardRotation(&s[37], &s[42], false);
+  HadamardRotation(&s[38], &s[41], false);
+  HadamardRotation(&s[39], &s[40], false);
+  HadamardRotation(&s[48], &s[63], true);
+  HadamardRotation(&s[49], &s[62], true);
+  HadamardRotation(&s[50], &s[61], true);
+  HadamardRotation(&s[51], &s[60], true);
+  HadamardRotation(&s[52], &s[59], true);
+  HadamardRotation(&s[53], &s[58], true);
+  HadamardRotation(&s[54], &s[57], true);
+  HadamardRotation(&s[55], &s[56], true);
+
+  // stage 30.
+  ButterflyRotation_8(&s[55], &s[40], 32, true);
+  ButterflyRotation_8(&s[54], &s[41], 32, true);
+  ButterflyRotation_8(&s[53], &s[42], 32, true);
+  ButterflyRotation_8(&s[52], &s[43], 32, true);
+  ButterflyRotation_8(&s[51], &s[44], 32, true);
+  ButterflyRotation_8(&s[50], &s[45], 32, true);
+  ButterflyRotation_8(&s[49], &s[46], 32, true);
+  ButterflyRotation_8(&s[48], &s[47], 32, true);
+
+  // stage 31.
+  for (int i = 0; i < 32; i += 4) {
+    HadamardRotation(&s[i], &s[63 - i], false);
+    HadamardRotation(&s[i + 1], &s[63 - i - 1], false);
+    HadamardRotation(&s[i + 2], &s[63 - i - 2], false);
+    HadamardRotation(&s[i + 3], &s[63 - i - 3], false);
+  }
+  //-- end dct 64 stages
+
+  if (transpose) {
+    for (int idx = 0; idx < 64; idx += 8) {
+      int16x8_t output[8];
+      Transpose8x8(&s[idx], output);
+      StoreDst<16, 8>(dst, step, idx, output);
+    }
+  } else {
+    StoreDst<16, 64>(dst, step, 0, s);
+  }
+}
+
+//------------------------------------------------------------------------------
+// Asymmetric Discrete Sine Transforms (ADST).
+template <bool stage_is_rectangular>
+LIBGAV1_ALWAYS_INLINE void Adst4_NEON(void* dest, const void* source,
+                                      int32_t step, bool transpose) {
+  auto* const dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  int32x4_t s[8];
+  int16x8_t x[4];
+
+  if (stage_is_rectangular) {
+    if (transpose) {
+      int16x8_t input[8];
+      LoadSrc<8, 8>(src, step, 0, input);
+      Transpose4x8To8x4(input, x);
+    } else {
+      LoadSrc<16, 4>(src, step, 0, x);
+    }
+  } else {
+    LoadSrc<8, 4>(src, step, 0, x);
+    if (transpose) {
+      Transpose4x4(x, x);
+    }
+  }
+
+  const int16x4_t kAdst4Multiplier_0 = vdup_n_s16(kAdst4Multiplier[0]);
+  const int16x4_t kAdst4Multiplier_1 = vdup_n_s16(kAdst4Multiplier[1]);
+  const int16x4_t kAdst4Multiplier_2 = vdup_n_s16(kAdst4Multiplier[2]);
+  const int16x4_t kAdst4Multiplier_3 = vdup_n_s16(kAdst4Multiplier[3]);
+
+  // stage 1.
+  s[5] = vmull_s16(kAdst4Multiplier_1, vget_low_s16(x[3]));
+  s[6] = vmull_s16(kAdst4Multiplier_3, vget_low_s16(x[3]));
+
+  // stage 2.
+  const int32x4_t a7 = vsubl_s16(vget_low_s16(x[0]), vget_low_s16(x[2]));
+  const int32x4_t b7 = vaddw_s16(a7, vget_low_s16(x[3]));
+
+  // stage 3.
+  s[0] = vmull_s16(kAdst4Multiplier_0, vget_low_s16(x[0]));
+  s[1] = vmull_s16(kAdst4Multiplier_1, vget_low_s16(x[0]));
+  // s[0] = s[0] + s[3]
+  s[0] = vmlal_s16(s[0], kAdst4Multiplier_3, vget_low_s16(x[2]));
+  // s[1] = s[1] - s[4]
+  s[1] = vmlsl_s16(s[1], kAdst4Multiplier_0, vget_low_s16(x[2]));
+
+  s[3] = vmull_s16(kAdst4Multiplier_2, vget_low_s16(x[1]));
+  s[2] = vmulq_s32(vmovl_s16(kAdst4Multiplier_2), b7);
+
+  // stage 4.
+  s[0] = vaddq_s32(s[0], s[5]);
+  s[1] = vsubq_s32(s[1], s[6]);
+
+  // stages 5 and 6.
+  const int32x4_t x0 = vaddq_s32(s[0], s[3]);
+  const int32x4_t x1 = vaddq_s32(s[1], s[3]);
+  const int32x4_t x3_a = vaddq_s32(s[0], s[1]);
+  const int32x4_t x3 = vsubq_s32(x3_a, s[3]);
+  const int16x4_t dst_0 = vqrshrn_n_s32(x0, 12);
+  const int16x4_t dst_1 = vqrshrn_n_s32(x1, 12);
+  const int16x4_t dst_2 = vqrshrn_n_s32(s[2], 12);
+  const int16x4_t dst_3 = vqrshrn_n_s32(x3, 12);
+
+  x[0] = vcombine_s16(dst_0, dst_0);
+  x[1] = vcombine_s16(dst_1, dst_1);
+  x[2] = vcombine_s16(dst_2, dst_2);
+  x[3] = vcombine_s16(dst_3, dst_3);
+
+  if (stage_is_rectangular) {
+    if (transpose) {
+      int16x8_t output[8];
+      Transpose8x4To4x8(x, output);
+      StoreDst<8, 8>(dst, step, 0, output);
+    } else {
+      StoreDst<16, 4>(dst, step, 0, x);
+    }
+  } else {
+    if (transpose) {
+      Transpose4x4(x, x);
+    }
+    StoreDst<8, 4>(dst, step, 0, x);
+  }
+}
+
+template <ButterflyRotationFunc bufferfly_rotation, bool stage_is_rectangular>
+LIBGAV1_ALWAYS_INLINE void Adst8_NEON(void* dest, const void* source,
+                                      int32_t step, bool transpose) {
+  auto* const dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  int16x8_t s[8], x[8];
+
+  if (stage_is_rectangular) {
+    if (transpose) {
+      int16x8_t input[4];
+      LoadSrc<16, 4>(src, step, 0, input);
+      Transpose8x4To4x8(input, x);
+    } else {
+      LoadSrc<8, 8>(src, step, 0, x);
+    }
+  } else {
+    if (transpose) {
+      int16x8_t input[8];
+      LoadSrc<16, 8>(src, step, 0, input);
+      Transpose8x8(input, x);
+    } else {
+      LoadSrc<16, 8>(src, step, 0, x);
+    }
+  }
+
+  // stage 1.
+  s[0] = x[7];
+  s[1] = x[0];
+  s[2] = x[5];
+  s[3] = x[2];
+  s[4] = x[3];
+  s[5] = x[4];
+  s[6] = x[1];
+  s[7] = x[6];
+
+  // stage 2.
+  bufferfly_rotation(&s[0], &s[1], 60 - 0, true);
+  bufferfly_rotation(&s[2], &s[3], 60 - 16, true);
+  bufferfly_rotation(&s[4], &s[5], 60 - 32, true);
+  bufferfly_rotation(&s[6], &s[7], 60 - 48, true);
+
+  // stage 3.
+  HadamardRotation(&s[0], &s[4], false);
+  HadamardRotation(&s[1], &s[5], false);
+  HadamardRotation(&s[2], &s[6], false);
+  HadamardRotation(&s[3], &s[7], false);
+
+  // stage 4.
+  bufferfly_rotation(&s[4], &s[5], 48 - 0, true);
+  bufferfly_rotation(&s[7], &s[6], 48 - 32, true);
+
+  // stage 5.
+  HadamardRotation(&s[0], &s[2], false);
+  HadamardRotation(&s[4], &s[6], false);
+  HadamardRotation(&s[1], &s[3], false);
+  HadamardRotation(&s[5], &s[7], false);
+
+  // stage 6.
+  bufferfly_rotation(&s[2], &s[3], 32, true);
+  bufferfly_rotation(&s[6], &s[7], 32, true);
+
+  // stage 7.
+  x[0] = s[0];
+  x[1] = vqnegq_s16(s[4]);
+  x[2] = s[6];
+  x[3] = vqnegq_s16(s[2]);
+  x[4] = s[3];
+  x[5] = vqnegq_s16(s[7]);
+  x[6] = s[5];
+  x[7] = vqnegq_s16(s[1]);
+
+  if (stage_is_rectangular) {
+    if (transpose) {
+      int16x8_t output[4];
+      Transpose4x8To8x4(x, output);
+      StoreDst<16, 4>(dst, step, 0, output);
+    } else {
+      StoreDst<8, 8>(dst, step, 0, x);
+    }
+  } else {
+    if (transpose) {
+      int16x8_t output[8];
+      Transpose8x8(x, output);
+      StoreDst<16, 8>(dst, step, 0, output);
+    } else {
+      StoreDst<16, 8>(dst, step, 0, x);
+    }
+  }
+}
+
+template <ButterflyRotationFunc bufferfly_rotation, bool stage_is_rectangular>
+LIBGAV1_ALWAYS_INLINE void Adst16_NEON(void* dest, const void* source,
+                                       int32_t step, bool transpose) {
+  auto* const dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  int16x8_t s[16], x[16];
+
+  if (stage_is_rectangular) {
+    if (transpose) {
+      int16x8_t input[4];
+      LoadSrc<16, 4>(src, step, 0, input);
+      Transpose8x4To4x8(input, x);
+      LoadSrc<16, 4>(src, step, 8, input);
+      Transpose8x4To4x8(input, &x[8]);
+    } else {
+      LoadSrc<8, 16>(src, step, 0, x);
+    }
+  } else {
+    if (transpose) {
+      for (int idx = 0; idx < 16; idx += 8) {
+        int16x8_t input[8];
+        LoadSrc<16, 8>(src, step, idx, input);
+        Transpose8x8(input, &x[idx]);
+      }
+    } else {
+      LoadSrc<16, 16>(src, step, 0, x);
+    }
+  }
+
+  // stage 1.
+  s[0] = x[15];
+  s[1] = x[0];
+  s[2] = x[13];
+  s[3] = x[2];
+  s[4] = x[11];
+  s[5] = x[4];
+  s[6] = x[9];
+  s[7] = x[6];
+  s[8] = x[7];
+  s[9] = x[8];
+  s[10] = x[5];
+  s[11] = x[10];
+  s[12] = x[3];
+  s[13] = x[12];
+  s[14] = x[1];
+  s[15] = x[14];
+
+  // stage 2.
+  bufferfly_rotation(&s[0], &s[1], 62 - 0, true);
+  bufferfly_rotation(&s[2], &s[3], 62 - 8, true);
+  bufferfly_rotation(&s[4], &s[5], 62 - 16, true);
+  bufferfly_rotation(&s[6], &s[7], 62 - 24, true);
+  bufferfly_rotation(&s[8], &s[9], 62 - 32, true);
+  bufferfly_rotation(&s[10], &s[11], 62 - 40, true);
+  bufferfly_rotation(&s[12], &s[13], 62 - 48, true);
+  bufferfly_rotation(&s[14], &s[15], 62 - 56, true);
+
+  // stage 3.
+  HadamardRotation(&s[0], &s[8], false);
+  HadamardRotation(&s[1], &s[9], false);
+  HadamardRotation(&s[2], &s[10], false);
+  HadamardRotation(&s[3], &s[11], false);
+  HadamardRotation(&s[4], &s[12], false);
+  HadamardRotation(&s[5], &s[13], false);
+  HadamardRotation(&s[6], &s[14], false);
+  HadamardRotation(&s[7], &s[15], false);
+
+  // stage 4.
+  bufferfly_rotation(&s[8], &s[9], 56 - 0, true);
+  bufferfly_rotation(&s[13], &s[12], 8 + 0, true);
+  bufferfly_rotation(&s[10], &s[11], 56 - 32, true);
+  bufferfly_rotation(&s[15], &s[14], 8 + 32, true);
+
+  // stage 5.
+  HadamardRotation(&s[0], &s[4], false);
+  HadamardRotation(&s[8], &s[12], false);
+  HadamardRotation(&s[1], &s[5], false);
+  HadamardRotation(&s[9], &s[13], false);
+  HadamardRotation(&s[2], &s[6], false);
+  HadamardRotation(&s[10], &s[14], false);
+  HadamardRotation(&s[3], &s[7], false);
+  HadamardRotation(&s[11], &s[15], false);
+
+  // stage 6.
+  bufferfly_rotation(&s[4], &s[5], 48 - 0, true);
+  bufferfly_rotation(&s[12], &s[13], 48 - 0, true);
+  bufferfly_rotation(&s[7], &s[6], 48 - 32, true);
+  bufferfly_rotation(&s[15], &s[14], 48 - 32, true);
+
+  // stage 7.
+  HadamardRotation(&s[0], &s[2], false);
+  HadamardRotation(&s[4], &s[6], false);
+  HadamardRotation(&s[8], &s[10], false);
+  HadamardRotation(&s[12], &s[14], false);
+  HadamardRotation(&s[1], &s[3], false);
+  HadamardRotation(&s[5], &s[7], false);
+  HadamardRotation(&s[9], &s[11], false);
+  HadamardRotation(&s[13], &s[15], false);
+
+  // stage 8.
+  bufferfly_rotation(&s[2], &s[3], 32, true);
+  bufferfly_rotation(&s[6], &s[7], 32, true);
+  bufferfly_rotation(&s[10], &s[11], 32, true);
+  bufferfly_rotation(&s[14], &s[15], 32, true);
+
+  // stage 9.
+  x[0] = s[0];
+  x[1] = vqnegq_s16(s[8]);
+  x[2] = s[12];
+  x[3] = vqnegq_s16(s[4]);
+  x[4] = s[6];
+  x[5] = vqnegq_s16(s[14]);
+  x[6] = s[10];
+  x[7] = vqnegq_s16(s[2]);
+  x[8] = s[3];
+  x[9] = vqnegq_s16(s[11]);
+  x[10] = s[15];
+  x[11] = vqnegq_s16(s[7]);
+  x[12] = s[5];
+  x[13] = vqnegq_s16(s[13]);
+  x[14] = s[9];
+  x[15] = vqnegq_s16(s[1]);
+
+  if (stage_is_rectangular) {
+    if (transpose) {
+      int16x8_t output[4];
+      Transpose4x8To8x4(x, output);
+      StoreDst<16, 4>(dst, step, 0, output);
+      Transpose4x8To8x4(&x[8], output);
+      StoreDst<16, 4>(dst, step, 8, output);
+    } else {
+      StoreDst<8, 16>(dst, step, 0, x);
+    }
+  } else {
+    if (transpose) {
+      for (int idx = 0; idx < 16; idx += 8) {
+        int16x8_t output[8];
+        Transpose8x8(&x[idx], output);
+        StoreDst<16, 8>(dst, step, idx, output);
+      }
+    } else {
+      StoreDst<16, 16>(dst, step, 0, x);
+    }
+  }
+}
+
+//------------------------------------------------------------------------------
+// Identity Transforms.
+
+template <bool is_row_shift>
+LIBGAV1_ALWAYS_INLINE void Identity4_NEON(void* dest, const void* source,
+                                          int32_t step) {
+  auto* const dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  if (is_row_shift) {
+    const int shift = 1;
+    const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
+    const int16x4_t v_multiplier = vdup_n_s16(kIdentity4Multiplier);
+    const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
+    for (int i = 0; i < 4; i += 2) {
+      const int16x8_t v_src = vld1q_s16(&src[i * step]);
+      const int32x4_t v_src_mult_lo =
+          vmlal_s16(v_dual_round, vget_low_s16(v_src), v_multiplier);
+      const int32x4_t v_src_mult_hi =
+          vmlal_s16(v_dual_round, vget_high_s16(v_src), v_multiplier);
+      const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift);
+      const int32x4_t shift_hi = vqshlq_s32(v_src_mult_hi, v_shift);
+      vst1q_s16(&dst[i * step],
+                vcombine_s16(vqmovn_s32(shift_lo), vqmovn_s32(shift_hi)));
+    }
+  } else {
+    for (int i = 0; i < 4; i += 2) {
+      const int16x8_t v_src = vld1q_s16(&src[i * step]);
+      const int16x8_t a =
+          vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 3);
+      const int16x8_t b = vqaddq_s16(v_src, a);
+      vst1q_s16(&dst[i * step], b);
+    }
+  }
+}
+
+LIBGAV1_ALWAYS_INLINE void Identity4ColumnStoreToFrame(
+    Array2DView<uint8_t> frame, const int start_x, const int start_y,
+    const int tx_width, const int tx_height, const int16_t* source) {
+  const int stride = frame.columns();
+  uint8_t* dst = frame[start_y] + start_x;
+
+  if (tx_width == 4) {
+    const uint8x8_t zero = vdup_n_u8(0);
+    for (int i = 0; i < tx_height; ++i) {
+      const int16x4_t residual = vld1_s16(&source[i * tx_width]);
+      const int16x4_t residual_fraction =
+          vqrdmulh_n_s16(residual, kIdentity4MultiplierFraction << 3);
+      const int16x4_t v_dst_i = vqadd_s16(residual, residual_fraction);
+      const int16x8_t frame_data =
+          vreinterpretq_s16_u16(vmovl_u8(LoadLo4(dst, zero)));
+      const int16x4_t a = vrshr_n_s16(v_dst_i, 4);
+      const int16x4_t b = vqadd_s16(a, vget_low_s16(frame_data));
+      const uint8x8_t d = vqmovun_s16(vcombine_s16(b, b));
+      StoreLo4(dst, d);
+      dst += stride;
+    }
+  } else {
+    for (int i = 0; i < tx_height; ++i) {
+      const int row = i * tx_width;
+      int j = 0;
+      do {
+        const int16x8_t residual = vld1q_s16(&source[row + j]);
+        const int16x8_t residual_fraction =
+            vqrdmulhq_n_s16(residual, kIdentity4MultiplierFraction << 3);
+        const int16x8_t v_dst_i = vqaddq_s16(residual, residual_fraction);
+        const int16x8_t frame_data =
+            vreinterpretq_s16_u16(vmovl_u8(vld1_u8(dst + j)));
+        const int16x8_t a = vrshrq_n_s16(v_dst_i, 4);
+        const int16x8_t b = vqaddq_s16(a, frame_data);
+        const uint8x8_t d = vqmovun_s16(b);
+        vst1_u8(dst + j, d);
+        j += 8;
+      } while (j < tx_width);
+      dst += stride;
+    }
+  }
+}
+
+LIBGAV1_ALWAYS_INLINE void Identity4RowColumnStoreToFrame(
+    Array2DView<uint8_t> frame, const int start_x, const int start_y,
+    const int tx_width, const int tx_height, const int16_t* source) {
+  const int stride = frame.columns();
+  uint8_t* dst = frame[start_y] + start_x;
+
+  if (tx_width == 4) {
+    const uint8x8_t zero = vdup_n_u8(0);
+    for (int i = 0; i < tx_height; ++i) {
+      const int16x4_t v_src = vld1_s16(&source[i * tx_width]);
+      const int16x4_t v_src_mult =
+          vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 3);
+      const int16x8_t frame_data =
+          vreinterpretq_s16_u16(vmovl_u8(LoadLo4(dst, zero)));
+      const int16x4_t v_dst_row = vqadd_s16(v_src, v_src_mult);
+      const int16x4_t v_src_mult2 =
+          vqrdmulh_n_s16(v_dst_row, kIdentity4MultiplierFraction << 3);
+      const int16x4_t v_dst_col = vqadd_s16(v_dst_row, v_src_mult2);
+      const int16x4_t a = vrshr_n_s16(v_dst_col, 4);
+      const int16x4_t b = vqadd_s16(a, vget_low_s16(frame_data));
+      const uint8x8_t d = vqmovun_s16(vcombine_s16(b, b));
+      StoreLo4(dst, d);
+      dst += stride;
+    }
+  } else {
+    for (int i = 0; i < tx_height; ++i) {
+      const int row = i * tx_width;
+      int j = 0;
+      do {
+        const int16x8_t v_src = vld1q_s16(&source[row + j]);
+        const int16x8_t v_src_round =
+            vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3);
+        const int16x8_t v_dst_row = vqaddq_s16(v_src_round, v_src_round);
+        const int16x8_t v_src_mult2 =
+            vqrdmulhq_n_s16(v_dst_row, kIdentity4MultiplierFraction << 3);
+        const int16x8_t frame_data =
+            vreinterpretq_s16_u16(vmovl_u8(vld1_u8(dst + j)));
+        const int16x8_t v_dst_col = vqaddq_s16(v_dst_row, v_src_mult2);
+        const int16x8_t a = vrshrq_n_s16(v_dst_col, 4);
+        const int16x8_t b = vqaddq_s16(a, frame_data);
+        const uint8x8_t d = vqmovun_s16(b);
+        vst1_u8(dst + j, d);
+        j += 8;
+      } while (j < tx_width);
+      dst += stride;
+    }
+  }
+}
+
+LIBGAV1_ALWAYS_INLINE void Identity8Row32_NEON(void* dest, const void* source,
+                                               int32_t step) {
+  auto* const dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  // When combining the identity8 multiplier with the row shift, the
+  // calculations for tx_height equal to 32 can be simplified from
+  // ((A * 2) + 2) >> 2) to ((A + 1) >> 1).
+  for (int i = 0; i < 4; ++i) {
+    const int16x8_t v_src = vld1q_s16(&src[i * step]);
+    const int16x8_t a = vrshrq_n_s16(v_src, 1);
+    vst1q_s16(&dst[i * step], a);
+  }
+}
+
+LIBGAV1_ALWAYS_INLINE void Identity8Row4_NEON(void* dest, const void* source,
+                                              int32_t step) {
+  auto* const dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  for (int i = 0; i < 4; ++i) {
+    const int16x8_t v_src = vld1q_s16(&src[i * step]);
+    // For bitdepth == 8, the identity row clamps to a signed 16bit value, so
+    // saturating add here is ok.
+    const int16x8_t v_srcx2 = vqaddq_s16(v_src, v_src);
+    vst1q_s16(&dst[i * step], v_srcx2);
+  }
+}
+
+LIBGAV1_ALWAYS_INLINE void Identity8ColumnStoreToFrame_NEON(
+    Array2DView<uint8_t> frame, const int start_x, const int start_y,
+    const int tx_width, const int tx_height, const int16_t* source) {
+  const int stride = frame.columns();
+  uint8_t* dst = frame[start_y] + start_x;
+
+  if (tx_width == 4) {
+    const uint8x8_t zero = vdup_n_u8(0);
+    for (int i = 0; i < tx_height; ++i) {
+      const int16x4_t residual = vld1_s16(&source[i * tx_width]);
+      const int16x4_t v_dst_i = vqadd_s16(residual, residual);
+      const int16x8_t frame_data =
+          vreinterpretq_s16_u16(vmovl_u8(LoadLo4(dst, zero)));
+      const int16x4_t a = vrshr_n_s16(v_dst_i, 4);
+      const int16x4_t b = vqadd_s16(a, vget_low_s16(frame_data));
+      const uint8x8_t d = vqmovun_s16(vcombine_s16(b, b));
+      StoreLo4(dst, d);
+      dst += stride;
+    }
+  } else {
+    for (int i = 0; i < tx_height; ++i) {
+      const int row = i * tx_width;
+      for (int j = 0; j < tx_width; j += 8) {
+        const int16x8_t residual = vld1q_s16(&source[row + j]);
+        const int16x8_t v_dst_i = vqaddq_s16(residual, residual);
+        const int16x8_t frame_data =
+            vreinterpretq_s16_u16(vmovl_u8(vld1_u8(dst + j)));
+        const int16x8_t a = vrshrq_n_s16(v_dst_i, 4);
+        const int16x8_t b = vqaddq_s16(a, frame_data);
+        const uint8x8_t d = vqmovun_s16(b);
+        vst1_u8(dst + j, d);
+      }
+      dst += stride;
+    }
+  }
+}
+
+LIBGAV1_ALWAYS_INLINE void Identity16Row_NEON(void* dest, const void* source,
+                                              int32_t step, int shift) {
+  auto* const dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
+  const int16x4_t v_multiplier = vdup_n_s16(kIdentity16Multiplier);
+  const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
+
+  for (int i = 0; i < 4; ++i) {
+    for (int j = 0; j < 2; ++j) {
+      const int16x8_t v_src = vld1q_s16(&src[i * step + j * 8]);
+      const int32x4_t v_src_mult_lo =
+          vmlal_s16(v_dual_round, vget_low_s16(v_src), v_multiplier);
+      const int32x4_t v_src_mult_hi =
+          vmlal_s16(v_dual_round, vget_high_s16(v_src), v_multiplier);
+      const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift);
+      const int32x4_t shift_hi = vqshlq_s32(v_src_mult_hi, v_shift);
+      vst1q_s16(&dst[i * step + j * 8],
+                vcombine_s16(vqmovn_s32(shift_lo), vqmovn_s32(shift_hi)));
+    }
+  }
+}
+
+LIBGAV1_ALWAYS_INLINE void Identity16ColumnStoreToFrame_NEON(
+    Array2DView<uint8_t> frame, const int start_x, const int start_y,
+    const int tx_width, const int tx_height, const int16_t* source) {
+  const int stride = frame.columns();
+  uint8_t* dst = frame[start_y] + start_x;
+
+  if (tx_width == 4) {
+    const uint8x8_t zero = vdup_n_u8(0);
+    for (int i = 0; i < tx_height; ++i) {
+      const int16x4_t v_src = vld1_s16(&source[i * tx_width]);
+      const int16x4_t v_src_mult =
+          vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 4);
+      const int16x8_t frame_data =
+          vreinterpretq_s16_u16(vmovl_u8(LoadLo4(dst, zero)));
+
+      const int16x4_t v_srcx2 = vqadd_s16(v_src, v_src);
+      const int16x4_t v_dst_i = vqadd_s16(v_srcx2, v_src_mult);
+
+      const int16x4_t a = vrshr_n_s16(v_dst_i, 4);
+      const int16x4_t b = vqadd_s16(a, vget_low_s16(frame_data));
+      const uint8x8_t d = vqmovun_s16(vcombine_s16(b, b));
+      StoreLo4(dst, d);
+      dst += stride;
+    }
+  } else {
+    for (int i = 0; i < tx_height; ++i) {
+      const int row = i * tx_width;
+      for (int j = 0; j < tx_width; j += 8) {
+        const int16x8_t v_src = vld1q_s16(&source[row + j]);
+        const int16x8_t v_src_mult =
+            vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 4);
+        const int16x8_t frame_data =
+            vreinterpretq_s16_u16(vmovl_u8(vld1_u8(dst + j)));
+        const int16x8_t v_srcx2 = vqaddq_s16(v_src, v_src);
+        const int16x8_t v_dst_i = vqaddq_s16(v_src_mult, v_srcx2);
+        const int16x8_t a = vrshrq_n_s16(v_dst_i, 4);
+        const int16x8_t b = vqaddq_s16(a, frame_data);
+        const uint8x8_t d = vqmovun_s16(b);
+        vst1_u8(dst + j, d);
+      }
+      dst += stride;
+    }
+  }
+}
+
+LIBGAV1_ALWAYS_INLINE void Identity32Row16_NEON(void* dest, const void* source,
+                                                const int32_t step) {
+  auto* const dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  // When combining the identity32 multiplier with the row shift, the
+  // calculation for tx_height equal to 16 can be simplified from
+  // ((A * 4) + 1) >> 1) to (A * 2).
+  for (int i = 0; i < 4; ++i) {
+    for (int j = 0; j < 32; j += 8) {
+      const int16x8_t v_src = vld1q_s16(&src[i * step + j]);
+      // For bitdepth == 8, the identity row clamps to a signed 16bit value, so
+      // saturating add here is ok.
+      const int16x8_t v_dst_i = vqaddq_s16(v_src, v_src);
+      vst1q_s16(&dst[i * step + j], v_dst_i);
+    }
+  }
+}
+
+LIBGAV1_ALWAYS_INLINE void Identity32ColumnStoreToFrame(
+    Array2DView<uint8_t> frame, const int start_x, const int start_y,
+    const int tx_width, const int tx_height, const int16_t* source) {
+  const int stride = frame.columns();
+  uint8_t* dst = frame[start_y] + start_x;
+
+  for (int i = 0; i < tx_height; ++i) {
+    const int row = i * tx_width;
+    int j = 0;
+    do {
+      const int16x8_t v_dst_i = vld1q_s16(&source[row + j]);
+      const int16x8_t frame_data =
+          vreinterpretq_s16_u16(vmovl_u8(vld1_u8(dst + j)));
+      const int16x8_t a = vrshrq_n_s16(v_dst_i, 2);
+      const int16x8_t b = vqaddq_s16(a, frame_data);
+      const uint8x8_t d = vqmovun_s16(b);
+      vst1_u8(dst + j, d);
+      j += 8;
+    } while (j < tx_width);
+    dst += stride;
+  }
+}
+
+//------------------------------------------------------------------------------
+// Walsh Hadamard Transform.
+
+// Transposes a 4x4 matrix and then permutes the rows of the transposed matrix
+// for the WHT. The input matrix is in two "wide" int16x8_t variables. The
+// output matrix is in four int16x4_t variables.
+//
+// Input:
+// in[0]: 00 01 02 03  10 11 12 13
+// in[1]: 20 21 22 23  30 31 32 33
+// Output:
+// out[0]: 00 10 20 30
+// out[1]: 03 13 23 33
+// out[2]: 01 11 21 31
+// out[3]: 02 12 22 32
+LIBGAV1_ALWAYS_INLINE void TransposeAndPermute4x4WideInput(
+    const int16x8_t in[2], int16x4_t out[4]) {
+  // Swap 32 bit elements. Goes from:
+  // in[0]: 00 01 02 03  10 11 12 13
+  // in[1]: 20 21 22 23  30 31 32 33
+  // to:
+  // b0.val[0]: 00 01 20 21  10 11 30 31
+  // b0.val[1]: 02 03 22 23  12 13 32 33
+
+  const int32x4x2_t b0 =
+      vtrnq_s32(vreinterpretq_s32_s16(in[0]), vreinterpretq_s32_s16(in[1]));
+
+  // Swap 16 bit elements. Goes from:
+  // vget_low_s32(b0.val[0]):  00 01 20 21
+  // vget_high_s32(b0.val[0]): 10 11 30 31
+  // vget_low_s32(b0.val[1]):  02 03 22 23
+  // vget_high_s32(b0.val[1]): 12 13 32 33
+  // to:
+  // c0.val[0]: 00 10 20 30
+  // c0.val[1]: 01 11 21 32
+  // c1.val[0]: 02 12 22 32
+  // c1.val[1]: 03 13 23 33
+
+  const int16x4x2_t c0 =
+      vtrn_s16(vreinterpret_s16_s32(vget_low_s32(b0.val[0])),
+               vreinterpret_s16_s32(vget_high_s32(b0.val[0])));
+  const int16x4x2_t c1 =
+      vtrn_s16(vreinterpret_s16_s32(vget_low_s32(b0.val[1])),
+               vreinterpret_s16_s32(vget_high_s32(b0.val[1])));
+
+  out[0] = c0.val[0];
+  out[1] = c1.val[1];
+  out[2] = c0.val[1];
+  out[3] = c1.val[0];
+}
+
+// Process 4 wht4 rows and columns.
+LIBGAV1_ALWAYS_INLINE void Wht4_NEON(uint8_t* dst, const int dst_stride,
+                                     const void* source,
+                                     const int non_zero_coeff_count) {
+  const auto* const src = static_cast<const int16_t*>(source);
+  int16x4_t s[4];
+
+  if (non_zero_coeff_count == 1) {
+    // Special case: only src[0] is nonzero.
+    //   src[0]  0   0   0
+    //       0   0   0   0
+    //       0   0   0   0
+    //       0   0   0   0
+    //
+    // After the row and column transforms are applied, we have:
+    //       f   h   h   h
+    //       g   i   i   i
+    //       g   i   i   i
+    //       g   i   i   i
+    // where f, g, h, i are computed as follows.
+    int16_t f = (src[0] >> 2) - (src[0] >> 3);
+    const int16_t g = f >> 1;
+    f = f - (f >> 1);
+    const int16_t h = (src[0] >> 3) - (src[0] >> 4);
+    const int16_t i = (src[0] >> 4);
+    s[0] = vdup_n_s16(h);
+    s[0] = vset_lane_s16(f, s[0], 0);
+    s[1] = vdup_n_s16(i);
+    s[1] = vset_lane_s16(g, s[1], 0);
+    s[2] = s[3] = s[1];
+  } else {
+    // Load the 4x4 source in transposed form.
+    int16x4x4_t columns = vld4_s16(src);
+    // Shift right and permute the columns for the WHT.
+    s[0] = vshr_n_s16(columns.val[0], 2);
+    s[2] = vshr_n_s16(columns.val[1], 2);
+    s[3] = vshr_n_s16(columns.val[2], 2);
+    s[1] = vshr_n_s16(columns.val[3], 2);
+
+    // Row transforms.
+    s[0] = vadd_s16(s[0], s[2]);
+    s[3] = vsub_s16(s[3], s[1]);
+    int16x4_t e = vhsub_s16(s[0], s[3]);  // e = (s[0] - s[3]) >> 1
+    s[1] = vsub_s16(e, s[1]);
+    s[2] = vsub_s16(e, s[2]);
+    s[0] = vsub_s16(s[0], s[1]);
+    s[3] = vadd_s16(s[3], s[2]);
+
+    int16x8_t x[2];
+    x[0] = vcombine_s16(s[0], s[1]);
+    x[1] = vcombine_s16(s[2], s[3]);
+    TransposeAndPermute4x4WideInput(x, s);
+
+    // Column transforms.
+    s[0] = vadd_s16(s[0], s[2]);
+    s[3] = vsub_s16(s[3], s[1]);
+    e = vhsub_s16(s[0], s[3]);  // e = (s[0] - s[3]) >> 1
+    s[1] = vsub_s16(e, s[1]);
+    s[2] = vsub_s16(e, s[2]);
+    s[0] = vsub_s16(s[0], s[1]);
+    s[3] = vadd_s16(s[3], s[2]);
+  }
+
+  // Store to frame.
+  uint8x8_t frame_data = vdup_n_u8(0);
+  for (int row = 0; row < 4; row += 2) {
+    frame_data = LoadLo4(dst, frame_data);
+    frame_data = LoadHi4(dst + dst_stride, frame_data);
+    const int16x8_t a = vreinterpretq_s16_u16(vmovl_u8(frame_data));
+    const int16x8_t residual = vcombine_s16(s[row], s[row + 1]);
+    // Saturate to prevent overflowing int16_t
+    const int16x8_t b = vqaddq_s16(a, residual);
+    frame_data = vqmovun_s16(b);
+    StoreLo4(dst, frame_data);
+    dst += dst_stride;
+    StoreHi4(dst, frame_data);
+    dst += dst_stride;
+  }
+}
+
+//------------------------------------------------------------------------------
+// row/column transform loops
+
+template <int tx_height>
+LIBGAV1_ALWAYS_INLINE void FlipColumns(int16_t* source, int tx_width) {
+  if (tx_width >= 16) {
+    int i = 0;
+    do {
+      const int16x8_t a = vld1q_s16(&source[i]);
+      const int16x8_t b = vld1q_s16(&source[i + 8]);
+      const int16x8_t c = vrev64q_s16(a);
+      const int16x8_t d = vrev64q_s16(b);
+      vst1q_s16(&source[i], vcombine_s16(vget_high_s16(d), vget_low_s16(d)));
+      vst1q_s16(&source[i + 8],
+                vcombine_s16(vget_high_s16(c), vget_low_s16(c)));
+      i += 16;
+    } while (i < tx_width * tx_height);
+  } else if (tx_width == 8) {
+    for (int i = 0; i < 8 * tx_height; i += 8) {
+      const int16x8_t a = vld1q_s16(&source[i]);
+      const int16x8_t b = vrev64q_s16(a);
+      vst1q_s16(&source[i], vcombine_s16(vget_high_s16(b), vget_low_s16(b)));
+    }
+  } else {
+    // Process two rows per iteration.
+    for (int i = 0; i < 4 * tx_height; i += 8) {
+      const int16x8_t a = vld1q_s16(&source[i]);
+      vst1q_s16(&source[i], vrev64q_s16(a));
+    }
+  }
+}
+
+template <int tx_width>
+LIBGAV1_ALWAYS_INLINE void ApplyRounding(int16_t* source, int num_rows) {
+  if (tx_width == 4) {
+    // Process two rows per iteration.
+    int i = 0;
+    do {
+      const int16x8_t a = vld1q_s16(&source[i]);
+      const int16x8_t b = vqrdmulhq_n_s16(a, kTransformRowMultiplier << 3);
+      vst1q_s16(&source[i], b);
+      i += 8;
+    } while (i < tx_width * num_rows);
+  } else {
+    int i = 0;
+    do {
+      // The last 32 values of every row are always zero if the |tx_width| is
+      // 64.
+      const int non_zero_width = (tx_width < 64) ? tx_width : 32;
+      int j = 0;
+      do {
+        const int16x8_t a = vld1q_s16(&source[i * tx_width + j]);
+        const int16x8_t b = vqrdmulhq_n_s16(a, kTransformRowMultiplier << 3);
+        vst1q_s16(&source[i * tx_width + j], b);
+        j += 8;
+      } while (j < non_zero_width);
+    } while (++i < num_rows);
+  }
+}
+
+template <int tx_width>
+LIBGAV1_ALWAYS_INLINE void RowShift(int16_t* source, int num_rows,
+                                    int row_shift) {
+  // vqrshlq_s16 will shift right if shift value is negative.
+  row_shift = -row_shift;
+
+  if (tx_width == 4) {
+    // Process two rows per iteration.
+    int i = 0;
+    do {
+      const int16x8_t residual = vld1q_s16(&source[i]);
+      vst1q_s16(&source[i], vqrshlq_s16(residual, vdupq_n_s16(row_shift)));
+      i += 8;
+    } while (i < tx_width * num_rows);
+  } else {
+    int i = 0;
+    do {
+      for (int j = 0; j < tx_width; j += 8) {
+        const int16x8_t residual = vld1q_s16(&source[i * tx_width + j]);
+        const int16x8_t residual_shifted =
+            vqrshlq_s16(residual, vdupq_n_s16(row_shift));
+        vst1q_s16(&source[i * tx_width + j], residual_shifted);
+      }
+    } while (++i < num_rows);
+  }
+}
+
+template <bool enable_flip_rows = false>
+LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound(
+    Array2DView<uint8_t> frame, const int start_x, const int start_y,
+    const int tx_width, const int tx_height, const int16_t* source,
+    TransformType tx_type) {
+  const bool flip_rows =
+      enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false;
+  const int stride = frame.columns();
+  uint8_t* dst = frame[start_y] + start_x;
+
+  if (tx_width == 4) {
+    const uint8x8_t zero = vdup_n_u8(0);
+    for (int i = 0; i < tx_height; ++i) {
+      const int row = flip_rows ? (tx_height - i - 1) * 4 : i * 4;
+      const int16x4_t residual = vld1_s16(&source[row]);
+      const int16x8_t frame_data =
+          vreinterpretq_s16_u16(vmovl_u8(LoadLo4(dst, zero)));
+      const int16x4_t a = vrshr_n_s16(residual, 4);
+      const int16x4_t b = vqadd_s16(a, vget_low_s16(frame_data));
+      const uint8x8_t d = vqmovun_s16(vcombine_s16(b, b));
+      StoreLo4(dst, d);
+      dst += stride;
+    }
+  } else if (tx_width == 8) {
+    for (int i = 0; i < tx_height; ++i) {
+      const int row = flip_rows ? (tx_height - i - 1) * 8 : i * 8;
+      const int16x8_t residual = vld1q_s16(&source[row]);
+      const int16x8_t frame_data =
+          vreinterpretq_s16_u16(vmovl_u8(vld1_u8(dst)));
+      const int16x8_t a = vrshrq_n_s16(residual, 4);
+      const int16x8_t b = vqaddq_s16(a, frame_data);
+      const uint8x8_t d = vqmovun_s16(b);
+      vst1_u8(dst, d);
+      dst += stride;
+    }
+  } else {
+    for (int i = 0; i < tx_height; ++i) {
+      const int y = start_y + i;
+      const int row = flip_rows ? (tx_height - i - 1) * tx_width : i * tx_width;
+      int j = 0;
+      do {
+        const int x = start_x + j;
+        const int16x8_t residual = vld1q_s16(&source[row + j]);
+        const int16x8_t residual_hi = vld1q_s16(&source[row + j + 8]);
+        const uint8x16_t frame_data = vld1q_u8(frame[y] + x);
+        const int16x8_t a = vrshrq_n_s16(residual, 4);
+        const int16x8_t a_hi = vrshrq_n_s16(residual_hi, 4);
+        const int16x8_t d =
+            vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(frame_data)));
+        const int16x8_t d_hi =
+            vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(frame_data)));
+        const int16x8_t e = vqaddq_s16(a, d);
+        const int16x8_t e_hi = vqaddq_s16(a_hi, d_hi);
+        vst1q_u8(frame[y] + x, vcombine_u8(vqmovun_s16(e), vqmovun_s16(e_hi)));
+        j += 16;
+      } while (j < tx_width);
+    }
+  }
+}
+
+void Dct4TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
+                            void* src_buffer, int start_x, int start_y,
+                            void* dst_frame, bool is_row,
+                            int non_zero_coeff_count) {
+  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto* src = static_cast<int16_t*>(src_buffer);
+  const int tx_width = kTransformWidth[tx_size];
+  const int tx_height = kTransformHeight[tx_size];
+
+  if (is_row) {
+    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
+    const bool should_round = (tx_height == 8);
+    if (should_round) {
+      ApplyRounding<4>(src, num_rows);
+    }
+
+    if (num_rows <= 4) {
+      // Process 4 1d dct4 rows in parallel.
+      Dct4_NEON<ButterflyRotation_4, false>(&src[0], &src[0], /*step=*/4,
+                                            /*transpose=*/true);
+    } else {
+      // Process 8 1d dct4 rows in parallel per iteration.
+      int i = 0;
+      do {
+        Dct4_NEON<ButterflyRotation_8, true>(&src[i * 4], &src[i * 4],
+                                             /*step=*/4, /*transpose=*/true);
+        i += 8;
+      } while (i < num_rows);
+    }
+    if (tx_height == 16) {
+      RowShift<4>(src, num_rows, 1);
+    }
+    return;
+  }
+
+  assert(!is_row);
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
+    FlipColumns<4>(src, tx_width);
+  }
+
+  if (tx_width == 4) {
+    // Process 4 1d dct4 columns in parallel.
+    Dct4_NEON<ButterflyRotation_4, false>(&src[0], &src[0], tx_width,
+                                          /*transpose=*/false);
+  } else {
+    // Process 8 1d dct4 columns in parallel per iteration.
+    int i = 0;
+    do {
+      Dct4_NEON<ButterflyRotation_8, true>(&src[i], &src[i], tx_width,
+                                           /*transpose=*/false);
+      i += 8;
+    } while (i < tx_width);
+  }
+  StoreToFrameWithRound(frame, start_x, start_y, tx_width, 4, src, tx_type);
+}
+
+void Dct8TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
+                            void* src_buffer, int start_x, int start_y,
+                            void* dst_frame, bool is_row,
+                            int non_zero_coeff_count) {
+  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto* src = static_cast<int16_t*>(src_buffer);
+  const int tx_width = kTransformWidth[tx_size];
+  const int tx_height = kTransformHeight[tx_size];
+
+  if (is_row) {
+    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
+    if (kShouldRound[tx_size]) {
+      ApplyRounding<8>(src, num_rows);
+    }
+
+    if (num_rows <= 4) {
+      // Process 4 1d dct8 rows in parallel.
+      Dct8_NEON<ButterflyRotation_4, true>(&src[0], &src[0], /*step=*/8,
+                                           /*transpose=*/true);
+    } else {
+      // Process 8 1d dct8 rows in parallel per iteration.
+      int i = 0;
+      do {
+        Dct8_NEON<ButterflyRotation_8, false>(&src[i * 8], &src[i * 8],
+                                              /*step=*/8, /*transpose=*/true);
+        i += 8;
+      } while (i < num_rows);
+    }
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+    if (row_shift > 0) {
+      RowShift<8>(src, num_rows, row_shift);
+    }
+    return;
+  }
+
+  assert(!is_row);
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
+    FlipColumns<8>(src, tx_width);
+  }
+
+  if (tx_width == 4) {
+    // Process 4 1d dct8 columns in parallel.
+    Dct8_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 4,
+                                         /*transpose=*/false);
+  } else {
+    // Process 8 1d dct8 columns in parallel per iteration.
+    int i = 0;
+    do {
+      Dct8_NEON<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
+                                            /*transpose=*/false);
+      i += 8;
+    } while (i < tx_width);
+  }
+  StoreToFrameWithRound(frame, start_x, start_y, tx_width, 8, src, tx_type);
+}
+
+void Dct16TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
+                             void* src_buffer, int start_x, int start_y,
+                             void* dst_frame, bool is_row,
+                             int non_zero_coeff_count) {
+  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto* src = static_cast<int16_t*>(src_buffer);
+  const int tx_width = kTransformWidth[tx_size];
+  const int tx_height = kTransformHeight[tx_size];
+
+  if (is_row) {
+    const int num_rows =
+        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
+    if (kShouldRound[tx_size]) {
+      ApplyRounding<16>(src, num_rows);
+    }
+
+    if (num_rows <= 4) {
+      // Process 4 1d dct16 rows in parallel.
+      Dct16_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 16,
+                                            /*transpose=*/true);
+    } else {
+      int i = 0;
+      do {
+        // Process 8 1d dct16 rows in parallel per iteration.
+        Dct16_NEON<ButterflyRotation_8, false>(&src[i * 16], &src[i * 16], 16,
+                                               /*transpose=*/true);
+        i += 8;
+      } while (i < num_rows);
+    }
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+    // row_shift is always non zero here.
+    RowShift<16>(src, num_rows, row_shift);
+
+    return;
+  }
+
+  assert(!is_row);
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
+    FlipColumns<16>(src, tx_width);
+  }
+
+  if (tx_width == 4) {
+    // Process 4 1d dct16 columns in parallel.
+    Dct16_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 4,
+                                          /*transpose=*/false);
+  } else {
+    int i = 0;
+    do {
+      // Process 8 1d dct16 columns in parallel per iteration.
+      Dct16_NEON<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
+                                             /*transpose=*/false);
+      i += 8;
+    } while (i < tx_width);
+  }
+  StoreToFrameWithRound(frame, start_x, start_y, tx_width, 16, src, tx_type);
+}
+
+void Dct32TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
+                             void* src_buffer, int start_x, int start_y,
+                             void* dst_frame, bool is_row,
+                             int non_zero_coeff_count) {
+  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto* src = static_cast<int16_t*>(src_buffer);
+  const int tx_width = kTransformWidth[tx_size];
+  const int tx_height = kTransformHeight[tx_size];
+
+  if (is_row) {
+    const int num_rows =
+        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
+    if (kShouldRound[tx_size]) {
+      ApplyRounding<32>(src, num_rows);
+    }
+    // Process 8 1d dct32 rows in parallel per iteration.
+    int i = 0;
+    do {
+      Dct32_NEON(&src[i * 32], &src[i * 32], 32, /*transpose=*/true);
+      i += 8;
+    } while (i < num_rows);
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+    // row_shift is always non zero here.
+    RowShift<32>(src, num_rows, row_shift);
+
+    return;
+  }
+
+  assert(!is_row);
+  // Process 8 1d dct32 columns in parallel per iteration.
+  int i = 0;
+  do {
+    Dct32_NEON(&src[i], &src[i], tx_width, /*transpose=*/false);
+    i += 8;
+  } while (i < tx_width);
+  StoreToFrameWithRound(frame, start_x, start_y, tx_width, 32, src, tx_type);
+}
+
+void Dct64TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
+                             void* src_buffer, int start_x, int start_y,
+                             void* dst_frame, bool is_row,
+                             int non_zero_coeff_count) {
+  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto* src = static_cast<int16_t*>(src_buffer);
+  const int tx_width = kTransformWidth[tx_size];
+  const int tx_height = kTransformHeight[tx_size];
+
+  if (is_row) {
+    const int num_rows =
+        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
+    if (kShouldRound[tx_size]) {
+      ApplyRounding<64>(src, num_rows);
+    }
+    // Process 8 1d dct64 rows in parallel per iteration.
+    int i = 0;
+    do {
+      Dct64_NEON(&src[i * 64], &src[i * 64], 64, /*transpose=*/true);
+      i += 8;
+    } while (i < num_rows);
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+    // row_shift is always non zero here.
+    RowShift<64>(src, num_rows, row_shift);
+
+    return;
+  }
+
+  assert(!is_row);
+  // Process 8 1d dct64 columns in parallel per iteration.
+  int i = 0;
+  do {
+    Dct64_NEON(&src[i], &src[i], tx_width, /*transpose=*/false);
+    i += 8;
+  } while (i < tx_width);
+  StoreToFrameWithRound(frame, start_x, start_y, tx_width, 64, src, tx_type);
+}
+
+void Adst4TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
+                             void* src_buffer, int start_x, int start_y,
+                             void* dst_frame, bool is_row,
+                             int non_zero_coeff_count) {
+  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto* src = static_cast<int16_t*>(src_buffer);
+  const int tx_width = kTransformWidth[tx_size];
+  const int tx_height = kTransformHeight[tx_size];
+
+  if (is_row) {
+    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
+    const bool should_round = (tx_height == 8);
+    if (should_round) {
+      ApplyRounding<4>(src, num_rows);
+    }
+
+    // Process 4 1d adst4 rows in parallel per iteration.
+    int i = 0;
+    do {
+      Adst4_NEON<false>(&src[i * 4], &src[i * 4], /*step=*/4,
+                        /*transpose=*/true);
+      i += 4;
+    } while (i < num_rows);
+
+    if (tx_height == 16) {
+      RowShift<4>(src, num_rows, 1);
+    }
+    return;
+  }
+
+  assert(!is_row);
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
+    FlipColumns<4>(src, tx_width);
+  }
+
+  // Process 4 1d adst4 columns in parallel per iteration.
+  int i = 0;
+  do {
+    Adst4_NEON<false>(&src[i], &src[i], tx_width, /*transpose=*/false);
+    i += 4;
+  } while (i < tx_width);
+
+  StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
+                                                   tx_width, 4, src, tx_type);
+}
+
+void Adst8TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
+                             void* src_buffer, int start_x, int start_y,
+                             void* dst_frame, bool is_row,
+                             int non_zero_coeff_count) {
+  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto* src = static_cast<int16_t*>(src_buffer);
+  const int tx_width = kTransformWidth[tx_size];
+  const int tx_height = kTransformHeight[tx_size];
+
+  if (is_row) {
+    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
+    if (kShouldRound[tx_size]) {
+      ApplyRounding<8>(src, num_rows);
+    }
+
+    if (num_rows <= 4) {
+      // Process 4 1d adst8 rows in parallel.
+      Adst8_NEON<ButterflyRotation_4, true>(&src[0], &src[0], /*step=*/8,
+                                            /*transpose=*/true);
+    } else {
+      // Process 8 1d adst8 rows in parallel per iteration.
+      int i = 0;
+      do {
+        Adst8_NEON<ButterflyRotation_8, false>(&src[i * 8], &src[i * 8],
+                                               /*step=*/8,
+                                               /*transpose=*/true);
+        i += 8;
+      } while (i < num_rows);
+    }
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+    if (row_shift > 0) {
+      RowShift<8>(src, num_rows, row_shift);
+    }
+    return;
+  }
+
+  assert(!is_row);
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
+    FlipColumns<8>(src, tx_width);
+  }
+
+  if (tx_width == 4) {
+    // Process 4 1d adst8 columns in parallel.
+    Adst8_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 4,
+                                          /*transpose=*/false);
+  } else {
+    // Process 8 1d adst8 columns in parallel per iteration.
+    int i = 0;
+    do {
+      Adst8_NEON<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
+                                             /*transpose=*/false);
+      i += 8;
+    } while (i < tx_width);
+  }
+  StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
+                                                   tx_width, 8, src, tx_type);
+}
+
+void Adst16TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
+                              void* src_buffer, int start_x, int start_y,
+                              void* dst_frame, bool is_row,
+                              int non_zero_coeff_count) {
+  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto* src = static_cast<int16_t*>(src_buffer);
+  const int tx_width = kTransformWidth[tx_size];
+  const int tx_height = kTransformHeight[tx_size];
+
+  if (is_row) {
+    const int num_rows =
+        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
+    if (kShouldRound[tx_size]) {
+      ApplyRounding<16>(src, num_rows);
+    }
+
+    if (num_rows <= 4) {
+      // Process 4 1d adst16 rows in parallel.
+      Adst16_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 16,
+                                             /*transpose=*/true);
+    } else {
+      int i = 0;
+      do {
+        // Process 8 1d adst16 rows in parallel per iteration.
+        Adst16_NEON<ButterflyRotation_8, false>(&src[i * 16], &src[i * 16], 16,
+                                                /*transpose=*/true);
+        i += 8;
+      } while (i < num_rows);
+    }
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+    // row_shift is always non zero here.
+    RowShift<16>(src, num_rows, row_shift);
+
+    return;
+  }
+
+  assert(!is_row);
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
+    FlipColumns<16>(src, tx_width);
+  }
+
+  if (tx_width == 4) {
+    // Process 4 1d adst16 columns in parallel.
+    Adst16_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 4,
+                                           /*transpose=*/false);
+  } else {
+    int i = 0;
+    do {
+      // Process 8 1d adst16 columns in parallel per iteration.
+      Adst16_NEON<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
+                                              /*transpose=*/false);
+      i += 8;
+    } while (i < tx_width);
+  }
+  StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
+                                                   tx_width, 16, src, tx_type);
+}
+
+void Identity4TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
+                                 void* src_buffer, int start_x, int start_y,
+                                 void* dst_frame, bool is_row,
+                                 int non_zero_coeff_count) {
+  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto* src = static_cast<int16_t*>(src_buffer);
+  const int tx_width = kTransformWidth[tx_size];
+  const int tx_height = kTransformHeight[tx_size];
+
+  if (is_row) {
+    // Special case: Process row calculations during column transform call.
+    // Improves performance.
+    if (tx_type == kTransformTypeIdentityIdentity &&
+        tx_size == kTransformSize4x4) {
+      return;
+    }
+
+    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
+    const bool should_round = (tx_height == 8);
+    if (should_round) {
+      ApplyRounding<4>(src, num_rows);
+    }
+    if (tx_height < 16) {
+      int i = 0;
+      do {
+        Identity4_NEON<false>(&src[i * 4], &src[i * 4], /*step=*/4);
+        i += 4;
+      } while (i < num_rows);
+    } else {
+      int i = 0;
+      do {
+        Identity4_NEON<true>(&src[i * 4], &src[i * 4], /*step=*/4);
+        i += 4;
+      } while (i < num_rows);
+    }
+    return;
+  }
+  assert(!is_row);
+  // Special case: Process row calculations during column transform call.
+  if (tx_type == kTransformTypeIdentityIdentity &&
+      (tx_size == kTransformSize4x4 || tx_size == kTransformSize8x4)) {
+    Identity4RowColumnStoreToFrame(frame, start_x, start_y, tx_width,
+                                   /*tx_height=*/4, src);
+    return;
+  }
+
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
+    FlipColumns<4>(src, tx_width);
+  }
+
+  Identity4ColumnStoreToFrame(frame, start_x, start_y, tx_width,
+                              /*tx_height=*/4, src);
+}
+
+void Identity8TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
+                                 void* src_buffer, int start_x, int start_y,
+                                 void* dst_frame, bool is_row,
+                                 int non_zero_coeff_count) {
+  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto* src = static_cast<int16_t*>(src_buffer);
+  const int tx_width = kTransformWidth[tx_size];
+  const int tx_height = kTransformHeight[tx_size];
+
+  if (is_row) {
+    // Special case: Process row calculations during column transform call.
+    // Improves performance.
+    if (tx_type == kTransformTypeIdentityIdentity &&
+        tx_size == kTransformSize8x4) {
+      return;
+    }
+    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
+    if (kShouldRound[tx_size]) {
+      ApplyRounding<8>(src, num_rows);
+    }
+
+    // When combining the identity8 multiplier with the row shift, the
+    // calculations for tx_height == 8 and tx_height == 16 can be simplified
+    // from ((A * 2) + 1) >> 1) to A.
+    if ((tx_height & 0x18) != 0) {
+      return;
+    }
+    if (tx_height == 32) {
+      for (int i = 0; i < num_rows; i += 4) {
+        Identity8Row32_NEON(&src[i * 8], &src[i * 8], /*step=*/8);
+      }
+      return;
+    }
+
+    // Process kTransformSize8x4
+    assert(tx_size == kTransformSize8x4);
+    for (int i = 0; i < num_rows; i += 4) {
+      Identity8Row4_NEON(&src[i * 8], &src[i * 8], /*step=*/8);
+    }
+    return;
+  }
+
+  assert(!is_row);
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
+    FlipColumns<8>(src, tx_width);
+  }
+
+  Identity8ColumnStoreToFrame_NEON(frame, start_x, start_y, tx_width,
+                                   /*tx_height=*/8, src);
+}
+
+void Identity16TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
+                                  void* src_buffer, int start_x, int start_y,
+                                  void* dst_frame, bool is_row,
+                                  int non_zero_coeff_count) {
+  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto* src = static_cast<int16_t*>(src_buffer);
+  const int tx_width = kTransformWidth[tx_size];
+  const int tx_height = kTransformHeight[tx_size];
+
+  if (is_row) {
+    const int num_rows =
+        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
+    if (kShouldRound[tx_size]) {
+      ApplyRounding<16>(src, num_rows);
+    }
+    for (int i = 0; i < num_rows; i += 4) {
+      Identity16Row_NEON(&src[i * 16], &src[i * 16], /*step=*/16,
+                         kTransformRowShift[tx_size]);
+    }
+    return;
+  }
+
+  assert(!is_row);
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
+    FlipColumns<16>(src, tx_width);
+  }
+  Identity16ColumnStoreToFrame_NEON(frame, start_x, start_y, tx_width,
+                                    /*tx_height=*/16, src);
+}
+
+void Identity32TransformLoop_NEON(TransformType /*tx_type*/,
+                                  TransformSize tx_size, void* src_buffer,
+                                  int start_x, int start_y, void* dst_frame,
+                                  bool is_row, int non_zero_coeff_count) {
+  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto* src = static_cast<int16_t*>(src_buffer);
+  const int tx_width = kTransformWidth[tx_size];
+  const int tx_height = kTransformHeight[tx_size];
+
+  if (is_row) {
+    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
+
+    // When combining the identity32 multiplier with the row shift, the
+    // calculations for tx_height == 8 and tx_height == 32 can be simplified
+    // from ((A * 4) + 2) >> 2) to A.
+    if ((tx_height & 0x28) != 0) {
+      return;
+    }
+
+    // Process kTransformSize32x16
+    assert(tx_size == kTransformSize32x16);
+    ApplyRounding<32>(src, num_rows);
+    for (int i = 0; i < num_rows; i += 4) {
+      Identity32Row16_NEON(&src[i * 32], &src[i * 32], /*step=*/32);
+    }
+    return;
+  }
+
+  assert(!is_row);
+  Identity32ColumnStoreToFrame(frame, start_x, start_y, tx_width,
+                               /*tx_height=*/32, src);
+}
+
+void Wht4TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
+                            void* src_buffer, int start_x, int start_y,
+                            void* dst_frame, bool is_row,
+                            int non_zero_coeff_count) {
+  assert(tx_type == kTransformTypeDctDct);
+  assert(tx_size == kTransformSize4x4);
+  static_cast<void>(tx_type);
+  static_cast<void>(tx_size);
+  if (is_row) {
+    // Do both row and column transforms in the column-transform pass.
+    return;
+  }
+
+  assert(!is_row);
+  // Process 4 1d wht4 rows and columns in parallel.
+  auto* src = static_cast<int16_t*>(src_buffer);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+  uint8_t* dst = frame[start_y] + start_x;
+  const int dst_stride = frame.columns();
+  Wht4_NEON(dst, dst_stride, src, non_zero_coeff_count);
+}
+
+//------------------------------------------------------------------------------
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+  // Maximum transform size for Dct is 64.
+  dsp->inverse_transforms[k1DTransformSize4][k1DTransformDct] =
+      Dct4TransformLoop_NEON;
+  dsp->inverse_transforms[k1DTransformSize8][k1DTransformDct] =
+      Dct8TransformLoop_NEON;
+  dsp->inverse_transforms[k1DTransformSize16][k1DTransformDct] =
+      Dct16TransformLoop_NEON;
+  dsp->inverse_transforms[k1DTransformSize32][k1DTransformDct] =
+      Dct32TransformLoop_NEON;
+  dsp->inverse_transforms[k1DTransformSize64][k1DTransformDct] =
+      Dct64TransformLoop_NEON;
+
+  // Maximum transform size for Adst is 16.
+  dsp->inverse_transforms[k1DTransformSize4][k1DTransformAdst] =
+      Adst4TransformLoop_NEON;
+  dsp->inverse_transforms[k1DTransformSize8][k1DTransformAdst] =
+      Adst8TransformLoop_NEON;
+  dsp->inverse_transforms[k1DTransformSize16][k1DTransformAdst] =
+      Adst16TransformLoop_NEON;
+
+  // Maximum transform size for Identity transform is 32.
+  dsp->inverse_transforms[k1DTransformSize4][k1DTransformIdentity] =
+      Identity4TransformLoop_NEON;
+  dsp->inverse_transforms[k1DTransformSize8][k1DTransformIdentity] =
+      Identity8TransformLoop_NEON;
+  dsp->inverse_transforms[k1DTransformSize16][k1DTransformIdentity] =
+      Identity16TransformLoop_NEON;
+  dsp->inverse_transforms[k1DTransformSize32][k1DTransformIdentity] =
+      Identity32TransformLoop_NEON;
+
+  // Maximum transform size for Wht is 4.
+  dsp->inverse_transforms[k1DTransformSize4][k1DTransformWht] =
+      Wht4TransformLoop_NEON;
+}
+
+}  // namespace
+}  // namespace low_bitdepth
+
+void InverseTransformInit_NEON() { low_bitdepth::Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+#else   // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void InverseTransformInit_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/inverse_transform_neon.h b/libgav1/src/dsp/arm/inverse_transform_neon.h
new file mode 100644
index 0000000..eefaada
--- /dev/null
+++ b/libgav1/src/dsp/arm/inverse_transform_neon.h
@@ -0,0 +1,36 @@
+#ifndef LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_
+
+#include "src/dsp/cpu.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::inverse_transforms, see the defines below for specifics.
+// This function is not thread-safe.
+void InverseTransformInit_NEON();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct LIBGAV1_DSP_NEON
+
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst LIBGAV1_DSP_NEON
+
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity LIBGAV1_DSP_NEON
+
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht LIBGAV1_DSP_NEON
+#endif  // LIBGAV1_ENABLE_NEON
+
+#endif  // LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_
diff --git a/libgav1/src/dsp/arm/loop_filter_neon.cc b/libgav1/src/dsp/arm/loop_filter_neon.cc
index 8994aaf..4a1b6a7 100644
--- a/libgav1/src/dsp/arm/loop_filter_neon.cc
+++ b/libgav1/src/dsp/arm/loop_filter_neon.cc
@@ -1,4 +1,5 @@
-#include "src/dsp/arm/loop_filter_neon.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/loop_filter.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -15,70 +16,6 @@
 namespace low_bitdepth {
 namespace {
 
-// vzipN is exclusive to A64.
-inline uint8x8_t InterleaveLow32(const uint8x8_t a, const uint8x8_t b) {
-#if defined(__aarch64__)
-  return vreinterpret_u8_u32(
-      vzip1_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)));
-#else
-  // Discard |.val[1]|
-  return vreinterpret_u8_u32(
-      vzip_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)).val[0]);
-#endif
-}
-
-inline int8x8_t InterleaveLow32(const int8x8_t a, const int8x8_t b) {
-#if defined(__aarch64__)
-  return vreinterpret_s8_u32(
-      vzip1_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)));
-#else
-  // Discard |.val[1]|
-  return vreinterpret_s8_u32(
-      vzip_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)).val[0]);
-#endif
-}
-
-inline uint8x8_t InterleaveHigh32(const uint8x8_t a, const uint8x8_t b) {
-#if defined(__aarch64__)
-  return vreinterpret_u8_u32(
-      vzip2_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)));
-#else
-  return vreinterpret_u8_u32(
-      vzip_u32(vreinterpret_u32_u8(a), vreinterpret_u32_u8(b)).val[1]);
-#endif
-}
-
-inline int8x8_t InterleaveHigh32(const int8x8_t a, const int8x8_t b) {
-#if defined(__aarch64__)
-  return vreinterpret_s8_u32(
-      vzip2_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)));
-#else
-  return vreinterpret_s8_u32(
-      vzip_u32(vreinterpret_u32_s8(a), vreinterpret_u32_s8(b)).val[1]);
-#endif
-}
-
-// Transpose 32 bit elements such that:
-// a: 00 01
-// b: 02 03
-// returns
-// val[0]: 00 02
-// val[1]: 01 03
-inline uint8x8x2_t Interleave32(const uint8x8_t a, const uint8x8_t b) {
-  const uint32x2_t a_32 = vreinterpret_u32_u8(a);
-  const uint32x2_t b_32 = vreinterpret_u32_u8(b);
-  const uint32x2x2_t c = vtrn_u32(a_32, b_32);
-  const uint8x8x2_t d = {vreinterpret_u8_u32(c.val[0]),
-                         vreinterpret_u8_u32(c.val[1])};
-  return d;
-}
-
-// Swap high and low 32 bit elements.
-inline uint8x8_t Transpose32(const uint8x8_t a) {
-  const uint32x2_t b = vrev64_u32(vreinterpret_u32_u8(a));
-  return vreinterpret_u8_u32(b);
-}
-
 // (abs(p1 - p0) > thresh) || (abs(q1 - q0) > thresh)
 inline uint8x8_t Hev(const uint8x8_t abd_p0p1_q0q1, const uint8_t thresh) {
   const uint8x8_t a = vcgt_u8(abd_p0p1_q0q1, vdup_n_u8(thresh));
diff --git a/libgav1/src/dsp/arm/loop_filter_neon.h b/libgav1/src/dsp/arm/loop_filter_neon.h
index 50ef294..ca80746 100644
--- a/libgav1/src/dsp/arm/loop_filter_neon.h
+++ b/libgav1/src/dsp/arm/loop_filter_neon.h
@@ -3,13 +3,12 @@
 
 #include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
-#include "src/dsp/loop_filter.h"
 
 namespace libgav1 {
 namespace dsp {
 
-// Initializes Dsp::loop_filters with neon implementations. This function
-// is not thread-safe.
+// Initializes Dsp::loop_filters, see the defines below for specifics. This
+// function is not thread-safe.
 void LoopFilterInit_NEON();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/arm/loop_restoration_neon.cc b/libgav1/src/dsp/arm/loop_restoration_neon.cc
index 367826b..fd0eda5 100644
--- a/libgav1/src/dsp/arm/loop_restoration_neon.cc
+++ b/libgav1/src/dsp/arm/loop_restoration_neon.cc
@@ -1,4 +1,5 @@
-#include "src/dsp/arm/loop_restoration_neon.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/loop_restoration.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -46,14 +47,13 @@
   sum = vaddq_s16(
       sum, vmulq_n_s16(vreinterpretq_s16_u16(vmovl_u8(a[6])), filter[6]));
 
-  // |inter_round_bits[0]| == 3 for 8 bit inputs.
-  sum = vrshrq_n_s16(sum, 3);
+  sum = vrshrq_n_s16(sum, kInterRoundBitsHorizontal);
 
   // Delaying |horizontal_rounding| until after dowshifting allows the sum to
   // stay in 16 bits.
   // |horizontal_rounding| = 1 << (bitdepth + kWienerFilterBits - 1)
   //                         1 << (       8 +                 7 - 1)
-  // Plus |inter_round_bits[0]| and it works out to 1 << 11.
+  // Plus |kInterRoundBitsHorizontal| and it works out to 1 << 11.
   sum = vaddq_s16(sum, vdupq_n_s16(1 << 11));
 
   // Just like |horizontal_rounding|, adding |filter[3]| at this point allows
@@ -64,13 +64,124 @@
   sum = vaddq_s16(sum, vreinterpretq_s16_u16(vshll_n_u8(a[3], 4)));
 
   // Saturate to
-  // [0, (1 << (bitdepth + 1 + kWienerFilterBits - inter_round_bits[0])) - 1)]
+  // [0, (1 << (bitdepth + 1 + kWienerFilterBits - kInterRoundBitsHorizontal)) -
+  //  1)]
   //     (1 << (       8 + 1 +                 7 -                   3)) - 1)
   sum = vminq_s16(sum, vdupq_n_s16((1 << 13) - 1));
   sum = vmaxq_s16(sum, vdupq_n_s16(0));
   return sum;
 }
 
+template <int min_width>
+inline void VerticalSum(const int16_t* src_base, const int src_stride,
+                        uint8_t* dst_base, const int dst_stride,
+                        const int16x4_t filter[7], const int width,
+                        const int height) {
+  static_assert(min_width == 4 || min_width == 8, "");
+  // -(1 << (bitdepth + kInterRoundBitsVertical - 1))
+  // -(1 << (       8 +                      11 - 1))
+  constexpr int vertical_rounding = -(1 << 18);
+  if (min_width == 8) {
+    int x = 0;
+    do {
+      const int16_t* src = src_base + x;
+      uint8_t* dst = dst_base + x;
+      int16x8_t a[7];
+      a[0] = vld1q_s16(src);
+      src += src_stride;
+      a[1] = vld1q_s16(src);
+      src += src_stride;
+      a[2] = vld1q_s16(src);
+      src += src_stride;
+      a[3] = vld1q_s16(src);
+      src += src_stride;
+      a[4] = vld1q_s16(src);
+      src += src_stride;
+      a[5] = vld1q_s16(src);
+      src += src_stride;
+
+      int y = 0;
+      do {
+        a[6] = vld1q_s16(src);
+        src += src_stride;
+
+        int32x4_t sum_lo = vdupq_n_s32(vertical_rounding);
+        sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[0]), filter[0]);
+        sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[1]), filter[1]);
+        sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[2]), filter[2]);
+        sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[3]), filter[3]);
+        sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[4]), filter[4]);
+        sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[5]), filter[5]);
+        sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[6]), filter[6]);
+        uint16x4_t sum_lo_16 = vqrshrun_n_s32(sum_lo, 11);
+
+        int32x4_t sum_hi = vdupq_n_s32(vertical_rounding);
+        sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[0]), filter[0]);
+        sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[1]), filter[1]);
+        sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[2]), filter[2]);
+        sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[3]), filter[3]);
+        sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[4]), filter[4]);
+        sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[5]), filter[5]);
+        sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[6]), filter[6]);
+        uint16x4_t sum_hi_16 = vqrshrun_n_s32(sum_hi, 11);
+
+        vst1_u8(dst, vqmovn_u16(vcombine_u16(sum_lo_16, sum_hi_16)));
+        dst += dst_stride;
+
+        a[0] = a[1];
+        a[1] = a[2];
+        a[2] = a[3];
+        a[3] = a[4];
+        a[4] = a[5];
+        a[5] = a[6];
+      } while (++y < height);
+      x += 8;
+    } while (x < width);
+  } else if (min_width == 4) {
+    const int16_t* src = src_base;
+    uint8_t* dst = dst_base;
+    int16x4_t a[7];
+    a[0] = vld1_s16(src);
+    src += src_stride;
+    a[1] = vld1_s16(src);
+    src += src_stride;
+    a[2] = vld1_s16(src);
+    src += src_stride;
+    a[3] = vld1_s16(src);
+    src += src_stride;
+    a[4] = vld1_s16(src);
+    src += src_stride;
+    a[5] = vld1_s16(src);
+    src += src_stride;
+
+    int y = 0;
+    do {
+      a[6] = vld1_s16(src);
+      src += src_stride;
+
+      int32x4_t sum = vdupq_n_s32(vertical_rounding);
+      sum = vmlal_s16(sum, a[0], filter[0]);
+      sum = vmlal_s16(sum, a[1], filter[1]);
+      sum = vmlal_s16(sum, a[2], filter[2]);
+      sum = vmlal_s16(sum, a[3], filter[3]);
+      sum = vmlal_s16(sum, a[4], filter[4]);
+      sum = vmlal_s16(sum, a[5], filter[5]);
+      sum = vmlal_s16(sum, a[6], filter[6]);
+      uint16x4_t sum_16 = vqrshrun_n_s32(sum, 11);
+
+      StoreLo4(dst, vqmovn_u16(vcombine_u16(sum_16, sum_16)));
+      dst += dst_stride;
+
+      a[0] = a[1];
+      a[1] = a[2];
+      a[2] = a[3];
+      a[3] = a[4];
+      a[4] = a[5];
+      a[5] = a[6];
+    } while (++y < height);
+  }
+}
+
 void WienerFilter_NEON(const void* const source, void* const dest,
                        const RestorationUnitInfo& restoration_info,
                        const ptrdiff_t source_stride,
@@ -91,9 +202,14 @@
   // left value.
   const int center_tap = 3;
   src -= center_tap * source_stride + center_tap;
-  // This writes out 2 more rows than we need.
-  for (int y = 0; y < height + kSubPixelTaps - 2; y += 8) {
-    for (int x = 0; x < width; x += 8) {
+  // This writes out 2 more rows than we need. It always writes out at least
+  // width 8 for the intermediate buffer.
+  // TODO(johannkoenig): Investigate a 4x specific first pass. May be possible
+  // to do it in half the passes.
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       const uint8_t* src_v = src + x;
       const uint8x16_t a0 = vld1q_u8(src_v);
       src_v += source_stride;
@@ -116,6 +232,7 @@
       // This could load and transpose one 8x8 block to prime the loop, then
       // load and transpose a second block in the loop. The second block could
       // be passed to subsequent iterations.
+      // TODO(johannkoenig): convert these to arrays.
       Transpose16x8(a0, a1, a2, a3, a4, a5, a6, a7, b, b + 1, b + 2, b + 3,
                     b + 4, b + 5, b + 6, b + 7, b + 8, b + 9, b + 10, b + 11,
                     b + 12, b + 13, b + 14, b + 15);
@@ -129,6 +246,7 @@
       int16x8_t sum_6 = HorizontalSum(b + 6, filter);
       int16x8_t sum_7 = HorizontalSum(b + 7, filter);
 
+      // TODO(johannkoenig): convert this to an array.
       Transpose8x8(&sum_0, &sum_1, &sum_2, &sum_3, &sum_4, &sum_5, &sum_6,
                    &sum_7);
 
@@ -148,10 +266,12 @@
       vst1q_s16(wiener_buffer_v, sum_6);
       wiener_buffer_v += buffer_stride;
       vst1q_s16(wiener_buffer_v, sum_7);
-    }
+      x += 8;
+    } while (x < width);
     src += 8 * source_stride;
     wiener_buffer += 8 * buffer_stride;
-  }
+    y += 8;
+  } while (y < height + kSubPixelTaps - 2);
 
   // Vertical filtering.
   wiener_buffer = reinterpret_cast<int16_t*>(buffer->wiener_buffer);
@@ -162,61 +282,13 @@
       vdup_n_s16(filter[0]),       vdup_n_s16(filter[1]), vdup_n_s16(filter[2]),
       vdup_n_s16(filter[3] + 128), vdup_n_s16(filter[4]), vdup_n_s16(filter[5]),
       vdup_n_s16(filter[6])};
-  // |inter_round_bits[1]| == 11 for 8 bit inputs.
-  // -(1 << (bitdepth + inter_round_bits[1] - 1))
-  // -(1 << (       8 +                  11 - 1))
-  const int vertical_rounding = -(1 << 18);
-  for (int x = 0; x < width; x += 8) {
-    int16_t* wiener_v = wiener_buffer + x;
-    uint8_t* dst_v = dst + x;
-    int16x8_t a[7];
-    a[0] = vld1q_s16(wiener_v);
-    wiener_v += buffer_stride;
-    a[1] = vld1q_s16(wiener_v);
-    wiener_v += buffer_stride;
-    a[2] = vld1q_s16(wiener_v);
-    wiener_v += buffer_stride;
-    a[3] = vld1q_s16(wiener_v);
-    wiener_v += buffer_stride;
-    a[4] = vld1q_s16(wiener_v);
-    wiener_v += buffer_stride;
-    a[5] = vld1q_s16(wiener_v);
-    wiener_v += buffer_stride;
 
-    for (int y = 0; y < height; ++y) {
-      a[6] = vld1q_s16(wiener_v);
-      wiener_v += buffer_stride;
-
-      int32x4_t sum_lo = vdupq_n_s32(vertical_rounding);
-      sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[0]), filter_v[0]);
-      sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[1]), filter_v[1]);
-      sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[2]), filter_v[2]);
-      sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[3]), filter_v[3]);
-      sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[4]), filter_v[4]);
-      sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[5]), filter_v[5]);
-      sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[6]), filter_v[6]);
-      uint16x4_t sum_lo_16 = vqrshrun_n_s32(sum_lo, 11);
-
-      int32x4_t sum_hi = vdupq_n_s32(vertical_rounding);
-      sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[0]), filter_v[0]);
-      sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[1]), filter_v[1]);
-      sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[2]), filter_v[2]);
-      sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[3]), filter_v[3]);
-      sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[4]), filter_v[4]);
-      sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[5]), filter_v[5]);
-      sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[6]), filter_v[6]);
-      uint16x4_t sum_hi_16 = vqrshrun_n_s32(sum_hi, 11);
-
-      vst1_u8(dst_v, vqmovn_u16(vcombine_u16(sum_lo_16, sum_hi_16)));
-      dst_v += dest_stride;
-
-      a[0] = a[1];
-      a[1] = a[2];
-      a[2] = a[3];
-      a[3] = a[4];
-      a[4] = a[5];
-      a[5] = a[6];
-    }
+  if (width == 4) {
+    VerticalSum<4>(wiener_buffer, buffer_stride, dst, dest_stride, filter_v,
+                   width, height);
+  } else {
+    VerticalSum<8>(wiener_buffer, buffer_stride, dst, dest_stride, filter_v,
+                   width, height);
   }
 }
 
@@ -463,7 +535,8 @@
     row_sq[1] = vmull_u8(row[1], row[1]);
     row_sq[2] = vmull_u8(row[2], row[2]);
 
-    for (int y = -1; y < height + 1; y += 2) {
+    int y = -1;
+    do {
       row[3] = vld1_u8(column);
       column += stride;
       row[4] = vld1_u8(column);
@@ -488,10 +561,12 @@
       row_sq[0] = row_sq[2];
       row_sq[1] = row_sq[3];
       row_sq[2] = row_sq[4];
-    }
+      y += 2;
+    } while (y < height + 1);
   }
 
-  for (int x = 0; x < width; x += 4) {
+  int x = 0;
+  do {
     // |src_pre_process| is X but we already processed the first column of 4
     // values so we want to start at Y and increment from there.
     // X s s s Y s s
@@ -564,7 +639,8 @@
     // Calculate one output line. Add in the line from the previous pass and
     // output one even row. Sum the new line and output the odd row. Carry the
     // new row into the next pass.
-    for (int y = 0; y < height; y += 2) {
+    int y = 0;
+    do {
       row[3] = vld1_u8(column);
       column += stride;
       row[4] = vld1_u8(column);
@@ -616,8 +692,10 @@
 
       sum565_a0 = sum565_a1;
       sum565_b0 = sum565_b1;
-    }
-  }
+      y += 2;
+    } while (y < height);
+    x += 4;
+  } while (x < width);
 }
 
 inline void BoxFilterPreProcess_SecondPass(const uint8_t* const src,
@@ -637,7 +715,8 @@
   // get 68 values. This doesn't appear to be causing problems yet but it
   // might.
   const uint8_t* const src_top_left_corner = src - 1 - 2 * stride;
-  for (int x = -1; x < width + 1; x += 4) {
+  int x = -1;
+  do {
     const uint8_t* column = src_top_left_corner + x;
     uint16_t* a2_column = a2 + (x + 1);
     uint8x8_t row[3];
@@ -650,7 +729,8 @@
     row_sq[0] = vmull_u8(row[0], row[0]);
     row_sq[1] = vmull_u8(row[1], row[1]);
 
-    for (int y = -1; y < height + 1; ++y) {
+    int y = -1;
+    do {
       row[2] = vld1_u8(column);
       column += stride;
 
@@ -672,8 +752,9 @@
 
       row_sq[0] = row_sq[1];
       row_sq[1] = row_sq[2];
-    }
-  }
+    } while (++y < height + 1);
+    x += 4;
+  } while (x < width + 1);
 }
 
 inline uint16x4_t Sum444(const uint16x8_t a) {
@@ -707,7 +788,8 @@
 
   BoxFilterPreProcess_SecondPass(src, stride, width, height, s, a2);
 
-  for (int x = 0; x < width; x += 4) {
+  int x = 0;
+  do {
     uint16_t* a2_ptr = a2 + x;
     const uint8_t* src_ptr = src + x;
     // |filtered_output| must match how |a2| values are read since they are
@@ -733,7 +815,8 @@
     sum343_b[1] = Sum343W(b_1);
     sum444_b = Sum444W(b_1);
 
-    for (int y = 0; y < height; ++y) {
+    int y = 0;
+    do {
       const uint16x8_t a_2 = vld1q_u16(a2_ptr);
       a2_ptr += kIntermediateStride;
 
@@ -762,8 +845,137 @@
 
       src_ptr += stride;
       filtered_output += kIntermediateStride;
+    } while (++y < height);
+    x += 4;
+  } while (x < width);
+}
+
+template <int min_width>
+inline void SelfGuidedSingleMultiplier(const uint8_t* src,
+                                       const ptrdiff_t src_stride,
+                                       uint16_t* box_filter_process_output,
+                                       uint8_t* dst, const ptrdiff_t dst_stride,
+                                       const int width, const int height,
+                                       const int16_t w_combo,
+                                       const int16x4_t w_single) {
+  static_assert(min_width == 4 || min_width == 8, "");
+
+  int y = 0;
+  do {
+    if (min_width == 8) {
+      int x = 0;
+      do {
+        const int16x8_t u = vreinterpretq_s16_u16(
+            vshll_n_u8(vld1_u8(src + x), kSgrProjRestoreBits));
+        const int16x8_t p =
+            vreinterpretq_s16_u16(vld1q_u16(box_filter_process_output + x));
+
+        // u * w1 + u * wN == u * (w1 + wN)
+        int32x4_t v_lo = vmull_n_s16(vget_low_s16(u), w_combo);
+        v_lo = vmlal_s16(v_lo, vget_low_s16(p), w_single);
+
+        int32x4_t v_hi = vmull_n_s16(vget_high_s16(u), w_combo);
+        v_hi = vmlal_s16(v_hi, vget_high_s16(p), w_single);
+
+        const int16x4_t s_lo =
+            vrshrn_n_s32(v_lo, kSgrProjRestoreBits + kSgrProjPrecisionBits);
+        const int16x4_t s_hi =
+            vrshrn_n_s32(v_hi, kSgrProjRestoreBits + kSgrProjPrecisionBits);
+        vst1_u8(dst + x, vqmovun_s16(vcombine_s16(s_lo, s_hi)));
+        x += 8;
+      } while (x < width);
+    } else if (min_width == 4) {
+      const int16x8_t u =
+          vreinterpretq_s16_u16(vshll_n_u8(vld1_u8(src), kSgrProjRestoreBits));
+      const int16x8_t p =
+          vreinterpretq_s16_u16(vld1q_u16(box_filter_process_output));
+
+      // u * w1 + u * wN == u * (w1 + wN)
+      int32x4_t v_lo = vmull_n_s16(vget_low_s16(u), w_combo);
+      v_lo = vmlal_s16(v_lo, vget_low_s16(p), w_single);
+
+      int32x4_t v_hi = vmull_n_s16(vget_high_s16(u), w_combo);
+      v_hi = vmlal_s16(v_hi, vget_high_s16(p), w_single);
+
+      const int16x4_t s_lo =
+          vrshrn_n_s32(v_lo, kSgrProjRestoreBits + kSgrProjPrecisionBits);
+      const int16x4_t s_hi =
+          vrshrn_n_s32(v_hi, kSgrProjRestoreBits + kSgrProjPrecisionBits);
+      StoreLo4(dst, vqmovun_s16(vcombine_s16(s_lo, s_hi)));
     }
-  }
+    src += src_stride;
+    dst += dst_stride;
+    box_filter_process_output += kIntermediateStride;
+  } while (++y < height);
+}
+
+template <int min_width>
+inline void SelfGuidedDoubleMultiplier(const uint8_t* src,
+                                       const ptrdiff_t src_stride,
+                                       uint16_t* box_filter_process_output[2],
+                                       uint8_t* dst, const ptrdiff_t dst_stride,
+                                       const int width, const int height,
+                                       const int16x4_t w0, const int w1,
+                                       const int16x4_t w2) {
+  static_assert(min_width == 4 || min_width == 8, "");
+
+  int y = 0;
+  do {
+    if (min_width == 8) {
+      int x = 0;
+      do {
+        // |wN| values are signed. |src| values can be treated as int16_t.
+        const int16x8_t u = vreinterpretq_s16_u16(
+            vshll_n_u8(vld1_u8(src + x), kSgrProjRestoreBits));
+        // |box_filter_process_output| is 14 bits, also safe to treat as
+        // int16_t.
+        const int16x8_t p0 =
+            vreinterpretq_s16_u16(vld1q_u16(box_filter_process_output[0] + x));
+        const int16x8_t p1 =
+            vreinterpretq_s16_u16(vld1q_u16(box_filter_process_output[1] + x));
+
+        int32x4_t v_lo = vmull_n_s16(vget_low_s16(u), w1);
+        v_lo = vmlal_s16(v_lo, vget_low_s16(p0), w0);
+        v_lo = vmlal_s16(v_lo, vget_low_s16(p1), w2);
+
+        int32x4_t v_hi = vmull_n_s16(vget_high_s16(u), w1);
+        v_hi = vmlal_s16(v_hi, vget_high_s16(p0), w0);
+        v_hi = vmlal_s16(v_hi, vget_high_s16(p1), w2);
+
+        // |s| is saturated to uint8_t.
+        const int16x4_t s_lo =
+            vrshrn_n_s32(v_lo, kSgrProjRestoreBits + kSgrProjPrecisionBits);
+        const int16x4_t s_hi =
+            vrshrn_n_s32(v_hi, kSgrProjRestoreBits + kSgrProjPrecisionBits);
+        vst1_u8(dst + x, vqmovun_s16(vcombine_s16(s_lo, s_hi)));
+        x += 8;
+      } while (x < width);
+    } else if (min_width == 4) {
+      // |wN| values are signed. |src| values can be treated as int16_t.
+      // Load 8 values but ignore 4.
+      const int16x4_t u = vget_low_s16(
+          vreinterpretq_s16_u16(vshll_n_u8(vld1_u8(src), kSgrProjRestoreBits)));
+      // |box_filter_process_output| is 14 bits, also safe to treat as
+      // int16_t.
+      const int16x4_t p0 =
+          vreinterpret_s16_u16(vld1_u16(box_filter_process_output[0]));
+      const int16x4_t p1 =
+          vreinterpret_s16_u16(vld1_u16(box_filter_process_output[1]));
+
+      int32x4_t v = vmull_n_s16(u, w1);
+      v = vmlal_s16(v, p0, w0);
+      v = vmlal_s16(v, p1, w2);
+
+      // |s| is saturated to uint8_t.
+      const int16x4_t s =
+          vrshrn_n_s32(v, kSgrProjRestoreBits + kSgrProjPrecisionBits);
+      StoreLo4(dst, vqmovun_s16(vcombine_s16(s, s)));
+    }
+    src += src_stride;
+    dst += dst_stride;
+    box_filter_process_output[0] += kIntermediateStride;
+    box_filter_process_output[1] += kIntermediateStride;
+  } while (++y < height);
 }
 
 // Assume box_filter_process_output[2] are allocated before calling
@@ -771,8 +983,17 @@
 void SelfGuidedFilter_NEON(const void* const source, void* const dest,
                            const RestorationUnitInfo& restoration_info,
                            ptrdiff_t source_stride, ptrdiff_t dest_stride,
-                           int width, int height,
+                           const int width, const int height,
                            RestorationBuffer* const /*buffer*/) {
+  // The output frame is broken into blocks of 64x64 (32x32 if U/V are
+  // subsampled). If either dimension is less than 32/64 it indicates it is at
+  // the right or bottom edge of the frame. It is safe to overwrite the output
+  // as it will not be part of the visible frame. This saves us from having to
+  // handle non-multiple-of-8 widths.
+  // We could round here, but the for loop with += 8 does the same thing.
+
+  // width = (width + 7) & ~0x7;
+
   // -96 to 96 (Sgrproj_Xqd_Min/Max)
   const int8_t w0 = restoration_info.sgr_proj_info.multiplier[0];
   const int8_t w1 = restoration_info.sgr_proj_info.multiplier[1];
@@ -815,38 +1036,15 @@
   // is no vmlal_n_s16().
   const int16x4_t w0_v = vdup_n_s16(w0);
   const int16x4_t w2_v = vdup_n_s16(w2);
-  assert(width % 8 == 0);
   if (radius_pass_0 != 0 && radius_pass_1 != 0) {
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; x += 8) {
-        // |wN| values are signed. |src| values can be treated as int16_t.
-        const int16x8_t u_v = vreinterpretq_s16_u16(
-            vshll_n_u8(vld1_u8(src + x), kSgrProjRestoreBits));
-        // |filtered_output| is 14 bits, also safe to treat as int16_t.
-        const int16x8_t p0_v =
-            vreinterpretq_s16_u16(vld1q_u16(box_filter_process_output[0] + x));
-        const int16x8_t p1_v =
-            vreinterpretq_s16_u16(vld1q_u16(box_filter_process_output[1] + x));
-
-        int32x4_t v_lo = vmull_n_s16(vget_low_s16(u_v), w1);
-        v_lo = vmlal_s16(v_lo, vget_low_s16(p0_v), w0_v);
-        v_lo = vmlal_s16(v_lo, vget_low_s16(p1_v), w2_v);
-
-        int32x4_t v_hi = vmull_n_s16(vget_high_s16(u_v), w1);
-        v_hi = vmlal_s16(v_hi, vget_high_s16(p0_v), w0_v);
-        v_hi = vmlal_s16(v_hi, vget_high_s16(p1_v), w2_v);
-
-        // |s| is saturated to uint8_t.
-        const int16x4_t s_lo =
-            vrshrn_n_s32(v_lo, kSgrProjRestoreBits + kSgrProjPrecisionBits);
-        const int16x4_t s_hi =
-            vrshrn_n_s32(v_hi, kSgrProjRestoreBits + kSgrProjPrecisionBits);
-        vst1_u8(dst + x, vqmovun_s16(vcombine_s16(s_lo, s_hi)));
-      }
-      src += source_stride;
-      dst += dest_stride;
-      box_filter_process_output[0] += kIntermediateStride;
-      box_filter_process_output[1] += kIntermediateStride;
+    if (width > 4) {
+      SelfGuidedDoubleMultiplier<8>(src, source_stride,
+                                    box_filter_process_output, dst, dest_stride,
+                                    width, height, w0_v, w1, w2_v);
+    } else /* if (width == 4) */ {
+      SelfGuidedDoubleMultiplier<4>(src, source_stride,
+                                    box_filter_process_output, dst, dest_stride,
+                                    width, height, w0_v, w1, w2_v);
     }
   } else {
     int16_t w_combo;
@@ -862,29 +1060,14 @@
       box_filter_process_output_n = box_filter_process_output[1];
     }
 
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; x += 8) {
-        const int16x8_t u_v = vreinterpretq_s16_u16(
-            vshll_n_u8(vld1_u8(src + x), kSgrProjRestoreBits));
-        const int16x8_t pN_v =
-            vreinterpretq_s16_u16(vld1q_u16(box_filter_process_output_n + x));
-
-        // u_v * w1 + u_v * wN == u_v * (w1 + wN)
-        int32x4_t v_lo = vmull_n_s16(vget_low_s16(u_v), w_combo);
-        v_lo = vmlal_s16(v_lo, vget_low_s16(pN_v), w_single);
-
-        int32x4_t v_hi = vmull_n_s16(vget_high_s16(u_v), w_combo);
-        v_hi = vmlal_s16(v_hi, vget_high_s16(pN_v), w_single);
-
-        const int16x4_t s_lo =
-            vrshrn_n_s32(v_lo, kSgrProjRestoreBits + kSgrProjPrecisionBits);
-        const int16x4_t s_hi =
-            vrshrn_n_s32(v_hi, kSgrProjRestoreBits + kSgrProjPrecisionBits);
-        vst1_u8(dst + x, vqmovun_s16(vcombine_s16(s_lo, s_hi)));
-      }
-      src += source_stride;
-      dst += dest_stride;
-      box_filter_process_output_n += kIntermediateStride;
+    if (width > 4) {
+      SelfGuidedSingleMultiplier<8>(
+          src, source_stride, box_filter_process_output_n, dst, dest_stride,
+          width, height, w_combo, w_single);
+    } else /* if (width == 4) */ {
+      SelfGuidedSingleMultiplier<4>(
+          src, source_stride, box_filter_process_output_n, dst, dest_stride,
+          width, height, w_combo, w_single);
     }
   }
 }
diff --git a/libgav1/src/dsp/arm/loop_restoration_neon.h b/libgav1/src/dsp/arm/loop_restoration_neon.h
index 85a06c8..1723c50 100644
--- a/libgav1/src/dsp/arm/loop_restoration_neon.h
+++ b/libgav1/src/dsp/arm/loop_restoration_neon.h
@@ -3,13 +3,12 @@
 
 #include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
-#include "src/dsp/loop_filter.h"
 
 namespace libgav1 {
 namespace dsp {
 
-// Initializes Dsp::loop_restorations with neon implementations. This function
-// is not thread-safe.
+// Initializes Dsp::loop_restorations, see the defines below for specifics.
+// This function is not thread-safe.
 void LoopRestorationInit_NEON();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/arm/mask_blend_neon.cc b/libgav1/src/dsp/arm/mask_blend_neon.cc
new file mode 100644
index 0000000..bc69c55
--- /dev/null
+++ b/libgav1/src/dsp/arm/mask_blend_neon.cc
@@ -0,0 +1,319 @@
+#include "src/dsp/dsp.h"
+#include "src/dsp/mask_blend.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+constexpr int kBitdepth8 = 8;
+
+template <int subsampling_x, int subsampling_y>
+inline uint16x8_t GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) {
+  if (subsampling_x == 1) {
+    const uint16x4_t mask_val0 = vpaddl_u8(vld1_u8(mask));
+    const uint16x4_t mask_val1 =
+        vpaddl_u8(vld1_u8(mask + (mask_stride << subsampling_y)));
+    uint16x8_t final_val;
+    if (subsampling_y == 1) {
+      const uint16x4_t next_mask_val0 = vpaddl_u8(vld1_u8(mask + mask_stride));
+      const uint16x4_t next_mask_val1 =
+          vpaddl_u8(vld1_u8(mask + mask_stride * 3));
+      final_val = vaddq_u16(vcombine_u16(mask_val0, mask_val1),
+                            vcombine_u16(next_mask_val0, next_mask_val1));
+    } else {
+      final_val = vpaddlq_u8(vcombine_u8(mask_val0, mask_val1));
+    }
+    return vrshrq_n_u16(final_val, subsampling_y + 1);
+  }
+  assert(subsampling_y == 0 && subsampling_x == 0);
+  const uint8x8_t mask_val0 = LoadLo4(mask, vdup_n_u8(0));
+  const uint8x8_t mask_val =
+      LoadHi4(mask + (mask_stride << subsampling_y), mask_val0);
+  return vmovl_u8(mask_val);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline uint16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) {
+  if (subsampling_x == 1) {
+    uint16x8_t mask_val = vpaddlq_u8(vld1q_u8(mask));
+    if (subsampling_y == 1) {
+      const uint16x8_t next_mask_val = vpaddlq_u8(vld1q_u8(mask + mask_stride));
+      mask_val = vaddq_u16(mask_val, next_mask_val);
+    }
+    return vrshrq_n_u16(mask_val, 1 + subsampling_y);
+  }
+  assert(subsampling_y == 0 && subsampling_x == 0);
+  const uint8x8_t mask_val = vld1_u8(mask);
+  return vmovl_u8(mask_val);
+}
+
+template <bool is_inter_intra>
+inline void WriteMaskBlendLine4x2(const uint16_t* const pred_0,
+                                  const ptrdiff_t pred_stride_0,
+                                  const uint16_t* const pred_1,
+                                  const ptrdiff_t pred_stride_1,
+                                  const uint16x8_t pred_mask_0,
+                                  const uint16x8_t pred_mask_1, uint8_t* dst,
+                                  const ptrdiff_t dst_stride) {
+  const uint16x4_t pred_val_0_lo = vld1_u16(pred_0);
+  const uint16x4_t pred_val_0_hi = vld1_u16(pred_0 + pred_stride_0);
+  uint16x4_t pred_val_1_lo = vld1_u16(pred_1);
+  uint16x4_t pred_val_1_hi = vld1_u16(pred_1 + pred_stride_1);
+  uint8x8_t result;
+  if (is_inter_intra) {
+    // An offset to cancel offsets used in compound predictor generation
+    // that make intermediate computations non negative.
+    const uint16x8_t single_round_offset =
+        vdupq_n_u16((1 << kBitdepth8) + (1 << (kBitdepth8 - 1)));
+    // pred_0 and pred_1 are switched at the beginning with is_inter_intra.
+    // Clip3(prediction_0[x] - single_round_offset, 0, (1 << kBitdepth8) - 1)
+    const uint16x8_t pred_val_1 = vmovl_u8(vqmovn_u16(vqsubq_u16(
+        vcombine_u16(pred_val_1_lo, pred_val_1_hi), single_round_offset)));
+
+    const uint16x8_t pred_val_0 = vcombine_u16(pred_val_0_lo, pred_val_0_hi);
+    const uint16x8_t weighted_pred_0 = vmulq_u16(pred_val_0, pred_mask_0);
+    const uint16x8_t weighted_combo =
+        vmlaq_u16(weighted_pred_0, pred_mask_1, pred_val_1);
+    result = vrshrn_n_u16(weighted_combo, 6);
+  } else {
+    // int res = (mask_value * prediction_0[x] +
+    //      (64 - mask_value) * prediction_1[x]) >> 6;
+    const uint32x4_t weighted_pred_0_lo =
+        vmull_u16(vget_low_u16(pred_mask_0), pred_val_0_lo);
+    const uint32x4_t weighted_pred_0_hi =
+        vmull_u16(vget_high_u16(pred_mask_0), pred_val_0_hi);
+    const uint32x4_t weighted_combo_lo =
+        vmlal_u16(weighted_pred_0_lo, vget_low_u16(pred_mask_1), pred_val_1_lo);
+    const uint32x4_t weighted_combo_hi = vmlal_u16(
+        weighted_pred_0_hi, vget_high_u16(pred_mask_1), pred_val_1_hi);
+    // res -= compound_round_offset;
+    // dst[x] = static_cast<Pixel>(
+    //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
+    //         (1 << kBitdepth8) - 1));
+    const int16x8_t compound_round_offset =
+        vdupq_n_s16((1 << (kBitdepth8 + 4)) + (1 << (kBitdepth8 + 3)));
+    result = vqrshrun_n_s16(vsubq_s16(vreinterpretq_s16_u16(vcombine_u16(
+                                          vshrn_n_u32(weighted_combo_lo, 6),
+                                          vshrn_n_u32(weighted_combo_hi, 6))),
+                                      compound_round_offset),
+                            4);
+  }
+  StoreLo4(dst, result);
+  StoreHi4(dst + dst_stride, result);
+}
+
+template <bool is_inter_intra, int subsampling_x, int subsampling_y>
+inline void MaskBlending4x4_NEON(const uint16_t* pred_0,
+                                 const ptrdiff_t prediction_stride_0,
+                                 const uint16_t* pred_1,
+                                 const ptrdiff_t prediction_stride_1,
+                                 const uint8_t* mask,
+                                 const ptrdiff_t mask_stride, uint8_t* dst,
+                                 const ptrdiff_t dst_stride) {
+  const uint16x8_t mask_inverter = vdupq_n_u16(64);
+  uint16x8_t pred_mask_0 =
+      GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+  uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
+  WriteMaskBlendLine4x2<is_inter_intra>(pred_0, prediction_stride_0, pred_1,
+                                        prediction_stride_1, pred_mask_0,
+                                        pred_mask_1, dst, dst_stride);
+  pred_0 += prediction_stride_0 << 1;
+  pred_1 += prediction_stride_1 << 1;
+  mask += mask_stride << (1 + subsampling_y);
+  dst += dst_stride << 1;
+
+  pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+  pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
+  WriteMaskBlendLine4x2<is_inter_intra>(pred_0, prediction_stride_0, pred_1,
+                                        prediction_stride_1, pred_mask_0,
+                                        pred_mask_1, dst, dst_stride);
+}
+
+template <bool is_inter_intra, int subsampling_x, int subsampling_y>
+inline void MaskBlending4xH_NEON(const uint16_t* pred_0,
+                                 const ptrdiff_t pred_stride_0,
+                                 const int height, const uint16_t* pred_1,
+                                 const ptrdiff_t pred_stride_1,
+                                 const uint8_t* const mask_ptr,
+                                 const ptrdiff_t mask_stride, uint8_t* dst,
+                                 const ptrdiff_t dst_stride) {
+  const uint8_t* mask = mask_ptr;
+  if (height == 4) {
+    MaskBlending4x4_NEON<is_inter_intra, subsampling_x, subsampling_y>(
+        pred_0, pred_stride_0, pred_1, pred_stride_1, mask, mask_stride, dst,
+        dst_stride);
+    return;
+  }
+  const uint16x8_t mask_inverter = vdupq_n_u16(64);
+  int y = 0;
+  do {
+    uint16x8_t pred_mask_0 =
+        GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+    uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
+
+    WriteMaskBlendLine4x2<is_inter_intra>(pred_0, pred_stride_0, pred_1,
+                                          pred_stride_1, pred_mask_0,
+                                          pred_mask_1, dst, dst_stride);
+    pred_0 += pred_stride_0 << 1;
+    pred_1 += pred_stride_1 << 1;
+    mask += mask_stride << (1 + subsampling_y);
+    dst += dst_stride << 1;
+
+    pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+    pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
+    WriteMaskBlendLine4x2<is_inter_intra>(pred_0, pred_stride_0, pred_1,
+                                          pred_stride_1, pred_mask_0,
+                                          pred_mask_1, dst, dst_stride);
+    pred_0 += pred_stride_0 << 1;
+    pred_1 += pred_stride_1 << 1;
+    mask += mask_stride << (1 + subsampling_y);
+    dst += dst_stride << 1;
+
+    pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+    pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
+    WriteMaskBlendLine4x2<is_inter_intra>(pred_0, pred_stride_0, pred_1,
+                                          pred_stride_1, pred_mask_0,
+                                          pred_mask_1, dst, dst_stride);
+    pred_0 += pred_stride_0 << 1;
+    pred_1 += pred_stride_1 << 1;
+    mask += mask_stride << (1 + subsampling_y);
+    dst += dst_stride << 1;
+
+    pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+    pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
+    WriteMaskBlendLine4x2<is_inter_intra>(pred_0, pred_stride_0, pred_1,
+                                          pred_stride_1, pred_mask_0,
+                                          pred_mask_1, dst, dst_stride);
+    pred_0 += pred_stride_0 << 1;
+    pred_1 += pred_stride_1 << 1;
+    mask += mask_stride << (1 + subsampling_y);
+    dst += dst_stride << 1;
+    y += 8;
+  } while (y < height);
+}
+
+template <bool is_inter_intra, int subsampling_x, int subsampling_y>
+inline void MaskBlend_NEON(
+    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
+    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
+    const uint8_t* const mask_ptr, const ptrdiff_t mask_stride, const int width,
+    const int height, void* dest, const ptrdiff_t dst_stride) {
+  uint8_t* dst = reinterpret_cast<uint8_t*>(dest);
+  const uint16_t* pred_0 = is_inter_intra ? prediction_1 : prediction_0;
+  const uint16_t* pred_1 = is_inter_intra ? prediction_0 : prediction_1;
+  const ptrdiff_t pred_stride_0 =
+      is_inter_intra ? prediction_stride_1 : prediction_stride_0;
+  const ptrdiff_t pred_stride_1 =
+      is_inter_intra ? prediction_stride_0 : prediction_stride_1;
+  if (width == 4) {
+    MaskBlending4xH_NEON<is_inter_intra, subsampling_x, subsampling_y>(
+        pred_0, pred_stride_0, height, pred_1, pred_stride_1, mask_ptr,
+        mask_stride, dst, dst_stride);
+    return;
+  }
+  const uint8_t* mask = mask_ptr;
+  const uint16x8_t mask_inverter = vdupq_n_u16(64);
+  int y = 0;
+  do {
+    int x = 0;
+    do {
+      const uint16x8_t pred_mask_0 = GetMask8<subsampling_x, subsampling_y>(
+          mask + (x << subsampling_x), mask_stride);
+      // 64 - mask
+      const uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
+      const uint16x8_t pred_val_0 = vld1q_u16(pred_0 + x);
+      uint16x8_t pred_val_1 = vld1q_u16(pred_1 + x);
+      if (is_inter_intra) {
+        // An offset to cancel offsets used in compound predictor generation
+        // that make intermediate computations non negative.
+        const uint16x8_t single_round_offset =
+            vdupq_n_u16((1 << kBitdepth8) + (1 << (kBitdepth8 - 1)));
+        pred_val_1 =
+            vmovl_u8(vqmovn_u16(vqsubq_u16(pred_val_1, single_round_offset)));
+      }
+      uint8x8_t result;
+      if (is_inter_intra) {
+        const uint16x8_t weighted_pred_0 = vmulq_u16(pred_mask_0, pred_val_0);
+        // weighted_pred0 + weighted_pred1
+        const uint16x8_t weighted_combo =
+            vmlaq_u16(weighted_pred_0, pred_mask_1, pred_val_1);
+        result = vrshrn_n_u16(weighted_combo, 6);
+      } else {
+        // int res = (mask_value * prediction_0[x] +
+        //      (64 - mask_value) * prediction_1[x]) >> 6;
+        const uint32x4_t weighted_pred_0_lo =
+            vmull_u16(vget_low_u16(pred_mask_0), vget_low_u16(pred_val_0));
+        const uint32x4_t weighted_pred_0_hi =
+            vmull_u16(vget_high_u16(pred_mask_0), vget_high_u16(pred_val_0));
+        const uint32x4_t weighted_combo_lo =
+            vmlal_u16(weighted_pred_0_lo, vget_low_u16(pred_mask_1),
+                      vget_low_u16(pred_val_1));
+        const uint32x4_t weighted_combo_hi =
+            vmlal_u16(weighted_pred_0_hi, vget_high_u16(pred_mask_1),
+                      vget_high_u16(pred_val_1));
+
+        const int16x8_t compound_round_offset =
+            vdupq_n_s16((1 << (kBitdepth8 + 4)) + (1 << (kBitdepth8 + 3)));
+        // res -= compound_round_offset;
+        // dst[x] = static_cast<Pixel>(
+        //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
+        //           (1 << kBitdepth8) - 1));
+        result =
+            vqrshrun_n_s16(vsubq_s16(vreinterpretq_s16_u16(vcombine_u16(
+                                         vshrn_n_u32(weighted_combo_lo, 6),
+                                         vshrn_n_u32(weighted_combo_hi, 6))),
+                                     compound_round_offset),
+                           4);
+      }
+      vst1_u8(dst + x, result);
+
+      x += 8;
+    } while (x < width);
+    dst += dst_stride;
+    pred_0 += pred_stride_0;
+    pred_1 += pred_stride_1;
+    mask += mask_stride << subsampling_y;
+  } while (++y < height);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+  dsp->mask_blend[0][0] = MaskBlend_NEON<false, 0, 0>;
+  dsp->mask_blend[1][0] = MaskBlend_NEON<false, 1, 0>;
+  dsp->mask_blend[2][0] = MaskBlend_NEON<false, 1, 1>;
+  dsp->mask_blend[0][1] = MaskBlend_NEON<true, 0, 0>;
+  dsp->mask_blend[1][1] = MaskBlend_NEON<true, 1, 0>;
+  dsp->mask_blend[2][1] = MaskBlend_NEON<true, 1, 1>;
+}
+
+}  // namespace
+}  // namespace low_bitdepth
+
+void MaskBlendInit_NEON() { low_bitdepth::Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else   // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void MaskBlendInit_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/mask_blend_neon.h b/libgav1/src/dsp/arm/mask_blend_neon.h
new file mode 100644
index 0000000..eed54dd
--- /dev/null
+++ b/libgav1/src/dsp/arm/mask_blend_neon.h
@@ -0,0 +1,25 @@
+#ifndef LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_
+
+#include "src/dsp/cpu.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::mask_blend. This function is not thread-safe.
+void MaskBlendInit_NEON();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_MaskBlend444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_MaskBlend422 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_MaskBlend420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_MaskBlendInterIntra444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_MaskBlendInterIntra422 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_MaskBlendInterIntra420 LIBGAV1_DSP_NEON
+#endif  // LIBGAV1_ENABLE_NEON
+
+#endif  // LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_
diff --git a/libgav1/src/dsp/arm/obmc_neon.cc b/libgav1/src/dsp/arm/obmc_neon.cc
new file mode 100644
index 0000000..88c42c0
--- /dev/null
+++ b/libgav1/src/dsp/arm/obmc_neon.cc
@@ -0,0 +1,387 @@
+#include "src/dsp/dsp.h"
+#include "src/dsp/obmc.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+#include "src/dsp/obmc.inc"
+
+inline uint8x8_t Load2(const uint8_t* src) {
+  uint16_t tmp;
+  memcpy(&tmp, src, 2);
+  uint16x4_t result = vcreate_u16(tmp);
+  return vreinterpret_u8_u16(result);
+}
+
+template <int lane>
+inline void StoreLane2(uint8_t* dst, uint8x8_t src) {
+  const uint16_t out_val = vget_lane_u16(vreinterpret_u16_u8(src), lane);
+  memcpy(dst, &out_val, 2);
+}
+
+inline void WriteObmcLine4(uint8_t* const pred, const uint8_t* const obmc_pred,
+                           const uint8x8_t pred_mask,
+                           const uint8x8_t obmc_pred_mask) {
+  const uint8x8_t pred_val = LoadLo4(pred, vdup_n_u8(0));
+  const uint8x8_t obmc_pred_val = LoadLo4(obmc_pred, vdup_n_u8(0));
+  const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val);
+  const uint8x8_t result =
+      vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
+  StoreLo4(pred, result);
+}
+
+template <bool from_left>
+inline void OverlapBlend2xH_NEON(uint8_t* const prediction,
+                                 const ptrdiff_t prediction_stride,
+                                 const int height,
+                                 const uint8_t* const obmc_prediction,
+                                 const ptrdiff_t obmc_prediction_stride) {
+  uint8_t* pred = prediction;
+  const uint8x8_t mask_inverter = vdup_n_u8(64);
+  const uint8_t* obmc_pred = obmc_prediction;
+  uint8x8_t pred_mask;
+  uint8x8_t obmc_pred_mask;
+  int compute_height;
+  const int mask_offset = height - 2;
+  if (from_left) {
+    pred_mask = Load2(kObmcMask);
+    obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+    compute_height = height;
+  } else {
+    // Weights for the last line are all 64, which is a no-op.
+    compute_height = height - 1;
+  }
+  int y = 0;
+  do {
+    if (!from_left) {
+      pred_mask = vdup_n_u8(kObmcMask[mask_offset + y]);
+      obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+    }
+    const uint8x8_t pred_val = Load2(pred);
+    const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val);
+    const uint8x8_t obmc_pred_val = Load2(obmc_pred);
+    const uint8x8_t result =
+        vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
+    StoreLane2<0>(pred, result);
+
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+  } while (++y != compute_height);
+}
+
+inline void OverlapBlendFromLeft4xH_NEON(
+    uint8_t* const prediction, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
+  uint8_t* pred = prediction;
+  const uint8_t* obmc_pred = obmc_prediction;
+
+  const uint8x8_t mask_inverter = vdup_n_u8(64);
+  const uint8x8_t pred_mask = LoadLo4(kObmcMask + 2, vdup_n_u8(0));
+  // 64 - mask
+  const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+  int y = 0;
+  do {
+    WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    y += 2;
+  } while (y != height);
+}
+
+inline void OverlapBlendFromLeft8xH_NEON(
+    uint8_t* const prediction, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
+  uint8_t* pred = prediction;
+  const uint8_t* obmc_pred = obmc_prediction;
+  const uint8x8_t mask_inverter = vdup_n_u8(64);
+  const uint8x8_t pred_mask = vld1_u8(kObmcMask + 6);
+  // 64 - mask
+  const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+  int y = 0;
+  do {
+    const uint8x8_t pred_val = vld1_u8(pred);
+    const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val);
+    const uint8x8_t obmc_pred_val = vld1_u8(obmc_pred);
+    const uint8x8_t result =
+        vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
+
+    vst1_u8(pred, result);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+  } while (++y != height);
+}
+
+void OverlapBlendFromLeft_NEON(void* const prediction,
+                               const ptrdiff_t prediction_stride,
+                               const int width, const int height,
+                               const void* const obmc_prediction,
+                               const ptrdiff_t obmc_prediction_stride) {
+  auto* pred = static_cast<uint8_t*>(prediction);
+  const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
+
+  if (width == 2) {
+    OverlapBlend2xH_NEON<true>(pred, prediction_stride, height, obmc_pred,
+                               obmc_prediction_stride);
+    return;
+  }
+  if (width == 4) {
+    OverlapBlendFromLeft4xH_NEON(pred, prediction_stride, height, obmc_pred,
+                                 obmc_prediction_stride);
+    return;
+  }
+  if (width == 8) {
+    OverlapBlendFromLeft8xH_NEON(pred, prediction_stride, height, obmc_pred,
+                                 obmc_prediction_stride);
+    return;
+  }
+  const uint8x16_t mask_inverter = vdupq_n_u8(64);
+  const uint8_t* mask = kObmcMask + width - 2;
+  int x = 0;
+  do {
+    pred = static_cast<uint8_t*>(prediction) + x;
+    obmc_pred = static_cast<const uint8_t*>(obmc_prediction) + x;
+    const uint8x16_t pred_mask = vld1q_u8(mask + x);
+    // 64 - mask
+    const uint8x16_t obmc_pred_mask = vsubq_u8(mask_inverter, pred_mask);
+    int y = 0;
+    do {
+      const uint8x16_t pred_val = vld1q_u8(pred);
+      const uint8x16_t obmc_pred_val = vld1q_u8(obmc_pred);
+      const uint16x8_t weighted_pred_lo =
+          vmull_u8(vget_low_u8(pred_mask), vget_low_u8(pred_val));
+      const uint8x8_t result_lo =
+          vrshrn_n_u16(vmlal_u8(weighted_pred_lo, vget_low_u8(obmc_pred_mask),
+                                vget_low_u8(obmc_pred_val)),
+                       6);
+      const uint16x8_t weighted_pred_hi =
+          vmull_u8(vget_high_u8(pred_mask), vget_high_u8(pred_val));
+      const uint8x8_t result_hi =
+          vrshrn_n_u16(vmlal_u8(weighted_pred_hi, vget_high_u8(obmc_pred_mask),
+                                vget_high_u8(obmc_pred_val)),
+                       6);
+      vst1q_u8(pred, vcombine_u8(result_lo, result_hi));
+
+      pred += prediction_stride;
+      obmc_pred += obmc_prediction_stride;
+    } while (++y < height);
+    x += 16;
+  } while (x < width);
+}
+
+inline void OverlapBlendFromTop4x4_NEON(uint8_t* const prediction,
+                                        const ptrdiff_t prediction_stride,
+                                        const uint8_t* const obmc_prediction,
+                                        const ptrdiff_t obmc_prediction_stride,
+                                        const int height) {
+  uint8_t* pred = prediction;
+  const uint8_t* obmc_pred = obmc_prediction;
+  uint8x8_t pred_mask = vdup_n_u8(kObmcMask[height - 2]);
+  const uint8x8_t mask_inverter = vdup_n_u8(64);
+  uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+  WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  if (height == 2) {
+    return;
+  }
+
+  pred_mask = vdup_n_u8(kObmcMask[3]);
+  obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+  WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  pred_mask = vdup_n_u8(kObmcMask[4]);
+  obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+  WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+}
+
+inline void OverlapBlendFromTop4xH_NEON(
+    uint8_t* const prediction, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
+  if (height < 8) {
+    OverlapBlendFromTop4x4_NEON(prediction, prediction_stride, obmc_prediction,
+                                obmc_prediction_stride, height);
+    return;
+  }
+  uint8_t* pred = prediction;
+  const uint8_t* obmc_pred = obmc_prediction;
+  const uint8_t* mask = kObmcMask + height - 2;
+  const uint8x8_t mask_inverter = vdup_n_u8(64);
+  int y = 0;
+  // Compute 6 lines for height 8, or 12 lines for height 16. The remaining
+  // lines are unchanged as the corresponding mask value is 64.
+  do {
+    uint8x8_t pred_mask = vdup_n_u8(mask[y]);
+    uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+    WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    pred_mask = vdup_n_u8(mask[y + 1]);
+    obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+    WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    pred_mask = vdup_n_u8(mask[y + 2]);
+    obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+    WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    pred_mask = vdup_n_u8(mask[y + 3]);
+    obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+    WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    pred_mask = vdup_n_u8(mask[y + 4]);
+    obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+    WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    pred_mask = vdup_n_u8(mask[y + 5]);
+    obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+    WriteObmcLine4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    // Increment for the right mask index.
+    y += 6;
+  } while (y < height - 4);
+}
+
+inline void OverlapBlendFromTop8xH_NEON(
+    uint8_t* const prediction, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
+  uint8_t* pred = prediction;
+  const uint8_t* obmc_pred = obmc_prediction;
+  const uint8x8_t mask_inverter = vdup_n_u8(64);
+  const uint8_t* mask = kObmcMask + height - 2;
+  const int compute_height = height - (height >> 2);
+  int y = 0;
+  do {
+    const uint8x8_t pred_mask = vdup_n_u8(mask[y]);
+    // 64 - mask
+    const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+    const uint8x8_t pred_val = vld1_u8(pred);
+    const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val);
+    const uint8x8_t obmc_pred_val = vld1_u8(obmc_pred);
+    const uint8x8_t result =
+        vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
+
+    vst1_u8(pred, result);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+  } while (++y != compute_height);
+}
+
+void OverlapBlendFromTop_NEON(void* const prediction,
+                              const ptrdiff_t prediction_stride,
+                              const int width, const int height,
+                              const void* const obmc_prediction,
+                              const ptrdiff_t obmc_prediction_stride) {
+  auto* pred = static_cast<uint8_t*>(prediction);
+  const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
+
+  if (width == 2) {
+    OverlapBlend2xH_NEON<false>(pred, prediction_stride, height, obmc_pred,
+                                obmc_prediction_stride);
+    return;
+  }
+  if (width == 4) {
+    OverlapBlendFromTop4xH_NEON(pred, prediction_stride, height, obmc_pred,
+                                obmc_prediction_stride);
+    return;
+  }
+
+  if (width == 8) {
+    OverlapBlendFromTop8xH_NEON(pred, prediction_stride, height, obmc_pred,
+                                obmc_prediction_stride);
+    return;
+  }
+
+  const uint8_t* mask = kObmcMask + height - 2;
+  const uint8x8_t mask_inverter = vdup_n_u8(64);
+  // Stop when mask value becomes 64. This is inferred for 4xH.
+  const int compute_height = height - (height >> 2);
+  int y = 0;
+  do {
+    const uint8x8_t pred_mask = vdup_n_u8(mask[y]);
+    // 64 - mask
+    const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
+    int x = 0;
+    do {
+      const uint8x16_t pred_val = vld1q_u8(pred + x);
+      const uint8x16_t obmc_pred_val = vld1q_u8(obmc_pred + x);
+      const uint16x8_t weighted_pred_lo =
+          vmull_u8(pred_mask, vget_low_u8(pred_val));
+      const uint8x8_t result_lo =
+          vrshrn_n_u16(vmlal_u8(weighted_pred_lo, obmc_pred_mask,
+                                vget_low_u8(obmc_pred_val)),
+                       6);
+      const uint16x8_t weighted_pred_hi =
+          vmull_u8(pred_mask, vget_high_u8(pred_val));
+      const uint8x8_t result_hi =
+          vrshrn_n_u16(vmlal_u8(weighted_pred_hi, obmc_pred_mask,
+                                vget_high_u8(obmc_pred_val)),
+                       6);
+      vst1q_u8(pred + x, vcombine_u8(result_lo, result_hi));
+
+      x += 16;
+    } while (x < width);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+  } while (++y < compute_height);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+  dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendFromTop_NEON;
+  dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendFromLeft_NEON;
+}
+
+}  // namespace
+
+void ObmcInit_NEON() { Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else   // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void ObmcInit_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/obmc_neon.h b/libgav1/src/dsp/arm/obmc_neon.h
new file mode 100644
index 0000000..77a9fa8
--- /dev/null
+++ b/libgav1/src/dsp/arm/obmc_neon.h
@@ -0,0 +1,22 @@
+#ifndef LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_
+
+#include "src/dsp/cpu.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::obmc_blend. This function is not thread-safe.
+void ObmcInit_NEON();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+// If NEON is enabled, signal the NEON implementation should be used.
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_ObmcVertical LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_ObmcHorizontal LIBGAV1_DSP_NEON
+#endif  // LIBGAV1_ENABLE_NEON
+
+#endif  // LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_
diff --git a/libgav1/src/dsp/arm/warp_neon.cc b/libgav1/src/dsp/arm/warp_neon.cc
new file mode 100644
index 0000000..7ab3d27
--- /dev/null
+++ b/libgav1/src/dsp/arm/warp_neon.cc
@@ -0,0 +1,321 @@
+#include "src/dsp/dsp.h"
+#include "src/dsp/warp.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+// Number of extra bits of precision in warped filtering.
+constexpr int kWarpedDiffPrecisionBits = 10;
+
+void Warp_NEON(const void* const source, const ptrdiff_t source_stride,
+               const int source_width, const int source_height,
+               const int* const warp_params, const int subsampling_x,
+               const int subsampling_y, const uint8_t inter_round_bits_vertical,
+               const int block_start_x, const int block_start_y,
+               const int block_width, const int block_height,
+               const int16_t alpha, const int16_t beta, const int16_t gamma,
+               const int16_t delta, uint16_t* dest,
+               const ptrdiff_t dest_stride) {
+  constexpr int bitdepth = 8;
+  // Intermediate_result is the output of the horizontal filtering and rounding.
+  // The range is within 13 (= bitdepth + kFilterBits + 1 -
+  // kInterRoundBitsHorizontal) bits (unsigned). We use the signed int16_t type
+  // so that we can multiply it by kWarpedFilters (which has signed values)
+  // using vmlal_s16().
+  int16_t intermediate_result[15][8];  // 15 rows, 8 columns.
+  const int horizontal_offset = 1 << (bitdepth + kFilterBits - 1);
+  const int vertical_offset =
+      1 << (bitdepth + 2 * kFilterBits - kInterRoundBitsHorizontal);
+  const auto* const src = static_cast<const uint8_t*>(source);
+
+  assert(block_width >= 8);
+  assert(block_height >= 8);
+
+  const uint8x16_t index_vec = vcombine_u8(vcreate_u8(0x0706050403020100),
+                                           vcreate_u8(0x0f0e0d0c0b0a0908));
+
+  // Warp process applies for each 8x8 block (or smaller).
+  int start_y = block_start_y;
+  do {
+    int start_x = block_start_x;
+    do {
+      const int src_x = (start_x + 4) << subsampling_x;
+      const int src_y = (start_y + 4) << subsampling_y;
+      const int dst_x =
+          src_x * warp_params[2] + src_y * warp_params[3] + warp_params[0];
+      const int dst_y =
+          src_x * warp_params[4] + src_y * warp_params[5] + warp_params[1];
+      const int x4 = dst_x >> subsampling_x;
+      const int y4 = dst_y >> subsampling_y;
+      const int ix4 = x4 >> kWarpedModelPrecisionBits;
+      const int iy4 = y4 >> kWarpedModelPrecisionBits;
+
+      // Horizontal filter.
+      int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7;
+      for (int y = -7; y < 8; ++y) {
+        // TODO(chengchen):
+        // Because of warping, the index could be out of frame boundary. Thus
+        // clip is needed. However, can we remove or reduce usage of clip?
+        // Besides, special cases exist, for example,
+        // if iy4 - 7 >= source_height or iy4 + 7 < 0, there's no need to do the
+        // filtering.
+        const int row = Clip3(iy4 + y, 0, source_height - 1);
+        const uint8_t* const src_row = src + row * source_stride;
+        // Check for two simple special cases.
+        if (ix4 - 7 >= source_width - 1) {
+          // Every sample is equal to src_row[source_width - 1]. Since the sum
+          // of the warped filter coefficients is 128 (= 2^7), the filtering is
+          // equivalent to multiplying src_row[source_width - 1] by 128.
+          const int16_t s =
+              (horizontal_offset >> kInterRoundBitsHorizontal) +
+              (src_row[source_width - 1] << (7 - kInterRoundBitsHorizontal));
+          const int16x8_t sum = vdupq_n_s16(s);
+          vst1q_s16(&intermediate_result[y + 7][0], sum);
+          sx4 += beta;
+          continue;
+        }
+        if (ix4 + 7 <= 0) {
+          // Every sample is equal to src_row[0]. Since the sum of the warped
+          // filter coefficients is 128 (= 2^7), the filtering is equivalent to
+          // multiplying src_row[0] by 128.
+          const int16_t s = (horizontal_offset >> kInterRoundBitsHorizontal) +
+                            (src_row[0] << (7 - kInterRoundBitsHorizontal));
+          const int16x8_t sum = vdupq_n_s16(s);
+          vst1q_s16(&intermediate_result[y + 7][0], sum);
+          sx4 += beta;
+          continue;
+        }
+        // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also
+        // read but is ignored.
+        //
+        // NOTE: This may read up to 13 bytes before src_row[0] or up to 14
+        // bytes after src_row[source_width - 1]. There must be enough padding
+        // before and after the |source| buffer.
+        static_assert(kBorderPixels >= 14, "");
+        uint8x16_t src_row_u8 = vld1q_u8(&src_row[ix4 - 7]);
+        // If clipping is needed, duplicate the border samples.
+        //
+        // Here is the correspondence between the index for the src_row
+        // buffer and the index for the src_row_u8 vector:
+        //
+        // src_row index    : (ix4 - 7) (ix4 - 6) ... (ix4 + 6) (ix4 + 7)
+        // src_row_u8 index :     0         1     ...     13        14
+        if (ix4 - 7 < 0) {
+          const int out_of_boundary_left = -(ix4 - 7);
+          const uint8x16_t cmp_vec = vdupq_n_u8(out_of_boundary_left);
+          const uint8x16_t vec_dup = vdupq_n_u8(src_row[0]);
+          const uint8x16_t mask_val = vcltq_u8(index_vec, cmp_vec);
+          src_row_u8 = vbslq_u8(mask_val, vec_dup, src_row_u8);
+        }
+        if (ix4 + 7 > source_width - 1) {
+          const int out_of_boundary_right = ix4 + 8 - source_width;
+          const uint8x16_t cmp_vec = vdupq_n_u8(14 - out_of_boundary_right);
+          const uint8x16_t vec_dup = vdupq_n_u8(src_row[source_width - 1]);
+          const uint8x16_t mask_val = vcgtq_u8(index_vec, cmp_vec);
+          src_row_u8 = vbslq_u8(mask_val, vec_dup, src_row_u8);
+        }
+        const int16x8_t src_row_low_s16 =
+            vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(src_row_u8)));
+        const int16x8_t src_row_high_s16 =
+            vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(src_row_u8)));
+        int sx = sx4 - MultiplyBy4(alpha);
+        int16x8_t filter[8];
+        for (int x = 0; x < 8; ++x) {
+          const int offset =
+              RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) +
+              kWarpedPixelPrecisionShifts;
+          filter[x] = vld1q_s16(kWarpedFilters[offset]);
+          sx += alpha;
+        }
+        Transpose8x8(&filter[0], &filter[1], &filter[2], &filter[3], &filter[4],
+                     &filter[5], &filter[6], &filter[7]);
+        // For 8 bit, the range of sum is within uint16_t, if we add a
+        // horizontal offset. horizontal_offset guarantees sum is nonnegative.
+        //
+        // Proof:
+        // Given that the minimum (most negative) sum of the negative filter
+        // coefficients is -47 and the maximum sum of the positive filter
+        // coefficients is 175, the range of the horizontal filter output is
+        //   -47 * 255 <= output <= 175 * 255
+        // Since -2^14 < -47 * 255, adding -2^14 (= horizontal_offset) to the
+        // horizontal filter output produces a positive value:
+        //   0 < output + 2^14 <= 175 * 255 + 2^14
+        // The final rounding right shift by 3 (= kInterRoundBitsHorizontal)
+        // bits adds 2^2 to the sum:
+        //   0 < output + 2^14 + 2^2 <= 175 * 255 + 2^14 + 2^2 = 61013
+        // Since 61013 < 2^16, the final sum (right before the right shift by 3
+        // bits) will not overflow uint16_t. In addition, the value after the
+        // right shift by 3 bits is in the following range:
+        //   0 <= intermediate_result[y][x] < 2^13
+        // This property is used in determining the range of the vertical
+        // filtering output. [End of proof.]
+        //
+        // We can do signed int16_t arithmetic and just treat the final result
+        // as uint16_t when we shift it right.
+        int16x8_t sum = vdupq_n_s16(horizontal_offset);
+        // Unrolled k = 0..7 loop. We need to manually unroll the loop because
+        // the third argument (an index value) to vextq_s16() must be a
+        // constant (immediate).
+        // k = 0.
+        int16x8_t src_row_v_s16 = src_row_low_s16;
+        sum = vmlaq_s16(sum, filter[0], src_row_v_s16);
+        // k = 1.
+        src_row_v_s16 = vextq_s16(src_row_low_s16, src_row_high_s16, 1);
+        sum = vmlaq_s16(sum, filter[1], src_row_v_s16);
+        // k = 2.
+        src_row_v_s16 = vextq_s16(src_row_low_s16, src_row_high_s16, 2);
+        sum = vmlaq_s16(sum, filter[2], src_row_v_s16);
+        // k = 3.
+        src_row_v_s16 = vextq_s16(src_row_low_s16, src_row_high_s16, 3);
+        sum = vmlaq_s16(sum, filter[3], src_row_v_s16);
+        // k = 4.
+        src_row_v_s16 = vextq_s16(src_row_low_s16, src_row_high_s16, 4);
+        sum = vmlaq_s16(sum, filter[4], src_row_v_s16);
+        // k = 5.
+        src_row_v_s16 = vextq_s16(src_row_low_s16, src_row_high_s16, 5);
+        sum = vmlaq_s16(sum, filter[5], src_row_v_s16);
+        // k = 6.
+        src_row_v_s16 = vextq_s16(src_row_low_s16, src_row_high_s16, 6);
+        sum = vmlaq_s16(sum, filter[6], src_row_v_s16);
+        // k = 7.
+        src_row_v_s16 = vextq_s16(src_row_low_s16, src_row_high_s16, 7);
+        sum = vmlaq_s16(sum, filter[7], src_row_v_s16);
+        // End of unrolled k = 0..7 loop.
+        // Treat sum as unsigned for the right shift.
+        sum = vreinterpretq_s16_u16(vrshrq_n_u16(vreinterpretq_u16_s16(sum),
+                                                 kInterRoundBitsHorizontal));
+        vst1q_s16(&intermediate_result[y + 7][0], sum);
+        sx4 += beta;
+      }
+
+      // Vertical filter.
+      uint16_t* dst_row = dest + start_x - block_start_x;
+      int sy4 =
+          (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta);
+      for (int y = -4; y < 4; ++y) {
+        int sy = sy4 - MultiplyBy4(gamma);
+        int16x8_t filter[8];
+        for (int x = 0; x < 8; ++x) {
+          const int offset =
+              RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
+              kWarpedPixelPrecisionShifts;
+          filter[x] = vld1q_s16(kWarpedFilters[offset]);
+          sy += gamma;
+        }
+        Transpose8x8(&filter[0], &filter[1], &filter[2], &filter[3], &filter[4],
+                     &filter[5], &filter[6], &filter[7]);
+        // Similar to horizontal_offset, vertical_offset guarantees sum before
+        // shifting is nonnegative.
+        //
+        // Proof:
+        // The range of an entry in intermediate_result is
+        //   0 <= intermediate_result[y][x] < 2^13
+        // The range of the vertical filter output is
+        //   -47 * 2^13 < output < 175 * 2^13
+        // Since -2^19 < -47 * 2^13, adding -2^19 (= vertical_offset) to the
+        // vertical filter output produces a positive value:
+        //   0 < output + 2^19 < 175 * 2^13 + 2^19
+        // The final rounding right shift by either 7 or 11 bits adds at most
+        // 2^10 to the sum:
+        //   0 < output + 2^19 + rounding < 175 * 2^13 + 2^19 + 2^10 = 1958912
+        // Since 1958912 = 0x1DE400 < 2^22, shifting it right by 7 or 11 bits
+        // brings the value under 2^15, which fits in uint16_t.
+        int32x4_t sum_low = vdupq_n_s32(vertical_offset);
+        int32x4_t sum_high = sum_low;
+        for (int k = 0; k < 8; ++k) {
+          const int16x8_t intermediate =
+              vld1q_s16(intermediate_result[y + 4 + k]);
+          sum_low = vmlal_s16(sum_low, vget_low_s16(filter[k]),
+                              vget_low_s16(intermediate));
+          sum_high = vmlal_s16(sum_high, vget_high_s16(filter[k]),
+                               vget_high_s16(intermediate));
+        }
+        assert(inter_round_bits_vertical == 7 ||
+               inter_round_bits_vertical == 11);
+        // Since inter_round_bits_vertical can be 7 or 11, and all the narrowing
+        // shift intrinsics require the shift argument to be a constant
+        // (immediate), we have two options:
+        // 1. Call a non-narrowing shift, followed by a narrowing extract.
+        // 2. Call a narrowing shift (with a constant shift of 7 or 11) in an
+        //    if-else statement.
+#if defined(__aarch64__)
+        // This version is slightly faster for arm64 (1106 ms vs 1112 ms).
+        // This version is slower for 32-bit arm (1235 ms vs 1149 ms).
+        const int32x4_t shift = vdupq_n_s32(-inter_round_bits_vertical);
+        const uint32x4_t sum_low_shifted =
+            vrshlq_u32(vreinterpretq_u32_s32(sum_low), shift);
+        const uint32x4_t sum_high_shifted =
+            vrshlq_u32(vreinterpretq_u32_s32(sum_high), shift);
+        const uint16x4_t sum_low_16 = vmovn_u32(sum_low_shifted);
+        const uint16x4_t sum_high_16 = vmovn_u32(sum_high_shifted);
+#else   // !defined(__aarch64__)
+        // This version is faster for 32-bit arm.
+        // This version is slightly slower for arm64.
+        uint16x4_t sum_low_16;
+        uint16x4_t sum_high_16;
+        if (inter_round_bits_vertical == 7) {
+          sum_low_16 = vrshrn_n_u32(vreinterpretq_u32_s32(sum_low), 7);
+          sum_high_16 = vrshrn_n_u32(vreinterpretq_u32_s32(sum_high), 7);
+        } else {
+          sum_low_16 = vrshrn_n_u32(vreinterpretq_u32_s32(sum_low), 11);
+          sum_high_16 = vrshrn_n_u32(vreinterpretq_u32_s32(sum_high), 11);
+        }
+#endif  // defined(__aarch64__)
+        // vst1q_u16 can also be used:
+        //   vst1q_u16(dst_row, vcombine_u16(sum_low_16, sum_high_16));
+        // But it is slightly slower for arm64 (the same speed for 32-bit arm).
+        //
+        // vst1_u16_x2 could be used, but it is also slightly slower for arm64
+        // and causes a bus error for 32-bit arm. Also, it is not supported by
+        // gcc 7.2.0.
+        vst1_u16(dst_row, sum_low_16);
+        vst1_u16(dst_row + 4, sum_high_16);
+        dst_row += dest_stride;
+        sy4 += delta;
+      }
+      start_x += 8;
+    } while (start_x < block_start_x + block_width);
+    dest += 8 * dest_stride;
+    start_y += 8;
+  } while (start_y < block_start_y + block_height);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+  dsp->warp = Warp_NEON;
+}
+
+}  // namespace
+}  // namespace low_bitdepth
+
+void WarpInit_NEON() { low_bitdepth::Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+#else   // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void WarpInit_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/warp_neon.h b/libgav1/src/dsp/arm/warp_neon.h
new file mode 100644
index 0000000..c5e1bc0
--- /dev/null
+++ b/libgav1/src/dsp/arm/warp_neon.h
@@ -0,0 +1,20 @@
+#ifndef LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_
+
+#include "src/dsp/cpu.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::warp. This function is not thread-safe.
+void WarpInit_NEON();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_Warp LIBGAV1_DSP_NEON
+#endif  // LIBGAV1_ENABLE_NEON
+
+#endif  // LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_
diff --git a/libgav1/src/dsp/average_blend.cc b/libgav1/src/dsp/average_blend.cc
index 4b2b7d7..a5de6d3 100644
--- a/libgav1/src/dsp/average_blend.cc
+++ b/libgav1/src/dsp/average_blend.cc
@@ -12,26 +12,29 @@
 namespace {
 
 template <int bitdepth, typename Pixel>
-void AverageBlending_C(const uint16_t* prediction_0,
-                       const ptrdiff_t prediction_stride_0,
-                       const uint16_t* prediction_1,
-                       const ptrdiff_t prediction_stride_1,
-                       const int inter_post_round_bit, const int width,
-                       const int height, void* const dest,
-                       const ptrdiff_t dest_stride) {
+void AverageBlend_C(const uint16_t* prediction_0,
+                    const ptrdiff_t prediction_stride_0,
+                    const uint16_t* prediction_1,
+                    const ptrdiff_t prediction_stride_1, const int width,
+                    const int height, void* const dest,
+                    const ptrdiff_t dest_stride) {
   // An offset to cancel offsets used in compound predictor generation that
   // make intermediate computations non negative.
-  const int compound_round_offset =
-      (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
+  constexpr int compound_round_offset =
+      (2 << (bitdepth + 4)) + (2 << (bitdepth + 3));
+  // 7.11.3.2 Rounding variables derivation process
+  //   2 * FILTER_BITS(7) - (InterRound0(3|5) + InterRound1(7))
+  constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
   auto* dst = static_cast<Pixel*>(dest);
   const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel);
 
   for (int y = 0; y < height; ++y) {
     for (int x = 0; x < width; ++x) {
-      int res = (prediction_0[x] + prediction_1[x]) >> 1;
+      // prediction range: 8bpp: [0, 15471] 10bpp: [0, 61983] 12bpp: [0, 62007].
+      int res = prediction_0[x] + prediction_1[x];
       res -= compound_round_offset;
       dst[x] = static_cast<Pixel>(
-          Clip3(RightShiftWithRounding(res, inter_post_round_bit), 0,
+          Clip3(RightShiftWithRounding(res, inter_post_round_bits + 1), 0,
                 (1 << bitdepth) - 1));
     }
     dst += dst_stride;
@@ -43,14 +46,30 @@
 void Init8bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
   assert(dsp != nullptr);
-  dsp->average_blend = AverageBlending_C<8, uint8_t>;
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  dsp->average_blend = AverageBlend_C<8, uint8_t>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp8bpp_AverageBlend
+  dsp->average_blend = AverageBlend_C<8, uint8_t>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
 
 #if LIBGAV1_MAX_BITDEPTH >= 10
 void Init10bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
   assert(dsp != nullptr);
-  dsp->average_blend = AverageBlending_C<10, uint16_t>;
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+#ifndef LIBGAV1_Dsp10bpp_AverageBlend
+  dsp->average_blend = AverageBlend_C<10, uint16_t>;
+#endif
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp10bpp_AverageBlend
+  dsp->average_blend = AverageBlend_C<10, uint16_t>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
 #endif
 
diff --git a/libgav1/src/dsp/average_blend.h b/libgav1/src/dsp/average_blend.h
index e126d39..f91a102 100644
--- a/libgav1/src/dsp/average_blend.h
+++ b/libgav1/src/dsp/average_blend.h
@@ -1,11 +1,28 @@
 #ifndef LIBGAV1_SRC_DSP_AVERAGE_BLEND_H_
 #define LIBGAV1_SRC_DSP_AVERAGE_BLEND_H_
 
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
+
+// ARM:
+#include "src/dsp/arm/average_blend_neon.h"
+
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
+#include "src/dsp/x86/average_blend_sse4.h"
+// clang-format on
+
+// IWYU pragma: end_exports
+
 namespace libgav1 {
 namespace dsp {
 
-// Initializes Dsp::average_blend.
-// This function is not thread-safe.
+// Initializes Dsp::average_blend. This function is not thread-safe.
 void AverageBlendInit_C();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/cdef.cc b/libgav1/src/dsp/cdef.cc
index e3095d5..e412342 100644
--- a/libgav1/src/dsp/cdef.cc
+++ b/libgav1/src/dsp/cdef.cc
@@ -13,6 +13,8 @@
 namespace dsp {
 namespace {
 
+constexpr uint16_t kCdefLargeValue = 30000;
+
 constexpr int16_t kDivisionTable[] = {0,   840, 420, 280, 210,
                                       168, 140, 120, 105};
 
@@ -32,14 +34,6 @@
          Clip3(threshold - (std::abs(diff) >> damping), 0, std::abs(diff));
 }
 
-// 5.11.52.
-bool InsideFrame(int x, int y, int subsampling_x, int subsampling_y,
-                 int rows4x4, int columns4x4) {
-  const int row = DivideBy4(LeftShift(y, subsampling_y));
-  const int column = DivideBy4(LeftShift(x, subsampling_x));
-  return row >= 0 && row < rows4x4 && column >= 0 && column < columns4x4;
-}
-
 int32_t Square(int32_t x) { return x * x; }
 
 template <int bitdepth, typename Pixel>
@@ -102,53 +96,41 @@
   *variance = (best_cost - cost[(*direction + 4) & 7]) >> 10;
 }
 
+// Filters the source block. It doesn't check whether the candidate pixel is
+// inside the frame. However it requires the source input to be padded with a
+// constant large value if at the boundary. And the input should be uint16_t.
 template <int bitdepth, typename Pixel>
-void CdefFiltering_C(const void* const source, const ptrdiff_t source_stride,
-                     const int rows4x4, const int columns4x4, const int curr_x,
-                     const int curr_y, const int subsampling_x,
-                     const int subsampling_y, const int primary_strength,
-                     const int secondary_strength, const int damping,
-                     const int direction, void* const dest,
-                     const ptrdiff_t dest_stride) {
+void CdefFilter_C(const void* const source, const ptrdiff_t source_stride,
+                  const int rows4x4, const int columns4x4, const int curr_x,
+                  const int curr_y, const int subsampling_x,
+                  const int subsampling_y, const int primary_strength,
+                  const int secondary_strength, const int damping,
+                  const int direction, void* const dest,
+                  const ptrdiff_t dest_stride) {
   const int coeff_shift = bitdepth - 8;
   const int plane_width = MultiplyBy4(columns4x4) >> subsampling_x;
   const int plane_height = MultiplyBy4(rows4x4) >> subsampling_y;
   const int block_width = std::min(8 >> subsampling_x, plane_width - curr_x);
   const int block_height = std::min(8 >> subsampling_y, plane_height - curr_y);
-  const auto* src = static_cast<const Pixel*>(source);
-  const ptrdiff_t src_stride = source_stride / sizeof(Pixel);
+  const auto* src = static_cast<const uint16_t*>(source);
   auto* dst = static_cast<Pixel*>(dest);
   const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel);
   for (int y = 0; y < block_height; ++y) {
     for (int x = 0; x < block_width; ++x) {
       int16_t sum = 0;
-      const Pixel pixel_value = src[x];
-      Pixel max_value = pixel_value;
-      Pixel min_value = pixel_value;
+      const uint16_t pixel_value = src[x];
+      uint16_t max_value = pixel_value;
+      uint16_t min_value = pixel_value;
       for (int k = 0; k < 2; ++k) {
         const int signs[] = {-1, 1};
         for (const int& sign : signs) {
           int dy = sign * kCdefDirections[direction][k][0];
           int dx = sign * kCdefDirections[direction][k][1];
-          int y0 = curr_y + y + dy;
-          int x0 = curr_x + x + dx;
-          // TODO(chengchen): Optimize cdef data fetching.
-          // Cdef needs to get pixel values from 3x3 neighborhood.
-          // It could happen that the target position is out of the frame.
-          // When it's out of frame, that pixel should not be taken into
-          // calculation.
-          // In libaom's implementation, borders are padded around the whole
-          // frame such that out of frame access gets a large value. The
-          // large value is defined as 30000. This implementation has a problem
-          // because 8-bit input can't represent 30000. It has to allocate a
-          // 16-bit frame buffer to set large values for the borders.
-          // In this implementation, we detect whether it's out of frame,
-          // which is not friendly for SIMD implementation.
-          // We can avoid the extra frame buffer by allocating a 16-bit block
-          // buffer, like the implementation of loop restoration.
-          if (InsideFrame(x0, y0, subsampling_x, subsampling_y, rows4x4,
-                          columns4x4)) {
-            const Pixel value = src[dy * src_stride + dx + x];
+          uint16_t value = src[dy * source_stride + dx + x];
+          // Note: the summation can ignore the condition check in SIMD
+          // implementation, because Constrain() will return 0 when
+          // value == kCdefLargeValue.
+          if (value != kCdefLargeValue) {
             sum += Constrain(value - pixel_value, primary_strength, damping) *
                    kPrimaryTaps[(primary_strength >> coeff_shift) & 1][k];
             max_value = std::max(value, max_value);
@@ -158,11 +140,10 @@
           for (const int& offset : offsets) {
             dy = sign * kCdefDirections[(direction + offset) & 7][k][0];
             dx = sign * kCdefDirections[(direction + offset) & 7][k][1];
-            y0 = curr_y + y + dy;
-            x0 = curr_x + x + dx;
-            if (InsideFrame(x0, y0, subsampling_x, subsampling_y, rows4x4,
-                            columns4x4)) {
-              const Pixel value = src[dy * src_stride + dx + x];
+            value = src[dy * source_stride + dx + x];
+            // Note: the summation can ignore the condition check in SIMD
+            // implementation.
+            if (value != kCdefLargeValue) {
               sum +=
                   Constrain(value - pixel_value, secondary_strength, damping) *
                   kSecondaryTaps[(primary_strength >> coeff_shift) & 1][k];
@@ -173,10 +154,10 @@
         }
       }
 
-      dst[x] = Clip3(pixel_value + ((8 + sum - (sum < 0)) >> 4), min_value,
-                     max_value);
+      dst[x] = static_cast<Pixel>(Clip3(
+          pixel_value + ((8 + sum - (sum < 0)) >> 4), min_value, max_value));
     }
-    src += src_stride;
+    src += source_stride;
     dst += dst_stride;
   }
 }
@@ -186,11 +167,14 @@
   assert(dsp != nullptr);
 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   dsp->cdef_direction = CdefDirection_C<8, uint8_t>;
-  dsp->cdef_filter = CdefFiltering_C<8, uint8_t>;
+  dsp->cdef_filter = CdefFilter_C<8, uint8_t>;
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
 #ifndef LIBGAV1_Dsp8bpp_CdefDirection
   dsp->cdef_direction = CdefDirection_C<8, uint8_t>;
-  dsp->cdef_filter = CdefFiltering_C<8, uint8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_CdefFilter
+  dsp->cdef_filter = CdefFilter_C<8, uint8_t>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
@@ -201,11 +185,14 @@
   assert(dsp != nullptr);
 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   dsp->cdef_direction = CdefDirection_C<10, uint16_t>;
-  dsp->cdef_filter = CdefFiltering_C<10, uint16_t>;
+  dsp->cdef_filter = CdefFilter_C<10, uint16_t>;
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
 #ifndef LIBGAV1_Dsp10bpp_CdefDirection
   dsp->cdef_direction = CdefDirection_C<10, uint16_t>;
-  dsp->cdef_filter = CdefFiltering_C<10, uint16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_CdefFilter
+  dsp->cdef_filter = CdefFilter_C<10, uint16_t>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
diff --git a/libgav1/src/dsp/cdef.h b/libgav1/src/dsp/cdef.h
index 5da4aba..3c1ab80 100644
--- a/libgav1/src/dsp/cdef.h
+++ b/libgav1/src/dsp/cdef.h
@@ -4,7 +4,8 @@
 namespace libgav1 {
 namespace dsp {
 
-// Initializes Dsp::cdef_direction. This function is not thread-safe.
+// Initializes Dsp::cdef_direction and cdef::filter. This function is not
+// thread-safe.
 void CdefInit_C();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/common.h b/libgav1/src/dsp/common.h
index e9842a2..ddd67e4 100644
--- a/libgav1/src/dsp/common.h
+++ b/libgav1/src/dsp/common.h
@@ -6,6 +6,7 @@
 
 #include "src/dsp/constants.h"
 #include "src/utils/constants.h"
+#include "src/utils/memory.h"
 
 namespace libgav1 {
 
@@ -24,10 +25,10 @@
   static const int kVertical = 0;
   static const int kHorizontal = 1;
 
-  alignas(16) int16_t filter[2][kSubPixelTaps];
+  alignas(kMaxAlignment) int16_t filter[2][kSubPixelTaps];
 };
 
-struct RestorationUnitInfo {
+struct RestorationUnitInfo : public Allocable {
   LoopRestorationType type;
   SgrProjInfo sgr_proj_info;
   WienerInfo wiener_info;
@@ -42,7 +43,6 @@
   // For wiener filter.
   uint16_t* wiener_buffer;
   ptrdiff_t wiener_buffer_stride;
-  int inter_round_bits[2];
 };
 
 // Section 6.8.20.
diff --git a/libgav1/src/dsp/constants.h b/libgav1/src/dsp/constants.h
index f042150..0889f8f 100644
--- a/libgav1/src/dsp/constants.h
+++ b/libgav1/src/dsp/constants.h
@@ -15,6 +15,11 @@
   // 2^kSmoothWeightScale.
   kSmoothWeightScale = 8,
   kCflLumaBufferStride = 32,
+  // InterRound0, Section 7.11.3.2.
+  kInterRoundBitsHorizontal = 3,  // 8 & 10-bit.
+  kInterRoundBitsHorizontal12bpp = 5,
+  kInterRoundBitsVertical = 11,  // 8 & 10-bit, single prediction.
+  kInterRoundBitsVertical12bpp = 9,
 };  // anonymous enum
 
 extern const int8_t kFilterIntraTaps[kNumFilterIntraPredictors][8][8];
diff --git a/libgav1/src/dsp/convolve.cc b/libgav1/src/dsp/convolve.cc
index 0fdebb7..8c409c8 100644
--- a/libgav1/src/dsp/convolve.cc
+++ b/libgav1/src/dsp/convolve.cc
@@ -33,12 +33,15 @@
 }
 
 template <int bitdepth, typename Pixel>
-void Convolve2DScaleSingle_C(
+void ConvolveScale2D_C(
     const void* const reference, const ptrdiff_t reference_stride,
     const int horizontal_filter_index, const int vertical_filter_index,
-    const uint8_t inter_round_bits[2], const int subpixel_x,
+    const uint8_t inter_round_bits_vertical, const int subpixel_x,
     const int subpixel_y, const int step_x, const int step_y, const int width,
     const int height, void* prediction, const ptrdiff_t pred_stride) {
+  constexpr int kRoundBitsHorizontal = (bitdepth == 12)
+                                           ? kInterRoundBitsHorizontal12bpp
+                                           : kInterRoundBitsHorizontal;
   const int intermediate_height =
       (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
        kScaleSubPixelBits) +
@@ -76,7 +79,7 @@
       }
       assert(sum >= 0 && sum < (1 << (bitdepth + kFilterBits + 1)));
       intermediate[x] = static_cast<int16_t>(
-          RightShiftWithRounding(sum, inter_round_bits[0]));
+          RightShiftWithRounding(sum, kRoundBitsHorizontal));
     }
     src += src_stride;
     intermediate += intermediate_stride;
@@ -84,7 +87,7 @@
   // Vertical filter.
   filter_index = GetFilterIndex(vertical_filter_index, height);
   intermediate = intermediate_result;
-  const int offset_bits = bitdepth + 2 * kFilterBits - inter_round_bits[0];
+  const int offset_bits = bitdepth + 2 * kFilterBits - kRoundBitsHorizontal;
   for (int y = 0, p = subpixel_y & 1023; y < height; ++y, p += step_y) {
     const int filter_id = (p >> 6) & kSubPixelMask;
     for (int x = 0; x < width; ++x) {
@@ -98,7 +101,7 @@
       }
       assert(sum >= 0 && sum < (1 << (offset_bits + 2)));
       dest[x] = static_cast<Pixel>(
-          Clip3(RightShiftWithRounding(sum, inter_round_bits[1]) -
+          Clip3(RightShiftWithRounding(sum, inter_round_bits_vertical) -
                     single_round_offset,
                 0, max_pixel_value));
     }
@@ -107,14 +110,15 @@
 }
 
 template <int bitdepth, typename Pixel>
-void Convolve2DScale_C(const void* const reference,
-                       const ptrdiff_t reference_stride,
-                       const int horizontal_filter_index,
-                       const int vertical_filter_index,
-                       const uint8_t inter_round_bits[2], const int subpixel_x,
-                       const int subpixel_y, const int step_x, const int step_y,
-                       const int width, const int height,
-                       void* prediction, const ptrdiff_t pred_stride) {
+void ConvolveCompoundScale2D_C(
+    const void* const reference, const ptrdiff_t reference_stride,
+    const int horizontal_filter_index, const int vertical_filter_index,
+    const uint8_t inter_round_bits_vertical, const int subpixel_x,
+    const int subpixel_y, const int step_x, const int step_y, const int width,
+    const int height, void* prediction, const ptrdiff_t pred_stride) {
+  constexpr int kRoundBitsHorizontal = (bitdepth == 12)
+                                           ? kInterRoundBitsHorizontal12bpp
+                                           : kInterRoundBitsHorizontal;
   const int intermediate_height =
       (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
        kScaleSubPixelBits) +
@@ -149,7 +153,7 @@
       }
       assert(sum >= 0 && sum < (1 << (bitdepth + kFilterBits + 1)));
       intermediate[x] = static_cast<int16_t>(
-          RightShiftWithRounding(sum, inter_round_bits[0]));
+          RightShiftWithRounding(sum, kRoundBitsHorizontal));
     }
     src += src_stride;
     intermediate += intermediate_stride;
@@ -157,7 +161,7 @@
   // Vertical filter.
   filter_index = GetFilterIndex(vertical_filter_index, height);
   intermediate = intermediate_result;
-  const int offset_bits = bitdepth + 2 * kFilterBits - inter_round_bits[0];
+  const int offset_bits = bitdepth + 2 * kFilterBits - kRoundBitsHorizontal;
   for (int y = 0, p = subpixel_y & 1023; y < height; ++y, p += step_y) {
     const int filter_id = (p >> 6) & kSubPixelMask;
     for (int x = 0; x < width; ++x) {
@@ -171,25 +175,30 @@
       }
       assert(sum >= 0 && sum < (1 << (offset_bits + 2)));
       dest[x] = static_cast<uint16_t>(
-          RightShiftWithRounding(sum, inter_round_bits[1]));
+          RightShiftWithRounding(sum, inter_round_bits_vertical));
     }
     dest += pred_stride;
   }
 }
 
 template <int bitdepth, typename Pixel>
-void Convolve2D_C(const void* const reference, const ptrdiff_t reference_stride,
-                  const int horizontal_filter_index,
-                  const int vertical_filter_index,
-                  const uint8_t inter_round_bits[2], const int subpixel_x,
-                  const int subpixel_y, const int /*step_x*/,
-                  const int /*step_y*/, const int width, const int height,
-                  void* prediction, const ptrdiff_t pred_stride) {
+void ConvolveCompound2D_C(const void* const reference,
+                          const ptrdiff_t reference_stride,
+                          const int horizontal_filter_index,
+                          const int vertical_filter_index,
+                          const uint8_t inter_round_bits_vertical,
+                          const int subpixel_x, const int subpixel_y,
+                          const int /*step_x*/, const int /*step_y*/,
+                          const int width, const int height, void* prediction,
+                          const ptrdiff_t pred_stride) {
+  constexpr int kRoundBitsHorizontal = (bitdepth == 12)
+                                           ? kInterRoundBitsHorizontal12bpp
+                                           : kInterRoundBitsHorizontal;
   const int intermediate_height = height + kSubPixelTaps - 1;
   // The output of the horizontal filter, i.e. the intermediate_result, is
   // guaranteed to fit in int16_t.
   int16_t intermediate_result[kMaxSuperBlockSizeInPixels *
-                              (2 * kMaxSuperBlockSizeInPixels + 8)];
+                              (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
   const int intermediate_stride = kMaxSuperBlockSizeInPixels;
 
   // Horizontal filter.
@@ -213,7 +222,7 @@
       }
       assert(sum >= 0 && sum < (1 << (bitdepth + kFilterBits + 1)));
       intermediate[x] = static_cast<int16_t>(
-          RightShiftWithRounding(sum, inter_round_bits[0]));
+          RightShiftWithRounding(sum, kRoundBitsHorizontal));
     }
     src += src_stride;
     intermediate += intermediate_stride;
@@ -222,7 +231,7 @@
   filter_index = GetFilterIndex(vertical_filter_index, height);
   intermediate = intermediate_result;
   filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
-  const int offset_bits = bitdepth + 2 * kFilterBits - inter_round_bits[0];
+  const int offset_bits = bitdepth + 2 * kFilterBits - kRoundBitsHorizontal;
   for (int y = 0; y < height; ++y) {
     for (int x = 0; x < width; ++x) {
       // An offset to guarantee the sum is non negative.
@@ -233,32 +242,34 @@
       }
       assert(sum >= 0 && sum < (1 << (offset_bits + 2)));
       dest[x] = static_cast<uint16_t>(
-          RightShiftWithRounding(sum, inter_round_bits[1]));
+          RightShiftWithRounding(sum, inter_round_bits_vertical));
     }
     dest += pred_stride;
     intermediate += intermediate_stride;
   }
 }
 
-// This function is a simplified version of Convolve2D_C.
-// It is called when it is single prediction mode, where only horizontal
-// filtering is required.
+// This function is a simplified version of ConvolveCompound2D_C.
+// It is called when it is single prediction mode, where both horizontal and
+// vertical filtering are required.
 // The output is the single prediction of the block, clipped to valid pixel
 // range.
 template <int bitdepth, typename Pixel>
-void Convolve2DSingle_C(const void* const reference,
-                        const ptrdiff_t reference_stride,
-                        const int horizontal_filter_index,
-                        const int vertical_filter_index,
-                        const uint8_t inter_round_bits[2], const int subpixel_x,
-                        const int subpixel_y, const int /*step_x*/,
-                        const int /*step_y*/, const int width, const int height,
-                        void* prediction, const ptrdiff_t pred_stride) {
+void Convolve2D_C(const void* const reference, const ptrdiff_t reference_stride,
+                  const int horizontal_filter_index,
+                  const int vertical_filter_index,
+                  const uint8_t inter_round_bits_vertical, const int subpixel_x,
+                  const int subpixel_y, const int /*step_x*/,
+                  const int /*step_y*/, const int width, const int height,
+                  void* prediction, const ptrdiff_t pred_stride) {
+  constexpr int kRoundBitsHorizontal = (bitdepth == 12)
+                                           ? kInterRoundBitsHorizontal12bpp
+                                           : kInterRoundBitsHorizontal;
   const int intermediate_height = height + kSubPixelTaps - 1;
   // The output of the horizontal filter, i.e. the intermediate_result, is
   // guaranteed to fit in int16_t.
   int16_t intermediate_result[kMaxSuperBlockSizeInPixels *
-                              (2 * kMaxSuperBlockSizeInPixels + 8)];
+                              (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
   const int intermediate_stride = kMaxSuperBlockSizeInPixels;
   const int single_round_offset = (1 << bitdepth) + (1 << (bitdepth - 1));
   const int max_pixel_value = (1 << bitdepth) - 1;
@@ -285,7 +296,7 @@
       }
       assert(sum >= 0 && sum < (1 << (bitdepth + kFilterBits + 1)));
       intermediate[x] = static_cast<int16_t>(
-          RightShiftWithRounding(sum, inter_round_bits[0]));
+          RightShiftWithRounding(sum, kRoundBitsHorizontal));
     }
     src += src_stride;
     intermediate += intermediate_stride;
@@ -294,7 +305,7 @@
   filter_index = GetFilterIndex(vertical_filter_index, height);
   intermediate = intermediate_result;
   filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
-  const int offset_bits = bitdepth + 2 * kFilterBits - inter_round_bits[0];
+  const int offset_bits = bitdepth + 2 * kFilterBits - kRoundBitsHorizontal;
   for (int y = 0; y < height; ++y) {
     for (int x = 0; x < width; ++x) {
       // An offset to guarantee the sum is non negative.
@@ -305,7 +316,7 @@
       }
       assert(sum >= 0 && sum < (1 << (offset_bits + 2)));
       dest[x] = static_cast<Pixel>(
-          Clip3(RightShiftWithRounding(sum, inter_round_bits[1]) -
+          Clip3(RightShiftWithRounding(sum, inter_round_bits_vertical) -
                     single_round_offset,
                 0, max_pixel_value));
     }
@@ -324,13 +335,16 @@
                           const ptrdiff_t reference_stride,
                           const int horizontal_filter_index,
                           const int /*vertical_filter_index*/,
-                          const uint8_t inter_round_bits[2],
+                          const uint8_t /*inter_round_bits_vertical*/,
                           const int subpixel_x, const int /*subpixel_y*/,
                           const int /*step_x*/, const int /*step_y*/,
                           const int width, const int height, void* prediction,
                           const ptrdiff_t pred_stride) {
+  constexpr int kRoundBitsHorizontal = (bitdepth == 12)
+                                           ? kInterRoundBitsHorizontal12bpp
+                                           : kInterRoundBitsHorizontal;
   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
-  const int bits = kFilterBits - inter_round_bits[0];
+  const int bits = kFilterBits - kRoundBitsHorizontal;
   const auto* src = static_cast<const Pixel*>(reference) - kHorizontalOffset;
   const ptrdiff_t src_stride = reference_stride / sizeof(Pixel);
   auto* dest = static_cast<Pixel*>(prediction);
@@ -343,7 +357,7 @@
       for (int k = 0; k < kSubPixelTaps; ++k) {
         sum += kSubPixelFilters[filter_index][filter_id][k] * src[x + k];
       }
-      sum = RightShiftWithRounding(sum, inter_round_bits[0]);
+      sum = RightShiftWithRounding(sum, kRoundBitsHorizontal);
       dest[x] = static_cast<Pixel>(
           Clip3(RightShiftWithRounding(sum, bits), 0, max_pixel_value));
     }
@@ -362,7 +376,7 @@
                         const ptrdiff_t reference_stride,
                         const int /*horizontal_filter_index*/,
                         const int vertical_filter_index,
-                        const uint8_t /*inter_round_bits*/[2],
+                        const uint8_t /*inter_round_bits_vertical*/,
                         const int /*subpixel_x*/, const int subpixel_y,
                         const int /*step_x*/, const int /*step_y*/,
                         const int width, const int height, void* prediction,
@@ -374,6 +388,18 @@
   auto* dest = static_cast<Pixel*>(prediction);
   const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel);
   const int filter_id = (subpixel_y >> 6) & kSubPixelMask;
+  // First filter is always a copy.
+  if (filter_id == 0) {
+    // Move |src| down the actual values and not the start of the context.
+    src = static_cast<const Pixel*>(reference);
+    int y = 0;
+    do {
+      memcpy(dest, src, width * sizeof(src[0]));
+      src += src_stride;
+      dest += dest_stride;
+    } while (++y < height);
+    return;
+  }
   const int max_pixel_value = (1 << bitdepth) - 1;
   for (int y = 0; y < height; ++y) {
     for (int x = 0; x < width; ++x) {
@@ -395,7 +421,7 @@
                     const ptrdiff_t reference_stride,
                     const int /*horizontal_filter_index*/,
                     const int /*vertical_filter_index*/,
-                    const uint8_t /*inter_round_bits*/[2],
+                    const uint8_t /*inter_round_bits_vertical*/,
                     const int /*subpixel_x*/, const int /*subpixel_y*/,
                     const int /*step_x*/, const int /*step_y*/, const int width,
                     const int height, void* prediction,
@@ -414,7 +440,7 @@
                             const ptrdiff_t reference_stride,
                             const int /*horizontal_filter_index*/,
                             const int /*vertical_filter_index*/,
-                            const uint8_t /*inter_round_bits*/[2],
+                            const uint8_t /*inter_round_bits_vertical*/,
                             const int /*subpixel_x*/, const int /*subpixel_y*/,
                             const int /*step_x*/, const int /*step_y*/,
                             const int width, const int height, void* prediction,
@@ -433,7 +459,7 @@
   }
 }
 
-// This function is a simplified version of Convolve2D_C.
+// This function is a simplified version of ConvolveCompound2D_C.
 // It is called when it is compound prediction mode, where only horizontal
 // filtering is required.
 // The output is not clipped to valid pixel range. Its output will be
@@ -442,16 +468,19 @@
 void ConvolveCompoundHorizontal_C(
     const void* const reference, const ptrdiff_t reference_stride,
     const int horizontal_filter_index, const int /*vertical_filter_index*/,
-    const uint8_t inter_round_bits[2], const int subpixel_x,
+    const uint8_t inter_round_bits_vertical, const int subpixel_x,
     const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
     const int width, const int height, void* prediction,
     const ptrdiff_t pred_stride) {
+  constexpr int kRoundBitsHorizontal = (bitdepth == 12)
+                                           ? kInterRoundBitsHorizontal12bpp
+                                           : kInterRoundBitsHorizontal;
   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
   const auto* src = static_cast<const Pixel*>(reference) - kHorizontalOffset;
   const ptrdiff_t src_stride = reference_stride / sizeof(Pixel);
   auto* dest = static_cast<uint16_t*>(prediction);
   const int filter_id = (subpixel_x >> 6) & kSubPixelMask;
-  const int bits_shift = kFilterBits - inter_round_bits[1];
+  const int bits_shift = kFilterBits - inter_round_bits_vertical;
   const int compound_round_offset =
       (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
   for (int y = 0; y < height; ++y) {
@@ -460,7 +489,7 @@
       for (int k = 0; k < kSubPixelTaps; ++k) {
         sum += kSubPixelFilters[filter_index][filter_id][k] * src[x + k];
       }
-      sum = RightShiftWithRounding(sum, inter_round_bits[0]) << bits_shift;
+      sum = RightShiftWithRounding(sum, kRoundBitsHorizontal) << bits_shift;
       dest[x] = sum + compound_round_offset;
     }
     src += src_stride;
@@ -468,7 +497,7 @@
   }
 }
 
-// This function is a simplified version of Convolve2D_C.
+// This function is a simplified version of ConvolveCompound2D_C.
 // It is called when it is compound prediction mode, where only vertical
 // filtering is required.
 // The output is not clipped to valid pixel range. Its output will be
@@ -478,18 +507,21 @@
                                 const ptrdiff_t reference_stride,
                                 const int /*horizontal_filter_index*/,
                                 const int vertical_filter_index,
-                                const uint8_t inter_round_bits[2],
+                                const uint8_t inter_round_bits_vertical,
                                 const int /*subpixel_x*/, const int subpixel_y,
                                 const int /*step_x*/, const int /*step_y*/,
                                 const int width, const int height,
                                 void* prediction, const ptrdiff_t pred_stride) {
+  constexpr int kRoundBitsHorizontal = (bitdepth == 12)
+                                           ? kInterRoundBitsHorizontal12bpp
+                                           : kInterRoundBitsHorizontal;
   const int filter_index = GetFilterIndex(vertical_filter_index, height);
   const ptrdiff_t src_stride = reference_stride / sizeof(Pixel);
   const auto* src =
       static_cast<const Pixel*>(reference) - kVerticalOffset * src_stride;
   auto* dest = static_cast<uint16_t*>(prediction);
   const int filter_id = (subpixel_y >> 6) & kSubPixelMask;
-  const int bits_shift = kFilterBits - inter_round_bits[0];
+  const int bits_shift = kFilterBits - kRoundBitsHorizontal;
   const int compound_round_offset =
       (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
   for (int y = 0; y < height; ++y) {
@@ -500,7 +532,7 @@
                src[k * src_stride + x];
       }
       dest[x] = RightShiftWithRounding(LeftShift(sum, bits_shift),
-                                       inter_round_bits[1]) +
+                                       inter_round_bits_vertical) +
                 compound_round_offset;
     }
     src += src_stride;
@@ -518,7 +550,7 @@
 void ConvolveIntraBlockCopy2D_C(
     const void* const reference, const ptrdiff_t reference_stride,
     const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const uint8_t /*inter_round_bits*/[2], const int /*subpixel_x*/,
+    const uint8_t /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
     const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
     const int width, const int height, void* prediction,
     const ptrdiff_t pred_stride) {
@@ -526,13 +558,14 @@
   const ptrdiff_t src_stride = reference_stride / sizeof(Pixel);
   auto* dest = reinterpret_cast<Pixel*>(prediction);
   const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel);
+  const int intermediate_height = height + 1;
   uint16_t intermediate_result[kMaxSuperBlockSizeInPixels *
-                               (2 * kMaxSuperBlockSizeInPixels + 8)];
+                               (kMaxSuperBlockSizeInPixels + 1)];
   uint16_t* intermediate = intermediate_result;
   // Note: allow vertical access to height + 1. Because this function is only
   // for u/v plane of intra block copy, such access is guaranteed to be within
   // the prediction block.
-  for (int y = 0; y <= height; ++y) {
+  for (int y = 0; y < intermediate_height; ++y) {
     for (int x = 0; x < width; ++x) {
       intermediate[x] = src[x] + src[x + 1];
     }
@@ -562,7 +595,7 @@
 void ConvolveIntraBlockCopy1D_C(
     const void* const reference, const ptrdiff_t reference_stride,
     const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const uint8_t /*inter_round_bits*/[2], const int /*subpixel_x*/,
+    const uint8_t /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
     const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
     const int width, const int height, void* prediction,
     const ptrdiff_t pred_stride) {
@@ -584,15 +617,16 @@
 void Init8bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
   assert(dsp != nullptr);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   dsp->convolve[0][0][0][0] = ConvolveCopy_C<8, uint8_t>;
   dsp->convolve[0][0][0][1] = ConvolveHorizontal_C<8, uint8_t>;
   dsp->convolve[0][0][1][0] = ConvolveVertical_C<8, uint8_t>;
-  dsp->convolve[0][0][1][1] = Convolve2DSingle_C<8, uint8_t>;
+  dsp->convolve[0][0][1][1] = Convolve2D_C<8, uint8_t>;
 
   dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_C<8, uint8_t>;
   dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_C<8, uint8_t>;
   dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_C<8, uint8_t>;
-  dsp->convolve[0][1][1][1] = Convolve2D_C<8, uint8_t>;
+  dsp->convolve[0][1][1][1] = ConvolveCompound2D_C<8, uint8_t>;
 
   dsp->convolve[1][0][0][0] = ConvolveCopy_C<8, uint8_t>;
   dsp->convolve[1][0][0][1] =
@@ -606,23 +640,78 @@
   dsp->convolve[1][1][1][0] = nullptr;
   dsp->convolve[1][1][1][1] = nullptr;
 
-  dsp->convolve_scale[0] = Convolve2DScaleSingle_C<8, uint8_t>;
-  dsp->convolve_scale[1] = Convolve2DScale_C<8, uint8_t>;
+  dsp->convolve_scale[0] = ConvolveScale2D_C<8, uint8_t>;
+  dsp->convolve_scale[1] = ConvolveCompoundScale2D_C<8, uint8_t>;
+#else  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+#ifndef LIBGAV1_Dsp8bpp_ConvolveCopy
+  dsp->convolve[0][0][0][0] = ConvolveCopy_C<8, uint8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_ConvolveHorizontal
+  dsp->convolve[0][0][0][1] = ConvolveHorizontal_C<8, uint8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_ConvolveVertical
+  dsp->convolve[0][0][1][0] = ConvolveVertical_C<8, uint8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_Convolve2D
+  dsp->convolve[0][0][1][1] = Convolve2D_C<8, uint8_t>;
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundCopy
+  dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_C<8, uint8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal
+  dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_C<8, uint8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundVertical
+  dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_C<8, uint8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_ConvolveCompound2D
+  dsp->convolve[0][1][1][1] = ConvolveCompound2D_C<8, uint8_t>;
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopy
+  dsp->convolve[1][0][0][0] = ConvolveCopy_C<8, uint8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopyHorizontal
+  dsp->convolve[1][0][0][1] =
+      ConvolveIntraBlockCopy1D_C<8, uint8_t, /*is_horizontal=*/true>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopyVertical
+  dsp->convolve[1][0][1][0] =
+      ConvolveIntraBlockCopy1D_C<8, uint8_t, /*is_horizontal=*/false>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopy2D
+  dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_C<8, uint8_t>;
+#endif
+
+  dsp->convolve[1][1][0][0] = nullptr;
+  dsp->convolve[1][1][0][1] = nullptr;
+  dsp->convolve[1][1][1][0] = nullptr;
+  dsp->convolve[1][1][1][1] = nullptr;
+
+#ifndef LIBGAV1_Dsp8bpp_ConvolveScale2D
+  dsp->convolve_scale[0] = ConvolveScale2D_C<8, uint8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D
+  dsp->convolve_scale[1] = ConvolveCompoundScale2D_C<8, uint8_t>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
 
 #if LIBGAV1_MAX_BITDEPTH >= 10
 void Init10bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
   assert(dsp != nullptr);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   dsp->convolve[0][0][0][0] = ConvolveCopy_C<10, uint16_t>;
   dsp->convolve[0][0][0][1] = ConvolveHorizontal_C<10, uint16_t>;
   dsp->convolve[0][0][1][0] = ConvolveVertical_C<10, uint16_t>;
-  dsp->convolve[0][0][1][1] = Convolve2DSingle_C<10, uint16_t>;
+  dsp->convolve[0][0][1][1] = Convolve2D_C<10, uint16_t>;
 
   dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_C<10, uint16_t>;
   dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_C<10, uint16_t>;
   dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_C<10, uint16_t>;
-  dsp->convolve[0][1][1][1] = Convolve2D_C<10, uint16_t>;
+  dsp->convolve[0][1][1][1] = ConvolveCompound2D_C<10, uint16_t>;
 
   dsp->convolve[1][0][0][0] = ConvolveCopy_C<10, uint16_t>;
   dsp->convolve[1][0][0][1] =
@@ -636,8 +725,62 @@
   dsp->convolve[1][1][1][0] = nullptr;
   dsp->convolve[1][1][1][1] = nullptr;
 
-  dsp->convolve_scale[0] = Convolve2DScaleSingle_C<10, uint16_t>;
-  dsp->convolve_scale[1] = Convolve2DScale_C<10, uint16_t>;
+  dsp->convolve_scale[0] = ConvolveScale2D_C<10, uint16_t>;
+  dsp->convolve_scale[1] = ConvolveCompoundScale2D_C<10, uint16_t>;
+#else  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+#ifndef LIBGAV1_Dsp10bpp_ConvolveCopy
+  dsp->convolve[0][0][0][0] = ConvolveCopy_C<10, uint16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_ConvolveHorizontal
+  dsp->convolve[0][0][0][1] = ConvolveHorizontal_C<10, uint16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_ConvolveVertical
+  dsp->convolve[0][0][1][0] = ConvolveVertical_C<10, uint16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_Convolve2D
+  dsp->convolve[0][0][1][1] = Convolve2D_C<10, uint16_t>;
+#endif
+
+#ifndef LIBGAV1_Dsp10bpp_ConvolveCompoundCopy
+  dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_C<10, uint16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_ConvolveCompoundHorizontal
+  dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_C<10, uint16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_ConvolveCompoundVertical
+  dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_C<10, uint16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_ConvolveCompound2D
+  dsp->convolve[0][1][1][1] = ConvolveCompound2D_C<10, uint16_t>;
+#endif
+
+#ifndef LIBGAV1_Dsp10bpp_ConvolveIntraBlockCopy
+  dsp->convolve[1][0][0][0] = ConvolveCopy_C<10, uint16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_ConvolveIntraBlockHorizontal
+  dsp->convolve[1][0][0][1] =
+      ConvolveIntraBlockCopy1D_C<10, uint16_t, /*is_horizontal=*/true>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_ConvolveIntraBlockVertical
+  dsp->convolve[1][0][1][0] =
+      ConvolveIntraBlockCopy1D_C<10, uint16_t, /*is_horizontal=*/false>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_ConvolveIntraBlock2D
+  dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_C<10, uint16_t>;
+#endif
+
+  dsp->convolve[1][1][0][0] = nullptr;
+  dsp->convolve[1][1][0][1] = nullptr;
+  dsp->convolve[1][1][1][0] = nullptr;
+  dsp->convolve[1][1][1][1] = nullptr;
+
+#ifndef LIBGAV1_Dsp10bpp_ConvolveScale2D
+  dsp->convolve_scale[0] = ConvolveScale2D_C<10, uint16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_ConvolveCompoundScale2D
+  dsp->convolve_scale[1] = ConvolveCompoundScale2D_C<10, uint16_t>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
 #endif
 
diff --git a/libgav1/src/dsp/convolve.h b/libgav1/src/dsp/convolve.h
index 9839bb2..545e6ee 100644
--- a/libgav1/src/dsp/convolve.h
+++ b/libgav1/src/dsp/convolve.h
@@ -1,9 +1,29 @@
 #ifndef LIBGAV1_SRC_DSP_CONVOLVE_H_
 #define LIBGAV1_SRC_DSP_CONVOLVE_H_
 
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
+
+// ARM:
+#include "src/dsp/arm/convolve_neon.h"
+
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
+#include "src/dsp/x86/convolve_sse4.h"
+// clang-format on
+
+// IWYU pragma: end_exports
+
 namespace libgav1 {
 namespace dsp {
 
+// Initializes Dsp::convolve and Dsp::convolve_scale. This function is not
+// thread-safe.
 void ConvolveInit_C();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/distance_weighted_blend.cc b/libgav1/src/dsp/distance_weighted_blend.cc
index fda84ee..05e8e3f 100644
--- a/libgav1/src/dsp/distance_weighted_blend.cc
+++ b/libgav1/src/dsp/distance_weighted_blend.cc
@@ -12,27 +12,31 @@
 namespace {
 
 template <int bitdepth, typename Pixel>
-void DistanceWeightedBlending_C(const uint16_t* prediction_0,
-                                const ptrdiff_t prediction_stride_0,
-                                const uint16_t* prediction_1,
-                                const ptrdiff_t prediction_stride_1,
-                                const uint8_t weight_0, const uint8_t weight_1,
-                                const int inter_post_round_bit, const int width,
-                                const int height, void* const dest,
-                                const ptrdiff_t dest_stride) {
+void DistanceWeightedBlend_C(const uint16_t* prediction_0,
+                             const ptrdiff_t prediction_stride_0,
+                             const uint16_t* prediction_1,
+                             const ptrdiff_t prediction_stride_1,
+                             const uint8_t weight_0, const uint8_t weight_1,
+                             const int width, const int height,
+                             void* const dest, const ptrdiff_t dest_stride) {
   // An offset to cancel offsets used in compound predictor generation that
   // make intermediate computations non negative.
-  const int compound_round_offset =
-      (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
+  constexpr int compound_round_offset =
+      (16 << (bitdepth + 4)) + (16 << (bitdepth + 3));
+  // 7.11.3.2 Rounding variables derivation process
+  //   2 * FILTER_BITS(7) - (InterRound0(3|5) + InterRound1(7))
+  constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
   auto* dst = static_cast<Pixel*>(dest);
   const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel);
 
   for (int y = 0; y < height; ++y) {
     for (int x = 0; x < width; ++x) {
-      int res = (prediction_0[x] * weight_0 + prediction_1[x] * weight_1) >> 4;
+      // prediction range: 8bpp: [0, 15471] 10bpp: [0, 61983] 12bpp: [0, 62007].
+      // weight_0 + weight_1 = 16.
+      int res = prediction_0[x] * weight_0 + prediction_1[x] * weight_1;
       res -= compound_round_offset;
       dst[x] = static_cast<Pixel>(
-          Clip3(RightShiftWithRounding(res, inter_post_round_bit), 0,
+          Clip3(RightShiftWithRounding(res, inter_post_round_bits + 4), 0,
                 (1 << bitdepth) - 1));
     }
     dst += dst_stride;
@@ -44,14 +48,28 @@
 void Init8bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
   assert(dsp != nullptr);
-  dsp->distance_weighted_blend = DistanceWeightedBlending_C<8, uint8_t>;
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  dsp->distance_weighted_blend = DistanceWeightedBlend_C<8, uint8_t>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp8bpp_DistanceWeightedBlend
+  dsp->distance_weighted_blend = DistanceWeightedBlend_C<8, uint8_t>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
 
 #if LIBGAV1_MAX_BITDEPTH >= 10
 void Init10bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
   assert(dsp != nullptr);
-  dsp->distance_weighted_blend = DistanceWeightedBlending_C<10, uint16_t>;
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  dsp->distance_weighted_blend = DistanceWeightedBlend_C<10, uint16_t>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp10bpp_DistanceWeightedBlend
+  dsp->distance_weighted_blend = DistanceWeightedBlend_C<10, uint16_t>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
 #endif
 
diff --git a/libgav1/src/dsp/distance_weighted_blend.h b/libgav1/src/dsp/distance_weighted_blend.h
index d71f09a..b37bd64 100644
--- a/libgav1/src/dsp/distance_weighted_blend.h
+++ b/libgav1/src/dsp/distance_weighted_blend.h
@@ -1,9 +1,28 @@
 #ifndef LIBGAV1_SRC_DSP_DISTANCE_WEIGHTED_BLEND_H_
 #define LIBGAV1_SRC_DSP_DISTANCE_WEIGHTED_BLEND_H_
 
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
+
+// ARM:
+#include "src/dsp/arm/distance_weighted_blend_neon.h"
+
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
+#include "src/dsp/x86/distance_weighted_blend_sse4.h"
+// clang-format on
+
+// IWYU pragma: end_exports
+
 namespace libgav1 {
 namespace dsp {
 
+// Initializes Dsp::distance_weighted_blend. This function is not thread-safe.
 void DistanceWeightedBlendInit_C();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/dsp.cc b/libgav1/src/dsp/dsp.cc
index c28692b..c4bfc96 100644
--- a/libgav1/src/dsp/dsp.cc
+++ b/libgav1/src/dsp/dsp.cc
@@ -2,9 +2,6 @@
 
 #include <mutex>  // NOLINT (unapproved c++11 header)
 
-#include "src/dsp/arm/intrapred_neon.h"
-#include "src/dsp/arm/loop_filter_neon.h"
-#include "src/dsp/arm/loop_restoration_neon.h"
 #include "src/dsp/average_blend.h"
 #include "src/dsp/cdef.h"
 #include "src/dsp/convolve.h"
@@ -16,15 +13,9 @@
 #include "src/dsp/inverse_transform.h"
 #include "src/dsp/loop_filter.h"
 #include "src/dsp/loop_restoration.h"
-#include "src/dsp/mask_blending.h"
+#include "src/dsp/mask_blend.h"
 #include "src/dsp/obmc.h"
 #include "src/dsp/warp.h"
-#include "src/dsp/x86/intra_edge_sse4.h"
-#include "src/dsp/x86/intrapred_smooth_sse4.h"
-#include "src/dsp/x86/intrapred_sse4.h"
-#include "src/dsp/x86/inverse_transform_sse4.h"
-#include "src/dsp/x86/loop_filter_sse4.h"
-#include "src/dsp/x86/loop_restoration_sse4.h"
 
 namespace libgav1 {
 namespace dsp_internal {
@@ -62,12 +53,15 @@
     InverseTransformInit_C();
     LoopFilterInit_C();
     LoopRestorationInit_C();
-    MaskBlendingInit_C();
+    MaskBlendInit_C();
     ObmcInit_C();
     WarpInit_C();
 #if LIBGAV1_ENABLE_SSE4_1
     const uint32_t cpu_features = GetCpuInfo();
     if ((cpu_features & kSSE4_1) != 0) {
+      AverageBlendInit_SSE4_1();
+      ConvolveInit_SSE4_1();
+      DistanceWeightedBlendInit_SSE4_1();
       IntraEdgeInit_SSE4_1();
       IntraPredInit_SSE4_1();
       IntraPredCflInit_SSE4_1();
@@ -75,16 +69,25 @@
       InverseTransformInit_SSE4_1();
       LoopFilterInit_SSE4_1();
       LoopRestorationInit_SSE4_1();
+      ObmcInit_SSE4_1();
     }
 #endif  // LIBGAV1_ENABLE_SSE4_1
 #if LIBGAV1_ENABLE_NEON
+    AverageBlendInit_NEON();
+    ConvolveInit_NEON();
+    DistanceWeightedBlendInit_NEON();
+    IntraEdgeInit_NEON();
     IntraPredCflInit_NEON();
     IntraPredDirectionalInit_NEON();
     IntraPredFilterIntraInit_NEON();
     IntraPredInit_NEON();
     IntraPredSmoothInit_NEON();
+    InverseTransformInit_NEON();
     LoopFilterInit_NEON();
     LoopRestorationInit_NEON();
+    MaskBlendInit_NEON();
+    ObmcInit_NEON();
+    WarpInit_NEON();
 #endif  // LIBGAV1_ENABLE_NEON
   });
 }
diff --git a/libgav1/src/dsp/dsp.h b/libgav1/src/dsp/dsp.h
index 8c7deb3..631a42c 100644
--- a/libgav1/src/dsp/dsp.h
+++ b/libgav1/src/dsp/dsp.h
@@ -306,7 +306,7 @@
 // signals the direction of the transform loop. |non_zero_coeff_count| is the
 // number of non zero coefficients in the block.
 using InverseTransformAddFunc = void (*)(TransformType tx_type,
-                                         TransformSize tx_size, int8_t bitdepth,
+                                         TransformSize tx_size,
                                          void* src_buffer, int start_x,
                                          int start_y, void* dst_frame,
                                          bool is_row, int non_zero_coeff_count);
@@ -339,7 +339,7 @@
 // plane.
 // |primary_strength|, |secondary_strength|, and |damping| are Cdef filtering
 // parameters.
-// |direction| is the filtering diretion.
+// |direction| is the filtering direction.
 // |dest| is the output buffer. |dest_stride| is given in bytes.
 using CdefFilteringFunc = void (*)(const void* source, ptrdiff_t source_stride,
                                    int rows4x4, int columns4x4, int curr_x,
@@ -362,7 +362,7 @@
     RestorationBuffer* buffer);
 
 // Index 0 is Wiener Filter.
-// Index 1 is Self Guilded Restoration Filter.
+// Index 1 is Self Guided Restoration Filter.
 // This can be accessed as LoopRestorationType - 2.
 using LoopRestorationFuncs = LoopRestorationFunc[2];
 
@@ -373,8 +373,9 @@
 // |vertical_filter_index|/|horizontal_filter_index| is the index to
 // retrieve the type of filter to be applied for vertical/horizontal direction
 // from the filter lookup table 'kSubPixelFilters'.
-// |inter_round_bits| is rounding prediction used in horizontal
-// (inter_round_bits[0]) and vertical (inter_round_bits[1]) filtering.
+// |inter_round_bits_vertical| is the rounding precision used after vertical
+// filtering (7 or 11). kInterRoundBitsHorizontal &
+// kInterRoundBitsHorizontal12bpp can be used after the horizontal pass.
 // |subpixel_x| and |subpixel_y| are starting positions in units of 1/1024.
 // |step_x| and |step_y| are step sizes in units of 1/1024 of a pixel.
 // |width| and |height| are width and height of the block to be filtered.
@@ -384,15 +385,22 @@
 using ConvolveFunc = void (*)(const void* reference, ptrdiff_t reference_stride,
                               int vertical_filter_index,
                               int horizontal_filter_index,
-                              const uint8_t inter_round_bits[2], int subpixel_x,
-                              int subpixel_y, int step_x, int step_y, int width,
-                              int height, void* prediction,
-                              ptrdiff_t pred_stride);
+                              const uint8_t inter_round_bits_vertical,
+                              int subpixel_x, int subpixel_y, int step_x,
+                              int step_y, int width, int height,
+                              void* prediction, ptrdiff_t pred_stride);
 
 // Convolve functions signature. Each points to one convolve function with
 // a specific setting:
 // ConvolveFunc[is_intra_block_copy][is_compound][has_vertical_filter]
 // [has_horizontal_filter].
+// If is_compound is false, the prediction is clipped to pixel.
+// If is_compound is true, the range of prediction is:
+//   8bpp: [0, 15471]
+//   10bpp: [0, 61983]
+//   12bpp: [0, 62007]
+// See:
+// https://docs.google.com/document/d/1f5YlLk02ETNxpilvsmjBtWgDXjtZYO33hjl6bAdvmxc
 using ConvolveFuncs = ConvolveFunc[2][2][2][2];
 
 // Convolve functions signature for scaling version.
@@ -407,18 +415,15 @@
 // |prediction_0| is the first input block.
 // |prediction_1| is the second input block.
 // |prediction_stride_0| and |prediction_stride_1| are corresponding strides.
-// |inter_post_round_bit| is a rounding bit. It is required since the value
-// range of inputs is scaled in the inter frame prediction process.
 // |width| and |height| are the same for the first and second input blocks.
 // The valid range of block size is [8x8, 128x128] for the luma plane.
 // |dest| is the output buffer. |dest_stride| is the output buffer stride.
-using AverageBlendingFunc = void (*)(const uint16_t* prediction_0,
-                                     ptrdiff_t prediction_stride_0,
-                                     const uint16_t* prediction_1,
-                                     ptrdiff_t prediction_stride_1,
-                                     int inter_post_round_bit, int width,
-                                     int height, void* dest,
-                                     ptrdiff_t dest_stride);
+using AverageBlendFunc = void (*)(const uint16_t* prediction_0,
+                                  ptrdiff_t prediction_stride_0,
+                                  const uint16_t* prediction_1,
+                                  ptrdiff_t prediction_stride_1, int width,
+                                  int height, void* dest,
+                                  ptrdiff_t dest_stride);
 
 // Distance weighted blending function signature.
 // Weights are generated in Section 7.11.3.15.
@@ -432,25 +437,25 @@
 // distance of the first reference frame and the current frame.
 // |weight_1| is the weight for the second block. It is derived from the
 // relative distance of the second reference frame and the current frame.
-// |inter_post_round_bit| is a rounding bit. It is required since the value
-// range of inputs is scaled in the inter frame prediction process.
 // |width| and |height| are the same for the first and second input blocks.
 // The valid range of block size is [8x8, 128x128] for the luma plane.
 // |dest| is the output buffer. |dest_stride| is the output buffer stride.
-using DistanceWeightedBlendFunc =
-    void (*)(const uint16_t* prediction_0, ptrdiff_t prediction_stride_0,
-             const uint16_t* prediction_1, ptrdiff_t prediction_stride_1,
-             uint8_t weight_0, uint8_t weight_1, int inter_post_round_bit,
-             int width, int height, void* dest, ptrdiff_t dest_stride);
+using DistanceWeightedBlendFunc = void (*)(const uint16_t* prediction_0,
+                                           ptrdiff_t prediction_stride_0,
+                                           const uint16_t* prediction_1,
+                                           ptrdiff_t prediction_stride_1,
+                                           uint8_t weight_0, uint8_t weight_1,
+                                           int width, int height, void* dest,
+                                           ptrdiff_t dest_stride);
 
 // Mask blending function signature. Section 7.11.3.14.
-// This function takes two blocks and produces a blended output stored onto the
-// dest. The blending is a weighted average process, controlled by
-// values of the mask.
+// This function takes two blocks and produces a blended output stored into the
+// output block |dest|. The blending is a weighted average process, controlled
+// by values of the mask.
 // |prediction_0| is the first input block. When prediction mode is inter_intra
 // (or wedge_inter_intra), this refers to the inter frame prediction.
 // |prediction_stride_0| is the stride, given in units of uint16_t.
-// |prediction_1| is the second input block. When prediciton mode is inter_intra
+// |prediction_1| is the second input block. When prediction mode is inter_intra
 // (or wedge_inter_intra), this refers to the intra frame prediction.
 // |prediction_stride_1| is the stride, given in units of uint16_t.
 // |mask| is an integer array, whose value indicates the weight of the blending.
@@ -466,19 +471,24 @@
 // prediction blocks is from intra prediction of current frame. Otherwise, two
 // prediction blocks are both inter frame predictions.
 // |is_wedge_inter_intra| indicates if the mask is for the wedge prediction.
-// |inter_post_round_bits| is the rounding bits.
 // |dest| is the output block.
 // |dest_stride| is the corresponding stride for dest.
-using MaskBlendFunc =
-    void (*)(const uint16_t* prediction_0, ptrdiff_t prediction_stride_0,
-             const uint16_t* prediction_1, ptrdiff_t prediction_stride_1,
-             const uint8_t* mask, ptrdiff_t mask_stride, int width, int height,
-             int subsampling_x, int subsampling_y, bool is_inter_intra,
-             bool is_wedge_inter_intra, int inter_post_round_bits, void* dest,
-             ptrdiff_t dest_stride);
+using MaskBlendFunc = void (*)(const uint16_t* prediction_0,
+                               ptrdiff_t prediction_stride_0,
+                               const uint16_t* prediction_1,
+                               ptrdiff_t prediction_stride_1,
+                               const uint8_t* mask, ptrdiff_t mask_stride,
+                               int width, int height, void* dest,
+                               ptrdiff_t dest_stride);
 
-// Blending function signature. Section 7.11.3.10.
-// This function takes two blocks and produces a blended output stored onto the
+// Mask blending functions signature. Each points to one function with
+// a specific setting:
+// MaskBlendFunc[subsampling_x + subsampling_y][is_inter_intra].
+using MaskBlendFuncs = MaskBlendFunc[3][2];
+
+// Obmc (overlapped block motion compensation) blending function signature.
+// Section 7.11.3.10.
+// This function takes two blocks and produces a blended output stored into the
 // first input block. The blending is a weighted average process, controlled by
 // values of the mask.
 // Obmc is not a compound mode. It is different from other compound blending,
@@ -488,15 +498,13 @@
 // |prediction| is the first input block, which will be overwritten.
 // |prediction_stride| is the stride, given in bytes.
 // |width|, |height| are the same for both input blocks.
-// |blending_direction|, 0 stands for the second block is above the
-// first block; 1 stands for the second block is to the left of the first block.
-// |mask| is an integer array, whose value indicates the weight of the blending.
 // |obmc_prediction| is the second input block.
 // |obmc_prediction_stride| is its stride, given in bytes.
 using ObmcBlendFunc = void (*)(void* prediction, ptrdiff_t prediction_stride,
-                               int width, int height, int blending_direction,
-                               const uint8_t* mask, const void* obmc_prediction,
+                               int width, int height,
+                               const void* obmc_prediction,
                                ptrdiff_t obmc_prediction_stride);
+using ObmcBlendFuncs = ObmcBlendFunc[kNumObmcDirections];
 
 // Warp function signature. Section 7.11.3.5.
 // This function applies warp filtering for each 8x8 block inside the current
@@ -512,19 +520,27 @@
 //     z .  y'  =   m4 m5 m1 *  y
 //          1]      m6 m7 1)    1]
 // |subsampling_x/y| is the current frame's plane subsampling factor.
-// |inter_round_bits| is rounding prediction used in horizontal
-// (inter_round_bits[0]) and vertical (inter_round_bits[1]) filtering.
+// |inter_round_bits_vertical| is the rounding precision used after vertical
+// filtering (7 or 11). kInterRoundBitsHorizontal &
+// kInterRoundBitsHorizontal12bpp can be used for the horizontal pass.
 // |block_start_x| and |block_start_y| are the starting position the current
 // coding block.
 // |block_width| and |block_height| are width and height of the current coding
-// block.
-// |alpha|, |beta|, |gamma|, |delta| are warp parameters.
+// block. |block_width| and |block_height| are at least 8.
+// |alpha|, |beta|, |gamma|, |delta| are valid warp parameters. See the
+// comments in the definition of struct GlobalMotion for the range of their
+// values.
 // |dest| is the output buffer. It is a predictor, whose type is int16_t.
 // |dest_stride| is the stride, in units of int16_t.
+//
+// NOTE: The ARM NEON implementation of WarpFunc may read up to 13 bytes before
+// the |source| buffer or up to 14 bytes after the |source| buffer. Therefore,
+// there must be enough padding before and after the |source| buffer.
 using WarpFunc = void (*)(const void* source, ptrdiff_t source_stride,
                           int source_width, int source_height,
                           const int* warp_params, int subsampling_x,
-                          int subsampling_y, const uint8_t inter_round_bits[2],
+                          int subsampling_y,
+                          const uint8_t inter_round_bits_vertical,
                           int block_start_x, int block_start_y, int block_width,
                           int block_height, int16_t alpha, int16_t beta,
                           int16_t gamma, int16_t delta, uint16_t* dest,
@@ -576,10 +592,10 @@
   LoopRestorationFuncs loop_restorations;
   ConvolveFuncs convolve;
   ConvolveScaleFuncs convolve_scale;
-  AverageBlendingFunc average_blend;
+  AverageBlendFunc average_blend;
   DistanceWeightedBlendFunc distance_weighted_blend;
-  MaskBlendFunc mask_blend;
-  ObmcBlendFunc obmc_blend;
+  MaskBlendFuncs mask_blend;
+  ObmcBlendFuncs obmc_blend;
   WarpFunc warp;
   FilmGrainSynthesisFunc film_grain_synthesis;
 };
diff --git a/libgav1/src/dsp/film_grain.cc b/libgav1/src/dsp/film_grain.cc
index cce73d3..b26b517 100644
--- a/libgav1/src/dsp/film_grain.cc
+++ b/libgav1/src/dsp/film_grain.cc
@@ -240,14 +240,28 @@
 void Init8bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
   assert(dsp != nullptr);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   dsp->film_grain_synthesis = FilmGrainSynthesis_C<8>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp8bpp_FilmGrainSynthesis
+  dsp->film_grain_synthesis = FilmGrainSynthesis_C<8>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
 
 #if LIBGAV1_MAX_BITDEPTH >= 10
 void Init10bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
   assert(dsp != nullptr);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   dsp->film_grain_synthesis = FilmGrainSynthesis_C<10>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp10bpp_FilmGrainSynthesis
+  dsp->film_grain_synthesis = FilmGrainSynthesis_C<10>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
 #endif
 
@@ -297,34 +311,41 @@
   const int grain_center = 128 << (bitdepth - 8);
   grain_min_ = -grain_center;
   grain_max_ = grain_center - 1;
+}
 
+template <int bitdepth>
+bool FilmGrain<bitdepth>::Init() {
   // Section 7.18.3.3. Generate grain process.
-  GenerateLumaGrain(params, luma_grain_);
-  ApplyAutoRegressiveFilterToLumaGrain(params, grain_min_, grain_max_,
+  GenerateLumaGrain(params_, luma_grain_);
+  ApplyAutoRegressiveFilterToLumaGrain(params_, grain_min_, grain_max_,
                                        luma_grain_);
-  if (!is_monochrome) {
-    GenerateChromaGrains(params, chroma_width_, chroma_height_, u_grain_,
+  if (!is_monochrome_) {
+    GenerateChromaGrains(params_, chroma_width_, chroma_height_, u_grain_,
                          v_grain_);
     ApplyAutoRegressiveFilterToChromaGrains(
-        params, grain_min_, grain_max_, luma_grain_, subsampling_x,
-        subsampling_y, chroma_width_, chroma_height_, u_grain_, v_grain_);
+        params_, grain_min_, grain_max_, luma_grain_, subsampling_x_,
+        subsampling_y_, chroma_width_, chroma_height_, u_grain_, v_grain_);
   }
 
   // Section 7.18.3.4. Scaling lookup initialization process.
-  InitializeScalingLookupTable(params.num_y_points, params.point_y_value,
-                               params.point_y_scaling, scaling_lut_y_);
-  if (!is_monochrome) {
-    if (params.chroma_scaling_from_luma) {
-      static_assert(sizeof(scaling_lut_y_) == 256, "");
-      memcpy(scaling_lut_u_, scaling_lut_y_, sizeof(scaling_lut_y_));
-      memcpy(scaling_lut_v_, scaling_lut_y_, sizeof(scaling_lut_y_));
+  InitializeScalingLookupTable(params_.num_y_points, params_.point_y_value,
+                               params_.point_y_scaling, scaling_lut_y_);
+  if (!is_monochrome_) {
+    if (params_.chroma_scaling_from_luma) {
+      scaling_lut_u_ = scaling_lut_y_;
+      scaling_lut_v_ = scaling_lut_y_;
     } else {
-      InitializeScalingLookupTable(params.num_u_points, params.point_u_value,
-                                   params.point_u_scaling, scaling_lut_u_);
-      InitializeScalingLookupTable(params.num_v_points, params.point_v_value,
-                                   params.point_v_scaling, scaling_lut_v_);
+      scaling_lut_chroma_buffer_.reset(new (std::nothrow) uint8_t[256 * 2]);
+      if (scaling_lut_chroma_buffer_ == nullptr) return false;
+      scaling_lut_u_ = &scaling_lut_chroma_buffer_[0];
+      scaling_lut_v_ = &scaling_lut_chroma_buffer_[256];
+      InitializeScalingLookupTable(params_.num_u_points, params_.point_u_value,
+                                   params_.point_u_scaling, scaling_lut_u_);
+      InitializeScalingLookupTable(params_.num_v_points, params_.point_v_value,
+                                   params_.point_v_scaling, scaling_lut_v_);
     }
   }
+  return true;
 }
 
 // Section 7.18.3.2.
@@ -816,6 +837,10 @@
     const void* source_plane_v, ptrdiff_t source_stride_v, void* dest_plane_y,
     ptrdiff_t dest_stride_y, void* dest_plane_u, ptrdiff_t dest_stride_u,
     void* dest_plane_v, ptrdiff_t dest_stride_v) {
+  if (!Init()) {
+    LIBGAV1_DLOG(ERROR, "Init() failed.");
+    return false;
+  }
   if (!AllocateNoiseStripes()) {
     LIBGAV1_DLOG(ERROR, "AllocateNoiseStripes() failed.");
     return false;
diff --git a/libgav1/src/dsp/film_grain.h b/libgav1/src/dsp/film_grain.h
index 307dbb2..abc7312 100644
--- a/libgav1/src/dsp/film_grain.h
+++ b/libgav1/src/dsp/film_grain.h
@@ -77,6 +77,8 @@
   using Pixel =
       typename std::conditional<bitdepth == 8, uint8_t, uint16_t>::type;
 
+  bool Init();
+
   // Allocates noise_stripe_, which points to memory owned by noise_buffer_.
   bool AllocateNoiseStripes();
 
@@ -127,8 +129,12 @@
   GrainType v_grain_[kMaxChromaHeight * kMaxChromaWidth];
   // Scaling lookup tables.
   uint8_t scaling_lut_y_[256];
-  uint8_t scaling_lut_u_[256];
-  uint8_t scaling_lut_v_[256];
+  uint8_t* scaling_lut_u_ = nullptr;
+  uint8_t* scaling_lut_v_ = nullptr;
+  // If allocated, this buffer is 256 * 2 bytes long and scaling_lut_u_ and
+  // scaling_lut_v_ point into this buffer. Otherwise, scaling_lut_u_ and
+  // scaling_lut_v_ point to scaling_lut_y_.
+  std::unique_ptr<uint8_t[]> scaling_lut_chroma_buffer_;
 
   // A two-dimensional array of noise data. Generated for each 32 luma sample
   // high stripe of the image. The first dimension is called luma_num. The
diff --git a/libgav1/src/dsp/intra_edge.cc b/libgav1/src/dsp/intra_edge.cc
index a29ce3c..7f0bc4c 100644
--- a/libgav1/src/dsp/intra_edge.cc
+++ b/libgav1/src/dsp/intra_edge.cc
@@ -59,6 +59,7 @@
   dsp->intra_edge_filter = IntraEdgeFilter_C<uint8_t>;
   dsp->intra_edge_upsampler = IntraEdgeUpsampler_C<8, uint8_t>;
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
 #ifndef LIBGAV1_Dsp8bpp_IntraEdgeFilter
   dsp->intra_edge_filter = IntraEdgeFilter_C<uint8_t>;
 #endif
@@ -76,6 +77,7 @@
   dsp->intra_edge_filter = IntraEdgeFilter_C<uint16_t>;
   dsp->intra_edge_upsampler = IntraEdgeUpsampler_C<10, uint16_t>;
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
 #ifndef LIBGAV1_Dsp10bpp_IntraEdgeFilter
   dsp->intra_edge_filter = IntraEdgeFilter_C<uint16_t>;
 #endif
diff --git a/libgav1/src/dsp/intra_edge.h b/libgav1/src/dsp/intra_edge.h
index 24fc96c..13664c7 100644
--- a/libgav1/src/dsp/intra_edge.h
+++ b/libgav1/src/dsp/intra_edge.h
@@ -1,6 +1,24 @@
 #ifndef LIBGAV1_SRC_DSP_INTRA_EDGE_H_
 #define LIBGAV1_SRC_DSP_INTRA_EDGE_H_
 
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
+
+// ARM:
+#include "src/dsp/arm/intra_edge_neon.h"
+
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
+#include "src/dsp/x86/intra_edge_sse4.h"
+// clang-format on
+
+// IWYU pragma: end_exports
+
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/intrapred.h b/libgav1/src/dsp/intrapred.h
index 7b6d801..3038372 100644
--- a/libgav1/src/dsp/intrapred.h
+++ b/libgav1/src/dsp/intrapred.h
@@ -3,18 +3,28 @@
 
 // Pull in LIBGAV1_DspXXX defines representing the implementation status
 // of each function. The resulting value of each can be used by each module to
-// determine whether an implementation is needed at compile time. The order of
-// includes is important as each tests for a superior version before setting
-// the base.
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
 
+// ARM:
 #include "src/dsp/arm/intrapred_neon.h"
+
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
 #include "src/dsp/x86/intrapred_sse4.h"
+// clang-format on
+
+// IWYU pragma: end_exports
 
 namespace libgav1 {
 namespace dsp {
 
-// Initializes Dsp::intra_predictors and Dsp::filter_intra_predictor with base
-// implementations. This function is not thread-safe.
+// Initializes Dsp::intra_predictors, Dsp::directional_intra_predictor_zone*,
+// Dsp::cfl_intra_predictors, Dsp::cfl_subsamplers and
+// Dsp::filter_intra_predictor. This function is not thread-safe.
 void IntraPredInit_C();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/inverse_transform.cc b/libgav1/src/dsp/inverse_transform.cc
index 66110af..bd9c324 100644
--- a/libgav1/src/dsp/inverse_transform.cc
+++ b/libgav1/src/dsp/inverse_transform.cc
@@ -17,8 +17,6 @@
 // Include the constants and utility functions inside the anonymous namespace.
 #include "src/dsp/inverse_transform.inc"
 
-constexpr uint8_t kTransformRowShift[kNumTransformSizes] = {
-    0, 0, 1, 0, 1, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2};
 constexpr uint8_t kTransformColumnShift = 4;
 
 int32_t RangeCheckValue(int32_t value, int8_t range) {
@@ -32,7 +30,7 @@
                  value, range);
     assert(min <= value && value <= max);
   }
-#endif  // LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECKING
+#endif  // LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK
   static_cast<void>(range);
   return value;
 }
@@ -599,9 +597,6 @@
 // because the multiplicative factor of inverse identity transforms is at most
 // 4 (2 bits) and |shift| is always 4.
 
-constexpr int32_t kIdentity4Multiplier /* round(2^12 * sqrt(2)) */ = 5793;
-constexpr int32_t kIdentity16Multiplier /* 2 * round(2^12 * sqrt(2)) */ = 11586;
-
 template <typename Residual>
 void Identity4Row_C(void* dest, const void* source, int8_t shift) {
   assert(shift == 0 || shift == 1);
@@ -732,32 +727,34 @@
   temp[1] = src[3] >> shift;
   temp[0] += temp[2];
   temp[3] -= temp[1];
+  // This signed right shift must be an arithmetic shift.
   Residual e = (temp[0] - temp[3]) >> 1;
-  temp[1] = e - temp[1];
-  temp[2] = e - temp[2];
-  temp[0] -= temp[1];
-  temp[3] += temp[2];
-  memcpy(dst, temp, sizeof(temp));
+  dst[1] = e - temp[1];
+  dst[2] = e - temp[2];
+  dst[0] = temp[0] - dst[1];
+  dst[3] = temp[3] + dst[2];
 }
 
 //------------------------------------------------------------------------------
 // row/column transform loop
-constexpr int16_t kTransformRowMultiplier /* round(2^12 / sqrt(2)) */ = 2896;
 
 using InverseTransform1DFunc = void (*)(void* dst, const void* src,
                                         int8_t range);
 
-template <typename Residual, typename Pixel, Transform1D transform1d_type,
+template <int bitdepth, typename Residual, typename Pixel,
+          Transform1D transform1d_type,
           InverseTransform1DFunc row_transform1d_func,
           InverseTransform1DFunc column_transform1d_func = row_transform1d_func>
 void TransformLoop_C(TransformType tx_type, TransformSize tx_size,
-                     int8_t bitdepth, void* src_buffer, int start_x,
-                     int start_y, void* dst_frame, bool is_row,
-                     int non_zero_coeff_count) {
+                     void* src_buffer, int start_x, int start_y,
+                     void* dst_frame, bool is_row, int non_zero_coeff_count) {
   constexpr bool lossless = transform1d_type == k1DTransformWht;
   constexpr bool is_identity = transform1d_type == k1DTransformIdentity;
-  const int tx_width = kTransformWidth[tx_size];
-  const int tx_height = kTransformHeight[tx_size];
+  // The transform size of the WHT is always 4x4. Setting tx_width and
+  // tx_height to the constant 4 for the WHT speeds the code up.
+  assert(!lossless || tx_size == kTransformSize4x4);
+  const int tx_width = lossless ? 4 : kTransformWidth[tx_size];
+  const int tx_height = lossless ? 4 : kTransformHeight[tx_size];
   const int tx_width_log2 = kTransformWidthLog2[tx_size];
   const int tx_height_log2 = kTransformHeightLog2[tx_size];
   auto* frame = reinterpret_cast<Array2DView<Pixel>*>(dst_frame);
@@ -825,9 +822,9 @@
   // transforms, this will be equal to the clamping range.
   const int8_t column_clamp_range = lossless ? 0 : std::max(bitdepth + 6, 16);
   const bool flip_rows = transform1d_type == k1DTransformAdst &&
-                         ((1U << tx_type) & kTransformFlipRowsMask) != 0;
+                         kTransformFlipRowsMask.Contains(tx_type);
   const bool flip_columns =
-      !lossless && ((1U << tx_type) & kTransformFlipColumnsMask) != 0;
+      !lossless && kTransformFlipColumnsMask.Contains(tx_type);
   const int min_value = 0;
   const int max_value = (1 << bitdepth) - 1;
   // Note: 64 is the maximum size of a 1D transform buffer (the largest
@@ -858,45 +855,54 @@
 
 //------------------------------------------------------------------------------
 
-template <typename Residual, typename Pixel>
+template <int bitdepth, typename Residual, typename Pixel>
 void InitAll(Dsp* const dsp) {
   // Maximum transform size for Dct is 64.
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformDct] =
-      TransformLoop_C<Residual, Pixel, k1DTransformDct, Dct_C<Residual, 2>>;
+      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+                      Dct_C<Residual, 2>>;
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformDct] =
-      TransformLoop_C<Residual, Pixel, k1DTransformDct, Dct_C<Residual, 3>>;
+      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+                      Dct_C<Residual, 3>>;
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformDct] =
-      TransformLoop_C<Residual, Pixel, k1DTransformDct, Dct_C<Residual, 4>>;
+      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+                      Dct_C<Residual, 4>>;
   dsp->inverse_transforms[k1DTransformSize32][k1DTransformDct] =
-      TransformLoop_C<Residual, Pixel, k1DTransformDct, Dct_C<Residual, 5>>;
+      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+                      Dct_C<Residual, 5>>;
   dsp->inverse_transforms[k1DTransformSize64][k1DTransformDct] =
-      TransformLoop_C<Residual, Pixel, k1DTransformDct, Dct_C<Residual, 6>>;
+      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+                      Dct_C<Residual, 6>>;
 
   // Maximum transform size for Adst is 16.
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformAdst] =
-      TransformLoop_C<Residual, Pixel, k1DTransformAdst, Adst4_C<Residual>>;
+      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst,
+                      Adst4_C<Residual>>;
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformAdst] =
-      TransformLoop_C<Residual, Pixel, k1DTransformAdst, Adst8_C<Residual>>;
+      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst,
+                      Adst8_C<Residual>>;
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformAdst] =
-      TransformLoop_C<Residual, Pixel, k1DTransformAdst, Adst16_C<Residual>>;
+      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst,
+                      Adst16_C<Residual>>;
 
   // Maximum transform size for Identity transform is 32.
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformIdentity] =
-      TransformLoop_C<Residual, Pixel, k1DTransformIdentity,
+      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
                       Identity4Row_C<Residual>, Identity4Column_C<Residual>>;
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformIdentity] =
-      TransformLoop_C<Residual, Pixel, k1DTransformIdentity,
+      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
                       Identity8Row_C<Residual>, Identity8Column_C<Residual>>;
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformIdentity] =
-      TransformLoop_C<Residual, Pixel, k1DTransformIdentity,
+      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
                       Identity16Row_C<Residual>, Identity16Column_C<Residual>>;
   dsp->inverse_transforms[k1DTransformSize32][k1DTransformIdentity] =
-      TransformLoop_C<Residual, Pixel, k1DTransformIdentity,
+      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
                       Identity32Row_C<Residual>, Identity32Column_C<Residual>>;
 
   // Maximum transform size for Wht is 4.
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformWht] =
-      TransformLoop_C<Residual, Pixel, k1DTransformWht, Wht4_C<Residual>>;
+      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformWht,
+                      Wht4_C<Residual>>;
 }
 
 void Init8bpp() {
@@ -908,63 +914,63 @@
     }
   }
 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
-  InitAll<int16_t, uint8_t>(dsp);
+  InitAll<8, int16_t, uint8_t>(dsp);
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformDct] =
-      TransformLoop_C<int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 2>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 2>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformDct] =
-      TransformLoop_C<int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 3>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 3>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformDct] =
-      TransformLoop_C<int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 4>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 4>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize32][k1DTransformDct] =
-      TransformLoop_C<int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 5>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 5>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize64][k1DTransformDct] =
-      TransformLoop_C<int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 6>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 6>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformAdst] =
-      TransformLoop_C<int16_t, uint8_t, k1DTransformAdst, Adst4_C<int16_t>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst, Adst4_C<int16_t>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformAdst] =
-      TransformLoop_C<int16_t, uint8_t, k1DTransformAdst, Adst8_C<int16_t>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst, Adst8_C<int16_t>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformAdst] =
-      TransformLoop_C<int16_t, uint8_t, k1DTransformAdst, Adst16_C<int16_t>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst, Adst16_C<int16_t>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformIdentity] =
-      TransformLoop_C<int16_t, uint8_t, k1DTransformIdentity,
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
                       Identity4Row_C<int16_t>, Identity4Column_C<int16_t>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformIdentity] =
-      TransformLoop_C<int16_t, uint8_t, k1DTransformIdentity,
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
                       Identity8Row_C<int16_t>, Identity8Column_C<int16_t>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformIdentity] =
-      TransformLoop_C<int16_t, uint8_t, k1DTransformIdentity,
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
                       Identity16Row_C<int16_t>, Identity16Column_C<int16_t>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize32][k1DTransformIdentity] =
-      TransformLoop_C<int16_t, uint8_t, k1DTransformIdentity,
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
                       Identity32Row_C<int16_t>, Identity32Column_C<int16_t>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformWht] =
-      TransformLoop_C<int16_t, uint8_t, k1DTransformWht, Wht4_C<int16_t>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformWht, Wht4_C<int16_t>>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
@@ -979,63 +985,71 @@
     }
   }
 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
-  InitAll<int32_t, uint16_t>(dsp);
+  InitAll<10, int32_t, uint16_t>(dsp);
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformDct] =
-      TransformLoop_C<int32_t, uint16_t, k1DTransformDct, Dct_C<int32_t, 2>>;
+      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+                      Dct_C<int32_t, 2>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformDct] =
-      TransformLoop_C<int32_t, uint16_t, k1DTransformDct, Dct_C<int32_t, 3>>;
+      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+                      Dct_C<int32_t, 3>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformDct] =
-      TransformLoop_C<int32_t, uint16_t, k1DTransformDct, Dct_C<int32_t, 4>>;
+      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+                      Dct_C<int32_t, 4>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize32_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize32][k1DTransformDct] =
-      TransformLoop_C<int32_t, uint16_t, k1DTransformDct, Dct_C<int32_t, 5>>;
+      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+                      Dct_C<int32_t, 5>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize64_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize64][k1DTransformDct] =
-      TransformLoop_C<int32_t, uint16_t, k1DTransformDct, Dct_C<int32_t, 6>>;
+      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+                      Dct_C<int32_t, 6>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformAdst
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformAdst] =
-      TransformLoop_C<int32_t, uint16_t, k1DTransformAdst, Adst4_C<int32_t>>;
+      TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst,
+                      Adst4_C<int32_t>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformAdst
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformAdst] =
-      TransformLoop_C<int32_t, uint16_t, k1DTransformAdst, Adst8_C<int32_t>>;
+      TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst,
+                      Adst8_C<int32_t>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformAdst
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformAdst] =
-      TransformLoop_C<int32_t, uint16_t, k1DTransformAdst, Adst16_C<int32_t>>;
+      TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst,
+                      Adst16_C<int32_t>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformIdentity] =
-      TransformLoop_C<int32_t, uint16_t, k1DTransformIdentity,
+      TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
                       Identity4Row_C<int32_t>, Identity4Column_C<int32_t>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformIdentity] =
-      TransformLoop_C<int32_t, uint16_t, k1DTransformIdentity,
+      TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
                       Identity8Row_C<int32_t>, Identity8Column_C<int32_t>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformIdentity] =
-      TransformLoop_C<int32_t, uint16_t, k1DTransformIdentity,
+      TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
                       Identity16Row_C<int32_t>, Identity16Column_C<int32_t>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize32_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize32][k1DTransformIdentity] =
-      TransformLoop_C<int32_t, uint16_t, k1DTransformIdentity,
+      TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
                       Identity32Row_C<int32_t>, Identity32Column_C<int32_t>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformWht
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformWht] =
-      TransformLoop_C<int32_t, uint16_t, k1DTransformWht, Wht4_C<int32_t>>;
+      TransformLoop_C<10, int32_t, uint16_t, k1DTransformWht, Wht4_C<int32_t>>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
@@ -1048,6 +1062,11 @@
 #if LIBGAV1_MAX_BITDEPTH >= 10
   Init10bpp();
 #endif
+
+  // Local functions that may be unused depending on the optimizations
+  // available.
+  static_cast<void>(RangeCheckValue);
+  static_cast<void>(kBitReverseLookup);
 }
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/inverse_transform.h b/libgav1/src/dsp/inverse_transform.h
index 63bc3c8..28faec7 100644
--- a/libgav1/src/dsp/inverse_transform.h
+++ b/libgav1/src/dsp/inverse_transform.h
@@ -1,6 +1,24 @@
 #ifndef LIBGAV1_SRC_DSP_INVERSE_TRANSFORM_H_
 #define LIBGAV1_SRC_DSP_INVERSE_TRANSFORM_H_
 
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
+
+// ARM:
+#include "src/dsp/arm/inverse_transform_neon.h"
+
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
+#include "src/dsp/x86/inverse_transform_sse4.h"
+// clang-format on
+
+// IWYU pragma: end_exports
+
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/inverse_transform.inc b/libgav1/src/dsp/inverse_transform.inc
index 3b5b052..3bffa23 100644
--- a/libgav1/src/dsp/inverse_transform.inc
+++ b/libgav1/src/dsp/inverse_transform.inc
@@ -35,3 +35,16 @@
 // The value for index i is derived as:
 // round(sqrt(2) * sin(i * pi / 9) * 2 / 3 * (1 << 12)).
 constexpr int16_t kAdst4Multiplier[4] = {1321, 2482, 3344, 3803};
+
+constexpr uint8_t kTransformRowShift[kNumTransformSizes] = {
+    0, 0, 1, 0, 1, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2};
+
+constexpr bool kShouldRound[kNumTransformSizes] = {
+    false, true,  false, true, false, true, false, false, true, false,
+    true,  false, false, true, false, true, false, true,  false};
+
+constexpr int16_t kIdentity4Multiplier /* round(2^12 * sqrt(2)) */ = 0x16A1;
+constexpr int16_t kIdentity4MultiplierFraction /* round(2^12 * (sqrt(2) - 1))*/
+    = 0x6A1;
+constexpr int16_t kIdentity16Multiplier /* 2 * round(2^12 * sqrt(2)) */ = 11586;
+constexpr int16_t kTransformRowMultiplier /* round(2^12 / sqrt(2)) */ = 2896;
diff --git a/libgav1/src/dsp/loop_filter.h b/libgav1/src/dsp/loop_filter.h
index ee33b91..c9e602c 100644
--- a/libgav1/src/dsp/loop_filter.h
+++ b/libgav1/src/dsp/loop_filter.h
@@ -3,18 +3,26 @@
 
 // Pull in LIBGAV1_DspXXX defines representing the implementation status
 // of each function. The resulting value of each can be used by each module to
-// determine whether an implementation is needed at compile time. The order of
-// includes is important as each tests for a superior version before setting
-// the base.
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
 
+// ARM:
 #include "src/dsp/arm/loop_filter_neon.h"
+
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
 #include "src/dsp/x86/loop_filter_sse4.h"
+// clang-format on
+
+// IWYU pragma: end_exports
 
 namespace libgav1 {
 namespace dsp {
 
-// Initializes Dsp::loop_filters with base implementations. This function
-// is not thread-safe.
+// Initializes Dsp::loop_filters. This function is not thread-safe.
 void LoopFilterInit_C();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/loop_restoration.cc b/libgav1/src/dsp/loop_restoration.cc
index 5f092c5..0a07c0e 100644
--- a/libgav1/src/dsp/loop_restoration.cc
+++ b/libgav1/src/dsp/loop_restoration.cc
@@ -90,15 +90,14 @@
     const RestorationUnitInfo& restoration_info, ptrdiff_t source_stride,
     ptrdiff_t dest_stride, int width, int height,
     RestorationBuffer* const buffer) {
-  const int* const inter_round_bits = buffer->inter_round_bits;
-  if (bitdepth == 12) {
-    assert(inter_round_bits[0] == 5 && inter_round_bits[1] == 9);
-  } else {
-    assert(inter_round_bits[0] == 3 && inter_round_bits[1] == 11);
-  }
+  constexpr int kRoundBitsHorizontal = (bitdepth == 12)
+                                           ? kInterRoundBitsHorizontal12bpp
+                                           : kInterRoundBitsHorizontal;
+  constexpr int kRoundBitsVertical =
+      (bitdepth == 12) ? kInterRoundBitsVertical12bpp : kInterRoundBitsVertical;
   int16_t filter[kSubPixelTaps - 1];
   const int limit =
-      (1 << (bitdepth + 1 + kWienerFilterBits - inter_round_bits[0])) - 1;
+      (1 << (bitdepth + 1 + kWienerFilterBits - kRoundBitsHorizontal)) - 1;
   const auto* src = static_cast<const Pixel*>(source);
   auto* dst = static_cast<Pixel*>(dest);
   source_stride /= sizeof(Pixel);
@@ -117,10 +116,7 @@
       for (int k = 0; k < kSubPixelTaps - 1; ++k) {
         sum += filter[k] * src[x + k];
       }
-      const int rounded_sum = RightShiftWithRounding(sum, inter_round_bits[0]);
-      // TODO(chengchen): make sure the horizontal and vertical rounding offset
-      // is correct and whether they ensure rounded_sum is non-negative.
-      // If yes, replace Clip3() with std::min().
+      const int rounded_sum = RightShiftWithRounding(sum, kRoundBitsHorizontal);
       wiener_buffer[x] = static_cast<uint16_t>(Clip3(rounded_sum, 0, limit));
     }
     src += source_stride;
@@ -129,7 +125,7 @@
   wiener_buffer = buffer->wiener_buffer;
   // vertical filtering.
   PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical, filter);
-  const int vertical_rounding = -(1 << (bitdepth + inter_round_bits[1] - 1));
+  const int vertical_rounding = -(1 << (bitdepth + kRoundBitsVertical - 1));
   for (int y = 0; y < height; ++y) {
     for (int x = 0; x < width; ++x) {
       // sum needs 32 bits.
@@ -137,7 +133,7 @@
       for (int k = 0; k < kSubPixelTaps - 1; ++k) {
         sum += filter[k] * wiener_buffer[k * buffer_stride + x];
       }
-      const int rounded_sum = RightShiftWithRounding(sum, inter_round_bits[1]);
+      const int rounded_sum = RightShiftWithRounding(sum, kRoundBitsVertical);
       dst[x] = static_cast<Pixel>(Clip3(rounded_sum, 0, (1 << bitdepth) - 1));
     }
     dst += dest_stride;
diff --git a/libgav1/src/dsp/loop_restoration.h b/libgav1/src/dsp/loop_restoration.h
index 59c51a3..0fa4e0f 100644
--- a/libgav1/src/dsp/loop_restoration.h
+++ b/libgav1/src/dsp/loop_restoration.h
@@ -1,14 +1,28 @@
 #ifndef LIBGAV1_SRC_DSP_LOOP_RESTORATION_H_
 #define LIBGAV1_SRC_DSP_LOOP_RESTORATION_H_
 
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
+
+// ARM:
 #include "src/dsp/arm/loop_restoration_neon.h"
+
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
 #include "src/dsp/x86/loop_restoration_sse4.h"
+// clang-format on
+
+// IWYU pragma: end_exports
 
 namespace libgav1 {
 namespace dsp {
 
-// Initializes Dsp::loop_restorations with base implementations. This function
-// is not thread-safe.
+// Initializes Dsp::loop_restorations. This function is not thread-safe.
 void LoopRestorationInit_C();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/mask_blend.cc b/libgav1/src/dsp/mask_blend.cc
new file mode 100644
index 0000000..a6e6805
--- /dev/null
+++ b/libgav1/src/dsp/mask_blend.cc
@@ -0,0 +1,161 @@
+#include "src/dsp/mask_blend.h"
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+template <int bitdepth, typename Pixel, bool is_inter_intra, int subsampling_x,
+          int subsampling_y>
+void MaskBlend_C(const uint16_t* prediction_0,
+                 const ptrdiff_t prediction_stride_0,
+                 const uint16_t* prediction_1,
+                 const ptrdiff_t prediction_stride_1, const uint8_t* mask,
+                 const ptrdiff_t mask_stride, const int width, const int height,
+                 void* dest, const ptrdiff_t dest_stride) {
+  assert(mask != nullptr);
+  auto* dst = static_cast<Pixel*>(dest);
+  const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel);
+  constexpr int step_y = subsampling_y ? 2 : 1;
+  const uint8_t* mask_next_row = mask + mask_stride;
+  // An offset to cancel offsets used in single predictor generation that
+  // make intermediate computations non negative.
+  const int single_round_offset = (1 << bitdepth) + (1 << (bitdepth - 1));
+  // An offset to cancel offsets used in compound predictor generation that
+  // make intermediate computations non negative.
+  const int compound_round_offset =
+      (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
+  // 7.11.3.2 Rounding variables derivation process
+  //   2 * FILTER_BITS(7) - (InterRound0(3|5) + InterRound1(7))
+  constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
+  for (int y = 0; y < height; ++y) {
+    for (int x = 0; x < width; ++x) {
+      uint8_t mask_value;
+      if ((subsampling_x | subsampling_y) == 0) {
+        mask_value = mask[x];
+      } else if (subsampling_x == 1 && subsampling_y == 0) {
+        mask_value = static_cast<uint8_t>(RightShiftWithRounding(
+            mask[MultiplyBy2(x)] + mask[MultiplyBy2(x) + 1], 1));
+      } else {
+        assert(subsampling_x == 1 && subsampling_y == 1);
+        mask_value = static_cast<uint8_t>(RightShiftWithRounding(
+            mask[MultiplyBy2(x)] + mask[MultiplyBy2(x) + 1] +
+                mask_next_row[MultiplyBy2(x)] +
+                mask_next_row[MultiplyBy2(x) + 1],
+            2));
+      }
+
+      if (is_inter_intra) {
+        // In inter intra prediction mode, the intra prediction (prediction_1)
+        // values are valid pixel values: [0, (1 << bitdepth) - 1].
+        // While the inter prediction values come from subpixel prediction
+        // from another frame, which involves interpolation and rounding.
+        // Therefore prediction_0 has to be clipped.
+        dst[x] = static_cast<Pixel>(RightShiftWithRounding(
+            mask_value * prediction_1[x] +
+                (64 - mask_value) * Clip3(prediction_0[x] - single_round_offset,
+                                          0, (1 << bitdepth) - 1),
+            6));
+      } else {
+        int res = (mask_value * prediction_0[x] +
+                   (64 - mask_value) * prediction_1[x]) >>
+                  6;
+        res -= compound_round_offset;
+        dst[x] = static_cast<Pixel>(
+            Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
+                  (1 << bitdepth) - 1));
+      }
+    }
+    dst += dst_stride;
+    mask += mask_stride * step_y;
+    mask_next_row += mask_stride * step_y;
+    prediction_0 += prediction_stride_0;
+    prediction_1 += prediction_stride_1;
+  }
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  dsp->mask_blend[0][0] = MaskBlend_C<8, uint8_t, false, 0, 0>;
+  dsp->mask_blend[1][0] = MaskBlend_C<8, uint8_t, false, 1, 0>;
+  dsp->mask_blend[2][0] = MaskBlend_C<8, uint8_t, false, 1, 1>;
+  dsp->mask_blend[0][1] = MaskBlend_C<8, uint8_t, true, 0, 0>;
+  dsp->mask_blend[1][1] = MaskBlend_C<8, uint8_t, true, 1, 0>;
+  dsp->mask_blend[2][1] = MaskBlend_C<8, uint8_t, true, 1, 1>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp8bpp_MaskBlend444
+  dsp->mask_blend[0][0] = MaskBlend_C<8, uint8_t, false, 0, 0>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_MaskBlend422
+  dsp->mask_blend[1][0] = MaskBlend_C<8, uint8_t, false, 1, 0>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_MaskBlend420
+  dsp->mask_blend[2][0] = MaskBlend_C<8, uint8_t, false, 1, 1>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_MaskBlendInterIntra444
+  dsp->mask_blend[0][1] = MaskBlend_C<8, uint8_t, true, 0, 0>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_MaskBlendInterIntra422
+  dsp->mask_blend[1][1] = MaskBlend_C<8, uint8_t, true, 1, 0>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_MaskBlendInterIntra420
+  dsp->mask_blend[2][1] = MaskBlend_C<8, uint8_t, true, 1, 1>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
+  assert(dsp != nullptr);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  dsp->mask_blend[0][0] = MaskBlend_C<10, uint16_t, false, 0, 0>;
+  dsp->mask_blend[1][0] = MaskBlend_C<10, uint16_t, false, 1, 0>;
+  dsp->mask_blend[2][0] = MaskBlend_C<10, uint16_t, false, 1, 1>;
+  dsp->mask_blend[0][1] = MaskBlend_C<10, uint16_t, true, 0, 0>;
+  dsp->mask_blend[1][1] = MaskBlend_C<10, uint16_t, true, 1, 0>;
+  dsp->mask_blend[2][1] = MaskBlend_C<10, uint16_t, true, 1, 1>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp10bpp_MaskBlend444
+  dsp->mask_blend[0][0] = MaskBlend_C<10, uint16_t, false, 0, 0>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_MaskBlend422
+  dsp->mask_blend[1][0] = MaskBlend_C<10, uint16_t, false, 1, 0>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_MaskBlend420
+  dsp->mask_blend[2][0] = MaskBlend_C<10, uint16_t, false, 1, 1>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_MaskBlendInterIntra444
+  dsp->mask_blend[0][1] = MaskBlend_C<10, uint16_t, true, 0, 0>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_MaskBlendInterIntra422
+  dsp->mask_blend[1][1] = MaskBlend_C<10, uint16_t, true, 1, 0>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_MaskBlendInterIntra420
+  dsp->mask_blend[2][1] = MaskBlend_C<10, uint16_t, true, 1, 1>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+}
+#endif
+
+}  // namespace
+
+void MaskBlendInit_C() {
+  Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  Init10bpp();
+#endif
+}
+
+}  // namespace dsp
+}  // namespace libgav1
diff --git a/libgav1/src/dsp/mask_blend.h b/libgav1/src/dsp/mask_blend.h
new file mode 100644
index 0000000..5c53270
--- /dev/null
+++ b/libgav1/src/dsp/mask_blend.h
@@ -0,0 +1,23 @@
+#ifndef LIBGAV1_SRC_DSP_MASK_BLEND_H_
+#define LIBGAV1_SRC_DSP_MASK_BLEND_H_
+
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
+
+// ARM:
+#include "src/dsp/arm/mask_blend_neon.h"
+
+// IWYU pragma: end_exports
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::mask_blend. This function is not thread-safe.
+void MaskBlendInit_C();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_DSP_MASK_BLEND_H_
diff --git a/libgav1/src/dsp/mask_blending.cc b/libgav1/src/dsp/mask_blending.cc
deleted file mode 100644
index b250d7c..0000000
--- a/libgav1/src/dsp/mask_blending.cc
+++ /dev/null
@@ -1,108 +0,0 @@
-#include "src/dsp/mask_blending.h"
-
-#include <cassert>
-#include <cstddef>
-#include <cstdint>
-
-#include "src/dsp/dsp.h"
-#include "src/utils/common.h"
-
-namespace libgav1 {
-namespace dsp {
-namespace {
-
-template <int bitdepth, typename Pixel>
-void MaskBlending_C(
-    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
-    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
-    const uint8_t* mask, const ptrdiff_t mask_stride, const int width,
-    const int height, const int subsampling_x, const int subsampling_y,
-    const bool is_inter_intra, const bool is_wedge_inter_intra,
-    const int inter_post_round_bits, void* dest, const ptrdiff_t dest_stride) {
-  auto* dst = static_cast<Pixel*>(dest);
-  const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel);
-  const int step_y = subsampling_y ? 2 : 1;
-  const int mask_step_y =
-      (is_inter_intra && !is_wedge_inter_intra) ? 1 : step_y;
-  const uint8_t* mask_next_row = mask + mask_stride;
-  // An offset to cancel offsets used in single predictor generation that
-  // make intermediate computations non negative.
-  const int single_round_offset = (1 << bitdepth) + (1 << (bitdepth - 1));
-  // An offset to cancel offsets used in compound predictor generation that
-  // make intermediate computations non negative.
-  const int compound_round_offset =
-      (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
-      uint8_t mask_value;
-      if (((subsampling_x | subsampling_y) == 0) ||
-          (is_inter_intra && !is_wedge_inter_intra)) {
-        mask_value = mask[x];
-      } else if (subsampling_x == 1 && subsampling_y == 0) {
-        mask_value = static_cast<uint8_t>(RightShiftWithRounding(
-            mask[MultiplyBy2(x)] + mask[MultiplyBy2(x) + 1], 1));
-      } else if (subsampling_x == 0 && subsampling_y == 1) {
-        mask_value = static_cast<uint8_t>(
-            RightShiftWithRounding(mask[x] + mask_next_row[x], 1));
-      } else {
-        mask_value = static_cast<uint8_t>(RightShiftWithRounding(
-            mask[MultiplyBy2(x)] + mask[MultiplyBy2(x) + 1] +
-                mask_next_row[MultiplyBy2(x)] +
-                mask_next_row[MultiplyBy2(x) + 1],
-            2));
-      }
-
-      if (is_inter_intra) {
-        // In inter intra prediction mode, the intra prediction (prediction_1)
-        // values are valid pixel values: [0, (1 << bitdepth) - 1].
-        // While the inter prediction values come from subpixel prediction
-        // from another frame, which involves interpolation and rounding.
-        // Therefore prediction_0 has to be clipped.
-        dst[x] = static_cast<Pixel>(RightShiftWithRounding(
-            mask_value * prediction_1[x] +
-                (64 - mask_value) * Clip3(prediction_0[x] - single_round_offset,
-                                          0, (1 << bitdepth) - 1),
-            6));
-      } else {
-        int res = (mask_value * prediction_0[x] +
-                   (64 - mask_value) * prediction_1[x]) >>
-                  6;
-        res -= compound_round_offset;
-        dst[x] = static_cast<Pixel>(
-            Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
-                  (1 << bitdepth) - 1));
-      }
-    }
-    dst += dst_stride;
-    mask += mask_stride * mask_step_y;
-    mask_next_row += mask_stride * step_y;
-    prediction_0 += prediction_stride_0;
-    prediction_1 += prediction_stride_1;
-  }
-}
-
-void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
-  assert(dsp != nullptr);
-  dsp->mask_blend = MaskBlending_C<8, uint8_t>;
-}
-
-#if LIBGAV1_MAX_BITDEPTH >= 10
-void Init10bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
-  assert(dsp != nullptr);
-  dsp->mask_blend = MaskBlending_C<10, uint16_t>;
-}
-#endif
-
-}  // namespace
-
-void MaskBlendingInit_C() {
-  Init8bpp();
-#if LIBGAV1_MAX_BITDEPTH >= 10
-  Init10bpp();
-#endif
-}
-
-}  // namespace dsp
-}  // namespace libgav1
diff --git a/libgav1/src/dsp/mask_blending.h b/libgav1/src/dsp/mask_blending.h
deleted file mode 100644
index 2d339db..0000000
--- a/libgav1/src/dsp/mask_blending.h
+++ /dev/null
@@ -1,12 +0,0 @@
-#ifndef LIBGAV1_SRC_DSP_MASK_BLENDING_H_
-#define LIBGAV1_SRC_DSP_MASK_BLENDING_H_
-
-namespace libgav1 {
-namespace dsp {
-
-void MaskBlendingInit_C();
-
-}  // namespace dsp
-}  // namespace libgav1
-
-#endif  // LIBGAV1_SRC_DSP_MASK_BLENDING_H_
diff --git a/libgav1/src/dsp/obmc.cc b/libgav1/src/dsp/obmc.cc
index ed0ca42..6c81852 100644
--- a/libgav1/src/dsp/obmc.cc
+++ b/libgav1/src/dsp/obmc.cc
@@ -6,29 +6,52 @@
 
 #include "src/dsp/dsp.h"
 #include "src/utils/common.h"
+#include "src/utils/constants.h"
 
 namespace libgav1 {
 namespace dsp {
 namespace {
 
-// 7.11.3.10.
+#include "src/dsp/obmc.inc"
+
+// 7.11.3.10 (from top samples).
 template <typename Pixel>
-void OverlapBlending_C(void* const prediction,
-                       const ptrdiff_t prediction_stride, const int width,
-                       const int height, const int blending_direction,
-                       const uint8_t* const mask,
-                       const void* const obmc_prediction,
-                       const ptrdiff_t obmc_prediction_stride) {
-  // 0 == kBlendFromAbove, 1 == kBlendFromLeft.
-  assert(blending_direction == 0 || blending_direction == 1);
+void OverlapBlendVertical_C(void* const prediction,
+                            const ptrdiff_t prediction_stride, const int width,
+                            const int height, const void* const obmc_prediction,
+                            const ptrdiff_t obmc_prediction_stride) {
   auto* pred = static_cast<Pixel*>(prediction);
   const ptrdiff_t pred_stride = prediction_stride / sizeof(Pixel);
   const auto* obmc_pred = static_cast<const Pixel*>(obmc_prediction);
   const ptrdiff_t obmc_pred_stride = obmc_prediction_stride / sizeof(Pixel);
+  const uint8_t* const mask = kObmcMask + height - 2;
 
   for (int y = 0; y < height; ++y) {
+    const uint8_t mask_value = mask[y];
     for (int x = 0; x < width; ++x) {
-      const uint8_t mask_value = (blending_direction == 0) ? mask[y] : mask[x];
+      pred[x] = static_cast<Pixel>(RightShiftWithRounding(
+          mask_value * pred[x] + (64 - mask_value) * obmc_pred[x], 6));
+    }
+    pred += pred_stride;
+    obmc_pred += obmc_pred_stride;
+  }
+}
+
+// 7.11.3.10 (from left samples).
+template <typename Pixel>
+void OverlapBlendHorizontal_C(void* const prediction,
+                              const ptrdiff_t prediction_stride,
+                              const int width, const int height,
+                              const void* const obmc_prediction,
+                              const ptrdiff_t obmc_prediction_stride) {
+  auto* pred = static_cast<Pixel*>(prediction);
+  const ptrdiff_t pred_stride = prediction_stride / sizeof(Pixel);
+  const auto* obmc_pred = static_cast<const Pixel*>(obmc_prediction);
+  const ptrdiff_t obmc_pred_stride = obmc_prediction_stride / sizeof(Pixel);
+  const uint8_t* const mask = kObmcMask + width - 2;
+  for (int y = 0; y < height; ++y) {
+    for (int x = 0; x < width; ++x) {
+      const uint8_t mask_value = mask[x];
       pred[x] = static_cast<Pixel>(RightShiftWithRounding(
           mask_value * pred[x] + (64 - mask_value) * obmc_pred[x], 6));
     }
@@ -40,14 +63,38 @@
 void Init8bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
   assert(dsp != nullptr);
-  dsp->obmc_blend = OverlapBlending_C<uint8_t>;
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendVertical_C<uint8_t>;
+  dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendHorizontal_C<uint8_t>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp8bpp_ObmcVertical
+  dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendVertical_C<uint8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_ObmcHorizontal
+  dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendHorizontal_C<uint8_t>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
 
 #if LIBGAV1_MAX_BITDEPTH >= 10
 void Init10bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
   assert(dsp != nullptr);
-  dsp->obmc_blend = OverlapBlending_C<uint16_t>;
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendVertical_C<uint16_t>;
+  dsp->obmc_blend[kObmcDirectionHorizontal] =
+      OverlapBlendHorizontal_C<uint16_t>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp10bpp_ObmcVertical
+  dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendVertical_C<uint16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_ObmcHorizontal
+  dsp->obmc_blend[kObmcDirectionHorizontal] =
+      OverlapBlendHorizontal_C<uint16_t>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
 #endif
 
diff --git a/libgav1/src/dsp/obmc.h b/libgav1/src/dsp/obmc.h
index 97e4b69..4a6d9cb 100644
--- a/libgav1/src/dsp/obmc.h
+++ b/libgav1/src/dsp/obmc.h
@@ -1,9 +1,28 @@
 #ifndef LIBGAV1_SRC_DSP_OBMC_H_
 #define LIBGAV1_SRC_DSP_OBMC_H_
 
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
+
+// ARM:
+#include "src/dsp/arm/obmc_neon.h"
+
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
+#include "src/dsp/x86/obmc_sse4.h"
+// clang-format on
+
+// IWYU pragma: end_exports
+
 namespace libgav1 {
 namespace dsp {
 
+// Initializes Dsp::obmc_blend. This function is not thread-safe.
 void ObmcInit_C();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/obmc.inc b/libgav1/src/dsp/obmc.inc
new file mode 100644
index 0000000..6da0001
--- /dev/null
+++ b/libgav1/src/dsp/obmc.inc
@@ -0,0 +1,18 @@
+// Constants and utility functions used for overlap blend implementations.
+// This will be included inside an anonymous namespace on files where these are
+// necessary.
+
+// This is a flat array of masks for each block dimension from 2 to 32. The
+// starting index for each length is length-2.
+constexpr uint8_t kObmcMask[62] = {
+    // Obmc Mask 2
+    45, 64,
+    // Obmc Mask 4
+    39, 50, 59, 64,
+    // Obmc Mask 8
+    36, 42, 48, 53, 57, 61, 64, 64,
+    // Obmc Mask 16
+    34, 37, 40, 43, 46, 49, 52, 54, 56, 58, 60, 61, 64, 64, 64, 64,
+    // Obmc Mask 32
+    33, 35, 36, 38, 40, 41, 43, 44, 45, 47, 48, 50, 51, 52, 53, 55, 56, 57,
+    58, 59, 60, 60, 61, 62, 64, 64, 64, 64, 64, 64, 64, 64};
diff --git a/libgav1/src/dsp/warp.cc b/libgav1/src/dsp/warp.cc
index 12eaa71..60aab23 100644
--- a/libgav1/src/dsp/warp.cc
+++ b/libgav1/src/dsp/warp.cc
@@ -22,20 +22,26 @@
 void Warp_C(const void* const source, ptrdiff_t source_stride,
             const int source_width, const int source_height,
             const int* const warp_params, const int subsampling_x,
-            const int subsampling_y, const uint8_t inter_round_bits[2],
+            const int subsampling_y, const uint8_t inter_round_bits_vertical,
             const int block_start_x, const int block_start_y,
             const int block_width, const int block_height, const int16_t alpha,
             const int16_t beta, const int16_t gamma, const int16_t delta,
             uint16_t* dest, const ptrdiff_t dest_stride) {
-  // Intermediate_result is the output of the horizontal filtering.
-  // The range is within 16 bits.
+  constexpr int kRoundBitsHorizontal = (bitdepth == 12)
+                                           ? kInterRoundBitsHorizontal12bpp
+                                           : kInterRoundBitsHorizontal;
+  // Intermediate_result is the output of the horizontal filtering and rounding.
+  // The range is within 16 bits (unsigned).
   uint16_t intermediate_result[15][8];  // 15 rows, 8 columns.
-  const int horizontal_offset_bits = bitdepth + kFilterBits - 1;
-  const int vertical_offset_bits =
-      bitdepth + 2 * kFilterBits - inter_round_bits[0];
+  const int horizontal_offset = 1 << (bitdepth + kFilterBits - 1);
+  const int vertical_offset =
+      1 << (bitdepth + 2 * kFilterBits - kRoundBitsHorizontal);
   const auto* const src = static_cast<const Pixel*>(source);
   source_stride /= sizeof(Pixel);
 
+  assert(block_width >= 8);
+  assert(block_height >= 8);
+
   // Warp process applies for each 8x8 block (or smaller).
   for (int start_y = block_start_y; start_y < block_start_y + block_height;
        start_y += 8) {
@@ -50,28 +56,32 @@
       const int x4 = dst_x >> subsampling_x;
       const int y4 = dst_y >> subsampling_y;
       const int ix4 = x4 >> kWarpedModelPrecisionBits;
-      const int sx4 = x4 & ((1 << kWarpedModelPrecisionBits) - 1);
       const int iy4 = y4 >> kWarpedModelPrecisionBits;
-      const int sy4 = y4 & ((1 << kWarpedModelPrecisionBits) - 1);
 
       // Horizontal filter.
+      int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7;
       for (int y = -7; y < 8; ++y) {
-        // TODO(chenghchen):
+        // TODO(chengchen):
         // Because of warping, the index could be out of frame boundary. Thus
         // clip is needed. However, can we remove or reduce usage of clip?
         // Besides, special cases exist, for example,
-        // if iy4 - 7 >= source_height, there's no need to do the filtering.
+        // if iy4 - 7 >= source_height or iy4 + 7 < 0, there's no need to do the
+        // filtering.
         const int row = Clip3(iy4 + y, 0, source_height - 1);
         const Pixel* const src_row = src + row * source_stride;
+        int sx = sx4 - MultiplyBy4(alpha);
         for (int x = -4; x < 4; ++x) {
-          const int sx = sx4 + alpha * x + beta * y;
           const int offset =
               RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) +
               kWarpedPixelPrecisionShifts;
+          // Since alpha and beta have been validated by SetupShear(), one can
+          // prove that 0 <= offset <= 3 * 2^6.
+          assert(offset >= 0);
+          assert(offset < 3 * kWarpedPixelPrecisionShifts + 1);
           // For SIMD optimization:
           // For 8 bit, the range of sum is within uint16_t, if we add an
           // horizontal offset:
-          int sum = 1 << horizontal_offset_bits;
+          int sum = horizontal_offset;
           // Horizontal_offset guarantees sum is non negative.
           // If horizontal_offset is used, intermediate_result needs to be
           // uint16_t.
@@ -80,37 +90,64 @@
             const int column = Clip3(ix4 + x + k - 3, 0, source_width - 1);
             sum += kWarpedFilters[offset][k] * src_row[column];
           }
-          assert(sum >= 0 && sum < (1 << (horizontal_offset_bits + 2)));
+          assert(sum >= 0 && sum < (horizontal_offset << 2));
           intermediate_result[y + 7][x + 4] = static_cast<uint16_t>(
-              RightShiftWithRounding(sum, inter_round_bits[0]));
+              RightShiftWithRounding(sum, kRoundBitsHorizontal));
+          sx += alpha;
         }
+        sx4 += beta;
       }
 
       // Vertical filter.
       uint16_t* dst_row = dest + start_x - block_start_x;
-      for (int y = -4;
-           y < std::min(4, block_start_y + block_height - start_y - 4); ++y) {
-        for (int x = -4;
-             x < std::min(4, block_start_x + block_width - start_x - 4); ++x) {
-          const int sy = sy4 + gamma * x + delta * y;
+      int sy4 =
+          (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta);
+      // The spec says we should use the following loop condition:
+      //   y < std::min(4, block_start_y + block_height - start_y - 4);
+      // We can prove that block_start_y + block_height - start_y >= 8, which
+      // implies std::min(4, block_start_y + block_height - start_y - 4) = 4.
+      // So the loop condition is simply y < 4.
+      //
+      // Proof:
+      //    start_y < block_start_y + block_height
+      // => block_start_y + block_height - start_y > 0
+      // => block_height - (start_y - block_start_y) > 0
+      //
+      // Since block_height >= 8 and is a power of 2, it follows that
+      // block_height is a multiple of 8. start_y - block_start_y is also a
+      // multiple of 8. Therefore their difference is a multiple of 8. Since
+      // their difference is > 0, their difference must be >= 8.
+      for (int y = -4; y < 4; ++y) {
+        int sy = sy4 - MultiplyBy4(gamma);
+        // The spec says we should use the following loop condition:
+        //   x < std::min(4, block_start_x + block_width - start_x - 4);
+        // Similar to the above, we can prove that the loop condition can be
+        // simplified to x < 4.
+        for (int x = -4; x < 4; ++x) {
           const int offset =
               RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
               kWarpedPixelPrecisionShifts;
+          // Since gamma and delta have been validated by SetupShear(), one can
+          // prove that 0 <= offset <= 3 * 2^6.
+          assert(offset >= 0);
+          assert(offset < 3 * kWarpedPixelPrecisionShifts + 1);
           // Similar to horizontal_offset, vertical_offset guarantees sum
           // before shifting is non negative:
-          int sum = 1 << vertical_offset_bits;
+          int sum = vertical_offset;
           for (int k = 0; k < 8; ++k) {
             sum += kWarpedFilters[offset][k] *
-                   intermediate_result[y + k + 4][x + 4];
+                   intermediate_result[y + 4 + k][x + 4];
           }
-          assert(sum >= 0 && sum < (1 << (vertical_offset_bits + 2)));
-          sum = RightShiftWithRounding(sum, inter_round_bits[1]);
+          assert(sum >= 0 && sum < (vertical_offset << 2));
+          sum = RightShiftWithRounding(sum, inter_round_bits_vertical);
           // Warp output is a predictor, whose type is uint16_t.
           // Do not clip it here. The clipping is applied at the stage of
           // final pixel value output.
           dst_row[x + 4] = static_cast<uint16_t>(sum);
+          sy += gamma;
         }
         dst_row += dest_stride;
+        sy4 += delta;
       }
     }
     dest += 8 * dest_stride;
@@ -120,14 +157,28 @@
 void Init8bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
   assert(dsp != nullptr);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   dsp->warp = Warp_C<8, uint8_t>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp8bpp_Warp
+  dsp->warp = Warp_C<8, uint8_t>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
 
 #if LIBGAV1_MAX_BITDEPTH >= 10
 void Init10bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
   assert(dsp != nullptr);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   dsp->warp = Warp_C<10, uint16_t>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp10bpp_Warp
+  dsp->warp = Warp_C<10, uint16_t>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
 #endif
 
diff --git a/libgav1/src/dsp/warp.h b/libgav1/src/dsp/warp.h
index e2cbf60..cd5f177 100644
--- a/libgav1/src/dsp/warp.h
+++ b/libgav1/src/dsp/warp.h
@@ -1,9 +1,20 @@
 #ifndef LIBGAV1_SRC_DSP_WARP_H_
 #define LIBGAV1_SRC_DSP_WARP_H_
 
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
+
+// ARM:
+#include "src/dsp/arm/warp_neon.h"
+
+// IWYU pragma: end_exports
+
 namespace libgav1 {
 namespace dsp {
 
+// Initializes Dsp::warp. This function is not thread-safe.
 void WarpInit_C();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/x86/average_blend_sse4.cc b/libgav1/src/dsp/x86/average_blend_sse4.cc
new file mode 100644
index 0000000..0ad0954
--- /dev/null
+++ b/libgav1/src/dsp/x86/average_blend_sse4.cc
@@ -0,0 +1,148 @@
+#include "src/dsp/average_blend.h"
+#include "src/dsp/dsp.h"
+
+#if LIBGAV1_ENABLE_SSE4_1
+
+#include <xmmintrin.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/x86/common_sse4.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+constexpr int kBitdepth8 = 8;
+constexpr int kInterPostRoundBit = 4;
+// An offset to cancel offsets used in compound predictor generation that
+// make intermediate computations non negative.
+const __m128i kCompoundRoundOffset =
+    _mm_set1_epi16((2 << (kBitdepth8 + 4)) + (2 << (kBitdepth8 + 3)));
+
+inline void AverageBlend4Row(const uint16_t* prediction_0,
+                             const uint16_t* prediction_1, uint8_t* dest) {
+  const __m128i pred_0 = LoadLo8(prediction_0);
+  const __m128i pred_1 = LoadLo8(prediction_1);
+  __m128i res = _mm_add_epi16(pred_0, pred_1);
+  res = _mm_sub_epi16(res, kCompoundRoundOffset);
+  res = RightShiftWithRounding_S16(res, kInterPostRoundBit + 1);
+  Store4(dest, _mm_packus_epi16(res, res));
+}
+
+inline void AverageBlend8Row(const uint16_t* prediction_0,
+                             const uint16_t* prediction_1, uint8_t* dest) {
+  const __m128i pred_0 = LoadUnaligned16(prediction_0);
+  const __m128i pred_1 = LoadUnaligned16(prediction_1);
+  __m128i res = _mm_add_epi16(pred_0, pred_1);
+  res = _mm_sub_epi16(res, kCompoundRoundOffset);
+  res = RightShiftWithRounding_S16(res, kInterPostRoundBit + 1);
+  StoreLo8(dest, _mm_packus_epi16(res, res));
+}
+
+inline void AverageBlendLargeRow(const uint16_t* prediction_0,
+                                 const uint16_t* prediction_1, const int width,
+                                 uint8_t* dest) {
+  int x = 0;
+  do {
+    const __m128i pred_00 = LoadUnaligned16(&prediction_0[x]);
+    const __m128i pred_01 = LoadUnaligned16(&prediction_1[x]);
+    __m128i res0 = _mm_add_epi16(pred_00, pred_01);
+    res0 = _mm_sub_epi16(res0, kCompoundRoundOffset);
+    res0 = RightShiftWithRounding_S16(res0, kInterPostRoundBit + 1);
+    const __m128i pred_10 = LoadUnaligned16(&prediction_0[x + 8]);
+    const __m128i pred_11 = LoadUnaligned16(&prediction_1[x + 8]);
+    __m128i res1 = _mm_add_epi16(pred_10, pred_11);
+    res1 = _mm_sub_epi16(res1, kCompoundRoundOffset);
+    res1 = RightShiftWithRounding_S16(res1, kInterPostRoundBit + 1);
+    StoreUnaligned16(dest + x, _mm_packus_epi16(res0, res1));
+    x += 16;
+  } while (x < width);
+}
+
+void AverageBlend_SSE4_1(const uint16_t* prediction_0,
+                         const ptrdiff_t prediction_stride_0,
+                         const uint16_t* prediction_1,
+                         const ptrdiff_t prediction_stride_1, const int width,
+                         const int height, void* const dest,
+                         const ptrdiff_t dest_stride) {
+  auto* dst = static_cast<uint8_t*>(dest);
+  int y = height;
+
+  if (width == 4) {
+    do {
+      AverageBlend4Row(prediction_0, prediction_1, dst);
+      dst += dest_stride;
+      prediction_0 += prediction_stride_0;
+      prediction_1 += prediction_stride_1;
+
+      AverageBlend4Row(prediction_0, prediction_1, dst);
+      dst += dest_stride;
+      prediction_0 += prediction_stride_0;
+      prediction_1 += prediction_stride_1;
+
+      y -= 2;
+    } while (y != 0);
+    return;
+  }
+
+  if (width == 8) {
+    do {
+      AverageBlend8Row(prediction_0, prediction_1, dst);
+      dst += dest_stride;
+      prediction_0 += prediction_stride_0;
+      prediction_1 += prediction_stride_1;
+
+      AverageBlend8Row(prediction_0, prediction_1, dst);
+      dst += dest_stride;
+      prediction_0 += prediction_stride_0;
+      prediction_1 += prediction_stride_1;
+
+      y -= 2;
+    } while (y != 0);
+    return;
+  }
+
+  do {
+    AverageBlendLargeRow(prediction_0, prediction_1, width, dst);
+    dst += dest_stride;
+    prediction_0 += prediction_stride_0;
+    prediction_1 += prediction_stride_1;
+
+    AverageBlendLargeRow(prediction_0, prediction_1, width, dst);
+    dst += dest_stride;
+    prediction_0 += prediction_stride_0;
+    prediction_1 += prediction_stride_1;
+
+    y -= 2;
+  } while (y != 0);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+#if DSP_ENABLED_8BPP_SSE4_1(AverageBlend)
+  dsp->average_blend = AverageBlend_SSE4_1;
+#endif
+}
+
+}  // namespace
+
+void AverageBlendInit_SSE4_1() { Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else   // !LIBGAV1_ENABLE_SSE4_1
+
+namespace libgav1 {
+namespace dsp {
+
+void AverageBlendInit_SSE4_1() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/average_blend_sse4.h b/libgav1/src/dsp/x86/average_blend_sse4.h
new file mode 100644
index 0000000..26449d1
--- /dev/null
+++ b/libgav1/src/dsp/x86/average_blend_sse4.h
@@ -0,0 +1,25 @@
+#ifndef LIBGAV1_SRC_DSP_X86_AVERAGE_BLEND_SSE4_H_
+#define LIBGAV1_SRC_DSP_X86_AVERAGE_BLEND_SSE4_H_
+
+#include "src/dsp/cpu.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::average_blend. This function is not thread-safe.
+void AverageBlendInit_SSE4_1();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+// If sse4 is enabled and the baseline isn't set due to a higher level of
+// optimization being enabled, signal the sse4 implementation should be used.
+#if LIBGAV1_ENABLE_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_AverageBlend
+#define LIBGAV1_Dsp8bpp_AverageBlend LIBGAV1_DSP_SSE4_1
+#endif
+
+#endif  // LIBGAV1_ENABLE_SSE4_1
+
+#endif  // LIBGAV1_SRC_DSP_X86_AVERAGE_BLEND_SSE4_H_
diff --git a/libgav1/src/dsp/x86/common_sse4.h b/libgav1/src/dsp/x86/common_sse4.h
index e584bcc..305837a 100644
--- a/libgav1/src/dsp/x86/common_sse4.h
+++ b/libgav1/src/dsp/x86/common_sse4.h
@@ -12,10 +12,8 @@
 #include <cstdint>
 #include <cstring>
 
-namespace libgav1 {
-namespace dsp {
-
 #if 0
+#include <cinttypes>
 #include <cstdio>
 
 // Quite useful macro for debugging. Left here for convenience.
@@ -37,8 +35,8 @@
   } else if (size == 32) {
     for (n = 0; n < 4; ++n) fprintf(stderr, "%.8x ", tmp.i32[n]);
   } else {
-    for (n = 0; n < 2; ++n) fprintf(
-        stderr, "%.16llx ", static_cast<uint64_t>(tmp.i64[n]));
+    for (n = 0; n < 2; ++n)
+      fprintf(stderr, "%.16" PRIx64 " ", static_cast<uint64_t>(tmp.i64[n]));
   }
   fprintf(stderr, "\n");
 }
@@ -56,10 +54,33 @@
 #define PX(var) PrintRegX(var, #var);
 #endif  // 0
 
+namespace libgav1 {
+namespace dsp {
+
 //------------------------------------------------------------------------------
 // Load functions.
 
+inline __m128i Load2(const void* src) {
+  int16_t val;
+  memcpy(&val, src, sizeof(val));
+  return _mm_cvtsi32_si128(val);
+}
+
+inline __m128i Load2x2(const void* src1, const void* src2) {
+  uint16_t val1;
+  uint16_t val2;
+  memcpy(&val1, src1, sizeof(val1));
+  memcpy(&val2, src2, sizeof(val2));
+  return _mm_cvtsi32_si128(val1 | (val2 << 16));
+}
+
 inline __m128i Load4(const void* src) {
+  // With new compilers such as clang 8.0.0 we can use the new _mm_loadu_si32
+  // intrinsic. Both _mm_loadu_si32(src) and the code here are compiled into a
+  // movss instruction.
+  //
+  // Until compiler support of _mm_loadu_si32 is widespread, use of
+  // _mm_loadu_si32 is banned.
   int val;
   memcpy(&val, src, sizeof(val));
   return _mm_cvtsi32_si128(val);
@@ -82,6 +103,11 @@
 //------------------------------------------------------------------------------
 // Store functions.
 
+inline void Store2(void* dst, const __m128i x) {
+  const int val = _mm_cvtsi128_si32(x);
+  memcpy(dst, &val, 2);
+}
+
 inline void Store4(void* dst, const __m128i x) {
   const int val = _mm_cvtsi128_si32(x);
   memcpy(dst, &val, sizeof(val));
diff --git a/libgav1/src/dsp/x86/convolve_sse4.cc b/libgav1/src/dsp/x86/convolve_sse4.cc
new file mode 100644
index 0000000..b945b09
--- /dev/null
+++ b/libgav1/src/dsp/x86/convolve_sse4.cc
@@ -0,0 +1,21 @@
+#include "src/dsp/convolve.h"
+#include "src/dsp/dsp.h"
+
+#if LIBGAV1_ENABLE_SSE4_1
+namespace libgav1 {
+namespace dsp {
+
+void ConvolveInit_SSE4_1() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else   // !LIBGAV1_ENABLE_SSE4_1
+namespace libgav1 {
+namespace dsp {
+
+void ConvolveInit_SSE4_1() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/convolve_sse4.h b/libgav1/src/dsp/x86/convolve_sse4.h
new file mode 100644
index 0000000..1cc240d
--- /dev/null
+++ b/libgav1/src/dsp/x86/convolve_sse4.h
@@ -0,0 +1,27 @@
+#ifndef LIBGAV1_SRC_DSP_X86_CONVOLVE_SSE4_H_
+#define LIBGAV1_SRC_DSP_X86_CONVOLVE_SSE4_H_
+
+#include "src/dsp/cpu.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::convolve, see the defines below for specifics. This
+// function is not thread-safe.
+void ConvolveInit_SSE4_1();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+// If sse4 is enabled and the baseline isn't set due to a higher level of
+// optimization being enabled, signal the sse4 implementation should be used.
+#if LIBGAV1_ENABLE_SSE4_1
+
+#ifndef LIBGAV1_Dsp8bpp_Convolve2D
+// #define LIBGAV1_Dsp8bpp_Convolve2D LIBGAV1_DSP_SSE4_1
+#endif
+
+#endif  // LIBGAV1_ENABLE_SSE4_1
+
+#endif  // LIBGAV1_SRC_DSP_X86_CONVOLVE_SSE4_H_
diff --git a/libgav1/src/dsp/x86/distance_weighted_blend_sse4.cc b/libgav1/src/dsp/x86/distance_weighted_blend_sse4.cc
new file mode 100644
index 0000000..3a9673a
--- /dev/null
+++ b/libgav1/src/dsp/x86/distance_weighted_blend_sse4.cc
@@ -0,0 +1,233 @@
+#include "src/dsp/distance_weighted_blend.h"
+#include "src/dsp/dsp.h"
+
+#if LIBGAV1_ENABLE_SSE4_1
+
+#include <xmmintrin.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/x86/common_sse4.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+constexpr int kBitdepth8 = 8;
+constexpr int kInterPostRoundBit = 4;
+
+inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
+                                       const __m128i& pred1,
+                                       const __m128i& weights) {
+  const __m128i compound_round_offset32 =
+      _mm_set1_epi32((16 << (kBitdepth8 + 4)) + (16 << (kBitdepth8 + 3)));
+  const __m128i preds_lo = _mm_unpacklo_epi16(pred0, pred1);
+  const __m128i mult_lo =
+      _mm_sub_epi32(_mm_madd_epi16(preds_lo, weights), compound_round_offset32);
+  const __m128i result_lo =
+      RightShiftWithRounding_S32(mult_lo, kInterPostRoundBit + 4);
+
+  const __m128i preds_hi = _mm_unpackhi_epi16(pred0, pred1);
+  const __m128i mult_hi =
+      _mm_sub_epi32(_mm_madd_epi16(preds_hi, weights), compound_round_offset32);
+  const __m128i result_hi =
+      RightShiftWithRounding_S32(mult_hi, kInterPostRoundBit + 4);
+
+  return _mm_packs_epi32(result_lo, result_hi);
+}
+
+template <int height>
+inline void DistanceWeightedBlend4xH_SSE4_1(
+    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
+    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
+    const uint8_t weight_0, const uint8_t weight_1, void* const dest,
+    const ptrdiff_t dest_stride) {
+  auto* dst = static_cast<uint8_t*>(dest);
+  const uint16_t* pred_0 = prediction_0;
+  const uint16_t* pred_1 = prediction_1;
+  const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
+
+  for (int y = 0; y < height; y += 4) {
+    const __m128i src_00 = LoadLo8(pred_0);
+    const __m128i src_10 = LoadLo8(pred_1);
+    pred_0 += prediction_stride_0;
+    pred_1 += prediction_stride_1;
+    __m128i src_0 = LoadHi8(src_00, pred_0);
+    __m128i src_1 = LoadHi8(src_10, pred_1);
+    pred_0 += prediction_stride_0;
+    pred_1 += prediction_stride_1;
+    const __m128i res0 = ComputeWeightedAverage8(src_0, src_1, weights);
+
+    const __m128i src_01 = LoadLo8(pred_0);
+    const __m128i src_11 = LoadLo8(pred_1);
+    pred_0 += prediction_stride_0;
+    pred_1 += prediction_stride_1;
+    src_0 = LoadHi8(src_01, pred_0);
+    src_1 = LoadHi8(src_11, pred_1);
+    pred_0 += prediction_stride_0;
+    pred_1 += prediction_stride_1;
+    const __m128i res1 = ComputeWeightedAverage8(src_0, src_1, weights);
+
+    const __m128i result_pixels = _mm_packus_epi16(res0, res1);
+    Store4(dst, result_pixels);
+    dst += dest_stride;
+    const int result_1 = _mm_extract_epi32(result_pixels, 1);
+    memcpy(dst, &result_1, sizeof(result_1));
+    dst += dest_stride;
+    const int result_2 = _mm_extract_epi32(result_pixels, 2);
+    memcpy(dst, &result_2, sizeof(result_2));
+    dst += dest_stride;
+    const int result_3 = _mm_extract_epi32(result_pixels, 3);
+    memcpy(dst, &result_3, sizeof(result_3));
+    dst += dest_stride;
+  }
+}
+
+template <int height>
+inline void DistanceWeightedBlend8xH_SSE4_1(
+    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
+    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
+    const uint8_t weight_0, const uint8_t weight_1, void* const dest,
+    const ptrdiff_t dest_stride) {
+  auto* dst = static_cast<uint8_t*>(dest);
+  const uint16_t* pred_0 = prediction_0;
+  const uint16_t* pred_1 = prediction_1;
+  const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
+
+  for (int y = 0; y < height; y += 2) {
+    const __m128i src_00 = LoadUnaligned16(pred_0);
+    const __m128i src_10 = LoadUnaligned16(pred_1);
+    pred_0 += prediction_stride_0;
+    pred_1 += prediction_stride_1;
+    const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights);
+
+    const __m128i src_01 = LoadUnaligned16(pred_0);
+    const __m128i src_11 = LoadUnaligned16(pred_1);
+    pred_0 += prediction_stride_0;
+    pred_1 += prediction_stride_1;
+    const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights);
+
+    const __m128i result_pixels = _mm_packus_epi16(res0, res1);
+    StoreLo8(dst, result_pixels);
+    dst += dest_stride;
+    StoreHi8(dst, result_pixels);
+    dst += dest_stride;
+  }
+}
+
+inline void DistanceWeightedBlendLarge_SSE4_1(
+    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
+    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
+    const uint8_t weight_0, const uint8_t weight_1, const int width,
+    const int height, void* const dest, const ptrdiff_t dest_stride) {
+  auto* dst = static_cast<uint8_t*>(dest);
+  const uint16_t* pred_0 = prediction_0;
+  const uint16_t* pred_1 = prediction_1;
+  const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
+
+  int y = height;
+  do {
+    int x = 0;
+    do {
+      const __m128i src_0_lo = LoadUnaligned16(pred_0 + x);
+      const __m128i src_1_lo = LoadUnaligned16(pred_1 + x);
+      const __m128i res_lo =
+          ComputeWeightedAverage8(src_0_lo, src_1_lo, weights);
+
+      const __m128i src_0_hi = LoadUnaligned16(pred_0 + x + 8);
+      const __m128i src_1_hi = LoadUnaligned16(pred_1 + x + 8);
+      const __m128i res_hi =
+          ComputeWeightedAverage8(src_0_hi, src_1_hi, weights);
+
+      StoreUnaligned16(dst + x, _mm_packus_epi16(res_lo, res_hi));
+      x += 16;
+    } while (x < width);
+    dst += dest_stride;
+    pred_0 += prediction_stride_0;
+    pred_1 += prediction_stride_1;
+  } while (--y != 0);
+}
+
+void DistanceWeightedBlend_SSE4_1(
+    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
+    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
+    const uint8_t weight_0, const uint8_t weight_1, const int width,
+    const int height, void* const dest, const ptrdiff_t dest_stride) {
+  if (width == 4) {
+    if (height == 4) {
+      DistanceWeightedBlend4xH_SSE4_1<4>(prediction_0, prediction_stride_0,
+                                         prediction_1, prediction_stride_1,
+                                         weight_0, weight_1, dest, dest_stride);
+    } else if (height == 8) {
+      DistanceWeightedBlend4xH_SSE4_1<8>(prediction_0, prediction_stride_0,
+                                         prediction_1, prediction_stride_1,
+                                         weight_0, weight_1, dest, dest_stride);
+    } else {
+      assert(height == 16);
+      DistanceWeightedBlend4xH_SSE4_1<16>(
+          prediction_0, prediction_stride_0, prediction_1, prediction_stride_1,
+          weight_0, weight_1, dest, dest_stride);
+    }
+    return;
+  }
+
+  if (width == 8) {
+    switch (height) {
+      case 4:
+        DistanceWeightedBlend8xH_SSE4_1<4>(
+            prediction_0, prediction_stride_0, prediction_1,
+            prediction_stride_1, weight_0, weight_1, dest, dest_stride);
+        return;
+      case 8:
+        DistanceWeightedBlend8xH_SSE4_1<8>(
+            prediction_0, prediction_stride_0, prediction_1,
+            prediction_stride_1, weight_0, weight_1, dest, dest_stride);
+        return;
+      case 16:
+        DistanceWeightedBlend8xH_SSE4_1<16>(
+            prediction_0, prediction_stride_0, prediction_1,
+            prediction_stride_1, weight_0, weight_1, dest, dest_stride);
+        return;
+      default:
+        assert(height == 32);
+        DistanceWeightedBlend8xH_SSE4_1<32>(
+            prediction_0, prediction_stride_0, prediction_1,
+            prediction_stride_1, weight_0, weight_1, dest, dest_stride);
+
+        return;
+    }
+  }
+
+  DistanceWeightedBlendLarge_SSE4_1(prediction_0, prediction_stride_0,
+                                    prediction_1, prediction_stride_1, weight_0,
+                                    weight_1, width, height, dest, dest_stride);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+#if DSP_ENABLED_8BPP_SSE4_1(DistanceWeightedBlend)
+  dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
+#endif
+}
+
+}  // namespace
+
+void DistanceWeightedBlendInit_SSE4_1() { Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else   // !LIBGAV1_ENABLE_SSE4_1
+
+namespace libgav1 {
+namespace dsp {
+
+void DistanceWeightedBlendInit_SSE4_1() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/distance_weighted_blend_sse4.h b/libgav1/src/dsp/x86/distance_weighted_blend_sse4.h
new file mode 100644
index 0000000..831853d
--- /dev/null
+++ b/libgav1/src/dsp/x86/distance_weighted_blend_sse4.h
@@ -0,0 +1,25 @@
+#ifndef LIBGAV1_SRC_DSP_X86_DISTANCE_WEIGHTED_BLEND_SSE4_H_
+#define LIBGAV1_SRC_DSP_X86_DISTANCE_WEIGHTED_BLEND_SSE4_H_
+
+#include "src/dsp/cpu.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::distance_weighted_blend. This function is not thread-safe.
+void DistanceWeightedBlendInit_SSE4_1();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+// If sse4 is enabled and the baseline isn't set due to a higher level of
+// optimization being enabled, signal the sse4 implementation should be used.
+#if LIBGAV1_ENABLE_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_DistanceWeightedBlend
+#define LIBGAV1_Dsp8bpp_DistanceWeightedBlend LIBGAV1_DSP_SSE4_1
+#endif
+
+#endif  // LIBGAV1_ENABLE_SSE4_1
+
+#endif  // LIBGAV1_SRC_DSP_X86_DISTANCE_WEIGHTED_BLEND_SSE4_H_
diff --git a/libgav1/src/dsp/x86/intra_edge_sse4.cc b/libgav1/src/dsp/x86/intra_edge_sse4.cc
index 7772234..7a20927 100644
--- a/libgav1/src/dsp/x86/intra_edge_sse4.cc
+++ b/libgav1/src/dsp/x86/intra_edge_sse4.cc
@@ -1,6 +1,5 @@
-#include "src/dsp/x86/intra_edge_sse4.h"
-
 #include "src/dsp/dsp.h"
+#include "src/dsp/intra_edge.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
@@ -143,6 +142,9 @@
   memcpy(edge, buffer, size);
   auto* dst_buffer = static_cast<uint8_t*>(buffer);
 
+  // Only process |size| - 1 elements. Nothing to do in this case.
+  if (size == 1) return;
+
   int i = 0;
   switch (strength) {
     case 1:
diff --git a/libgav1/src/dsp/x86/intra_edge_sse4.h b/libgav1/src/dsp/x86/intra_edge_sse4.h
index abaad40..7803ce9 100644
--- a/libgav1/src/dsp/x86/intra_edge_sse4.h
+++ b/libgav1/src/dsp/x86/intra_edge_sse4.h
@@ -3,13 +3,12 @@
 
 #include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
-#include "src/dsp/intra_edge.h"
 
 namespace libgav1 {
 namespace dsp {
 
-// Initializes intra edge functions with sse4 implementations. This function
-// is not thread-safe.
+// Initializes Dsp::intra_edge_filter and Dsp::intra_edge_upsampler. This
+// function is not thread-safe.
 void IntraEdgeInit_SSE4_1();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc b/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc
index 3744fc6..4d97fb0 100644
--- a/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc
+++ b/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc
@@ -1,12 +1,11 @@
 #include "src/dsp/dsp.h"
-#include "src/dsp/x86/intrapred_sse4.h"
+#include "src/dsp/intrapred.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
-#include <emmintrin.h>
 #include <smmintrin.h>
-#include <xmmintrin.h>
 
+#include <algorithm>
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
@@ -168,6 +167,7 @@
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int /*max_luma_width*/, const int /*max_luma_height*/,
     const void* const source, ptrdiff_t stride) {
+  static_assert(block_width_log2 == 4 || block_width_log2 == 5, "");
   const int block_height = 1 << block_height_log2;
   const int block_width = 1 << block_width_log2;
   const __m128i dup16 = _mm_set1_epi32(0x01000100);
@@ -214,9 +214,434 @@
   }
 }
 
+// Takes in two sums of input row pairs, and completes the computation for two
+// output rows.
+inline __m128i StoreLumaResults4_420(const __m128i vertical_sum0,
+                                     const __m128i vertical_sum1,
+                                     int16_t* luma_ptr) {
+  __m128i result = _mm_hadd_epi16(vertical_sum0, vertical_sum1);
+  result = _mm_slli_epi16(result, 1);
+  StoreLo8(luma_ptr, result);
+  StoreHi8(luma_ptr + kCflLumaBufferStride, result);
+  return result;
+}
+
+// Takes two halves of a vertically added pair of rows and completes the
+// computation for one output row.
+inline __m128i StoreLumaResults8_420(const __m128i vertical_sum0,
+                                     const __m128i vertical_sum1,
+                                     int16_t* luma_ptr) {
+  __m128i result = _mm_hadd_epi16(vertical_sum0, vertical_sum1);
+  result = _mm_slli_epi16(result, 1);
+  StoreUnaligned16(luma_ptr, result);
+  return result;
+}
+
+template <int block_height_log2>
+void CflSubsampler420_4xH_SSE4_1(
+    int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+    const int /*max_luma_width*/, const int max_luma_height,
+    const void* const source, ptrdiff_t stride) {
+  const int block_height = 1 << block_height_log2;
+  const auto* src = static_cast<const uint8_t*>(source);
+  int16_t* luma_ptr = luma[0];
+  const __m128i zero = _mm_setzero_si128();
+  __m128i final_sum = zero;
+  const int luma_height = std::min(block_height, max_luma_height >> 1);
+  int y = 0;
+  do {
+    // Note that double sampling and converting to 16bit makes a row fill the
+    // vector.
+    const __m128i samples_row0 = _mm_cvtepu8_epi16(LoadLo8(src));
+    src += stride;
+    const __m128i samples_row1 = _mm_cvtepu8_epi16(LoadLo8(src));
+    src += stride;
+    const __m128i luma_sum01 = _mm_add_epi16(samples_row0, samples_row1);
+
+    const __m128i samples_row2 = _mm_cvtepu8_epi16(LoadLo8(src));
+    src += stride;
+    const __m128i samples_row3 = _mm_cvtepu8_epi16(LoadLo8(src));
+    src += stride;
+    const __m128i luma_sum23 = _mm_add_epi16(samples_row2, samples_row3);
+    __m128i sum = StoreLumaResults4_420(luma_sum01, luma_sum23, luma_ptr);
+    luma_ptr += kCflLumaBufferStride << 1;
+
+    const __m128i samples_row4 = _mm_cvtepu8_epi16(LoadLo8(src));
+    src += stride;
+    const __m128i samples_row5 = _mm_cvtepu8_epi16(LoadLo8(src));
+    src += stride;
+    const __m128i luma_sum45 = _mm_add_epi16(samples_row4, samples_row5);
+
+    const __m128i samples_row6 = _mm_cvtepu8_epi16(LoadLo8(src));
+    src += stride;
+    const __m128i samples_row7 = _mm_cvtepu8_epi16(LoadLo8(src));
+    src += stride;
+    const __m128i luma_sum67 = _mm_add_epi16(samples_row6, samples_row7);
+    sum = _mm_add_epi16(
+        sum, StoreLumaResults4_420(luma_sum45, luma_sum67, luma_ptr));
+    luma_ptr += kCflLumaBufferStride << 1;
+
+    final_sum = _mm_add_epi32(final_sum, _mm_cvtepu16_epi32(sum));
+    final_sum = _mm_add_epi32(final_sum, _mm_unpackhi_epi16(sum, zero));
+    y += 4;
+  } while (y < luma_height);
+  const __m128i final_fill = LoadLo8(luma_ptr - kCflLumaBufferStride);
+  const __m128i final_fill_to_sum = _mm_cvtepu16_epi32(final_fill);
+  for (; y < block_height; ++y) {
+    StoreLo8(luma_ptr, final_fill);
+    luma_ptr += kCflLumaBufferStride;
+
+    final_sum = _mm_add_epi32(final_sum, final_fill_to_sum);
+  }
+  final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 8));
+  final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 4));
+
+  __m128i averages = RightShiftWithRounding_U32(
+      final_sum, block_height_log2 + 2 /*log2 of width 4*/);
+
+  averages = _mm_shufflelo_epi16(averages, 0);
+  luma_ptr = luma[0];
+  for (int y = 0; y < block_height; ++y, luma_ptr += kCflLumaBufferStride) {
+    const __m128i samples = LoadLo8(luma_ptr);
+    StoreLo8(luma_ptr, _mm_sub_epi16(samples, averages));
+  }
+}
+
+// This duplicates the last two 16-bit values in |row|.
+inline __m128i LastRowSamples(const __m128i row) {
+  return _mm_shuffle_epi32(row, 0xFF);
+}
+
+// This duplicates the last 16-bit value in |row|.
+inline __m128i LastRowResult(const __m128i row) {
+  const __m128i dup_row = _mm_shufflehi_epi16(row, 0xFF);
+  return _mm_shuffle_epi32(dup_row, 0xFF);
+}
+
+template <int block_height_log2, int max_luma_width>
+inline void CflSubsampler420Impl_8xH_SSE4_1(
+    int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+    const int /*max_luma_width*/, const int max_luma_height,
+    const void* const source, ptrdiff_t stride) {
+  const int block_height = 1 << block_height_log2;
+  const auto* src = static_cast<const uint8_t*>(source);
+  const __m128i zero = _mm_setzero_si128();
+  __m128i final_sum = zero;
+  int16_t* luma_ptr = luma[0];
+  const int luma_height = std::min(block_height, max_luma_height >> 1);
+  int y = 0;
+
+  do {
+    const __m128i samples_row00 = _mm_cvtepu8_epi16(LoadLo8(src));
+    const __m128i samples_row01 = (max_luma_width == 16)
+                                      ? _mm_cvtepu8_epi16(LoadLo8(src + 8))
+                                      : LastRowSamples(samples_row00);
+    src += stride;
+    const __m128i samples_row10 = _mm_cvtepu8_epi16(LoadLo8(src));
+    const __m128i samples_row11 = (max_luma_width == 16)
+                                      ? _mm_cvtepu8_epi16(LoadLo8(src + 8))
+                                      : LastRowSamples(samples_row10);
+    src += stride;
+    const __m128i luma_sum00 = _mm_add_epi16(samples_row00, samples_row10);
+    const __m128i luma_sum01 = _mm_add_epi16(samples_row01, samples_row11);
+    __m128i sum = StoreLumaResults8_420(luma_sum00, luma_sum01, luma_ptr);
+    luma_ptr += kCflLumaBufferStride;
+
+    const __m128i samples_row20 = _mm_cvtepu8_epi16(LoadLo8(src));
+    const __m128i samples_row21 = (max_luma_width == 16)
+                                      ? _mm_cvtepu8_epi16(LoadLo8(src + 8))
+                                      : LastRowSamples(samples_row20);
+    src += stride;
+    const __m128i samples_row30 = _mm_cvtepu8_epi16(LoadLo8(src));
+    const __m128i samples_row31 = (max_luma_width == 16)
+                                      ? _mm_cvtepu8_epi16(LoadLo8(src + 8))
+                                      : LastRowSamples(samples_row30);
+    src += stride;
+    const __m128i luma_sum10 = _mm_add_epi16(samples_row20, samples_row30);
+    const __m128i luma_sum11 = _mm_add_epi16(samples_row21, samples_row31);
+    sum = _mm_add_epi16(
+        sum, StoreLumaResults8_420(luma_sum10, luma_sum11, luma_ptr));
+    luma_ptr += kCflLumaBufferStride;
+
+    const __m128i samples_row40 = _mm_cvtepu8_epi16(LoadLo8(src));
+    const __m128i samples_row41 = (max_luma_width == 16)
+                                      ? _mm_cvtepu8_epi16(LoadLo8(src + 8))
+                                      : LastRowSamples(samples_row40);
+    src += stride;
+    const __m128i samples_row50 = _mm_cvtepu8_epi16(LoadLo8(src));
+    const __m128i samples_row51 = (max_luma_width == 16)
+                                      ? _mm_cvtepu8_epi16(LoadLo8(src + 8))
+                                      : LastRowSamples(samples_row50);
+    src += stride;
+    const __m128i luma_sum20 = _mm_add_epi16(samples_row40, samples_row50);
+    const __m128i luma_sum21 = _mm_add_epi16(samples_row41, samples_row51);
+    sum = _mm_add_epi16(
+        sum, StoreLumaResults8_420(luma_sum20, luma_sum21, luma_ptr));
+    luma_ptr += kCflLumaBufferStride;
+
+    const __m128i samples_row60 = _mm_cvtepu8_epi16(LoadLo8(src));
+    const __m128i samples_row61 = (max_luma_width == 16)
+                                      ? _mm_cvtepu8_epi16(LoadLo8(src + 8))
+                                      : LastRowSamples(samples_row60);
+    src += stride;
+    const __m128i samples_row70 = _mm_cvtepu8_epi16(LoadLo8(src));
+    const __m128i samples_row71 = (max_luma_width == 16)
+                                      ? _mm_cvtepu8_epi16(LoadLo8(src + 8))
+                                      : LastRowSamples(samples_row70);
+    src += stride;
+    const __m128i luma_sum30 = _mm_add_epi16(samples_row60, samples_row70);
+    const __m128i luma_sum31 = _mm_add_epi16(samples_row61, samples_row71);
+    sum = _mm_add_epi16(
+        sum, StoreLumaResults8_420(luma_sum30, luma_sum31, luma_ptr));
+    luma_ptr += kCflLumaBufferStride;
+
+    final_sum = _mm_add_epi32(final_sum, _mm_cvtepu16_epi32(sum));
+    final_sum = _mm_add_epi32(final_sum, _mm_unpackhi_epi16(sum, zero));
+    y += 4;
+  } while (y < luma_height);
+  // Duplicate the final row downward to the end after max_luma_height.
+  const __m128i final_fill = LoadUnaligned16(luma_ptr - kCflLumaBufferStride);
+  const __m128i final_fill_to_sum0 = _mm_cvtepi16_epi32(final_fill);
+  const __m128i final_fill_to_sum1 =
+      _mm_cvtepi16_epi32(_mm_srli_si128(final_fill, 8));
+  const __m128i final_fill_to_sum =
+      _mm_add_epi32(final_fill_to_sum0, final_fill_to_sum1);
+  for (; y < block_height; ++y) {
+    StoreUnaligned16(luma_ptr, final_fill);
+    luma_ptr += kCflLumaBufferStride;
+
+    final_sum = _mm_add_epi32(final_sum, final_fill_to_sum);
+  }
+  final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 8));
+  final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 4));
+
+  __m128i averages = RightShiftWithRounding_S32(
+      final_sum, block_height_log2 + 3 /*log2 of width 8*/);
+
+  averages = _mm_shufflelo_epi16(averages, 0);
+  averages = _mm_shuffle_epi32(averages, 0);
+  luma_ptr = luma[0];
+  for (int y = 0; y < block_height; ++y, luma_ptr += kCflLumaBufferStride) {
+    const __m128i samples = LoadUnaligned16(luma_ptr);
+    StoreUnaligned16(luma_ptr, _mm_sub_epi16(samples, averages));
+  }
+}
+
+template <int block_height_log2>
+void CflSubsampler420_8xH_SSE4_1(
+    int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+    const int max_luma_width, const int max_luma_height,
+    const void* const source, ptrdiff_t stride) {
+  if (max_luma_width == 8) {
+    CflSubsampler420Impl_8xH_SSE4_1<block_height_log2, 8>(
+        luma, max_luma_width, max_luma_height, source, stride);
+  } else {
+    CflSubsampler420Impl_8xH_SSE4_1<block_height_log2, 16>(
+        luma, max_luma_width, max_luma_height, source, stride);
+  }
+}
+
+template <int block_width_log2, int block_height_log2, int max_luma_width>
+inline void CflSubsampler420Impl_WxH_SSE4_1(
+    int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+    const int /*max_luma_width*/, const int max_luma_height,
+    const void* const source, ptrdiff_t stride) {
+  const auto* src = static_cast<const uint8_t*>(source);
+  const __m128i zero = _mm_setzero_si128();
+  __m128i final_sum = zero;
+  const int block_height = 1 << block_height_log2;
+  const int luma_height = std::min(block_height, max_luma_height >> 1);
+
+  int16_t* luma_ptr = luma[0];
+  __m128i final_row_result;
+  // Begin first y section, covering width up to 16.
+  int y = 0;
+  do {
+    const uint8_t* src_next = src + stride;
+    const __m128i samples_row0_lo = LoadUnaligned16(src);
+    const __m128i samples_row00 = _mm_cvtepu8_epi16(samples_row0_lo);
+    const __m128i samples_row01 = (max_luma_width >= 16)
+                                      ? _mm_unpackhi_epi8(samples_row0_lo, zero)
+                                      : LastRowSamples(samples_row00);
+    const __m128i samples_row0_hi = LoadUnaligned16(src + 16);
+    const __m128i samples_row02 = (max_luma_width >= 24)
+                                      ? _mm_cvtepu8_epi16(samples_row0_hi)
+                                      : LastRowSamples(samples_row01);
+    const __m128i samples_row03 = (max_luma_width == 32)
+                                      ? _mm_unpackhi_epi8(samples_row0_hi, zero)
+                                      : LastRowSamples(samples_row02);
+    const __m128i samples_row1_lo = LoadUnaligned16(src_next);
+    const __m128i samples_row10 = _mm_cvtepu8_epi16(samples_row1_lo);
+    const __m128i samples_row11 = (max_luma_width >= 16)
+                                      ? _mm_unpackhi_epi8(samples_row1_lo, zero)
+                                      : LastRowSamples(samples_row10);
+    const __m128i samples_row1_hi = LoadUnaligned16(src_next + 16);
+    const __m128i samples_row12 = (max_luma_width >= 24)
+                                      ? _mm_cvtepu8_epi16(samples_row1_hi)
+                                      : LastRowSamples(samples_row11);
+    const __m128i samples_row13 = (max_luma_width == 32)
+                                      ? _mm_unpackhi_epi8(samples_row1_hi, zero)
+                                      : LastRowSamples(samples_row12);
+    const __m128i luma_sum0 = _mm_add_epi16(samples_row00, samples_row10);
+    const __m128i luma_sum1 = _mm_add_epi16(samples_row01, samples_row11);
+    const __m128i luma_sum2 = _mm_add_epi16(samples_row02, samples_row12);
+    const __m128i luma_sum3 = _mm_add_epi16(samples_row03, samples_row13);
+    __m128i sum = StoreLumaResults8_420(luma_sum0, luma_sum1, luma_ptr);
+    final_row_result =
+        StoreLumaResults8_420(luma_sum2, luma_sum3, luma_ptr + 8);
+    sum = _mm_add_epi16(sum, final_row_result);
+    final_sum = _mm_add_epi32(final_sum, _mm_cvtepu16_epi32(sum));
+    final_sum = _mm_add_epi32(final_sum, _mm_unpackhi_epi16(sum, zero));
+    src += stride << 1;
+    luma_ptr += kCflLumaBufferStride;
+  } while (++y < luma_height);
+
+  // Because max_luma_width is at most 32, any values beyond x=16 will
+  // necessarily be duplicated.
+  if (block_width_log2 == 5) {
+    const __m128i wide_fill = LastRowResult(final_row_result);
+    // Multiply duplicated value by number of occurrences, height * 4, since
+    // there are 16 in each row and the value appears in the vector 4 times.
+    final_sum = _mm_add_epi32(
+        final_sum,
+        _mm_slli_epi32(_mm_cvtepi16_epi32(wide_fill), block_height_log2 + 2));
+  }
+
+  // Begin second y section.
+  if (y < block_height) {
+    const __m128i final_fill0 =
+        LoadUnaligned16(luma_ptr - kCflLumaBufferStride);
+    const __m128i final_fill1 =
+        LoadUnaligned16(luma_ptr - kCflLumaBufferStride + 8);
+    const __m128i final_inner_sum = _mm_add_epi16(final_fill0, final_fill1);
+    const __m128i final_inner_sum0 = _mm_cvtepu16_epi32(final_inner_sum);
+    const __m128i final_inner_sum1 = _mm_unpackhi_epi16(final_inner_sum, zero);
+    const __m128i final_fill_to_sum =
+        _mm_add_epi32(final_inner_sum0, final_inner_sum1);
+
+    do {
+      StoreUnaligned16(luma_ptr, final_fill0);
+      StoreUnaligned16(luma_ptr + 8, final_fill1);
+      luma_ptr += kCflLumaBufferStride;
+
+      final_sum = _mm_add_epi32(final_sum, final_fill_to_sum);
+    } while (++y < block_height);
+  }  // End second y section.
+
+  final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 8));
+  final_sum = _mm_add_epi32(final_sum, _mm_srli_si128(final_sum, 4));
+
+  __m128i averages = RightShiftWithRounding_S32(
+      final_sum, block_width_log2 + block_height_log2);
+  averages = _mm_shufflelo_epi16(averages, 0);
+  averages = _mm_shuffle_epi32(averages, 0);
+
+  luma_ptr = luma[0];
+  for (int y = 0; y < block_height; ++y, luma_ptr += kCflLumaBufferStride) {
+    const __m128i samples0 = LoadUnaligned16(luma_ptr);
+    StoreUnaligned16(luma_ptr, _mm_sub_epi16(samples0, averages));
+    const __m128i samples1 = LoadUnaligned16(luma_ptr + 8);
+    final_row_result = _mm_sub_epi16(samples1, averages);
+    StoreUnaligned16(luma_ptr + 8, final_row_result);
+  }
+  if (block_width_log2 == 5) {
+    int16_t* wide_luma_ptr = luma[0] + 16;
+    const __m128i wide_fill = LastRowResult(final_row_result);
+    for (int i = 0; i < block_height;
+         ++i, wide_luma_ptr += kCflLumaBufferStride) {
+      StoreUnaligned16(wide_luma_ptr, wide_fill);
+      StoreUnaligned16(wide_luma_ptr + 8, wide_fill);
+    }
+  }
+}
+
+template <int block_width_log2, int block_height_log2>
+void CflSubsampler420_WxH_SSE4_1(
+    int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
+    const int max_luma_width, const int max_luma_height,
+    const void* const source, ptrdiff_t stride) {
+  switch (max_luma_width) {
+    case 8:
+      CflSubsampler420Impl_WxH_SSE4_1<block_width_log2, block_height_log2, 8>(
+          luma, max_luma_width, max_luma_height, source, stride);
+      return;
+    case 16:
+      CflSubsampler420Impl_WxH_SSE4_1<block_width_log2, block_height_log2, 16>(
+          luma, max_luma_width, max_luma_height, source, stride);
+      return;
+    case 24:
+      CflSubsampler420Impl_WxH_SSE4_1<block_width_log2, block_height_log2, 24>(
+          luma, max_luma_width, max_luma_height, source, stride);
+      return;
+    default:
+      assert(max_luma_width == 32);
+      CflSubsampler420Impl_WxH_SSE4_1<block_width_log2, block_height_log2, 32>(
+          luma, max_luma_width, max_luma_height, source, stride);
+      return;
+  }
+}
+
 void Init8bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
   assert(dsp != nullptr);
+#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_CflSubsampler420)
+  dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType420] =
+      CflSubsampler420_4xH_SSE4_1<2>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x8_CflSubsampler420)
+  dsp->cfl_subsamplers[kTransformSize4x8][kSubsamplingType420] =
+      CflSubsampler420_4xH_SSE4_1<3>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x16_CflSubsampler420)
+  dsp->cfl_subsamplers[kTransformSize4x16][kSubsamplingType420] =
+      CflSubsampler420_4xH_SSE4_1<4>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x4_CflSubsampler420)
+  dsp->cfl_subsamplers[kTransformSize8x4][kSubsamplingType420] =
+      CflSubsampler420_8xH_SSE4_1<2>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x8_CflSubsampler420)
+  dsp->cfl_subsamplers[kTransformSize8x8][kSubsamplingType420] =
+      CflSubsampler420_8xH_SSE4_1<3>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x16_CflSubsampler420)
+  dsp->cfl_subsamplers[kTransformSize8x16][kSubsamplingType420] =
+      CflSubsampler420_8xH_SSE4_1<4>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(TransformSize8x32_CflSubsampler420)
+  dsp->cfl_subsamplers[kTransformSize8x32][kSubsamplingType420] =
+      CflSubsampler420_8xH_SSE4_1<5>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x4_CflSubsampler420)
+  dsp->cfl_subsamplers[kTransformSize16x4][kSubsamplingType420] =
+      CflSubsampler420_WxH_SSE4_1<4, 2>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x8_CflSubsampler420)
+  dsp->cfl_subsamplers[kTransformSize16x8][kSubsamplingType420] =
+      CflSubsampler420_WxH_SSE4_1<4, 3>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x16_CflSubsampler420)
+  dsp->cfl_subsamplers[kTransformSize16x16][kSubsamplingType420] =
+      CflSubsampler420_WxH_SSE4_1<4, 4>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(TransformSize16x32_CflSubsampler420)
+  dsp->cfl_subsamplers[kTransformSize16x32][kSubsamplingType420] =
+      CflSubsampler420_WxH_SSE4_1<4, 5>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x8_CflSubsampler420)
+  dsp->cfl_subsamplers[kTransformSize32x8][kSubsamplingType420] =
+      CflSubsampler420_WxH_SSE4_1<5, 3>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x16_CflSubsampler420)
+  dsp->cfl_subsamplers[kTransformSize32x16][kSubsamplingType420] =
+      CflSubsampler420_WxH_SSE4_1<5, 4>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(TransformSize32x32_CflSubsampler420)
+  dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType420] =
+      CflSubsampler420_WxH_SSE4_1<5, 5>;
+#endif
+  // TODO(b/137035169): enable these once test vector mismatches are fixed.
+#if 0
 #if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_CflSubsampler444)
   dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType444] =
       CflSubsampler444_4xH_SSE4_1<2>;
@@ -273,6 +698,7 @@
   dsp->cfl_subsamplers[kTransformSize32x32][kSubsamplingType444] =
       CflSubsampler444_SSE4_1<5, 5>;
 #endif
+#endif
 #if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_CflIntraPredictor)
   dsp->cfl_intra_predictors[kTransformSize4x4] = CflIntraPredictor_SSE4_1<4, 4>;
 #endif
diff --git a/libgav1/src/dsp/x86/intrapred_smooth_sse4.cc b/libgav1/src/dsp/x86/intrapred_smooth_sse4.cc
index 86504d7..eefe8c6 100644
--- a/libgav1/src/dsp/x86/intrapred_smooth_sse4.cc
+++ b/libgav1/src/dsp/x86/intrapred_smooth_sse4.cc
@@ -1,7 +1,5 @@
-#include "src/dsp/x86/intrapred_smooth_sse4.h"
-
 #include "src/dsp/dsp.h"
-#include "src/dsp/x86/intrapred_sse4.h"
+#include "src/dsp/intrapred.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
@@ -146,7 +144,7 @@
 // pixels[1]: left vector
 // pixels[2]: right_pred vector
 inline void LoadSmoothPixels4(const uint8_t* above, const uint8_t* left,
-                              int height, __m128i* pixels) {
+                              const int height, __m128i* pixels) {
   if (height == 4) {
     pixels[1] = Load4(left);
   } else if (height == 8) {
@@ -542,7 +540,7 @@
   }
 }
 
-void SmoothHorizontal4x4_SSE4_1(void* dest, ptrdiff_t stride,
+void SmoothHorizontal4x4_SSE4_1(void* dest, const ptrdiff_t stride,
                                 const void* top_row, const void* left_column) {
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi32(top_ptr[3]);
@@ -563,7 +561,7 @@
   WriteSmoothHorizontalSum4<0xFF>(dst, left, weights, scaled_top_right, scale);
 }
 
-void SmoothHorizontal4x8_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal4x8_SSE4_1(void* const dest, const ptrdiff_t stride,
                                 const void* const top_row,
                                 const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -595,7 +593,7 @@
   WriteSmoothHorizontalSum4<0xFF>(dst, left, weights, scaled_top_right, scale);
 }
 
-void SmoothHorizontal4x16_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal4x16_SSE4_1(void* const dest, const ptrdiff_t stride,
                                  const void* const top_row,
                                  const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -647,7 +645,7 @@
   WriteSmoothHorizontalSum4<0xFF>(dst, left, weights, scaled_top_right, scale);
 }
 
-void SmoothHorizontal8x4_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal8x4_SSE4_1(void* const dest, const ptrdiff_t stride,
                                 const void* const top_row,
                                 const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -676,7 +674,7 @@
   WriteSmoothDirectionalSum8(dst, left_y, weights, scaled_top_right, scale);
 }
 
-void SmoothHorizontal8x8_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal8x8_SSE4_1(void* const dest, const ptrdiff_t stride,
                                 const void* const top_row,
                                 const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -696,7 +694,7 @@
   }
 }
 
-void SmoothHorizontal8x16_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal8x16_SSE4_1(void* const dest, const ptrdiff_t stride,
                                  const void* const top_row,
                                  const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -724,7 +722,7 @@
   }
 }
 
-void SmoothHorizontal8x32_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal8x32_SSE4_1(void* const dest, const ptrdiff_t stride,
                                  const void* const top_row,
                                  const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -766,7 +764,7 @@
   }
 }
 
-void SmoothHorizontal16x4_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal16x4_SSE4_1(void* const dest, const ptrdiff_t stride,
                                  const void* const top_row,
                                  const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -805,7 +803,7 @@
                               scaled_top_right1, scaled_top_right2, scale);
 }
 
-void SmoothHorizontal16x8_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal16x8_SSE4_1(void* const dest, const ptrdiff_t stride,
                                  const void* const top_row,
                                  const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -832,7 +830,7 @@
   }
 }
 
-void SmoothHorizontal16x16_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal16x16_SSE4_1(void* const dest, const ptrdiff_t stride,
                                   const void* const top_row,
                                   const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -868,7 +866,7 @@
   }
 }
 
-void SmoothHorizontal16x32_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal16x32_SSE4_1(void* const dest, const ptrdiff_t stride,
                                   const void* const top_row,
                                   const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -920,7 +918,7 @@
   }
 }
 
-void SmoothHorizontal16x64_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal16x64_SSE4_1(void* const dest, const ptrdiff_t stride,
                                   const void* const top_row,
                                   const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -950,7 +948,7 @@
   }
 }
 
-void SmoothHorizontal32x8_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal32x8_SSE4_1(void* const dest, const ptrdiff_t stride,
                                  const void* const top_row,
                                  const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -988,7 +986,7 @@
   }
 }
 
-void SmoothHorizontal32x16_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal32x16_SSE4_1(void* const dest, const ptrdiff_t stride,
                                   const void* const top_row,
                                   const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -1037,7 +1035,7 @@
   }
 }
 
-void SmoothHorizontal32x32_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal32x32_SSE4_1(void* const dest, const ptrdiff_t stride,
                                   const void* const top_row,
                                   const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -1106,7 +1104,7 @@
   }
 }
 
-void SmoothHorizontal32x64_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal32x64_SSE4_1(void* const dest, const ptrdiff_t stride,
                                   const void* const top_row,
                                   const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -1147,7 +1145,7 @@
   }
 }
 
-void SmoothHorizontal64x16_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal64x16_SSE4_1(void* const dest, const ptrdiff_t stride,
                                   const void* const top_row,
                                   const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -1222,7 +1220,7 @@
   }
 }
 
-void SmoothHorizontal64x32_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal64x32_SSE4_1(void* const dest, const ptrdiff_t stride,
                                   const void* const top_row,
                                   const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -1325,7 +1323,7 @@
   }
 }
 
-void SmoothHorizontal64x64_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothHorizontal64x64_SSE4_1(void* const dest, const ptrdiff_t stride,
                                   const void* const top_row,
                                   const void* const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
@@ -1389,7 +1387,7 @@
 }
 
 inline void LoadSmoothVerticalPixels4(const uint8_t* above, const uint8_t* left,
-                                      int height, __m128i* pixels) {
+                                      const int height, __m128i* pixels) {
   __m128i top = Load4(above);
   const __m128i bottom_left = _mm_set1_epi16(left[height - 1]);
   top = _mm_cvtepu8_epi16(top);
@@ -1400,8 +1398,8 @@
 // (256-w) counterparts. This is precomputed by the compiler when the weights
 // table is visible to this module. Removing this visibility can cut speed by up
 // to half in both 4xH and 8xH transforms.
-inline void LoadSmoothVerticalWeights4(const uint8_t* weight_array, int height,
-                                       __m128i* weights) {
+inline void LoadSmoothVerticalWeights4(const uint8_t* weight_array,
+                                       const int height, __m128i* weights) {
   const __m128i inverter = _mm_set1_epi16(256);
 
   if (height == 4) {
@@ -1423,7 +1421,8 @@
 }
 
 inline void WriteSmoothVertical4xH(const __m128i* pixel, const __m128i* weight,
-                                   int height, uint8_t* dst, ptrdiff_t stride) {
+                                   const int height, uint8_t* dst,
+                                   const ptrdiff_t stride) {
   const __m128i pred_round = _mm_set1_epi32(128);
   const __m128i mask_increment = _mm_set1_epi16(0x0202);
   const __m128i cvtepu8_epi32 = _mm_set1_epi32(0xC080400);
@@ -1447,7 +1446,7 @@
   }
 }
 
-void SmoothVertical4x4_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical4x4_SSE4_1(void* const dest, const ptrdiff_t stride,
                               const void* const top_row,
                               const void* const left_column) {
   const auto* const left = static_cast<const uint8_t*>(left_column);
@@ -1462,7 +1461,7 @@
   WriteSmoothVertical4xH(&pixels, weights, 4, dst, stride);
 }
 
-void SmoothVertical4x8_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical4x8_SSE4_1(void* const dest, const ptrdiff_t stride,
                               const void* const top_row,
                               const void* const left_column) {
   const auto* const left = static_cast<const uint8_t*>(left_column);
@@ -1477,7 +1476,7 @@
   WriteSmoothVertical4xH(&pixels, weights, 8, dst, stride);
 }
 
-void SmoothVertical4x16_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical4x16_SSE4_1(void* const dest, const ptrdiff_t stride,
                                const void* const top_row,
                                const void* const left_column) {
   const auto* const left = static_cast<const uint8_t*>(left_column);
@@ -1494,7 +1493,7 @@
   WriteSmoothVertical4xH(&pixels, &weights[2], 8, dst, stride);
 }
 
-void SmoothVertical8x4_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical8x4_SSE4_1(void* const dest, const ptrdiff_t stride,
                               const void* const top_row,
                               const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -1529,7 +1528,7 @@
   WriteSmoothDirectionalSum8(dst, top, weights_y, scaled_bottom_left_y, scale);
 }
 
-void SmoothVertical8x8_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical8x8_SSE4_1(void* const dest, const ptrdiff_t stride,
                               const void* const top_row,
                               const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -1553,7 +1552,7 @@
   }
 }
 
-void SmoothVertical8x16_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical8x16_SSE4_1(void* const dest, const ptrdiff_t stride,
                                const void* const top_row,
                                const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -1592,7 +1591,7 @@
   }
 }
 
-void SmoothVertical8x32_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical8x32_SSE4_1(void* const dest, const ptrdiff_t stride,
                                const void* const top_row,
                                const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -1658,7 +1657,7 @@
   }
 }
 
-void SmoothVertical16x4_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical16x4_SSE4_1(void* const dest, const ptrdiff_t stride,
                                const void* const top_row,
                                const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -1703,7 +1702,7 @@
                               scale);
 }
 
-void SmoothVertical16x8_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical16x8_SSE4_1(void* const dest, const ptrdiff_t stride,
                                const void* const top_row,
                                const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -1731,7 +1730,7 @@
   }
 }
 
-void SmoothVertical16x16_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical16x16_SSE4_1(void* const dest, const ptrdiff_t stride,
                                 const void* const top_row,
                                 const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -1775,7 +1774,7 @@
   }
 }
 
-void SmoothVertical16x32_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical16x32_SSE4_1(void* const dest, const ptrdiff_t stride,
                                 const void* const top_row,
                                 const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -1848,7 +1847,7 @@
   }
 }
 
-void SmoothVertical16x64_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical16x64_SSE4_1(void* const dest, const ptrdiff_t stride,
                                 const void* const top_row,
                                 const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -1896,7 +1895,7 @@
   }
 }
 
-void SmoothVertical32x8_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical32x8_SSE4_1(void* const dest, const ptrdiff_t stride,
                                const void* const top_row,
                                const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -1931,7 +1930,7 @@
   }
 }
 
-void SmoothVertical32x16_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical32x16_SSE4_1(void* const dest, const ptrdiff_t stride,
                                 const void* const top_row,
                                 const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -1984,7 +1983,7 @@
   }
 }
 
-void SmoothVertical32x32_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical32x32_SSE4_1(void* const dest, const ptrdiff_t stride,
                                 const void* const top_row,
                                 const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -2072,7 +2071,7 @@
   }
 }
 
-void SmoothVertical32x64_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical32x64_SSE4_1(void* const dest, const ptrdiff_t stride,
                                 const void* const top_row,
                                 const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -2129,7 +2128,7 @@
   }
 }
 
-void SmoothVertical64x16_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical64x16_SSE4_1(void* const dest, const ptrdiff_t stride,
                                 const void* const top_row,
                                 const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -2201,7 +2200,7 @@
   }
 }
 
-void SmoothVertical64x32_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical64x32_SSE4_1(void* const dest, const ptrdiff_t stride,
                                 const void* const top_row,
                                 const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -2320,7 +2319,7 @@
   }
 }
 
-void SmoothVertical64x64_SSE4_1(void* const dest, ptrdiff_t stride,
+void SmoothVertical64x64_SSE4_1(void* const dest, const ptrdiff_t stride,
                                 const void* const top_row,
                                 const void* const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
diff --git a/libgav1/src/dsp/x86/intrapred_smooth_sse4.h b/libgav1/src/dsp/x86/intrapred_smooth_sse4.h
deleted file mode 100644
index 35c4a4e..0000000
--- a/libgav1/src/dsp/x86/intrapred_smooth_sse4.h
+++ /dev/null
@@ -1,10 +0,0 @@
-#ifndef LIBGAV1_SRC_DSP_X86_INTRAPRED_SMOOTH_SSE4_H_
-#define LIBGAV1_SRC_DSP_X86_INTRAPRED_SMOOTH_SSE4_H_
-namespace libgav1 {
-namespace dsp {
-
-void IntraPredSmoothInit_SSE4_1();
-
-}  // namespace dsp
-}  // namespace libgav1
-#endif  // LIBGAV1_SRC_DSP_X86_INTRAPRED_SMOOTH_SSE4_H_
diff --git a/libgav1/src/dsp/x86/intrapred_sse4.cc b/libgav1/src/dsp/x86/intrapred_sse4.cc
index bf1d939..8651437 100644
--- a/libgav1/src/dsp/x86/intrapred_sse4.cc
+++ b/libgav1/src/dsp/x86/intrapred_sse4.cc
@@ -1,6 +1,5 @@
-#include "src/dsp/x86/intrapred_sse4.h"
-
-#include "src/dsp/x86/intrapred_smooth_sse4.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/intrapred.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
@@ -1417,7 +1416,8 @@
     memcpy(dst, top + offset + 3, width);
     return;
   }
-  for (int y = 0; y < height; y += 8, offset += 8) {
+  int y = 0;
+  do {
     memcpy(dst, top + offset, width);
     dst += stride;
     memcpy(dst, top + offset + 1, width);
@@ -1434,7 +1434,10 @@
     dst += stride;
     memcpy(dst, top + offset + 7, width);
     dst += stride;
-  }
+
+    offset += 8;
+    y += 8;
+  } while (y < height);
 }
 
 inline void DirectionalZone1_4xH(uint8_t* dst, ptrdiff_t stride,
@@ -1454,8 +1457,17 @@
   const __m128i offsets =
       _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001);
 
-  for (int y = 0, top_x = xstep; y < height;
-       ++y, dst += stride, top_x += xstep) {
+  // All rows from |min_corner_only_y| down will simply use memcpy. |max_base_x|
+  // is always greater than |height|, so clipping to 1 is enough to make the
+  // logic work.
+  const int xstep_units = std::max(xstep >> scale_bits, 1);
+  const int min_corner_only_y = std::min(max_base_x / xstep_units, height);
+
+  // Rows up to this y-value can be computed without checking for bounds.
+  int y = 0;
+  int top_x = xstep;
+
+  for (; y < min_corner_only_y; ++y, dst += stride, top_x += xstep) {
     const int top_base_x = top_x >> scale_bits;
 
     // Permit negative values of |top_x|.
@@ -1475,9 +1487,114 @@
     const __m128i past_max = _mm_cmpgt_epi16(top_index_vect, max_base_x_vect);
     __m128i prod = _mm_maddubs_epi16(sampled_values, shifts);
     prod = RightShiftWithRounding_U16(prod, rounding_bits);
+    // Replace pixels from invalid range with top-right corner.
     prod = _mm_blendv_epi8(prod, final_top_val, past_max);
     Store4(dst, _mm_packus_epi16(prod, prod));
   }
+
+  // Fill in corner-only rows.
+  for (; y < height; ++y) {
+    memset(dst, top[max_base_x], /* width */ 4);
+    dst += stride;
+  }
+}
+
+// 7.11.2.4 (7) angle < 90
+inline void DirectionalZone1_Large(uint8_t* dest, ptrdiff_t stride,
+                                   const uint8_t* const top_row,
+                                   const int width, const int height,
+                                   const int xstep, const bool upsampled) {
+  const int upsample_shift = static_cast<int>(upsampled);
+  const __m128i sampler =
+      upsampled ? _mm_set_epi32(0x0F0E0D0C, 0x0B0A0908, 0x07060504, 0x03020100)
+                : _mm_set_epi32(0x08070706, 0x06050504, 0x04030302, 0x02010100);
+  const int scale_bits = 6 - upsample_shift;
+  const int max_base_x = ((width + height) - 1) << upsample_shift;
+
+  const __m128i max_shift = _mm_set1_epi8(32);
+  const int rounding_bits = 5;
+  const int base_step = 1 << upsample_shift;
+  const int base_step8 = base_step << 3;
+
+  // All rows from |min_corner_only_y| down will simply use memcpy. |max_base_x|
+  // is always greater than |height|, so clipping to 1 is enough to make the
+  // logic work.
+  const int xstep_units = std::max(xstep >> scale_bits, 1);
+  const int min_corner_only_y = std::min(max_base_x / xstep_units, height);
+
+  // Rows up to this y-value can be computed without checking for bounds.
+  const int max_no_corner_y = std::min(
+      LeftShift((max_base_x - (base_step * width)), scale_bits) / xstep,
+      height);
+  // No need to check for exceeding |max_base_x| in the first loop.
+  int y = 0;
+  int top_x = xstep;
+  for (; y < max_no_corner_y; ++y, dest += stride, top_x += xstep) {
+    int top_base_x = top_x >> scale_bits;
+    // Permit negative values of |top_x|.
+    const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1;
+    const __m128i shift = _mm_set1_epi8(shift_val);
+    const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift);
+    const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift);
+    int x = 0;
+    do {
+      const __m128i top_vals = LoadUnaligned16(top_row + top_base_x);
+      __m128i vals = _mm_shuffle_epi8(top_vals, sampler);
+      vals = _mm_maddubs_epi16(vals, shifts);
+      vals = RightShiftWithRounding_U16(vals, rounding_bits);
+      StoreLo8(dest + x, _mm_packus_epi16(vals, vals));
+      top_base_x += base_step8;
+      x += 8;
+    } while (x < width);
+  }
+
+  // Each 16-bit value here corresponds to a position that may exceed
+  // |max_base_x|. When added to the top_base_x, it is used to mask values
+  // that pass the end of |top|. Starting from 1 to simulate "cmpge" which is
+  // not supported for packed integers.
+  const __m128i offsets =
+      _mm_set_epi32(0x00080007, 0x00060005, 0x00040003, 0x00020001);
+
+  const __m128i max_base_x_vect = _mm_set1_epi16(max_base_x);
+  const __m128i final_top_val = _mm_set1_epi16(top_row[max_base_x]);
+  const __m128i base_step8_vect = _mm_set1_epi16(base_step8);
+  for (; y < min_corner_only_y; ++y, dest += stride, top_x += xstep) {
+    int top_base_x = top_x >> scale_bits;
+
+    const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1;
+    const __m128i shift = _mm_set1_epi8(shift_val);
+    const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift);
+    const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift);
+    __m128i top_index_vect = _mm_set1_epi16(top_base_x);
+    top_index_vect = _mm_add_epi16(top_index_vect, offsets);
+
+    int x = 0;
+    const int min_corner_only_x =
+        std::min(width, ((max_base_x - top_base_x) >> upsample_shift) + 7) & ~7;
+    for (; x < min_corner_only_x;
+         x += 8, top_base_x += base_step8,
+         top_index_vect = _mm_add_epi16(top_index_vect, base_step8_vect)) {
+      const __m128i past_max = _mm_cmpgt_epi16(top_index_vect, max_base_x_vect);
+      // Assuming a buffer zone of 8 bytes at the end of top_row, this prevents
+      // reading out of bounds. If all indices are past max and we don't need to
+      // use the loaded bytes at all, |top_base_x| becomes 0. |top_base_x| will
+      // reset for the next |y|.
+      top_base_x &= ~_mm_cvtsi128_si32(past_max);
+      const __m128i top_vals = LoadUnaligned16(top_row + top_base_x);
+      __m128i vals = _mm_shuffle_epi8(top_vals, sampler);
+      vals = _mm_maddubs_epi16(vals, shifts);
+      vals = RightShiftWithRounding_U16(vals, rounding_bits);
+      vals = _mm_blendv_epi8(vals, final_top_val, past_max);
+      StoreLo8(dest + x, _mm_packus_epi16(vals, vals));
+    }
+    // Corner-only section of the row.
+    memset(dest + x, top_row[max_base_x], width - x);
+  }
+  // Fill in corner-only rows.
+  for (; y < height; ++y) {
+    memset(dest, top_row[max_base_x], width);
+    dest += stride;
+  }
 }
 
 // 7.11.2.4 (7) angle < 90
@@ -1494,6 +1611,11 @@
     DirectionalZone1_4xH(dest, stride, top_row, height, xstep, upsampled);
     return;
   }
+  if (width >= 32) {
+    DirectionalZone1_Large(dest, stride, top_row, width, height, xstep,
+                           upsampled);
+    return;
+  }
   const __m128i sampler =
       upsampled ? _mm_set_epi32(0x0F0E0D0C, 0x0B0A0908, 0x07060504, 0x03020100)
                 : _mm_set_epi32(0x08070706, 0x06050504, 0x04030302, 0x02010100);
@@ -1507,22 +1629,28 @@
 
   // No need to check for exceeding |max_base_x| in the loops.
   if (((xstep * height) >> scale_bits) + base_step * width < max_base_x) {
-    for (int y = 0, top_x = xstep; y < height;
-         ++y, dest += stride, top_x += xstep) {
+    int top_x = xstep;
+    int y = 0;
+    do {
       int top_base_x = top_x >> scale_bits;
       // Permit negative values of |top_x|.
       const int shift_val = (LeftShift(top_x, upsample_shift) & 0x3F) >> 1;
       const __m128i shift = _mm_set1_epi8(shift_val);
       const __m128i opposite_shift = _mm_sub_epi8(max_shift, shift);
       const __m128i shifts = _mm_unpacklo_epi8(opposite_shift, shift);
-      for (int x = 0; x < width; x += 8, top_base_x += base_step8) {
+      int x = 0;
+      do {
         const __m128i top_vals = LoadUnaligned16(top_row + top_base_x);
         __m128i vals = _mm_shuffle_epi8(top_vals, sampler);
         vals = _mm_maddubs_epi16(vals, shifts);
         vals = RightShiftWithRounding_U16(vals, rounding_bits);
         StoreLo8(dest + x, _mm_packus_epi16(vals, vals));
-      }
-    }
+        top_base_x += base_step8;
+        x += 8;
+      } while (x < width);
+      dest += stride;
+      top_x += xstep;
+    } while (++y < height);
     return;
   }
 
@@ -1536,8 +1664,9 @@
   const __m128i max_base_x_vect = _mm_set1_epi16(max_base_x);
   const __m128i final_top_val = _mm_set1_epi16(top_row[max_base_x]);
   const __m128i base_step8_vect = _mm_set1_epi16(base_step8);
-  for (int y = 0, top_x = xstep; y < height;
-       ++y, dest += stride, top_x += xstep) {
+  int top_x = xstep;
+  int y = 0;
+  do {
     int top_base_x = top_x >> scale_bits;
 
     if (top_base_x >= max_base_x) {
@@ -1585,7 +1714,9 @@
     vals = RightShiftWithRounding_U16(vals, rounding_bits);
     vals = _mm_blendv_epi8(vals, final_top_val, past_max);
     StoreLo8(dest + x, _mm_packus_epi16(vals, vals));
-  }
+    dest += stride;
+    top_x += xstep;
+  } while (++y < height);
 }
 
 void DirectionalIntraPredictorZone1_SSE4_1(void* const dest, ptrdiff_t stride,
@@ -1690,42 +1821,70 @@
   if (width == 4 || height == 4) {
     const ptrdiff_t stride4 = stride << 2;
     if (upsampled) {
-      for (int x = 0, left_y = ystep; x < width; x += 4, left_y += ystep << 2) {
+      int left_y = ystep;
+      int x = 0;
+      do {
         uint8_t* dst_x = dst + x;
-        for (int y = 0; y < height; y += 4, dst_x += stride4) {
+        int y = 0;
+        do {
           DirectionalZone3_4x4<true>(
               dst_x, stride, left_ptr + (y << upsample_shift), left_y, ystep);
-        }
-      }
+          dst_x += stride4;
+          y += 4;
+        } while (y < height);
+        left_y += ystep << 2;
+        x += 4;
+      } while (x < width);
     } else {
-      for (int x = 0, left_y = ystep; x < width; x += 4, left_y += ystep << 2) {
+      int left_y = ystep;
+      int x = 0;
+      do {
         uint8_t* dst_x = dst + x;
-        for (int y = 0; y < height; y += 4, dst_x += stride4) {
+        int y = 0;
+        do {
           DirectionalZone3_4x4<false>(dst_x, stride, left_ptr + y, left_y,
                                       ystep);
-        }
-      }
+          dst_x += stride4;
+          y += 4;
+        } while (y < height);
+        left_y += ystep << 2;
+        x += 4;
+      } while (x < width);
     }
     return;
   }
 
   const ptrdiff_t stride8 = stride << 3;
   if (upsampled) {
-    for (int x = 0, left_y = ystep; x < width; x += 8, left_y += ystep << 3) {
+    int left_y = ystep;
+    int x = 0;
+    do {
       uint8_t* dst_x = dst + x;
-      for (int y = 0; y < height; y += 8, dst_x += stride8) {
+      int y = 0;
+      do {
         DirectionalZone3_8xH<true, 8>(
             dst_x, stride, left_ptr + (y << upsample_shift), left_y, ystep);
-      }
-    }
+        dst_x += stride8;
+        y += 8;
+      } while (y < height);
+      left_y += ystep << 3;
+      x += 8;
+    } while (x < width);
   } else {
-    for (int x = 0, left_y = ystep; x < width; x += 8, left_y += ystep << 3) {
+    int left_y = ystep;
+    int x = 0;
+    do {
       uint8_t* dst_x = dst + x;
-      for (int y = 0; y < height; y += 8, dst_x += stride8) {
+      int y = 0;
+      do {
         DirectionalZone3_8xH<false, 8>(
             dst_x, stride, left_ptr + (y << upsample_shift), left_y, ystep);
-      }
-    }
+        dst_x += stride8;
+        y += 8;
+      } while (y < height);
+      left_y += ystep << 3;
+      x += 8;
+    } while (x < width);
   }
 }
 
@@ -1988,7 +2147,7 @@
   // This loop treats each set of 4 columns in 3 stages with y-value boundaries.
   // The first stage, before the first y-loop, covers blocks that are only
   // computed from the top row. The second stage, comprising two y-loops, covers
-  // blocks that have a mixture of values computer from top or left. The final
+  // blocks that have a mixture of values computed from top or left. The final
   // stage covers blocks that are only computed from the left.
   for (int left_offset = -left_base_increment; x < min_top_only_x;
        x += 8,
diff --git a/libgav1/src/dsp/x86/intrapred_sse4.h b/libgav1/src/dsp/x86/intrapred_sse4.h
index bf6e2fe..ec5dfb0 100644
--- a/libgav1/src/dsp/x86/intrapred_sse4.h
+++ b/libgav1/src/dsp/x86/intrapred_sse4.h
@@ -3,15 +3,17 @@
 
 #include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
-#include "src/dsp/intrapred.h"
 
 namespace libgav1 {
 namespace dsp {
 
-// Initializes Dsp::intra_predictors with sse4 implementations. This function
-// is not thread-safe.
+// Initializes Dsp::intra_predictors, Dsp::directional_intra_predictor_zone*,
+// Dsp::cfl_intra_predictors, Dsp::cfl_subsamplers and
+// Dsp::filter_intra_predictor, see the defines below for specifics. These
+// functions are not thread-safe.
 void IntraPredInit_SSE4_1();
 void IntraPredCflInit_SSE4_1();
+void IntraPredSmoothInit_SSE4_1();
 
 }  // namespace dsp
 }  // namespace libgav1
@@ -120,6 +122,64 @@
   LIBGAV1_DSP_SSE4_1
 #endif
 
+#ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#endif
+
+// TODO(b/137035169): enable these once test vector mismatches are fixed.
+#if 0
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444
 #define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 LIBGAV1_DSP_SSE4_1
 #endif
@@ -175,6 +235,7 @@
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444
 #define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 LIBGAV1_DSP_SSE4_1
 #endif
+#endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor
 #define LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor LIBGAV1_DSP_SSE4_1
diff --git a/libgav1/src/dsp/x86/inverse_transform_sse4.cc b/libgav1/src/dsp/x86/inverse_transform_sse4.cc
index 65bc2f3..81f6244 100644
--- a/libgav1/src/dsp/x86/inverse_transform_sse4.cc
+++ b/libgav1/src/dsp/x86/inverse_transform_sse4.cc
@@ -1,4 +1,5 @@
-#include "src/dsp/x86/inverse_transform_sse4.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/inverse_transform.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
@@ -10,6 +11,7 @@
 #include <cstring>
 
 #include "src/dsp/x86/common_sse4.h"
+#include "src/dsp/x86/transpose_sse4.h"
 #include "src/utils/array_2d.h"
 #include "src/utils/common.h"
 #include "src/utils/compiler_attributes.h"
@@ -22,176 +24,6 @@
 // Include the constants and utility functions inside the anonymous namespace.
 #include "src/dsp/inverse_transform.inc"
 
-// TODO(slavarnway): move to transpose_sse4.h
-LIBGAV1_ALWAYS_INLINE void Transpose16_4x4(const __m128i* in, __m128i* out) {
-  const __m128i ba = _mm_unpacklo_epi16(in[0], in[1]);
-  const __m128i dc = _mm_unpacklo_epi16(in[2], in[3]);
-  const __m128i dcba_lo = _mm_unpacklo_epi32(ba, dc);
-  const __m128i dcba_hi = _mm_unpackhi_epi32(ba, dc);
-  out[0] = dcba_lo;
-  out[1] = _mm_srli_si128(dcba_lo, 8);
-  out[2] = dcba_hi;
-  out[3] = _mm_srli_si128(dcba_hi, 8);
-}
-
-// TODO(slavarnway): move to transpose_sse4.h
-LIBGAV1_ALWAYS_INLINE void Transpose16_8x4To4x8(const __m128i* in,
-                                                __m128i* out) {
-  // Unpack 16 bit elements. Goes from:
-  // in[0]: 00 01 02 03  04 05 06 07
-  // in[1]: 10 11 12 13  14 15 16 17
-  // in[2]: 20 21 22 23  24 25 26 27
-  // in[3]: 30 31 32 33  34 35 36 37
-
-  // to:
-  // a0:    00 10 01 11  02 12 03 13
-  // a1:    20 30 21 31  22 32 23 33
-  // a4:    04 14 05 15  06 16 07 17
-  // a5:    24 34 25 35  26 36 27 37
-  const __m128i a0 = _mm_unpacklo_epi16(in[0], in[1]);
-  const __m128i a1 = _mm_unpacklo_epi16(in[2], in[3]);
-  const __m128i a4 = _mm_unpackhi_epi16(in[0], in[1]);
-  const __m128i a5 = _mm_unpackhi_epi16(in[2], in[3]);
-
-  // Unpack 32 bit elements resulting in:
-  // b0: 00 10 20 30  01 11 21 31
-  // b2: 04 14 24 34  05 15 25 35
-  // b4: 02 12 22 32  03 13 23 33
-  // b6: 06 16 26 36  07 17 27 37
-  const __m128i b0 = _mm_unpacklo_epi32(a0, a1);
-  const __m128i b2 = _mm_unpacklo_epi32(a4, a5);
-  const __m128i b4 = _mm_unpackhi_epi32(a0, a1);
-  const __m128i b6 = _mm_unpackhi_epi32(a4, a5);
-
-  // Unpack 64 bit elements resulting in:
-  // out[0]: 00 10 20 30  XX XX XX XX
-  // out[1]: 01 11 21 31  XX XX XX XX
-  // out[2]: 02 12 22 32  XX XX XX XX
-  // out[3]: 03 13 23 33  XX XX XX XX
-  // out[4]: 04 14 24 34  XX XX XX XX
-  // out[5]: 05 15 25 35  XX XX XX XX
-  // out[6]: 06 16 26 36  XX XX XX XX
-  // out[7]: 07 17 27 37  XX XX XX XX
-  const __m128i zeros = _mm_setzero_si128();
-  out[0] = _mm_unpacklo_epi64(b0, zeros);
-  out[1] = _mm_unpackhi_epi64(b0, zeros);
-  out[2] = _mm_unpacklo_epi64(b4, zeros);
-  out[3] = _mm_unpackhi_epi64(b4, zeros);
-  out[4] = _mm_unpacklo_epi64(b2, zeros);
-  out[5] = _mm_unpackhi_epi64(b2, zeros);
-  out[6] = _mm_unpacklo_epi64(b6, zeros);
-  out[7] = _mm_unpackhi_epi64(b6, zeros);
-}
-
-// TODO(slavarnway): move to transpose_sse4.h
-LIBGAV1_ALWAYS_INLINE void Transpose16_4x8To8x4(const __m128i* in,
-                                                __m128i* out) {
-  // Unpack 16 bit elements. Goes from:
-  // in[0]: 00 01 02 03  XX XX XX XX
-  // in[1]: 10 11 12 13  XX XX XX XX
-  // in[2]: 20 21 22 23  XX XX XX XX
-  // in[3]: 30 31 32 33  XX XX XX XX
-  // in[4]: 40 41 42 43  XX XX XX XX
-  // in[5]: 50 51 52 53  XX XX XX XX
-  // in[6]: 60 61 62 63  XX XX XX XX
-  // in[7]: 70 71 72 73  XX XX XX XX
-  // to:
-  // a0:    00 10 01 11  02 12 03 13
-  // a1:    20 30 21 31  22 32 23 33
-  // a2:    40 50 41 51  42 52 43 53
-  // a3:    60 70 61 71  62 72 63 73
-  const __m128i a0 = _mm_unpacklo_epi16(in[0], in[1]);
-  const __m128i a1 = _mm_unpacklo_epi16(in[2], in[3]);
-  const __m128i a2 = _mm_unpacklo_epi16(in[4], in[5]);
-  const __m128i a3 = _mm_unpacklo_epi16(in[6], in[7]);
-
-  // Unpack 32 bit elements resulting in:
-  // b0: 00 10 20 30  01 11 21 31
-  // b1: 40 50 60 70  41 51 61 71
-  // b2: 02 12 22 32  03 13 23 33
-  // b3: 42 52 62 72  43 53 63 73
-  const __m128i b0 = _mm_unpacklo_epi32(a0, a1);
-  const __m128i b1 = _mm_unpacklo_epi32(a2, a3);
-  const __m128i b2 = _mm_unpackhi_epi32(a0, a1);
-  const __m128i b3 = _mm_unpackhi_epi32(a2, a3);
-
-  // Unpack 64 bit elements resulting in:
-  // out[0]: 00 10 20 30  40 50 60 70
-  // out[1]: 01 11 21 31  41 51 61 71
-  // out[2]: 02 12 22 32  42 52 62 72
-  // out[3]: 03 13 23 33  43 53 63 73
-  out[0] = _mm_unpacklo_epi64(b0, b1);
-  out[1] = _mm_unpackhi_epi64(b0, b1);
-  out[2] = _mm_unpacklo_epi64(b2, b3);
-  out[3] = _mm_unpackhi_epi64(b2, b3);
-}
-
-// TODO(slavarnway): move to transpose_sse4.h
-LIBGAV1_ALWAYS_INLINE void Transpose16_8x8(const __m128i* in, __m128i* out) {
-  // Unpack 16 bit elements. Goes from:
-  // in[0]: 00 01 02 03  04 05 06 07
-  // in[1]: 10 11 12 13  14 15 16 17
-  // in[2]: 20 21 22 23  24 25 26 27
-  // in[3]: 30 31 32 33  34 35 36 37
-  // in[4]: 40 41 42 43  44 45 46 47
-  // in[5]: 50 51 52 53  54 55 56 57
-  // in[6]: 60 61 62 63  64 65 66 67
-  // in[7]: 70 71 72 73  74 75 76 77
-  // to:
-  // a0:    00 10 01 11  02 12 03 13
-  // a1:    20 30 21 31  22 32 23 33
-  // a2:    40 50 41 51  42 52 43 53
-  // a3:    60 70 61 71  62 72 63 73
-  // a4:    04 14 05 15  06 16 07 17
-  // a5:    24 34 25 35  26 36 27 37
-  // a6:    44 54 45 55  46 56 47 57
-  // a7:    64 74 65 75  66 76 67 77
-  const __m128i a0 = _mm_unpacklo_epi16(in[0], in[1]);
-  const __m128i a1 = _mm_unpacklo_epi16(in[2], in[3]);
-  const __m128i a2 = _mm_unpacklo_epi16(in[4], in[5]);
-  const __m128i a3 = _mm_unpacklo_epi16(in[6], in[7]);
-  const __m128i a4 = _mm_unpackhi_epi16(in[0], in[1]);
-  const __m128i a5 = _mm_unpackhi_epi16(in[2], in[3]);
-  const __m128i a6 = _mm_unpackhi_epi16(in[4], in[5]);
-  const __m128i a7 = _mm_unpackhi_epi16(in[6], in[7]);
-
-  // Unpack 32 bit elements resulting in:
-  // b0: 00 10 20 30  01 11 21 31
-  // b1: 40 50 60 70  41 51 61 71
-  // b2: 04 14 24 34  05 15 25 35
-  // b3: 44 54 64 74  45 55 65 75
-  // b4: 02 12 22 32  03 13 23 33
-  // b5: 42 52 62 72  43 53 63 73
-  // b6: 06 16 26 36  07 17 27 37
-  // b7: 46 56 66 76  47 57 67 77
-  const __m128i b0 = _mm_unpacklo_epi32(a0, a1);
-  const __m128i b1 = _mm_unpacklo_epi32(a2, a3);
-  const __m128i b2 = _mm_unpacklo_epi32(a4, a5);
-  const __m128i b3 = _mm_unpacklo_epi32(a6, a7);
-  const __m128i b4 = _mm_unpackhi_epi32(a0, a1);
-  const __m128i b5 = _mm_unpackhi_epi32(a2, a3);
-  const __m128i b6 = _mm_unpackhi_epi32(a4, a5);
-  const __m128i b7 = _mm_unpackhi_epi32(a6, a7);
-
-  // Unpack 64 bit elements resulting in:
-  // out[0]: 00 10 20 30  40 50 60 70
-  // out[1]: 01 11 21 31  41 51 61 71
-  // out[2]: 02 12 22 32  42 52 62 72
-  // out[3]: 03 13 23 33  43 53 63 73
-  // out[4]: 04 14 24 34  44 54 64 74
-  // out[5]: 05 15 25 35  45 55 65 75
-  // out[6]: 06 16 26 36  46 56 66 76
-  // out[7]: 07 17 27 37  47 57 67 77
-  out[0] = _mm_unpacklo_epi64(b0, b1);
-  out[1] = _mm_unpackhi_epi64(b0, b1);
-  out[2] = _mm_unpacklo_epi64(b4, b5);
-  out[3] = _mm_unpackhi_epi64(b4, b5);
-  out[4] = _mm_unpacklo_epi64(b2, b3);
-  out[5] = _mm_unpackhi_epi64(b2, b3);
-  out[6] = _mm_unpacklo_epi64(b6, b7);
-  out[7] = _mm_unpackhi_epi64(b6, b7);
-}
-
 template <int store_width, int store_count>
 LIBGAV1_ALWAYS_INLINE void StoreDst(int16_t* dst, int32_t stride, int32_t idx,
                                     const __m128i* s) {
@@ -385,14 +217,14 @@
     if (transpose) {
       __m128i input[8];
       LoadSrc<8, 8>(src, step, 0, input);
-      Transpose16_4x8To8x4(input, x);
+      Transpose4x8To8x4_U16(input, x);
     } else {
       LoadSrc<16, 4>(src, step, 0, x);
     }
   } else {
     LoadSrc<8, 4>(src, step, 0, x);
     if (transpose) {
-      Transpose16_4x4(x, x);
+      Transpose4x4_U16(x, x);
     }
   }
   // stage 1.
@@ -407,14 +239,14 @@
   if (stage_is_rectangular) {
     if (transpose) {
       __m128i output[8];
-      Transpose16_8x4To4x8(s, output);
+      Transpose8x4To4x8_U16(s, output);
       StoreDst<8, 8>(dst, step, 0, output);
     } else {
       StoreDst<16, 4>(dst, step, 0, s);
     }
   } else {
     if (transpose) {
-      Transpose16_4x4(s, s);
+      Transpose4x4_U16(s, s);
     }
     StoreDst<8, 4>(dst, step, 0, s);
   }
@@ -458,7 +290,7 @@
     if (transpose) {
       __m128i input[4];
       LoadSrc<16, 4>(src, step, 0, input);
-      Transpose16_8x4To4x8(input, x);
+      Transpose8x4To4x8_U16(input, x);
     } else {
       LoadSrc<8, 8>(src, step, 0, x);
     }
@@ -466,7 +298,7 @@
     if (transpose) {
       __m128i input[8];
       LoadSrc<16, 8>(src, step, 0, input);
-      Transpose16_8x8(input, x);
+      Transpose8x8_U16(input, x);
     } else {
       LoadSrc<16, 8>(src, step, 0, x);
     }
@@ -489,7 +321,7 @@
   if (stage_is_rectangular) {
     if (transpose) {
       __m128i output[4];
-      Transpose16_4x8To8x4(s, output);
+      Transpose4x8To8x4_U16(s, output);
       StoreDst<16, 4>(dst, step, 0, output);
     } else {
       StoreDst<8, 8>(dst, step, 0, s);
@@ -497,7 +329,7 @@
   } else {
     if (transpose) {
       __m128i output[8];
-      Transpose16_8x8(s, output);
+      Transpose8x8_U16(s, output);
       StoreDst<16, 8>(dst, step, 0, output);
     } else {
       StoreDst<16, 8>(dst, step, 0, s);
@@ -564,9 +396,9 @@
     if (transpose) {
       __m128i input[4];
       LoadSrc<16, 4>(src, step, 0, input);
-      Transpose16_8x4To4x8(input, x);
+      Transpose8x4To4x8_U16(input, x);
       LoadSrc<16, 4>(src, step, 8, input);
-      Transpose16_8x4To4x8(input, &x[8]);
+      Transpose8x4To4x8_U16(input, &x[8]);
     } else {
       LoadSrc<8, 16>(src, step, 0, x);
     }
@@ -575,7 +407,7 @@
       for (int idx = 0; idx < 16; idx += 8) {
         __m128i input[8];
         LoadSrc<16, 8>(src, step, idx, input);
-        Transpose16_8x8(input, &x[idx]);
+        Transpose8x8_U16(input, &x[idx]);
       }
     } else {
       LoadSrc<16, 16>(src, step, 0, x);
@@ -608,9 +440,9 @@
   if (stage_is_rectangular) {
     if (transpose) {
       __m128i output[4];
-      Transpose16_4x8To8x4(s, output);
+      Transpose4x8To8x4_U16(s, output);
       StoreDst<16, 4>(dst, step, 0, output);
-      Transpose16_4x8To8x4(&s[8], output);
+      Transpose4x8To8x4_U16(&s[8], output);
       StoreDst<16, 4>(dst, step, 8, output);
     } else {
       StoreDst<8, 16>(dst, step, 0, s);
@@ -619,7 +451,7 @@
     if (transpose) {
       for (int idx = 0; idx < 16; idx += 8) {
         __m128i output[8];
-        Transpose16_8x8(&s[idx], output);
+        Transpose8x8_U16(&s[idx], output);
         StoreDst<16, 8>(dst, step, idx, output);
       }
     } else {
@@ -730,7 +562,7 @@
     for (int idx = 0; idx < 32; idx += 8) {
       __m128i input[8];
       LoadSrc<16, 8>(src, step, idx, input);
-      Transpose16_8x8(input, &x[idx]);
+      Transpose8x8_U16(input, &x[idx]);
     }
   } else {
     LoadSrc<16, 32>(src, step, 0, x);
@@ -782,7 +614,7 @@
   if (transpose) {
     for (int idx = 0; idx < 32; idx += 8) {
       __m128i output[8];
-      Transpose16_8x8(&s[idx], output);
+      Transpose8x8_U16(&s[idx], output);
       StoreDst<16, 8>(dst, step, idx, output);
     }
   } else {
@@ -804,7 +636,7 @@
     for (int idx = 0; idx < 32; idx += 8) {
       __m128i input[8];
       LoadSrc<16, 8>(src, step, idx, input);
-      Transpose16_8x8(input, &x[idx]);
+      Transpose8x8_U16(input, &x[idx]);
     }
   } else {
     // The last 32 values of every column are always zero if the |tx_height| is
@@ -1002,7 +834,7 @@
   if (transpose) {
     for (int idx = 0; idx < 64; idx += 8) {
       __m128i output[8];
-      Transpose16_8x8(&s[idx], output);
+      Transpose8x8_U16(&s[idx], output);
       StoreDst<16, 8>(dst, step, idx, output);
     }
   } else {
@@ -1024,14 +856,14 @@
     if (transpose) {
       __m128i input[8];
       LoadSrc<8, 8>(src, step, 0, input);
-      Transpose16_4x8To8x4(input, x);
+      Transpose4x8To8x4_U16(input, x);
     } else {
       LoadSrc<16, 4>(src, step, 0, x);
     }
   } else {
     LoadSrc<8, 4>(src, step, 0, x);
     if (transpose) {
-      Transpose16_4x4(x, x);
+      Transpose4x4_U16(x, x);
     }
   }
 
@@ -1090,14 +922,14 @@
   if (stage_is_rectangular) {
     if (transpose) {
       __m128i output[8];
-      Transpose16_8x4To4x8(x, output);
+      Transpose8x4To4x8_U16(x, output);
       StoreDst<8, 8>(dst, step, 0, output);
     } else {
       StoreDst<16, 4>(dst, step, 0, x);
     }
   } else {
     if (transpose) {
-      Transpose16_4x4(x, x);
+      Transpose4x4_U16(x, x);
     }
     StoreDst<8, 4>(dst, step, 0, x);
   }
@@ -1114,7 +946,7 @@
     if (transpose) {
       __m128i input[4];
       LoadSrc<16, 4>(src, step, 0, input);
-      Transpose16_8x4To4x8(input, x);
+      Transpose8x4To4x8_U16(input, x);
     } else {
       LoadSrc<8, 8>(src, step, 0, x);
     }
@@ -1122,7 +954,7 @@
     if (transpose) {
       __m128i input[8];
       LoadSrc<16, 8>(src, step, 0, input);
-      Transpose16_8x8(input, x);
+      Transpose8x8_U16(input, x);
     } else {
       LoadSrc<16, 8>(src, step, 0, x);
     }
@@ -1178,7 +1010,7 @@
   if (stage_is_rectangular) {
     if (transpose) {
       __m128i output[4];
-      Transpose16_4x8To8x4(x, output);
+      Transpose4x8To8x4_U16(x, output);
       StoreDst<16, 4>(dst, step, 0, output);
     } else {
       StoreDst<8, 8>(dst, step, 0, x);
@@ -1186,7 +1018,7 @@
   } else {
     if (transpose) {
       __m128i output[8];
-      Transpose16_8x8(x, output);
+      Transpose8x8_U16(x, output);
       StoreDst<16, 8>(dst, step, 0, output);
     } else {
       StoreDst<16, 8>(dst, step, 0, x);
@@ -1205,9 +1037,9 @@
     if (transpose) {
       __m128i input[4];
       LoadSrc<16, 4>(src, step, 0, input);
-      Transpose16_8x4To4x8(input, x);
+      Transpose8x4To4x8_U16(input, x);
       LoadSrc<16, 4>(src, step, 8, input);
-      Transpose16_8x4To4x8(input, &x[8]);
+      Transpose8x4To4x8_U16(input, &x[8]);
     } else {
       LoadSrc<8, 16>(src, step, 0, x);
     }
@@ -1216,7 +1048,7 @@
       for (int idx = 0; idx < 16; idx += 8) {
         __m128i input[8];
         LoadSrc<16, 8>(src, step, idx, input);
-        Transpose16_8x8(input, &x[idx]);
+        Transpose8x8_U16(input, &x[idx]);
       }
     } else {
       LoadSrc<16, 16>(src, step, 0, x);
@@ -1321,9 +1153,9 @@
   if (stage_is_rectangular) {
     if (transpose) {
       __m128i output[4];
-      Transpose16_4x8To8x4(x, output);
+      Transpose4x8To8x4_U16(x, output);
       StoreDst<16, 4>(dst, step, 0, output);
-      Transpose16_4x8To8x4(&x[8], output);
+      Transpose4x8To8x4_U16(&x[8], output);
       StoreDst<16, 4>(dst, step, 8, output);
     } else {
       StoreDst<8, 16>(dst, step, 0, x);
@@ -1332,7 +1164,7 @@
     if (transpose) {
       for (int idx = 0; idx < 16; idx += 8) {
         __m128i output[8];
-        Transpose16_8x8(&x[idx], output);
+        Transpose8x8_U16(&x[idx], output);
         StoreDst<16, 8>(dst, step, idx, output);
       }
     } else {
@@ -1343,11 +1175,6 @@
 
 //------------------------------------------------------------------------------
 // Identity Transforms.
-constexpr int16_t kIdentity4Multiplier /* round(2^12 * sqrt(2)) */ = 0x16A1;
-constexpr int16_t kIdentity4MultiplierFraction /* round(2^12 * (sqrt(2) - 1))*/
-    = 0x6a1;
-constexpr int16_t kIdentity16Multiplier /* 2 * round(2^12 * sqrt(2)) */ = 11586;
-constexpr int16_t kTransformRowMultiplier /* round(2^12 / sqrt(2)) */ = 2896;
 
 template <bool is_row_shift>
 LIBGAV1_ALWAYS_INLINE void Identity4_SSE4_1(void* dest, const void* source,
@@ -1408,7 +1235,8 @@
   } else {
     for (int i = 0; i < tx_height; ++i) {
       const int row = i * tx_width;
-      for (int j = 0; j < tx_width; j += 8) {
+      int j = 0;
+      do {
         const __m128i v_src = LoadUnaligned16(&source[row + j]);
         const __m128i v_src_mult =
             _mm_mulhrs_epi16(v_src, v_multiplier_fraction);
@@ -1419,7 +1247,8 @@
         const __m128i c = _mm_cvtepu8_epi16(frame_data);
         const __m128i d = _mm_adds_epi16(c, b);
         StoreLo8(dst + j, _mm_packus_epi16(d, d));
-      }
+        j += 8;
+      } while (j < tx_width);
       dst += stride;
     }
   }
@@ -1456,7 +1285,8 @@
   } else {
     for (int i = 0; i < tx_height; ++i) {
       const int row = i * tx_width;
-      for (int j = 0; j < tx_width; j += 8) {
+      int j = 0;
+      do {
         const __m128i v_src = LoadUnaligned16(&source[row + j]);
         const __m128i v_src_round =
             _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier);
@@ -1470,7 +1300,8 @@
         const __m128i b = _mm_srai_epi16(a, 4);
         const __m128i c = _mm_adds_epi16(frame_data16, b);
         StoreLo8(dst + j, _mm_packus_epi16(c, c));
-      }
+        j += 8;
+      } while (j < tx_width);
       dst += stride;
     }
   }
@@ -1528,7 +1359,8 @@
   } else {
     for (int i = 0; i < tx_height; ++i) {
       const int row = i * tx_width;
-      for (int j = 0; j < tx_width; j += 8) {
+      int j = 0;
+      do {
         const __m128i v_src = LoadUnaligned16(&source[row + j]);
         const __m128i v_dst_i = _mm_adds_epi16(v_src, v_src);
         const __m128i frame_data = LoadLo8(dst + j);
@@ -1537,7 +1369,8 @@
         const __m128i c = _mm_cvtepu8_epi16(frame_data);
         const __m128i d = _mm_adds_epi16(c, b);
         StoreLo8(dst + j, _mm_packus_epi16(d, d));
-      }
+        j += 8;
+      } while (j < tx_width);
       dst += stride;
     }
   }
@@ -1599,7 +1432,8 @@
   } else {
     for (int i = 0; i < tx_height; ++i) {
       const int row = i * tx_width;
-      for (int j = 0; j < tx_width; j += 8) {
+      int j = 0;
+      do {
         const __m128i v_src = LoadUnaligned16(&source[row + j]);
         const __m128i v_src_mult = _mm_mulhrs_epi16(v_src, v_multiplier);
         const __m128i frame_data = LoadLo8(dst + j);
@@ -1610,7 +1444,8 @@
         const __m128i c = _mm_cvtepu8_epi16(frame_data);
         const __m128i d = _mm_adds_epi16(c, b);
         StoreLo8(dst + j, _mm_packus_epi16(d, d));
-      }
+        j += 8;
+      } while (j < tx_width);
       dst += stride;
     }
   }
@@ -1645,7 +1480,8 @@
 
   for (int i = 0; i < tx_height; ++i) {
     const int row = i * tx_width;
-    for (int j = 0; j < tx_width; j += 8) {
+    int j = 0;
+    do {
       const __m128i v_dst_i = LoadUnaligned16(&source[row + j]);
       const __m128i frame_data = LoadLo8(dst + j);
       const __m128i a = _mm_adds_epi16(v_dst_i, v_two);
@@ -1653,7 +1489,88 @@
       const __m128i c = _mm_cvtepu8_epi16(frame_data);
       const __m128i d = _mm_adds_epi16(c, b);
       StoreLo8(dst + j, _mm_packus_epi16(d, d));
-    }
+      j += 8;
+    } while (j < tx_width);
+    dst += stride;
+  }
+}
+
+//------------------------------------------------------------------------------
+// Walsh Hadamard Transform.
+
+// Process 4 wht4 rows and columns.
+LIBGAV1_ALWAYS_INLINE void Wht4_SSE4_1(Array2DView<uint8_t> frame,
+                                       const int start_x, const int start_y,
+                                       const void* source,
+                                       const int non_zero_coeff_count) {
+  const auto* const src = static_cast<const int16_t*>(source);
+  __m128i s[4], x[4];
+
+  if (non_zero_coeff_count == 1) {
+    // Special case: only src[0] is nonzero.
+    //   src[0]  0   0   0
+    //       0   0   0   0
+    //       0   0   0   0
+    //       0   0   0   0
+    //
+    // After the row and column transforms are applied, we have:
+    //       f   h   h   h
+    //       g   i   i   i
+    //       g   i   i   i
+    //       g   i   i   i
+    // where f, g, h, i are computed as follows.
+    int16_t f = (src[0] >> 2) - (src[0] >> 3);
+    const int16_t g = f >> 1;
+    f = f - (f >> 1);
+    const int16_t h = (src[0] >> 3) - (src[0] >> 4);
+    const int16_t i = (src[0] >> 4);
+    s[0] = _mm_set1_epi16(h);
+    s[0] = _mm_insert_epi16(s[0], f, 0);
+    s[1] = _mm_set1_epi16(i);
+    s[1] = _mm_insert_epi16(s[1], g, 0);
+    s[2] = s[3] = s[1];
+  } else {
+    x[0] = LoadLo8(&src[0 * 4]);
+    x[2] = LoadLo8(&src[1 * 4]);
+    x[3] = LoadLo8(&src[2 * 4]);
+    x[1] = LoadLo8(&src[3 * 4]);
+
+    // Row transforms.
+    Transpose4x4_U16(x, x);
+    s[0] = _mm_srai_epi16(x[0], 2);
+    s[2] = _mm_srai_epi16(x[1], 2);
+    s[3] = _mm_srai_epi16(x[2], 2);
+    s[1] = _mm_srai_epi16(x[3], 2);
+    s[0] = _mm_add_epi16(s[0], s[2]);
+    s[3] = _mm_sub_epi16(s[3], s[1]);
+    __m128i e = _mm_sub_epi16(s[0], s[3]);
+    e = _mm_srai_epi16(e, 1);
+    s[1] = _mm_sub_epi16(e, s[1]);
+    s[2] = _mm_sub_epi16(e, s[2]);
+    s[0] = _mm_sub_epi16(s[0], s[1]);
+    s[3] = _mm_add_epi16(s[3], s[2]);
+    Transpose4x4_U16(s, s);
+
+    // Column transforms.
+    s[0] = _mm_add_epi16(s[0], s[2]);
+    s[3] = _mm_sub_epi16(s[3], s[1]);
+    e = _mm_sub_epi16(s[0], s[3]);
+    e = _mm_srai_epi16(e, 1);
+    s[1] = _mm_sub_epi16(e, s[1]);
+    s[2] = _mm_sub_epi16(e, s[2]);
+    s[0] = _mm_sub_epi16(s[0], s[1]);
+    s[3] = _mm_add_epi16(s[3], s[2]);
+  }
+
+  // Store to frame.
+  const int stride = frame.columns();
+  uint8_t* dst = frame[start_y] + start_x;
+  for (int row = 0; row < 4; ++row) {
+    const __m128i frame_data = Load4(dst);
+    const __m128i a = _mm_cvtepu8_epi16(frame_data);
+    // Saturate to prevent overflowing int16_t
+    const __m128i b = _mm_adds_epi16(a, s[row]);
+    Store4(dst, _mm_packus_epi16(b, b));
     dst += stride;
   }
 }
@@ -1661,21 +1578,13 @@
 //------------------------------------------------------------------------------
 // row/column transform loops
 
-constexpr uint8_t kTransformRowShift[kNumTransformSizes] = {
-    0, 0, 1, 0, 1, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2};
-
-constexpr bool kShouldRound[kNumTransformSizes] = {
-    false, true,  false, true, false, true, false, false, true, false,
-    true,  false, false, true, false, true, false, true,  false};
-
 template <bool enable_flip_rows = false>
 LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
     const int tx_width, const int tx_height, const int16_t* source,
     TransformType tx_type) {
-  const bool flip_rows = enable_flip_rows
-                             ? ((1U << tx_type) & kTransformFlipRowsMask) != 0
-                             : false;
+  const bool flip_rows =
+      enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false;
   const __m128i v_eight = _mm_set1_epi16(8);
   const int stride = frame.columns();
   uint8_t* dst = frame[start_y] + start_x;
@@ -1706,11 +1615,11 @@
       dst += stride;
     }
   } else {
-    const __m128i v_eight = _mm_set1_epi16(8);
     for (int i = 0; i < tx_height; ++i) {
       const int y = start_y + i;
       const int row = flip_rows ? (tx_height - i - 1) * tx_width : i * tx_width;
-      for (int j = 0; j < tx_width; j += 16) {
+      int j = 0;
+      do {
         const int x = start_x + j;
         const __m128i residual = LoadUnaligned16(&source[row + j]);
         const __m128i residual_hi = LoadUnaligned16(&source[row + j + 8]);
@@ -1724,7 +1633,8 @@
         const __m128i e = _mm_adds_epi16(d, c);
         const __m128i e_hi = _mm_adds_epi16(d_hi, c_hi);
         StoreUnaligned16(frame[y] + x, _mm_packus_epi16(e, e_hi));
-      }
+        j += 16;
+      } while (j < tx_width);
     }
   }
 }
@@ -1734,7 +1644,8 @@
   const __m128i word_reverse_8 =
       _mm_set_epi32(0x01000302, 0x05040706, 0x09080b0a, 0x0d0c0f0e);
   if (tx_width >= 16) {
-    for (int i = 0; i < tx_width * tx_height; i += 16) {
+    int i = 0;
+    do {
       // read 16 shorts
       const __m128i v3210 = LoadUnaligned16(&source[i]);
       const __m128i v7654 = LoadUnaligned16(&source[i + 8]);
@@ -1742,7 +1653,8 @@
       const __m128i v4567 = _mm_shuffle_epi8(v7654, word_reverse_8);
       StoreUnaligned16(&source[i], v4567);
       StoreUnaligned16(&source[i + 8], v0123);
-    }
+      i += 16;
+    } while (i < tx_width * tx_height);
   } else if (tx_width == 8) {
     for (int i = 0; i < 8 * tx_height; i += 8) {
       const __m128i a = LoadUnaligned16(&source[i]);
@@ -1762,27 +1674,32 @@
 }
 
 template <int tx_width>
-LIBGAV1_ALWAYS_INLINE void ShouldRound(int16_t* source, int num_rows) {
+LIBGAV1_ALWAYS_INLINE void ApplyRounding(int16_t* source, int num_rows) {
   const __m128i v_kTransformRowMultiplier =
       _mm_set1_epi16(kTransformRowMultiplier << 3);
   if (tx_width == 4) {
     // Process two rows per iteration.
-    for (int i = 0; i < tx_width * num_rows; i += 8) {
+    int i = 0;
+    do {
       const __m128i a = LoadUnaligned16(&source[i]);
       const __m128i b = _mm_mulhrs_epi16(a, v_kTransformRowMultiplier);
       StoreUnaligned16(&source[i], b);
-    }
+      i += 8;
+    } while (i < tx_width * num_rows);
   } else {
-    for (int i = 0; i < num_rows; ++i) {
+    int i = 0;
+    do {
       // The last 32 values of every row are always zero if the |tx_width| is
       // 64.
       const int non_zero_width = (tx_width < 64) ? tx_width : 32;
-      for (int j = 0; j < non_zero_width; j += 8) {
+      int j = 0;
+      do {
         const __m128i a = LoadUnaligned16(&source[i * tx_width + j]);
         const __m128i b = _mm_mulhrs_epi16(a, v_kTransformRowMultiplier);
         StoreUnaligned16(&source[i * tx_width + j], b);
-      }
-    }
+        j += 8;
+      } while (j < non_zero_width);
+    } while (++i < num_rows);
   }
 }
 
@@ -1793,7 +1710,8 @@
   const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add);
   if (tx_width == 4) {
     // Process two rows per iteration.
-    for (int i = 0; i < tx_width * num_rows; i += 8) {
+    int i = 0;
+    do {
       // Expand to 32 bits to prevent int16_t overflows during the shift add.
       const __m128i residual = LoadUnaligned16(&source[i]);
       const __m128i a = _mm_cvtepi16_epi32(residual);
@@ -1803,9 +1721,11 @@
       const __m128i c = _mm_sra_epi32(b, v_row_shift);
       const __m128i c1 = _mm_sra_epi32(b1, v_row_shift);
       StoreUnaligned16(&source[i], _mm_packs_epi32(c, c1));
-    }
+      i += 8;
+    } while (i < tx_width * num_rows);
   } else {
-    for (int i = 0; i < num_rows; ++i) {
+    int i = 0;
+    do {
       for (int j = 0; j < tx_width; j += 8) {
         // Expand to 32 bits to prevent int16_t overflows during the shift add.
         const __m128i residual = LoadUnaligned16(&source[i * tx_width + j]);
@@ -1817,14 +1737,14 @@
         const __m128i c1 = _mm_sra_epi32(b1, v_row_shift);
         StoreUnaligned16(&source[i * tx_width + j], _mm_packs_epi32(c, c1));
       }
-    }
+    } while (++i < num_rows);
   }
 }
 
 void Dct4TransformLoop_SSE4_1(TransformType tx_type, TransformSize tx_size,
-                              int8_t /*bitdepth*/, void* src_buffer,
-                              int start_x, int start_y, void* dst_frame,
-                              bool is_row, int non_zero_coeff_count) {
+                              void* src_buffer, int start_x, int start_y,
+                              void* dst_frame, bool is_row,
+                              int non_zero_coeff_count) {
   auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -1834,7 +1754,7 @@
     const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
     const bool should_round = (tx_height == 8);
     if (should_round) {
-      ShouldRound<4>(src, num_rows);
+      ApplyRounding<4>(src, num_rows);
     }
 
     if (num_rows <= 4) {
@@ -1843,10 +1763,12 @@
                                               /*transpose=*/true);
     } else {
       // Process 8 1d dct4 rows in parallel per iteration.
-      for (int i = 0; i < num_rows; i += 8) {
+      int i = 0;
+      do {
         Dct4_SSE4_1<ButterflyRotation_8, true>(&src[i * 4], &src[i * 4],
                                                /*step=*/4, /*transpose=*/true);
-      }
+        i += 8;
+      } while (i < num_rows);
     }
     if (tx_height == 16) {
       RowShift<4>(src, num_rows, 1);
@@ -1855,8 +1777,7 @@
   }
 
   assert(!is_row);
-  const bool flip_columns = ((1U << tx_type) & kTransformFlipColumnsMask) != 0;
-  if (flip_columns) {
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
     FlipColumns<4>(src, tx_width);
   }
 
@@ -1866,18 +1787,20 @@
                                             /*transpose=*/false);
   } else {
     // Process 8 1d dct4 columns in parallel per iteration.
-    for (int i = 0; i < tx_width; i += 8) {
+    int i = 0;
+    do {
       Dct4_SSE4_1<ButterflyRotation_8, true>(&src[i], &src[i], tx_width,
                                              /*transpose=*/false);
-    }
+      i += 8;
+    } while (i < tx_width);
   }
   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 4, src, tx_type);
 }
 
 void Dct8TransformLoop_SSE4_1(TransformType tx_type, TransformSize tx_size,
-                              int8_t /*bitdepth*/, void* src_buffer,
-                              int start_x, int start_y, void* dst_frame,
-                              bool is_row, int non_zero_coeff_count) {
+                              void* src_buffer, int start_x, int start_y,
+                              void* dst_frame, bool is_row,
+                              int non_zero_coeff_count) {
   auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -1886,7 +1809,7 @@
   if (is_row) {
     const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
     if (kShouldRound[tx_size]) {
-      ShouldRound<8>(src, num_rows);
+      ApplyRounding<8>(src, num_rows);
     }
 
     if (num_rows <= 4) {
@@ -1895,10 +1818,12 @@
                                              /*transpose=*/true);
     } else {
       // Process 8 1d dct8 rows in parallel per iteration.
-      for (int i = 0; i < num_rows; i += 8) {
+      int i = 0;
+      do {
         Dct8_SSE4_1<ButterflyRotation_8, false>(&src[i * 8], &src[i * 8],
                                                 /*step=*/8, /*transpose=*/true);
-      }
+        i += 8;
+      } while (i < num_rows);
     }
     const uint8_t row_shift = kTransformRowShift[tx_size];
     if (row_shift > 0) {
@@ -1908,8 +1833,7 @@
   }
 
   assert(!is_row);
-  const bool flip_columns = ((1U << tx_type) & kTransformFlipColumnsMask) != 0;
-  if (flip_columns) {
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
     FlipColumns<8>(src, tx_width);
   }
 
@@ -1919,18 +1843,20 @@
                                            /*transpose=*/false);
   } else {
     // Process 8 1d dct8 columns in parallel per iteration.
-    for (int i = 0; i < tx_width; i += 8) {
+    int i = 0;
+    do {
       Dct8_SSE4_1<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
                                               /*transpose=*/false);
-    }
+      i += 8;
+    } while (i < tx_width);
   }
   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 8, src, tx_type);
 }
 
 void Dct16TransformLoop_SSE4_1(TransformType tx_type, TransformSize tx_size,
-                               int8_t /*bitdepth*/, void* src_buffer,
-                               int start_x, int start_y, void* dst_frame,
-                               bool is_row, int non_zero_coeff_count) {
+                               void* src_buffer, int start_x, int start_y,
+                               void* dst_frame, bool is_row,
+                               int non_zero_coeff_count) {
   auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -1940,7 +1866,7 @@
     const int num_rows =
         (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
     if (kShouldRound[tx_size]) {
-      ShouldRound<16>(src, num_rows);
+      ApplyRounding<16>(src, num_rows);
     }
 
     if (num_rows <= 4) {
@@ -1948,11 +1874,13 @@
       Dct16_SSE4_1<ButterflyRotation_4, true>(&src[0], &src[0], 16,
                                               /*transpose=*/true);
     } else {
-      for (int i = 0; i < num_rows; i += 8) {
+      int i = 0;
+      do {
         // Process 8 1d dct16 rows in parallel per iteration.
         Dct16_SSE4_1<ButterflyRotation_8, false>(&src[i * 16], &src[i * 16], 16,
                                                  /*transpose=*/true);
-      }
+        i += 8;
+      } while (i < num_rows);
     }
     const uint8_t row_shift = kTransformRowShift[tx_size];
     // row_shift is always non zero here.
@@ -1962,8 +1890,7 @@
   }
 
   assert(!is_row);
-  const bool flip_columns = ((1U << tx_type) & kTransformFlipColumnsMask) != 0;
-  if (flip_columns) {
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
     FlipColumns<16>(src, tx_width);
   }
 
@@ -1972,19 +1899,21 @@
     Dct16_SSE4_1<ButterflyRotation_4, true>(&src[0], &src[0], 4,
                                             /*transpose=*/false);
   } else {
-    for (int i = 0; i < tx_width; i += 8) {
+    int i = 0;
+    do {
       // Process 8 1d dct16 columns in parallel per iteration.
       Dct16_SSE4_1<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
                                                /*transpose=*/false);
-    }
+      i += 8;
+    } while (i < tx_width);
   }
   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 16, src, tx_type);
 }
 
 void Dct32TransformLoop_SSE4_1(TransformType tx_type, TransformSize tx_size,
-                               int8_t /*bitdepth*/, void* src_buffer,
-                               int start_x, int start_y, void* dst_frame,
-                               bool is_row, int non_zero_coeff_count) {
+                               void* src_buffer, int start_x, int start_y,
+                               void* dst_frame, bool is_row,
+                               int non_zero_coeff_count) {
   auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -1994,12 +1923,14 @@
     const int num_rows =
         (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
     if (kShouldRound[tx_size]) {
-      ShouldRound<32>(src, num_rows);
+      ApplyRounding<32>(src, num_rows);
     }
     // Process 8 1d dct32 rows in parallel per iteration.
-    for (int i = 0; i < num_rows; i += 8) {
+    int i = 0;
+    do {
       Dct32_SSE4_1(&src[i * 32], &src[i * 32], 32, /*transpose=*/true);
-    }
+      i += 8;
+    } while (i < num_rows);
     const uint8_t row_shift = kTransformRowShift[tx_size];
     // row_shift is always non zero here.
     RowShift<32>(src, num_rows, row_shift);
@@ -2009,16 +1940,18 @@
 
   assert(!is_row);
   // Process 8 1d dct32 columns in parallel per iteration.
-  for (int i = 0; i < tx_width; i += 8) {
+  int i = 0;
+  do {
     Dct32_SSE4_1(&src[i], &src[i], tx_width, /*transpose=*/false);
-  }
+    i += 8;
+  } while (i < tx_width);
   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 32, src, tx_type);
 }
 
 void Dct64TransformLoop_SSE4_1(TransformType tx_type, TransformSize tx_size,
-                               int8_t /*bitdepth*/, void* src_buffer,
-                               int start_x, int start_y, void* dst_frame,
-                               bool is_row, int non_zero_coeff_count) {
+                               void* src_buffer, int start_x, int start_y,
+                               void* dst_frame, bool is_row,
+                               int non_zero_coeff_count) {
   auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -2028,12 +1961,14 @@
     const int num_rows =
         (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
     if (kShouldRound[tx_size]) {
-      ShouldRound<64>(src, num_rows);
+      ApplyRounding<64>(src, num_rows);
     }
     // Process 8 1d dct64 rows in parallel per iteration.
-    for (int i = 0; i < num_rows; i += 8) {
+    int i = 0;
+    do {
       Dct64_SSE4_1(&src[i * 64], &src[i * 64], 64, /*transpose=*/true);
-    }
+      i += 8;
+    } while (i < num_rows);
     const uint8_t row_shift = kTransformRowShift[tx_size];
     // row_shift is always non zero here.
     RowShift<64>(src, num_rows, row_shift);
@@ -2043,16 +1978,18 @@
 
   assert(!is_row);
   // Process 8 1d dct64 columns in parallel per iteration.
-  for (int i = 0; i < tx_width; i += 8) {
+  int i = 0;
+  do {
     Dct64_SSE4_1(&src[i], &src[i], tx_width, /*transpose=*/false);
-  }
+    i += 8;
+  } while (i < tx_width);
   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 64, src, tx_type);
 }
 
 void Adst4TransformLoop_SSE4_1(TransformType tx_type, TransformSize tx_size,
-                               int8_t /*bitdepth*/, void* src_buffer,
-                               int start_x, int start_y, void* dst_frame,
-                               bool is_row, int non_zero_coeff_count) {
+                               void* src_buffer, int start_x, int start_y,
+                               void* dst_frame, bool is_row,
+                               int non_zero_coeff_count) {
   auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -2062,14 +1999,16 @@
     const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
     const bool should_round = (tx_height == 8);
     if (should_round) {
-      ShouldRound<4>(src, num_rows);
+      ApplyRounding<4>(src, num_rows);
     }
 
     // Process 4 1d adst4 rows in parallel per iteration.
-    for (int i = 0; i < num_rows; i += 4) {
+    int i = 0;
+    do {
       Adst4_SSE4_1<false>(&src[i * 4], &src[i * 4], /*step=*/4,
                           /*transpose=*/true);
-    }
+      i += 4;
+    } while (i < num_rows);
 
     if (tx_height == 16) {
       RowShift<4>(src, num_rows, 1);
@@ -2078,24 +2017,25 @@
   }
 
   assert(!is_row);
-  const bool flip_columns = ((1U << tx_type) & kTransformFlipColumnsMask) != 0;
-  if (flip_columns) {
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
     FlipColumns<4>(src, tx_width);
   }
 
   // Process 4 1d adst4 columns in parallel per iteration.
-  for (int i = 0; i < tx_width; i += 4) {
+  int i = 0;
+  do {
     Adst4_SSE4_1<false>(&src[i], &src[i], tx_width, /*transpose=*/false);
-  }
+    i += 4;
+  } while (i < tx_width);
 
   StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
                                                    tx_width, 4, src, tx_type);
 }
 
 void Adst8TransformLoop_SSE4_1(TransformType tx_type, TransformSize tx_size,
-                               int8_t /*bitdepth*/, void* src_buffer,
-                               int start_x, int start_y, void* dst_frame,
-                               bool is_row, int non_zero_coeff_count) {
+                               void* src_buffer, int start_x, int start_y,
+                               void* dst_frame, bool is_row,
+                               int non_zero_coeff_count) {
   auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -2104,7 +2044,7 @@
   if (is_row) {
     const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
     if (kShouldRound[tx_size]) {
-      ShouldRound<8>(src, num_rows);
+      ApplyRounding<8>(src, num_rows);
     }
 
     if (num_rows <= 4) {
@@ -2113,11 +2053,13 @@
                                               /*transpose=*/true);
     } else {
       // Process 8 1d adst8 rows in parallel per iteration.
-      for (int i = 0; i < num_rows; i += 8) {
+      int i = 0;
+      do {
         Adst8_SSE4_1<ButterflyRotation_8, false>(&src[i * 8], &src[i * 8],
                                                  /*step=*/8,
                                                  /*transpose=*/true);
-      }
+        i += 8;
+      } while (i < num_rows);
     }
     const uint8_t row_shift = kTransformRowShift[tx_size];
     if (row_shift > 0) {
@@ -2127,8 +2069,7 @@
   }
 
   assert(!is_row);
-  const bool flip_columns = ((1U << tx_type) & kTransformFlipColumnsMask) != 0;
-  if (flip_columns) {
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
     FlipColumns<8>(src, tx_width);
   }
 
@@ -2138,19 +2079,21 @@
                                             /*transpose=*/false);
   } else {
     // Process 8 1d adst8 columns in parallel per iteration.
-    for (int i = 0; i < tx_width; i += 8) {
+    int i = 0;
+    do {
       Adst8_SSE4_1<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
                                                /*transpose=*/false);
-    }
+      i += 8;
+    } while (i < tx_width);
   }
   StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
                                                    tx_width, 8, src, tx_type);
 }
 
 void Adst16TransformLoop_SSE4_1(TransformType tx_type, TransformSize tx_size,
-                                int8_t /*bitdepth*/, void* src_buffer,
-                                int start_x, int start_y, void* dst_frame,
-                                bool is_row, int non_zero_coeff_count) {
+                                void* src_buffer, int start_x, int start_y,
+                                void* dst_frame, bool is_row,
+                                int non_zero_coeff_count) {
   auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -2160,7 +2103,7 @@
     const int num_rows =
         (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
     if (kShouldRound[tx_size]) {
-      ShouldRound<16>(src, num_rows);
+      ApplyRounding<16>(src, num_rows);
     }
 
     if (num_rows <= 4) {
@@ -2168,11 +2111,13 @@
       Adst16_SSE4_1<ButterflyRotation_4, true>(&src[0], &src[0], 16,
                                                /*transpose=*/true);
     } else {
-      for (int i = 0; i < num_rows; i += 8) {
+      int i = 0;
+      do {
         // Process 8 1d adst16 rows in parallel per iteration.
         Adst16_SSE4_1<ButterflyRotation_8, false>(&src[i * 16], &src[i * 16],
                                                   16, /*transpose=*/true);
-      }
+        i += 8;
+      } while (i < num_rows);
     }
     const uint8_t row_shift = kTransformRowShift[tx_size];
     // row_shift is always non zero here.
@@ -2182,8 +2127,7 @@
   }
 
   assert(!is_row);
-  const bool flip_columns = ((1U << tx_type) & kTransformFlipColumnsMask) != 0;
-  if (flip_columns) {
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
     FlipColumns<16>(src, tx_width);
   }
 
@@ -2192,20 +2136,22 @@
     Adst16_SSE4_1<ButterflyRotation_4, true>(&src[0], &src[0], 4,
                                              /*transpose=*/false);
   } else {
-    for (int i = 0; i < tx_width; i += 8) {
+    int i = 0;
+    do {
       // Process 8 1d adst16 columns in parallel per iteration.
       Adst16_SSE4_1<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
                                                 /*transpose=*/false);
-    }
+      i += 8;
+    } while (i < tx_width);
   }
   StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
                                                    tx_width, 16, src, tx_type);
 }
 
 void Identity4TransformLoop_SSE4_1(TransformType tx_type, TransformSize tx_size,
-                                   int8_t /*bitdepth*/, void* src_buffer,
-                                   int start_x, int start_y, void* dst_frame,
-                                   bool is_row, int non_zero_coeff_count) {
+                                   void* src_buffer, int start_x, int start_y,
+                                   void* dst_frame, bool is_row,
+                                   int non_zero_coeff_count) {
   auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -2222,16 +2168,20 @@
     const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
     const bool should_round = (tx_height == 8);
     if (should_round) {
-      ShouldRound<4>(src, num_rows);
+      ApplyRounding<4>(src, num_rows);
     }
     if (tx_height < 16) {
-      for (int i = 0; i < num_rows; i += 4) {
+      int i = 0;
+      do {
         Identity4_SSE4_1<false>(&src[i * 4], &src[i * 4], /*step=*/4);
-      }
+        i += 4;
+      } while (i < num_rows);
     } else {
-      for (int i = 0; i < num_rows; i += 4) {
+      int i = 0;
+      do {
         Identity4_SSE4_1<true>(&src[i * 4], &src[i * 4], /*step=*/4);
-      }
+        i += 4;
+      } while (i < num_rows);
     }
     return;
   }
@@ -2243,8 +2193,7 @@
     return;
   }
 
-  const bool flip_columns = ((1U << tx_type) & kTransformFlipColumnsMask) != 0;
-  if (flip_columns) {
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
     FlipColumns<4>(src, tx_width);
   }
 
@@ -2253,9 +2202,9 @@
 }
 
 void Identity8TransformLoop_SSE4_1(TransformType tx_type, TransformSize tx_size,
-                                   int8_t /*bitdepth*/, void* src_buffer,
-                                   int start_x, int start_y, void* dst_frame,
-                                   bool is_row, int non_zero_coeff_count) {
+                                   void* src_buffer, int start_x, int start_y,
+                                   void* dst_frame, bool is_row,
+                                   int non_zero_coeff_count) {
   auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -2270,7 +2219,7 @@
     }
     const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
     if (kShouldRound[tx_size]) {
-      ShouldRound<8>(src, num_rows);
+      ApplyRounding<8>(src, num_rows);
     }
 
     // When combining the identity8 multiplier with the row shift, the
@@ -2280,23 +2229,26 @@
       return;
     }
     if (tx_height == 32) {
-      for (int i = 0; i < num_rows; i += 4) {
+      int i = 0;
+      do {
         Identity8Row32_SSE4_1(&src[i * 8], &src[i * 8], /*step=*/8);
-      }
+        i += 4;
+      } while (i < num_rows);
       return;
     }
 
     // Process kTransformSize8x4
     assert(tx_size == kTransformSize8x4);
-    for (int i = 0; i < num_rows; i += 4) {
+    int i = 0;
+    do {
       Identity8Row4_SSE4_1(&src[i * 8], &src[i * 8], /*step=*/8);
-    }
+      i += 4;
+    } while (i < num_rows);
     return;
   }
 
   assert(!is_row);
-  const bool flip_columns = ((1U << tx_type) & kTransformFlipColumnsMask) != 0;
-  if (flip_columns) {
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
     FlipColumns<8>(src, tx_width);
   }
 
@@ -2305,10 +2257,9 @@
 }
 
 void Identity16TransformLoop_SSE4_1(TransformType tx_type,
-                                    TransformSize tx_size, int8_t /*bitdepth*/,
-                                    void* src_buffer, int start_x, int start_y,
-                                    void* dst_frame, bool is_row,
-                                    int non_zero_coeff_count) {
+                                    TransformSize tx_size, void* src_buffer,
+                                    int start_x, int start_y, void* dst_frame,
+                                    bool is_row, int non_zero_coeff_count) {
   auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -2318,18 +2269,19 @@
     const int num_rows =
         (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
     if (kShouldRound[tx_size]) {
-      ShouldRound<16>(src, num_rows);
+      ApplyRounding<16>(src, num_rows);
     }
-    for (int i = 0; i < num_rows; i += 4) {
+    int i = 0;
+    do {
       Identity16Row_SSE4_1(&src[i * 16], &src[i * 16], /*step=*/16,
                            kTransformRowShift[tx_size]);
-    }
+      i += 4;
+    } while (i < num_rows);
     return;
   }
 
   assert(!is_row);
-  const bool flip_columns = ((1U << tx_type) & kTransformFlipColumnsMask) != 0;
-  if (flip_columns) {
+  if (kTransformFlipColumnsMask.Contains(tx_type)) {
     FlipColumns<16>(src, tx_width);
   }
   Identity16ColumnStoreToFrame_SSE4_1(frame, start_x, start_y, tx_width,
@@ -2337,10 +2289,9 @@
 }
 
 void Identity32TransformLoop_SSE4_1(TransformType /*tx_type*/,
-                                    TransformSize tx_size, int8_t /*bitdepth*/,
-                                    void* src_buffer, int start_x, int start_y,
-                                    void* dst_frame, bool is_row,
-                                    int non_zero_coeff_count) {
+                                    TransformSize tx_size, void* src_buffer,
+                                    int start_x, int start_y, void* dst_frame,
+                                    bool is_row, int non_zero_coeff_count) {
   auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -2358,10 +2309,12 @@
 
     // Process kTransformSize32x16
     assert(tx_size == kTransformSize32x16);
-    ShouldRound<32>(src, num_rows);
-    for (int i = 0; i < num_rows; i += 4) {
+    ApplyRounding<32>(src, num_rows);
+    int i = 0;
+    do {
       Identity32Row16_SSE4_1(&src[i * 32], &src[i * 32], /*step=*/32);
-    }
+      i += 4;
+    } while (i < num_rows);
     return;
   }
 
@@ -2370,6 +2323,26 @@
                                /*tx_height=*/32, src);
 }
 
+void Wht4TransformLoop_SSE4_1(TransformType tx_type, TransformSize tx_size,
+                              void* src_buffer, int start_x, int start_y,
+                              void* dst_frame, bool is_row,
+                              int non_zero_coeff_count) {
+  assert(tx_type == kTransformTypeDctDct);
+  assert(tx_size == kTransformSize4x4);
+  static_cast<void>(tx_type);
+  static_cast<void>(tx_size);
+  if (is_row) {
+    // Do both row and column transforms in the column-transform pass.
+    return;
+  }
+
+  assert(!is_row);
+  // Process 4 1d wht4 rows and columns in parallel.
+  const auto* src = static_cast<int16_t*>(src_buffer);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
+  Wht4_SSE4_1(frame, start_x, start_y, src, non_zero_coeff_count);
+}
+
 //------------------------------------------------------------------------------
 
 template <typename Residual, typename Pixel>
@@ -2403,6 +2376,10 @@
       Identity16TransformLoop_SSE4_1;
   dsp->inverse_transforms[k1DTransformSize32][k1DTransformIdentity] =
       Identity32TransformLoop_SSE4_1;
+
+  // Maximum transform size for Wht is 4.
+  dsp->inverse_transforms[k1DTransformSize4][k1DTransformWht] =
+      Wht4TransformLoop_SSE4_1;
 }
 
 void Init8bpp() {
@@ -2459,6 +2436,10 @@
   dsp->inverse_transforms[k1DTransformSize32][k1DTransformIdentity] =
       Identity32TransformLoop_SSE4_1;
 #endif
+#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize4_1DTransformWht)
+  dsp->inverse_transforms[k1DTransformSize4][k1DTransformWht] =
+      Wht4TransformLoop_SSE4_1;
+#endif
 #endif
 }
 
diff --git a/libgav1/src/dsp/x86/inverse_transform_sse4.h b/libgav1/src/dsp/x86/inverse_transform_sse4.h
index dd30533..72fda5a 100644
--- a/libgav1/src/dsp/x86/inverse_transform_sse4.h
+++ b/libgav1/src/dsp/x86/inverse_transform_sse4.h
@@ -3,12 +3,12 @@
 
 #include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
-#include "src/dsp/inverse_transform.h"
 
 namespace libgav1 {
 namespace dsp {
 
-// Initializes Dsp::inverse_transforms. This function is not thread-safe.
+// Initializes Dsp::inverse_transforms, see the defines below for specifics.
+// This function is not thread-safe.
 void InverseTransformInit_SSE4_1();
 
 }  // namespace dsp
@@ -65,5 +65,9 @@
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity
 #define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity LIBGAV1_DSP_SSE4_1
 #endif
+
+#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht LIBGAV1_DSP_SSE4_1
+#endif
 #endif  // LIBGAV1_ENABLE_SSE4_1
 #endif  // LIBGAV1_SRC_DSP_X86_INVERSE_TRANSFORM_SSE4_H_
diff --git a/libgav1/src/dsp/x86/loop_filter_sse4.cc b/libgav1/src/dsp/x86/loop_filter_sse4.cc
index d68e76d..4c5ab23 100644
--- a/libgav1/src/dsp/x86/loop_filter_sse4.cc
+++ b/libgav1/src/dsp/x86/loop_filter_sse4.cc
@@ -1,4 +1,5 @@
-#include "src/dsp/x86/loop_filter_sse4.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/loop_filter.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
diff --git a/libgav1/src/dsp/x86/loop_filter_sse4.h b/libgav1/src/dsp/x86/loop_filter_sse4.h
index 76f2846..841eeec 100644
--- a/libgav1/src/dsp/x86/loop_filter_sse4.h
+++ b/libgav1/src/dsp/x86/loop_filter_sse4.h
@@ -3,13 +3,12 @@
 
 #include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
-#include "src/dsp/loop_filter.h"
 
 namespace libgav1 {
 namespace dsp {
 
-// Initializes Dsp::loop_filters with sse4 implementations. This function
-// is not thread-safe.
+// Initializes Dsp::loop_filters, see the defines below for specifics. This
+// function is not thread-safe.
 void LoopFilterInit_SSE4_1();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/x86/loop_restoration_sse4.cc b/libgav1/src/dsp/x86/loop_restoration_sse4.cc
index fa48bcc..3a08fc4 100644
--- a/libgav1/src/dsp/x86/loop_restoration_sse4.cc
+++ b/libgav1/src/dsp/x86/loop_restoration_sse4.cc
@@ -1,4 +1,5 @@
-#include "src/dsp/x86/loop_restoration_sse4.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/loop_restoration.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 #include <smmintrin.h>
@@ -9,6 +10,7 @@
 #include <cstring>
 
 #include "src/dsp/common.h"
+#include "src/dsp/constants.h"
 #include "src/dsp/x86/common_sse4.h"
 #include "src/utils/common.h"
 #include "src/utils/compiler_attributes.h"
@@ -63,10 +65,9 @@
                          ptrdiff_t source_stride, ptrdiff_t dest_stride,
                          int width, int height,
                          RestorationBuffer* const buffer) {
-  const int* const inter_round_bits = buffer->inter_round_bits;
   int8_t filter[kSubPixelTaps];
   const int limit =
-      (1 << (8 + 1 + kWienerFilterBits - inter_round_bits[0])) - 1;
+      (1 << (8 + 1 + kWienerFilterBits - kInterRoundBitsHorizontal)) - 1;
   const auto* src = static_cast<const uint8_t*>(source);
   auto* dst = static_cast<uint8_t*>(dest);
   const ptrdiff_t buffer_stride = buffer->wiener_buffer_stride;
@@ -77,7 +78,7 @@
   src -= center_tap * source_stride + center_tap;
 
   const int horizontal_rounding =
-      1 << (8 + kWienerFilterBits - inter_round_bits[0] - 1);
+      1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
   const __m128i v_horizontal_rounding =
       _mm_shufflelo_epi16(_mm_cvtsi32_si128(horizontal_rounding), 0);
   const __m128i v_limit = _mm_shufflelo_epi16(_mm_cvtsi32_si128(limit), 0);
@@ -86,13 +87,16 @@
   __m128i v_k3k2 = _mm_shufflelo_epi16(v_horizontal_filter, 0x55);
   __m128i v_k5k4 = _mm_shufflelo_epi16(v_horizontal_filter, 0xaa);
   __m128i v_k7k6 = _mm_shufflelo_epi16(v_horizontal_filter, 0xff);
-  const __m128i v_round_0 =
-      _mm_shufflelo_epi16(_mm_cvtsi32_si128(1 << (inter_round_bits[0] - 1)), 0);
-  const __m128i v_round_0_shift = _mm_cvtsi32_si128(inter_round_bits[0]);
-  const __m128i v_offset_shift = _mm_cvtsi32_si128(7 - inter_round_bits[0]);
+  const __m128i v_round_0 = _mm_shufflelo_epi16(
+      _mm_cvtsi32_si128(1 << (kInterRoundBitsHorizontal - 1)), 0);
+  const __m128i v_round_0_shift = _mm_cvtsi32_si128(kInterRoundBitsHorizontal);
+  const __m128i v_offset_shift =
+      _mm_cvtsi32_si128(7 - kInterRoundBitsHorizontal);
 
-  for (int y = 0; y < height + kSubPixelTaps - 2; ++y) {
-    for (int x = 0; x < width; x += 4) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       // Run the Wiener filter on four sets of source samples at a time:
       //   src[x + 0] ... src[x + 6]
       //   src[x + 1] ... src[x + 7]
@@ -126,29 +130,30 @@
           _mm_add_epi16(v_rounded_sum0, v_horizontal_rounding);
       // Zero out the even bytes, calculate scaled down offset correction, and
       // add to sum here to prevent signed 16 bit outranging.
-      // (src[3] * 128) >> inter_round_bits[0]
+      // (src[3] * 128) >> kInterRoundBitsHorizontal
       const __m128i v_src_3x128 =
           _mm_sll_epi16(_mm_srli_epi16(v_src_32, 8), v_offset_shift);
       const __m128i v_rounded_sum = _mm_add_epi16(v_rounded_sum1, v_src_3x128);
       const __m128i v_a = _mm_max_epi16(v_rounded_sum, _mm_setzero_si128());
       const __m128i v_b = _mm_min_epi16(v_a, v_limit);
       StoreLo8(&wiener_buffer[x], v_b);
-    }
+      x += 4;
+    } while (x < width);
     src += source_stride;
     wiener_buffer += buffer_stride;
-  }
+  } while (++y < height + kSubPixelTaps - 2);
 
   wiener_buffer = buffer->wiener_buffer;
   // vertical filtering.
   PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical, filter);
 
-  const int vertical_rounding = -(1 << (8 + inter_round_bits[1] - 1));
+  const int vertical_rounding = -(1 << (8 + kInterRoundBitsVertical - 1));
   const __m128i v_vertical_rounding =
       _mm_shuffle_epi32(_mm_cvtsi32_si128(vertical_rounding), 0);
   const __m128i v_offset_correction = _mm_set_epi16(0, 0, 0, 0, 128, 0, 0, 0);
-  const __m128i v_round_1 =
-      _mm_shuffle_epi32(_mm_cvtsi32_si128(1 << (inter_round_bits[1] - 1)), 0);
-  const __m128i v_round_1_shift = _mm_cvtsi32_si128(inter_round_bits[1]);
+  const __m128i v_round_1 = _mm_shuffle_epi32(
+      _mm_cvtsi32_si128(1 << (kInterRoundBitsVertical - 1)), 0);
+  const __m128i v_round_1_shift = _mm_cvtsi32_si128(kInterRoundBitsVertical);
   const __m128i v_vertical_filter0 = _mm_cvtepi8_epi16(LoadLo8(filter));
   const __m128i v_vertical_filter =
       _mm_add_epi16(v_vertical_filter0, v_offset_correction);
@@ -156,8 +161,10 @@
   v_k3k2 = _mm_shuffle_epi32(v_vertical_filter, 0x55);
   v_k5k4 = _mm_shuffle_epi32(v_vertical_filter, 0xaa);
   v_k7k6 = _mm_shuffle_epi32(v_vertical_filter, 0xff);
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; x += 4) {
+  y = 0;
+  do {
+    int x = 0;
+    do {
       const __m128i v_wb_0 = LoadLo8(&wiener_buffer[0 * buffer_stride + x]);
       const __m128i v_wb_1 = LoadLo8(&wiener_buffer[1 * buffer_stride + x]);
       const __m128i v_wb_2 = LoadLo8(&wiener_buffer[2 * buffer_stride + x]);
@@ -182,10 +189,11 @@
       const __m128i v_a = _mm_packs_epi32(v_rounded_sum, v_rounded_sum);
       const __m128i v_b = _mm_packus_epi16(v_a, v_a);
       Store4(&dst[x], v_b);
-    }
+      x += 4;
+    } while (x < width);
     dst += dest_stride;
     wiener_buffer += buffer_stride;
-  }
+  } while (++y < height);
 }
 
 // Section 7.17.3.
@@ -196,7 +204,7 @@
 //   a2 = 1;
 // else
 //   a2 = ((z << kSgrProjSgrBits) + (z >> 1)) / (z + 1);
-constexpr int x_by_xplus1[256] = {
+constexpr int kXByXPlus1[256] = {
     1,   128, 171, 192, 205, 213, 219, 224, 228, 230, 233, 235, 236, 238, 239,
     240, 241, 242, 243, 243, 244, 244, 245, 245, 246, 246, 247, 247, 247, 247,
     248, 248, 248, 248, 249, 249, 249, 249, 249, 250, 250, 250, 250, 250, 250,
@@ -262,10 +270,12 @@
 
   // Calculate intermediate results, including one-pixel border, for example,
   // if unit size is 64x64, we calculate 66x66 pixels.
-  for (int y = -1; y <= height; ++y) {
+  int y = -1;
+  do {
     const uint8_t* top_left = &src[(y - 1) * stride - 2];
     // Calculate the box vertical sums for each x position.
-    for (int vsx = -2; vsx <= width + 1; vsx += 4, top_left += 4) {
+    int vsx = -2;
+    do {
       const __m128i v_box0 = _mm_cvtepu8_epi32(Load4(top_left));
       const __m128i v_box1 = _mm_cvtepu8_epi32(Load4(top_left + stride));
       const __m128i v_box2 = _mm_cvtepu8_epi32(Load4(top_left + stride * 2));
@@ -278,9 +288,12 @@
       const __m128i v_b012 = _mm_add_epi32(v_b01, v_box2);
       StoreUnaligned16(&vertical_sum_of_squares[vsx], v_a012);
       StoreUnaligned16(&vertical_sums[vsx], v_b012);
-    }
+      top_left += 4;
+      vsx += 4;
+    } while (vsx <= width + 1);
 
-    for (int x = -1; x <= width; x += 4) {
+    int x = -1;
+    do {
       const __m128i v_a =
           HorizontalAddVerticalSumsRadius1(&vertical_sum_of_squares[x - 1]);
       const __m128i v_b =
@@ -297,11 +310,10 @@
       const __m128i v_z = _mm_min_epi32(
           v_255, RightShiftWithRounding_U32(_mm_mullo_epi32(v_p, v_s),
                                             kSgrProjScaleBits));
-      const __m128i v_a2 =
-          _mm_set_epi32(x_by_xplus1[_mm_extract_epi32(v_z, 3)],
-                        x_by_xplus1[_mm_extract_epi32(v_z, 2)],
-                        x_by_xplus1[_mm_extract_epi32(v_z, 1)],
-                        x_by_xplus1[_mm_extract_epi32(v_z, 0)]);
+      const __m128i v_a2 = _mm_set_epi32(kXByXPlus1[_mm_extract_epi32(v_z, 3)],
+                                         kXByXPlus1[_mm_extract_epi32(v_z, 2)],
+                                         kXByXPlus1[_mm_extract_epi32(v_z, 1)],
+                                         kXByXPlus1[_mm_extract_epi32(v_z, 0)]);
       // -----------------------
       // calc b2 and store
       // -----------------------
@@ -312,10 +324,11 @@
       StoreUnaligned16(
           &intermediate_result[1][x],
           RightShiftWithRounding_U32(v_b2, kSgrProjReciprocalBits));
-    }
+      x += 4;
+    } while (x <= width);
     intermediate_result[0] += array_stride;
     intermediate_result[1] += array_stride;
-  }
+  } while (++y <= height);
 }
 
 void BoxFilterPreProcessRadius2_SSE4_1(
@@ -332,10 +345,12 @@
 
   // Calculate intermediate results, including one-pixel border, for example,
   // if unit size is 64x64, we calculate 66x66 pixels.
-  for (int y = -1; y <= height; y += 2) {
+  int y = -1;
+  do {
     // Calculate the box vertical sums for each x position.
     const uint8_t* top_left = &src[(y - 2) * stride - 3];
-    for (int vsx = -3; vsx <= width + 2; vsx += 4, top_left += 4) {
+    int vsx = -3;
+    do {
       const __m128i v_box0 = _mm_cvtepu8_epi32(Load4(top_left));
       const __m128i v_box1 = _mm_cvtepu8_epi32(Load4(top_left + stride));
       const __m128i v_box2 = _mm_cvtepu8_epi32(Load4(top_left + stride * 2));
@@ -356,9 +371,12 @@
       const __m128i v_b01234 = _mm_add_epi32(v_b0123, v_box4);
       StoreUnaligned16(&vertical_sum_of_squares[vsx], v_a01234);
       StoreUnaligned16(&vertical_sums[vsx], v_b01234);
-    }
+      top_left += 4;
+      vsx += 4;
+    } while (vsx <= width + 2);
 
-    for (int x = -1; x <= width; x += 4) {
+    int x = -1;
+    do {
       const __m128i v_a =
           HorizontalAddVerticalSumsRadius2(&vertical_sum_of_squares[x - 2]);
       const __m128i v_b =
@@ -375,11 +393,10 @@
       const __m128i v_z = _mm_min_epi32(
           v_255, RightShiftWithRounding_U32(_mm_mullo_epi32(v_p, v_s),
                                             kSgrProjScaleBits));
-      const __m128i v_a2 =
-          _mm_set_epi32(x_by_xplus1[_mm_extract_epi32(v_z, 3)],
-                        x_by_xplus1[_mm_extract_epi32(v_z, 2)],
-                        x_by_xplus1[_mm_extract_epi32(v_z, 1)],
-                        x_by_xplus1[_mm_extract_epi32(v_z, 0)]);
+      const __m128i v_a2 = _mm_set_epi32(kXByXPlus1[_mm_extract_epi32(v_z, 3)],
+                                         kXByXPlus1[_mm_extract_epi32(v_z, 2)],
+                                         kXByXPlus1[_mm_extract_epi32(v_z, 1)],
+                                         kXByXPlus1[_mm_extract_epi32(v_z, 0)]);
       // -----------------------
       // calc b2 and store
       // -----------------------
@@ -390,10 +407,12 @@
       StoreUnaligned16(
           &intermediate_result[1][x],
           RightShiftWithRounding_U32(v_b2, kSgrProjReciprocalBits));
-    }
+      x += 4;
+    } while (x <= width);
     intermediate_result[0] += 2 * array_stride;
     intermediate_result[1] += 2 * array_stride;
-  }
+    y += 2;
+  } while (y <= height);
 }
 
 void BoxFilterPreProcess_SSE4_1(const RestorationUnitInfo& restoration_info,
@@ -543,7 +562,8 @@
         kRestorationBorder * intermediate_stride + kRestorationBorder;
 
     if (pass == 0) {
-      for (int y = 0; y < height; ++y) {
+      int y = 0;
+      do {
         const int shift = ((y & 1) != 0) ? 4 : 5;
         uint32_t* const array_start[2] = {
             buffer->box_filter_process_intermediate[0] +
@@ -554,7 +574,8 @@
             array_start[0] - intermediate_stride,
             array_start[1] - intermediate_stride};
         if ((y & 1) == 0) {  // even row
-          for (int x = 0; x < width; x += 4) {
+          int x = 0;
+          do {
             // 5 6 5
             // 0 0 0
             // 5 6 5
@@ -569,9 +590,11 @@
                 v_v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
 
             StoreUnaligned16(&filtered_output[x], v_filtered);
-          }
+            x += 4;
+          } while (x < width);
         } else {
-          for (int x = 0; x < width; x += 4) {
+          int x = 0;
+          do {
             // 0 0 0
             // 5 6 5
             // 0 0 0
@@ -586,13 +609,15 @@
                 v_v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
 
             StoreUnaligned16(&filtered_output[x], v_filtered);
-          }
+            x += 4;
+          } while (x < width);
         }
         src_ptr += stride;
         filtered_output += filtered_output_stride;
-      }
+      } while (++y < height);
     } else {
-      for (int y = 0; y < height; ++y) {
+      int y = 0;
+      do {
         const int shift = 5;
         uint32_t* const array_start[2] = {
             buffer->box_filter_process_intermediate[0] +
@@ -602,7 +627,8 @@
         uint32_t* intermediate_result2[2] = {
             array_start[0] - intermediate_stride,
             array_start[1] - intermediate_stride};
-        for (int x = 0; x < width; x += 4) {
+        int x = 0;
+        do {
           const __m128i v_A = Process3x3Block_343(&intermediate_result2[0][x],
                                                   intermediate_stride);
           const __m128i v_B = Process3x3Block_343(&intermediate_result2[1][x],
@@ -614,10 +640,11 @@
               v_v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
 
           StoreUnaligned16(&filtered_output[x], v_filtered);
-        }
+          x += 4;
+        } while (x < width);
         src_ptr += stride;
         filtered_output += filtered_output_stride;
-      }
+      } while (++y < height);
     }
   }
 }
@@ -652,8 +679,10 @@
   const __m128i v_r0_mask = _mm_cmpeq_epi32(v_r0, zero);
   const __m128i v_r1_mask = _mm_cmpeq_epi32(v_r1, zero);
 
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; x += 4) {
+  int y = 0;
+  do {
+    int x = 0;
+    do {
       const __m128i v_src = _mm_cvtepu8_epi32(Load4(src + x));
       const __m128i v_u = _mm_slli_epi32(v_src, kSgrProjRestoreBits);
       const __m128i v_v_a = _mm_mullo_epi32(v_w1, v_u);
@@ -674,13 +703,14 @@
       v_s = _mm_packs_epi32(v_s, v_s);
       v_s = _mm_packus_epi16(v_s, v_s);
       Store4(&dst[x], v_s);
-    }
+      x += 4;
+    } while (x < width);
 
     src += source_stride;
     dst += dest_stride;
     box_filter_process_output[0] += array_stride;
     box_filter_process_output[1] += array_stride;
-  }
+  } while (++y < height);
 }
 
 void Init8bpp() {
diff --git a/libgav1/src/dsp/x86/loop_restoration_sse4.h b/libgav1/src/dsp/x86/loop_restoration_sse4.h
index 522950f..2a91938 100644
--- a/libgav1/src/dsp/x86/loop_restoration_sse4.h
+++ b/libgav1/src/dsp/x86/loop_restoration_sse4.h
@@ -3,13 +3,12 @@
 
 #include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
-#include "src/dsp/loop_restoration.h"
 
 namespace libgav1 {
 namespace dsp {
 
-// Initializes Dsp::loop_restorations with sse4 implementations. This function
-// is not thread-safe.
+// Initializes Dsp::loop_restorations, see the defines below for specifics.
+// This function is not thread-safe.
 void LoopRestorationInit_SSE4_1();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/x86/obmc_sse4.cc b/libgav1/src/dsp/x86/obmc_sse4.cc
new file mode 100644
index 0000000..33aa8e4
--- /dev/null
+++ b/libgav1/src/dsp/x86/obmc_sse4.cc
@@ -0,0 +1,313 @@
+#include "src/dsp/dsp.h"
+#include "src/dsp/obmc.h"
+
+#if LIBGAV1_ENABLE_SSE4_1
+
+#include <xmmintrin.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/x86/common_sse4.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+#include "src/dsp/obmc.inc"
+
+inline void OverlapBlendFromLeft2xH_SSE4_1(
+    uint8_t* const prediction, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
+  uint8_t* pred = prediction;
+  const uint8_t* obmc_pred = obmc_prediction;
+  const __m128i mask_inverter = _mm_cvtsi32_si128(0x40404040);
+  const __m128i mask_val = _mm_shufflelo_epi16(Load4(kObmcMask), 0);
+  // 64 - mask
+  const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
+  const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val);
+  int y = height;
+  do {
+    const __m128i pred_val = Load2x2(pred, pred + prediction_stride);
+    const __m128i obmc_pred_val =
+        Load2x2(obmc_pred, obmc_pred + obmc_prediction_stride);
+
+    const __m128i terms = _mm_unpacklo_epi8(pred_val, obmc_pred_val);
+    const __m128i result =
+        RightShiftWithRounding_U16(_mm_maddubs_epi16(terms, masks), 6);
+    const __m128i packed_result = _mm_packus_epi16(result, result);
+    Store2(pred, packed_result);
+    pred += prediction_stride;
+    const int16_t second_row_result = _mm_extract_epi16(packed_result, 1);
+    memcpy(pred, &second_row_result, sizeof(second_row_result));
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride << 1;
+    y -= 2;
+  } while (y != 0);
+}
+
+inline void OverlapBlendFromLeft4xH_SSE4_1(
+    uint8_t* const prediction, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
+  uint8_t* pred = prediction;
+  const uint8_t* obmc_pred = obmc_prediction;
+  const __m128i mask_inverter = _mm_cvtsi32_si128(0x40404040);
+  const __m128i mask_val = Load4(kObmcMask + 2);
+  // 64 - mask
+  const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
+  // Duplicate first half of vector.
+  const __m128i masks =
+      _mm_shuffle_epi32(_mm_unpacklo_epi8(mask_val, obmc_mask_val), 0x44);
+  int y = height;
+  do {
+    const __m128i pred_val0 = Load4(pred);
+    const __m128i obmc_pred_val0 = Load4(obmc_pred);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    // Place the second row of each source in the second four bytes.
+    const __m128i pred_val =
+        _mm_alignr_epi8(Load4(pred), _mm_slli_si128(pred_val0, 12), 12);
+    const __m128i obmc_pred_val = _mm_alignr_epi8(
+        Load4(obmc_pred), _mm_slli_si128(obmc_pred_val0, 12), 12);
+    const __m128i terms = _mm_unpacklo_epi8(pred_val, obmc_pred_val);
+    const __m128i result =
+        RightShiftWithRounding_U16(_mm_maddubs_epi16(terms, masks), 6);
+    const __m128i packed_result = _mm_packus_epi16(result, result);
+    Store4(pred - prediction_stride, packed_result);
+    const int second_row_result = _mm_extract_epi32(packed_result, 1);
+    memcpy(pred, &second_row_result, sizeof(second_row_result));
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    y -= 2;
+  } while (y != 0);
+}
+
+inline void OverlapBlendFromLeft8xH_SSE4_1(
+    uint8_t* const prediction, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
+  uint8_t* pred = prediction;
+  const uint8_t* obmc_pred = obmc_prediction;
+  const __m128i mask_inverter = _mm_set1_epi8(64);
+  const __m128i mask_val = LoadLo8(kObmcMask + 6);
+  // 64 - mask
+  const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
+  const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val);
+  int y = height;
+  do {
+    const __m128i pred_val = LoadLo8(pred);
+    const __m128i obmc_pred_val = LoadLo8(obmc_pred);
+    const __m128i terms = _mm_unpacklo_epi8(pred_val, obmc_pred_val);
+    const __m128i result =
+        RightShiftWithRounding_U16(_mm_maddubs_epi16(terms, masks), 6);
+
+    StoreLo8(pred, _mm_packus_epi16(result, result));
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+  } while (--y != 0);
+}
+
+void OverlapBlendFromLeft_SSE4_1(void* const prediction,
+                                 const ptrdiff_t prediction_stride,
+                                 const int width, const int height,
+                                 const void* const obmc_prediction,
+                                 const ptrdiff_t obmc_prediction_stride) {
+  auto* pred = static_cast<uint8_t*>(prediction);
+  const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
+
+  if (width == 2) {
+    OverlapBlendFromLeft2xH_SSE4_1(pred, prediction_stride, height, obmc_pred,
+                                   obmc_prediction_stride);
+    return;
+  }
+  if (width == 4) {
+    OverlapBlendFromLeft4xH_SSE4_1(pred, prediction_stride, height, obmc_pred,
+                                   obmc_prediction_stride);
+    return;
+  }
+  if (width == 8) {
+    OverlapBlendFromLeft8xH_SSE4_1(pred, prediction_stride, height, obmc_pred,
+                                   obmc_prediction_stride);
+    return;
+  }
+  const __m128i mask_inverter = _mm_set1_epi8(64);
+  const uint8_t* mask = kObmcMask + width - 2;
+  int x = 0;
+  do {
+    pred = static_cast<uint8_t*>(prediction) + x;
+    obmc_pred = static_cast<const uint8_t*>(obmc_prediction) + x;
+    const __m128i mask_val = LoadUnaligned16(mask + x);
+    // 64 - mask
+    const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
+    const __m128i masks_lo = _mm_unpacklo_epi8(mask_val, obmc_mask_val);
+    const __m128i masks_hi = _mm_unpackhi_epi8(mask_val, obmc_mask_val);
+
+    int y = 0;
+    do {
+      const __m128i pred_val = LoadUnaligned16(pred);
+      const __m128i obmc_pred_val = LoadUnaligned16(obmc_pred);
+      const __m128i terms_lo = _mm_unpacklo_epi8(pred_val, obmc_pred_val);
+      const __m128i result_lo =
+          RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_lo, masks_lo), 6);
+      const __m128i terms_hi = _mm_unpackhi_epi8(pred_val, obmc_pred_val);
+      const __m128i result_hi =
+          RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_hi, masks_hi), 6);
+      StoreUnaligned16(pred, _mm_packus_epi16(result_lo, result_hi));
+
+      pred += prediction_stride;
+      obmc_pred += obmc_prediction_stride;
+    } while (++y < height);
+    x += 16;
+  } while (x < width);
+}
+
+inline void OverlapBlendFromTop4xH_SSE4_1(
+    uint8_t* const prediction, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
+  uint8_t* pred = prediction;
+  const uint8_t* obmc_pred = obmc_prediction;
+  const __m128i mask_inverter = _mm_set1_epi16(64);
+  const __m128i mask_shuffler = _mm_set_epi32(0x01010101, 0x01010101, 0, 0);
+  const __m128i mask_preinverter = _mm_set1_epi16(-256 | 1);
+
+  const uint8_t* mask = kObmcMask + height - 2;
+  const int compute_height = height - (height >> 2);
+  int y = 0;
+  do {
+    // First mask in the first half, second mask in the second half.
+    const __m128i mask_val = _mm_shuffle_epi8(
+        _mm_cvtsi32_si128(*reinterpret_cast<const uint16_t*>(mask + y)),
+        mask_shuffler);
+    const __m128i masks =
+        _mm_sub_epi8(mask_inverter, _mm_sign_epi8(mask_val, mask_preinverter));
+    const __m128i pred_val0 = Load4(pred);
+
+    const __m128i obmc_pred_val0 = Load4(obmc_pred);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    const __m128i pred_val =
+        _mm_alignr_epi8(Load4(pred), _mm_slli_si128(pred_val0, 12), 12);
+    const __m128i obmc_pred_val = _mm_alignr_epi8(
+        Load4(obmc_pred), _mm_slli_si128(obmc_pred_val0, 12), 12);
+    const __m128i terms = _mm_unpacklo_epi8(obmc_pred_val, pred_val);
+    const __m128i result =
+        RightShiftWithRounding_U16(_mm_maddubs_epi16(terms, masks), 6);
+
+    const __m128i packed_result = _mm_packus_epi16(result, result);
+    Store4(pred - prediction_stride, packed_result);
+    Store4(pred, _mm_srli_si128(packed_result, 4));
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    y += 2;
+  } while (y < compute_height);
+}
+
+inline void OverlapBlendFromTop8xH_SSE4_1(
+    uint8_t* const prediction, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
+  uint8_t* pred = prediction;
+  const uint8_t* obmc_pred = obmc_prediction;
+  const uint8_t* mask = kObmcMask + height - 2;
+  const __m128i mask_inverter = _mm_set1_epi8(64);
+  const int compute_height = height - (height >> 2);
+  int y = compute_height;
+  do {
+    const __m128i mask_val = _mm_set1_epi8(mask[compute_height - y]);
+    // 64 - mask
+    const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
+    const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val);
+    const __m128i pred_val = LoadLo8(pred);
+    const __m128i obmc_pred_val = LoadLo8(obmc_pred);
+    const __m128i terms = _mm_unpacklo_epi8(pred_val, obmc_pred_val);
+    const __m128i result =
+        RightShiftWithRounding_U16(_mm_maddubs_epi16(terms, masks), 6);
+
+    StoreLo8(pred, _mm_packus_epi16(result, result));
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+  } while (--y != 0);
+}
+
+void OverlapBlendFromTop_SSE4_1(void* const prediction,
+                                const ptrdiff_t prediction_stride,
+                                const int width, const int height,
+                                const void* const obmc_prediction,
+                                const ptrdiff_t obmc_prediction_stride) {
+  auto* pred = static_cast<uint8_t*>(prediction);
+  const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
+
+  if (width <= 4) {
+    OverlapBlendFromTop4xH_SSE4_1(pred, prediction_stride, height, obmc_pred,
+                                  obmc_prediction_stride);
+    return;
+  }
+  if (width == 8) {
+    OverlapBlendFromTop8xH_SSE4_1(pred, prediction_stride, height, obmc_pred,
+                                  obmc_prediction_stride);
+    return;
+  }
+
+  // Stop when mask value becomes 64.
+  const int compute_height = height - (height >> 2);
+  const __m128i mask_inverter = _mm_set1_epi8(64);
+  int y = 0;
+  const uint8_t* mask = kObmcMask + height - 2;
+  do {
+    const __m128i mask_val = _mm_set1_epi8(mask[y]);
+    // 64 - mask
+    const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
+    const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val);
+    int x = 0;
+    do {
+      const __m128i pred_val = LoadUnaligned16(pred + x);
+      const __m128i obmc_pred_val = LoadUnaligned16(obmc_pred + x);
+      const __m128i terms_lo = _mm_unpacklo_epi8(pred_val, obmc_pred_val);
+      const __m128i result_lo =
+          RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_lo, masks), 6);
+      const __m128i terms_hi = _mm_unpackhi_epi8(pred_val, obmc_pred_val);
+      const __m128i result_hi =
+          RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_hi, masks), 6);
+      StoreUnaligned16(pred + x, _mm_packus_epi16(result_lo, result_hi));
+      x += 16;
+    } while (x < width);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+  } while (++y < compute_height);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+#if DSP_ENABLED_8BPP_SSE4_1(ObmcVertical)
+  dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendFromTop_SSE4_1;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(ObmcHorizontal)
+  dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendFromLeft_SSE4_1;
+#endif
+}
+
+}  // namespace
+
+void ObmcInit_SSE4_1() { Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else   // !LIBGAV1_ENABLE_SSE4_1
+
+namespace libgav1 {
+namespace dsp {
+
+void ObmcInit_SSE4_1() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/obmc_sse4.h b/libgav1/src/dsp/x86/obmc_sse4.h
new file mode 100644
index 0000000..aa00124
--- /dev/null
+++ b/libgav1/src/dsp/x86/obmc_sse4.h
@@ -0,0 +1,27 @@
+#ifndef LIBGAV1_SRC_DSP_X86_OBMC_SSE4_H_
+#define LIBGAV1_SRC_DSP_X86_OBMC_SSE4_H_
+
+#include "src/dsp/cpu.h"
+#include "src/dsp/dsp.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::obmc_blend[]. This function is not thread-safe.
+void ObmcInit_SSE4_1();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+// If sse4 is enabled and the baseline isn't set due to a higher level of
+// optimization being enabled, signal the sse4 implementation should be used.
+#if LIBGAV1_ENABLE_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_ObmcVertical
+#define LIBGAV1_Dsp8bpp_ObmcVertical LIBGAV1_DSP_SSE4_1
+#endif
+#ifndef LIBGAV1_Dsp8bpp_ObmcHorizontal
+#define LIBGAV1_Dsp8bpp_ObmcHorizontal LIBGAV1_DSP_SSE4_1
+#endif
+#endif  // LIBGAV1_ENABLE_SSE4_1
+
+#endif  // LIBGAV1_SRC_DSP_X86_OBMC_SSE4_H_
diff --git a/libgav1/src/dsp/x86/transpose_sse4.h b/libgav1/src/dsp/x86/transpose_sse4.h
index 0b750cd..b222266 100644
--- a/libgav1/src/dsp/x86/transpose_sse4.h
+++ b/libgav1/src/dsp/x86/transpose_sse4.h
@@ -2,6 +2,7 @@
 #define LIBGAV1_SRC_DSP_X86_TRANSPOSE_SSE4_H_
 
 #include "src/dsp/dsp.h"
+#include "src/utils/compiler_attributes.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 #include <emmintrin.h>
@@ -9,7 +10,7 @@
 namespace libgav1 {
 namespace dsp {
 
-inline __m128i Transpose4x4_U8(const __m128i* const in) {
+LIBGAV1_ALWAYS_INLINE __m128i Transpose4x4_U8(const __m128i* const in) {
   // Unpack 16 bit elements. Goes from:
   // in[0]: 00 01 02 03
   // in[1]: 10 11 12 13
@@ -26,7 +27,125 @@
   return _mm_unpacklo_epi16(a0, a1);
 }
 
-inline void Transpose8x8_U16(const __m128i* const in, __m128i* const out) {
+LIBGAV1_ALWAYS_INLINE void Transpose4x4_U16(const __m128i* in, __m128i* out) {
+  // Unpack 16 bit elements. Goes from:
+  // in[0]: 00 01 02 03  XX XX XX XX
+  // in[1]: 10 11 12 13  XX XX XX XX
+  // in[2]: 20 21 22 23  XX XX XX XX
+  // in[3]: 30 31 32 33  XX XX XX XX
+  // to:
+  // ba:    00 10 01 11  02 12 03 13
+  // dc:    20 30 21 31  22 32 23 33
+  const __m128i ba = _mm_unpacklo_epi16(in[0], in[1]);
+  const __m128i dc = _mm_unpacklo_epi16(in[2], in[3]);
+  // Unpack 32 bit elements resulting in:
+  // dcba_lo: 00 10 20 30  01 11 21 31
+  // dcba_hi: 02 12 22 32  03 13 23 33
+  const __m128i dcba_lo = _mm_unpacklo_epi32(ba, dc);
+  const __m128i dcba_hi = _mm_unpackhi_epi32(ba, dc);
+  // Assign or shift right by 8 bytes resulting in:
+  // out[0]: 00 10 20 30  01 11 21 31
+  // out[1]: 01 11 21 31  XX XX XX XX
+  // out[2]: 02 12 22 32  03 13 23 33
+  // out[3]: 03 13 23 33  XX XX XX XX
+  out[0] = dcba_lo;
+  out[1] = _mm_srli_si128(dcba_lo, 8);
+  out[2] = dcba_hi;
+  out[3] = _mm_srli_si128(dcba_hi, 8);
+}
+
+LIBGAV1_ALWAYS_INLINE void Transpose4x8To8x4_U16(const __m128i* in,
+                                                 __m128i* out) {
+  // Unpack 16 bit elements. Goes from:
+  // in[0]: 00 01 02 03  XX XX XX XX
+  // in[1]: 10 11 12 13  XX XX XX XX
+  // in[2]: 20 21 22 23  XX XX XX XX
+  // in[3]: 30 31 32 33  XX XX XX XX
+  // in[4]: 40 41 42 43  XX XX XX XX
+  // in[5]: 50 51 52 53  XX XX XX XX
+  // in[6]: 60 61 62 63  XX XX XX XX
+  // in[7]: 70 71 72 73  XX XX XX XX
+  // to:
+  // a0:    00 10 01 11  02 12 03 13
+  // a1:    20 30 21 31  22 32 23 33
+  // a2:    40 50 41 51  42 52 43 53
+  // a3:    60 70 61 71  62 72 63 73
+  const __m128i a0 = _mm_unpacklo_epi16(in[0], in[1]);
+  const __m128i a1 = _mm_unpacklo_epi16(in[2], in[3]);
+  const __m128i a2 = _mm_unpacklo_epi16(in[4], in[5]);
+  const __m128i a3 = _mm_unpacklo_epi16(in[6], in[7]);
+
+  // Unpack 32 bit elements resulting in:
+  // b0: 00 10 20 30  01 11 21 31
+  // b1: 40 50 60 70  41 51 61 71
+  // b2: 02 12 22 32  03 13 23 33
+  // b3: 42 52 62 72  43 53 63 73
+  const __m128i b0 = _mm_unpacklo_epi32(a0, a1);
+  const __m128i b1 = _mm_unpacklo_epi32(a2, a3);
+  const __m128i b2 = _mm_unpackhi_epi32(a0, a1);
+  const __m128i b3 = _mm_unpackhi_epi32(a2, a3);
+
+  // Unpack 64 bit elements resulting in:
+  // out[0]: 00 10 20 30  40 50 60 70
+  // out[1]: 01 11 21 31  41 51 61 71
+  // out[2]: 02 12 22 32  42 52 62 72
+  // out[3]: 03 13 23 33  43 53 63 73
+  out[0] = _mm_unpacklo_epi64(b0, b1);
+  out[1] = _mm_unpackhi_epi64(b0, b1);
+  out[2] = _mm_unpacklo_epi64(b2, b3);
+  out[3] = _mm_unpackhi_epi64(b2, b3);
+}
+
+LIBGAV1_ALWAYS_INLINE void Transpose8x4To4x8_U16(const __m128i* in,
+                                                 __m128i* out) {
+  // Unpack 16 bit elements. Goes from:
+  // in[0]: 00 01 02 03  04 05 06 07
+  // in[1]: 10 11 12 13  14 15 16 17
+  // in[2]: 20 21 22 23  24 25 26 27
+  // in[3]: 30 31 32 33  34 35 36 37
+
+  // to:
+  // a0:    00 10 01 11  02 12 03 13
+  // a1:    20 30 21 31  22 32 23 33
+  // a4:    04 14 05 15  06 16 07 17
+  // a5:    24 34 25 35  26 36 27 37
+  const __m128i a0 = _mm_unpacklo_epi16(in[0], in[1]);
+  const __m128i a1 = _mm_unpacklo_epi16(in[2], in[3]);
+  const __m128i a4 = _mm_unpackhi_epi16(in[0], in[1]);
+  const __m128i a5 = _mm_unpackhi_epi16(in[2], in[3]);
+
+  // Unpack 32 bit elements resulting in:
+  // b0: 00 10 20 30  01 11 21 31
+  // b2: 04 14 24 34  05 15 25 35
+  // b4: 02 12 22 32  03 13 23 33
+  // b6: 06 16 26 36  07 17 27 37
+  const __m128i b0 = _mm_unpacklo_epi32(a0, a1);
+  const __m128i b2 = _mm_unpacklo_epi32(a4, a5);
+  const __m128i b4 = _mm_unpackhi_epi32(a0, a1);
+  const __m128i b6 = _mm_unpackhi_epi32(a4, a5);
+
+  // Unpack 64 bit elements resulting in:
+  // out[0]: 00 10 20 30  XX XX XX XX
+  // out[1]: 01 11 21 31  XX XX XX XX
+  // out[2]: 02 12 22 32  XX XX XX XX
+  // out[3]: 03 13 23 33  XX XX XX XX
+  // out[4]: 04 14 24 34  XX XX XX XX
+  // out[5]: 05 15 25 35  XX XX XX XX
+  // out[6]: 06 16 26 36  XX XX XX XX
+  // out[7]: 07 17 27 37  XX XX XX XX
+  const __m128i zeros = _mm_setzero_si128();
+  out[0] = _mm_unpacklo_epi64(b0, zeros);
+  out[1] = _mm_unpackhi_epi64(b0, zeros);
+  out[2] = _mm_unpacklo_epi64(b4, zeros);
+  out[3] = _mm_unpackhi_epi64(b4, zeros);
+  out[4] = _mm_unpacklo_epi64(b2, zeros);
+  out[5] = _mm_unpackhi_epi64(b2, zeros);
+  out[6] = _mm_unpacklo_epi64(b6, zeros);
+  out[7] = _mm_unpackhi_epi64(b6, zeros);
+}
+
+LIBGAV1_ALWAYS_INLINE void Transpose8x8_U16(const __m128i* const in,
+                                            __m128i* const out) {
   // Unpack 16 bit elements. Goes from:
   // in[0]: 00 01 02 03  04 05 06 07
   // in[1]: 10 11 12 13  14 15 16 17
diff --git a/libgav1/src/loop_filter_mask.cc b/libgav1/src/loop_filter_mask.cc
index 9304288..aaae2c7 100644
--- a/libgav1/src/loop_filter_mask.cc
+++ b/libgav1/src/loop_filter_mask.cc
@@ -10,6 +10,9 @@
 
 namespace libgav1 {
 
+// static.
+constexpr BitMaskSet LoopFilterMask::kPredictionModeDeltasMask;
+
 bool LoopFilterMask::Reset(int width, int height) {
   num_64x64_blocks_per_row_ = DivideBy64(width + 63);
   num_64x64_blocks_per_column_ = DivideBy64(height + 63);
@@ -31,10 +34,10 @@
   return true;
 }
 
-bool LoopFilterMask::Build(
+void LoopFilterMask::Build(
     const ObuSequenceHeader& sequence_header,
     const ObuFrameHeader& frame_header, int tile_group_start,
-    int tile_group_end, BlockParametersHolder* const block_parameters_holder,
+    int tile_group_end, const BlockParametersHolder& block_parameters_holder,
     const Array2D<TransformSize>& inter_transform_sizes) {
   for (int tile_number = tile_group_start; tile_number <= tile_group_end;
        ++tile_number) {
@@ -59,128 +62,124 @@
           (plane == kPlaneY) ? 0 : sequence_header.color_config.subsampling_x;
       const int8_t subsampling_y =
           (plane == kPlaneY) ? 0 : sequence_header.color_config.subsampling_y;
-      const int plane_width =
-          RightShiftWithRounding(frame_header.width, subsampling_x);
-      const int plane_height =
-          RightShiftWithRounding(frame_header.height, subsampling_y);
       const int vertical_step = 1 << subsampling_y;
       const int horizontal_step = 1 << subsampling_x;
 
-      // Build bit masks for vertical edges.
-      for (int row4x4 = GetDeblockPosition(row4x4_start, subsampling_y);
-           row4x4 < row4x4_end &&
-           MultiplyBy4(row4x4 >> subsampling_y) < plane_height;
-           row4x4 += vertical_step) {
-        if (column4x4_start == 0) break;  // Do not filter frame boundary.
-        const int column4x4 =
-            GetDeblockPosition(column4x4_start, subsampling_x);
-        const BlockParameters& bp =
-            *block_parameters_holder->Find(row4x4, column4x4);
-        const uint8_t vertical_level =
-            bp.deblock_filter_level[plane][kLoopFilterTypeVertical];
-        const BlockParameters& bp_left =
-            *block_parameters_holder->Find(row4x4, column4x4 - horizontal_step);
-        const uint8_t left_level =
-            bp_left.deblock_filter_level[plane][kLoopFilterTypeVertical];
-        const int unit_id = DivideBy16(row4x4) * num_64x64_blocks_per_row_ +
-                            DivideBy16(column4x4);
-        const int row = row4x4 % kNum4x4InLoopFilterMaskUnit;
-        const int column = column4x4 % kNum4x4InLoopFilterMaskUnit;
-        const int shift = LoopFilterMask::GetShift(row, column);
-        const int index = LoopFilterMask::GetIndex(row);
-        const auto mask = static_cast<uint64_t>(1) << shift;
-        // Tile boundary must be coding block boundary. So we don't have to
-        // check (!left_skip || !skip || is_vertical_border).
-        if (vertical_level != 0 || left_level != 0) {
-          const TransformSize tx_size = GetTransformSize(
-              frame_header.segmentation.lossless[bp.segment_id], bp.size,
-              static_cast<Plane>(plane),
-              inter_transform_sizes[row4x4][column4x4], subsampling_x,
-              subsampling_y);
-          const TransformSize left_tx_size = GetTransformSize(
-              frame_header.segmentation.lossless[bp_left.segment_id],
-              bp_left.size, static_cast<Plane>(plane),
-              inter_transform_sizes[row4x4][column4x4 - horizontal_step],
-              subsampling_x, subsampling_y);
-          // 0: 4x4, 1: 8x8, 2: 16x16.
-          const int transform_size_id =
-              std::min({kTransformWidthLog2[tx_size] - 2,
-                        kTransformWidthLog2[left_tx_size] - 2, 2});
-          SetLeft(mask, unit_id, plane, transform_size_id, index);
-          const uint8_t current_level =
-              (vertical_level == 0) ? left_level : vertical_level;
-          SetLevel(current_level, unit_id, plane, kLoopFilterTypeVertical,
-                   LoopFilterMask::GetLevelOffset(row, column));
+      // Build bit masks for vertical edges (except the frame boundary).
+      if (column4x4_start != 0) {
+        const int plane_height =
+            RightShiftWithRounding(frame_header.height, subsampling_y);
+        const int row4x4_limit =
+            std::min(row4x4_end, DivideBy4(plane_height + 3) << subsampling_y);
+        const int vertical_level_index =
+            kDeblockFilterLevelIndex[plane][kLoopFilterTypeVertical];
+        for (int row4x4 = GetDeblockPosition(row4x4_start, subsampling_y);
+             row4x4 < row4x4_limit; row4x4 += vertical_step) {
+          const int column4x4 =
+              GetDeblockPosition(column4x4_start, subsampling_x);
+          const BlockParameters& bp =
+              *block_parameters_holder.Find(row4x4, column4x4);
+          const uint8_t vertical_level =
+              bp.deblock_filter_level[vertical_level_index];
+          const BlockParameters& bp_left = *block_parameters_holder.Find(
+              row4x4, column4x4 - horizontal_step);
+          const uint8_t left_level =
+              bp_left.deblock_filter_level[vertical_level_index];
+          const int unit_id = DivideBy16(row4x4) * num_64x64_blocks_per_row_ +
+                              DivideBy16(column4x4);
+          const int row = row4x4 % kNum4x4InLoopFilterMaskUnit;
+          const int column = column4x4 % kNum4x4InLoopFilterMaskUnit;
+          const int shift = LoopFilterMask::GetShift(row, column);
+          const int index = LoopFilterMask::GetIndex(row);
+          const auto mask = static_cast<uint64_t>(1) << shift;
+          // Tile boundary must be coding block boundary. So we don't have to
+          // check (!left_skip || !skip || is_vertical_border).
+          if (vertical_level != 0 || left_level != 0) {
+            assert(inter_transform_sizes[row4x4] != nullptr);
+            const TransformSize tx_size =
+                (plane == kPlaneY) ? inter_transform_sizes[row4x4][column4x4]
+                                   : bp.uv_transform_size;
+            const TransformSize left_tx_size =
+                (plane == kPlaneY)
+                    ? inter_transform_sizes[row4x4][column4x4 - horizontal_step]
+                    : bp_left.uv_transform_size;
+            const LoopFilterTransformSizeId transform_size_id =
+                GetTransformSizeIdWidth(tx_size, left_tx_size);
+            SetLeft(mask, unit_id, plane, transform_size_id, index);
+            const uint8_t current_level =
+                (vertical_level == 0) ? left_level : vertical_level;
+            SetLevel(current_level, unit_id, plane, kLoopFilterTypeVertical,
+                     LoopFilterMask::GetLevelOffset(row, column));
+          }
         }
       }
 
-      // Build bit masks for horizontal edges.
-      for (int column4x4 = GetDeblockPosition(column4x4_start, subsampling_x);
-           column4x4 < column4x4_end &&
-           MultiplyBy4(column4x4 >> subsampling_x) < plane_width;
-           column4x4 += horizontal_step) {
-        if (row4x4_start == 0) break;  // Do not filter frame boundary.
-        const int row4x4 = GetDeblockPosition(row4x4_start, subsampling_y);
-        const BlockParameters& bp =
-            *block_parameters_holder->Find(row4x4, column4x4);
-        const uint8_t horizontal_level =
-            bp.deblock_filter_level[plane][kLoopFilterTypeHorizontal];
-        const BlockParameters& bp_top =
-            *block_parameters_holder->Find(row4x4 - vertical_step, column4x4);
-        const uint8_t top_level =
-            bp_top.deblock_filter_level[plane][kLoopFilterTypeHorizontal];
-        const int unit_id = DivideBy16(row4x4) * num_64x64_blocks_per_row_ +
-                            DivideBy16(column4x4);
-        const int row = row4x4 % kNum4x4InLoopFilterMaskUnit;
-        const int column = column4x4 % kNum4x4InLoopFilterMaskUnit;
-        const int shift = LoopFilterMask::GetShift(row, column);
-        const int index = LoopFilterMask::GetIndex(row);
-        const auto mask = static_cast<uint64_t>(1) << shift;
-        // Tile boundary must be coding block boundary. So we don't have to
-        // check (!top_skip || !skip || is_horizontal_border).
-        if (horizontal_level != 0 || top_level != 0) {
-          const TransformSize tx_size = GetTransformSize(
-              frame_header.segmentation.lossless[bp.segment_id], bp.size,
-              static_cast<Plane>(plane),
-              inter_transform_sizes[row4x4][column4x4], subsampling_x,
-              subsampling_y);
-          const TransformSize top_tx_size = GetTransformSize(
-              frame_header.segmentation.lossless[bp_top.segment_id],
-              bp_top.size, static_cast<Plane>(plane),
-              inter_transform_sizes[row4x4 - vertical_step][column4x4],
-              subsampling_x, subsampling_y);
-          // 0: 4x4, 1: 8x8, 2: 16x16.
-          const int transform_size_id =
-              std::min({kTransformHeightLog2[tx_size] - 2,
-                        kTransformHeightLog2[top_tx_size] - 2, 2});
-          SetTop(mask, unit_id, plane, transform_size_id, index);
-          const uint8_t current_level =
-              (horizontal_level == 0) ? top_level : horizontal_level;
-          SetLevel(current_level, unit_id, plane, kLoopFilterTypeHorizontal,
-                   LoopFilterMask::GetLevelOffset(row, column));
+      // Build bit masks for horizontal edges (except the frame boundary).
+      if (row4x4_start != 0) {
+        const int plane_width =
+            RightShiftWithRounding(frame_header.width, subsampling_x);
+        const int column4x4_limit = std::min(
+            column4x4_end, DivideBy4(plane_width + 3) << subsampling_y);
+        const int horizontal_level_index =
+            kDeblockFilterLevelIndex[plane][kLoopFilterTypeHorizontal];
+        for (int column4x4 = GetDeblockPosition(column4x4_start, subsampling_x);
+             column4x4 < column4x4_limit; column4x4 += horizontal_step) {
+          const int row4x4 = GetDeblockPosition(row4x4_start, subsampling_y);
+          const BlockParameters& bp =
+              *block_parameters_holder.Find(row4x4, column4x4);
+          const uint8_t horizontal_level =
+              bp.deblock_filter_level[horizontal_level_index];
+          const BlockParameters& bp_top =
+              *block_parameters_holder.Find(row4x4 - vertical_step, column4x4);
+          const uint8_t top_level =
+              bp_top.deblock_filter_level[horizontal_level_index];
+          const int unit_id = DivideBy16(row4x4) * num_64x64_blocks_per_row_ +
+                              DivideBy16(column4x4);
+          const int row = row4x4 % kNum4x4InLoopFilterMaskUnit;
+          const int column = column4x4 % kNum4x4InLoopFilterMaskUnit;
+          const int shift = LoopFilterMask::GetShift(row, column);
+          const int index = LoopFilterMask::GetIndex(row);
+          const auto mask = static_cast<uint64_t>(1) << shift;
+          // Tile boundary must be coding block boundary. So we don't have to
+          // check (!top_skip || !skip || is_horizontal_border).
+          if (horizontal_level != 0 || top_level != 0) {
+            assert(inter_transform_sizes[row4x4] != nullptr);
+            const TransformSize tx_size =
+                (plane == kPlaneY) ? inter_transform_sizes[row4x4][column4x4]
+                                   : bp.uv_transform_size;
+            const TransformSize top_tx_size =
+                (plane == kPlaneY)
+                    ? inter_transform_sizes[row4x4 - vertical_step][column4x4]
+                    : bp_top.uv_transform_size;
+            const LoopFilterTransformSizeId transform_size_id =
+                static_cast<LoopFilterTransformSizeId>(
+                    std::min({kTransformHeightLog2[tx_size] - 2,
+                              kTransformHeightLog2[top_tx_size] - 2, 2}));
+            SetTop(mask, unit_id, plane, transform_size_id, index);
+            const uint8_t current_level =
+                (horizontal_level == 0) ? top_level : horizontal_level;
+            SetLevel(current_level, unit_id, plane, kLoopFilterTypeHorizontal,
+                     LoopFilterMask::GetLevelOffset(row, column));
+          }
         }
       }
     }
   }
-  // Check bit masks.
-  for (int i = 0; i < num_64x64_blocks_; ++i) {
-    if (!IsValid(i)) return false;
-  }
-
-  return true;
+  assert(IsValid());
 }
 
-// Loop filter masks at different transform sizes should be mutually exclusive.
-bool LoopFilterMask::IsValid(int mask_id) const {
-  for (int plane = 0; plane < kMaxPlanes; ++plane) {
-    for (int i = 0; i < kNumTransformSizesLoopFilter; ++i) {
-      for (int j = i + 1; j < kNumTransformSizesLoopFilter; ++j) {
-        for (int k = 0; k < kNumLoopFilterMasks; ++k) {
-          if ((loop_filter_masks_[mask_id].left[plane][i][k] &
-               loop_filter_masks_[mask_id].left[plane][j][k]) != 0 ||
-              (loop_filter_masks_[mask_id].top[plane][i][k] &
-               loop_filter_masks_[mask_id].top[plane][j][k]) != 0) {
-            return false;
+bool LoopFilterMask::IsValid() const {
+  for (int mask_id = 0; mask_id < num_64x64_blocks_; ++mask_id) {
+    for (int plane = 0; plane < kMaxPlanes; ++plane) {
+      for (int i = 0; i < kNumLoopFilterTransformSizeIds; ++i) {
+        for (int j = i + 1; j < kNumLoopFilterTransformSizeIds; ++j) {
+          for (int k = 0; k < kNumLoopFilterMasks; ++k) {
+            if ((loop_filter_masks_[mask_id].left[plane][i][k] &
+                 loop_filter_masks_[mask_id].left[plane][j][k]) != 0 ||
+                (loop_filter_masks_[mask_id].top[plane][i][k] &
+                 loop_filter_masks_[mask_id].top[plane][j][k]) != 0) {
+              return false;
+            }
           }
         }
       }
diff --git a/libgav1/src/loop_filter_mask.h b/libgav1/src/loop_filter_mask.h
index eca4938..472c95f 100644
--- a/libgav1/src/loop_filter_mask.h
+++ b/libgav1/src/loop_filter_mask.h
@@ -10,6 +10,7 @@
 #include "src/dsp/dsp.h"
 #include "src/obu_parser.h"
 #include "src/utils/array_2d.h"
+#include "src/utils/bit_mask_set.h"
 #include "src/utils/block_parameters_holder.h"
 #include "src/utils/common.h"
 #include "src/utils/constants.h"
@@ -25,9 +26,10 @@
   // 4x4 blocks. It requires kNumLoopFilterMasks = 4 uint64_t to represent them.
   struct Data : public Allocable {
     uint8_t level[kMaxPlanes][kNumLoopFilterTypes][kNum4x4In64x64];
-    uint64_t left[kMaxPlanes][kNumTransformSizesLoopFilter]
+    uint64_t left[kMaxPlanes][kNumLoopFilterTransformSizeIds]
                  [kNumLoopFilterMasks];
-    uint64_t top[kMaxPlanes][kNumTransformSizesLoopFilter][kNumLoopFilterMasks];
+    uint64_t top[kMaxPlanes][kNumLoopFilterTransformSizeIds]
+                [kNumLoopFilterMasks];
   };
 
   LoopFilterMask() = default;
@@ -50,9 +52,10 @@
   // Before this function call, bit masks of transform edges other than those
   // on tile boundaries are built together with tile decoding, in
   // Tile::BuildBitMask().
-  bool Build(const ObuSequenceHeader& sequence_header,
+  void Build(const ObuSequenceHeader& sequence_header,
              const ObuFrameHeader& frame_header, int tile_group_start,
-             int tile_group_end, BlockParametersHolder* block_parameters_holder,
+             int tile_group_end,
+             const BlockParametersHolder& block_parameters_holder,
              const Array2D<TransformSize>& inter_transform_sizes);
 
   uint8_t GetLevel(int mask_id, int plane, LoopFilterType type,
@@ -60,24 +63,26 @@
     return loop_filter_masks_[mask_id].level[plane][type][offset];
   }
 
-  uint64_t GetLeft(int mask_id, int plane, int tx_size_id, int index) const {
+  uint64_t GetLeft(int mask_id, int plane, LoopFilterTransformSizeId tx_size_id,
+                   int index) const {
     return loop_filter_masks_[mask_id].left[plane][tx_size_id][index];
   }
 
-  uint64_t GetTop(int mask_id, int plane, int tx_size_id, int index) const {
+  uint64_t GetTop(int mask_id, int plane, LoopFilterTransformSizeId tx_size_id,
+                  int index) const {
     return loop_filter_masks_[mask_id].top[plane][tx_size_id][index];
   }
 
   int num_64x64_blocks_per_row() const { return num_64x64_blocks_per_row_; }
 
-  void SetLeft(uint64_t new_mask, int mask_id, int plane, int transform_size_id,
-               int index) {
+  void SetLeft(uint64_t new_mask, int mask_id, int plane,
+               LoopFilterTransformSizeId transform_size_id, int index) {
     loop_filter_masks_[mask_id].left[plane][transform_size_id][index] |=
         new_mask;
   }
 
-  void SetTop(uint64_t new_mask, int mask_id, int plane, int transform_size_id,
-              int index) {
+  void SetTop(uint64_t new_mask, int mask_id, int plane,
+              LoopFilterTransformSizeId transform_size_id, int index) {
     loop_filter_masks_[mask_id].top[plane][transform_size_id][index] |=
         new_mask;
   }
@@ -99,50 +104,68 @@
     return (row4x4 << 4) | column4x4;
   }
 
-  // 7.14.5.
-  static uint8_t GetDeblockFilterLevel(const ObuFrameHeader& frame_header,
-                                       const BlockParameters& bp, Plane plane,
-                                       int pass,
-                                       const int8_t delta_lf[kFrameLfCount]) {
-    const int filter_level_delta = (plane == kPlaneY) ? pass : plane + 1;
-    const int delta = frame_header.delta_lf.multi ? delta_lf[filter_level_delta]
-                                                  : delta_lf[0];
-    // TODO(chengchen): Could we reduce number of clips?
-    int level =
-        Clip3(frame_header.loop_filter.level[filter_level_delta] + delta, 0,
-              kMaxLoopFilterValue);
-    const auto feature = static_cast<SegmentFeature>(
-        kSegmentFeatureLoopFilterYVertical + filter_level_delta);
-    if (frame_header.segmentation.FeatureActive(bp.segment_id, feature)) {
-      level = Clip3(
-          level +
-              frame_header.segmentation.feature_data[bp.segment_id][feature],
-          0, kMaxLoopFilterValue);
-    }
-    if (frame_header.loop_filter.delta_enabled) {
-      const int shift = level >> 5;
-      if (bp.reference_frame[0] == kReferenceFrameIntra) {
-        level += LeftShift(
-            frame_header.loop_filter.ref_deltas[kReferenceFrameIntra], shift);
-      } else {
-        const int mode_id = kPredictionModeDeltasLookup[bp.y_mode];
-        level += LeftShift(
-            frame_header.loop_filter.ref_deltas[bp.reference_frame[0]] +
-                frame_header.loop_filter.mode_deltas[mode_id],
-            shift);
-      }
-      level = Clip3(level, 0, kMaxLoopFilterValue);
-    }
-    return level;
+  static constexpr int GetModeId(PredictionMode mode) {
+    return static_cast<int>(kPredictionModeDeltasMask.Contains(mode));
   }
 
-  bool IsValid(int mask_id) const;
+  // 7.14.5.
+  static void ComputeDeblockFilterLevels(
+      const ObuFrameHeader& frame_header, int segment_id, int level_index,
+      const int8_t delta_lf[kFrameLfCount],
+      uint8_t deblock_filter_levels[kNumReferenceFrameTypes][2]) {
+    const int delta = delta_lf[frame_header.delta_lf.multi ? level_index : 0];
+    uint8_t level = Clip3(frame_header.loop_filter.level[level_index] + delta,
+                          0, kMaxLoopFilterValue);
+    const auto feature = static_cast<SegmentFeature>(
+        kSegmentFeatureLoopFilterYVertical + level_index);
+    level = Clip3(
+        level + frame_header.segmentation.feature_data[segment_id][feature], 0,
+        kMaxLoopFilterValue);
+    if (!frame_header.loop_filter.delta_enabled) {
+      static_assert(sizeof(deblock_filter_levels[0][0]) == 1, "");
+      memset(deblock_filter_levels, level, kNumReferenceFrameTypes * 2);
+      return;
+    }
+    assert(frame_header.loop_filter.delta_enabled);
+    const int shift = level >> 5;
+    deblock_filter_levels[kReferenceFrameIntra][0] = Clip3(
+        level +
+            LeftShift(frame_header.loop_filter.ref_deltas[kReferenceFrameIntra],
+                      shift),
+        0, kMaxLoopFilterValue);
+    // deblock_filter_levels[kReferenceFrameIntra][1] is never used. So it does
+    // not have to be populated.
+    for (int reference_frame = kReferenceFrameIntra + 1;
+         reference_frame < kNumReferenceFrameTypes; ++reference_frame) {
+      for (int mode_id = 0; mode_id < 2; ++mode_id) {
+        deblock_filter_levels[reference_frame][mode_id] = Clip3(
+            level +
+                LeftShift(frame_header.loop_filter.ref_deltas[reference_frame] +
+                              frame_header.loop_filter.mode_deltas[mode_id],
+                          shift),
+            0, kMaxLoopFilterValue);
+      }
+    }
+  }
 
  private:
   std::unique_ptr<Data[]> loop_filter_masks_;
   int num_64x64_blocks_ = -1;
   int num_64x64_blocks_per_row_;
   int num_64x64_blocks_per_column_;
+
+  // Mask used to determine the index for mode_deltas lookup.
+  static constexpr BitMaskSet kPredictionModeDeltasMask{
+      BitMaskSet(kPredictionModeNearestMv, kPredictionModeNearMv,
+                 kPredictionModeNewMv, kPredictionModeNearestNearestMv,
+                 kPredictionModeNearNearMv, kPredictionModeNearestNewMv,
+                 kPredictionModeNewNearestMv, kPredictionModeNearNewMv,
+                 kPredictionModeNewNearMv, kPredictionModeNewNewMv)};
+
+  // Validates that the loop filter masks at different transform sizes are
+  // mutually exclusive. Only used in an assert. This function will not be used
+  // in optimized builds.
+  bool IsValid() const;
 };
 
 }  // namespace libgav1
diff --git a/libgav1/src/loop_restoration_info.cc b/libgav1/src/loop_restoration_info.cc
index e2b570e..cfeebed 100644
--- a/libgav1/src/loop_restoration_info.cc
+++ b/libgav1/src/loop_restoration_info.cc
@@ -6,6 +6,7 @@
 #include <cstddef>
 #include <cstdint>
 #include <memory>
+#include <new>
 
 #include "src/utils/common.h"
 #include "src/utils/logging.h"
@@ -24,37 +25,42 @@
 
 bool LoopRestorationInfo::Allocate() {
   const int num_planes = is_monochrome_ ? kMaxPlanesMonochrome : kMaxPlanes;
-  loop_restoration_info_.reserve(num_planes);
-  if (loop_restoration_info_.capacity() < static_cast<size_t>(num_planes)) {
-    return false;
-  }
-  loop_restoration_info_.resize(num_planes);
+  int total_num_units = 0;
   for (int plane = kPlaneY; plane < num_planes; ++plane) {
     if (loop_restoration_.type[plane] == kLoopRestorationTypeNone) {
       plane_needs_filtering_[plane] = false;
       continue;
     }
     plane_needs_filtering_[plane] = true;
-    const int tile_width = plane == kPlaneY
-                               ? width_
-                               : RightShiftWithRounding(width_, subsampling_x_);
-    const int tile_height =
-        plane == kPlaneY ? height_
-                         : RightShiftWithRounding(height_, subsampling_y_);
-    num_horizontal_units_[plane] = std::max(
-        1, (tile_width + DivideBy2(loop_restoration_.unit_size[plane])) /
-               loop_restoration_.unit_size[plane]);
-    num_vertical_units_[plane] = std::max(
-        1, (tile_height + DivideBy2(loop_restoration_.unit_size[plane])) /
-               loop_restoration_.unit_size[plane]);
+    const int width = (plane == kPlaneY)
+                          ? width_
+                          : RightShiftWithRounding(width_, subsampling_x_);
+    const int height = (plane == kPlaneY)
+                           ? height_
+                           : RightShiftWithRounding(height_, subsampling_y_);
+    num_horizontal_units_[plane] =
+        std::max(1, (width + DivideBy2(loop_restoration_.unit_size[plane])) /
+                        loop_restoration_.unit_size[plane]);
+    num_vertical_units_[plane] =
+        std::max(1, (height + DivideBy2(loop_restoration_.unit_size[plane])) /
+                        loop_restoration_.unit_size[plane]);
     num_units_[plane] =
         num_horizontal_units_[plane] * num_vertical_units_[plane];
-    loop_restoration_info_[plane].reserve(num_units_[plane]);
-    if (loop_restoration_info_[plane].capacity() <
-        static_cast<size_t>(num_units_[plane])) {
-      return false;
+    total_num_units += num_units_[plane];
+  }
+  // Allocate the RestorationUnitInfo arrays for all planes in a single heap
+  // allocation and divide up the buffer into arrays of the right sizes.
+  loop_restoration_info_buffer_.reset(new (std::nothrow)
+                                          RestorationUnitInfo[total_num_units]);
+  if (loop_restoration_info_buffer_ == nullptr) return false;
+  RestorationUnitInfo* loop_restoration_info =
+      loop_restoration_info_buffer_.get();
+  for (int plane = kPlaneY; plane < num_planes; ++plane) {
+    if (loop_restoration_.type[plane] == kLoopRestorationTypeNone) {
+      continue;
     }
-    loop_restoration_info_[plane].resize(num_units_[plane]);
+    loop_restoration_info_[plane] = loop_restoration_info;
+    loop_restoration_info += num_units_[plane];
   }
   return true;
 }
@@ -101,9 +107,9 @@
     std::array<RestorationUnitInfo, kMaxPlanes>* const reference_unit_info) {
   LoopRestorationType unit_restoration_type = kLoopRestorationTypeNone;
   if (loop_restoration_.type[plane] == kLoopRestorationTypeSwitchable) {
-    unit_restoration_type = kBitstreamRestorationTypeMap[reader->ReadSymbol(
-        symbol_decoder_context->restoration_type_cdf,
-        kRestorationTypeSymbolCount)];
+    unit_restoration_type = kBitstreamRestorationTypeMap
+        [reader->ReadSymbol<kRestorationTypeSymbolCount>(
+            symbol_decoder_context->restoration_type_cdf)];
   } else if (loop_restoration_.type[plane] == kLoopRestorationTypeWiener) {
     const bool use_wiener =
         reader->ReadSymbol(symbol_decoder_context->use_wiener_cdf);
diff --git a/libgav1/src/loop_restoration_info.h b/libgav1/src/loop_restoration_info.h
index 5792e69..61a1cfd 100644
--- a/libgav1/src/loop_restoration_info.h
+++ b/libgav1/src/loop_restoration_info.h
@@ -72,7 +72,11 @@
   int num_units(Plane plane) const { return num_units_[plane]; }
 
  private:
-  std::vector<std::vector<RestorationUnitInfo>> loop_restoration_info_;
+  // If plane_needs_filtering_[plane] is true, loop_restoration_info_[plane]
+  // points to an array of num_units_[plane] elements.
+  RestorationUnitInfo* loop_restoration_info_[kMaxPlanes];
+  // Owns the memory that loop_restoration_info_[plane] points to.
+  std::unique_ptr<RestorationUnitInfo[]> loop_restoration_info_buffer_;
   bool plane_needs_filtering_[kMaxPlanes];
   const LoopRestoration& loop_restoration_;
   uint32_t width_;
diff --git a/libgav1/src/motion_vector.cc b/libgav1/src/motion_vector.cc
index 2992896..f807da8 100644
--- a/libgav1/src/motion_vector.cc
+++ b/libgav1/src/motion_vector.cc
@@ -6,17 +6,18 @@
 #include <cstdlib>
 #include <memory>
 
+#include "src/utils/bit_mask_set.h"
 #include "src/utils/common.h"
 #include "src/utils/logging.h"
 
 namespace libgav1 {
 namespace {
 
-const int kMvBorder = 128;
-const int kProjectionMvClamp = 16383;
-const int kProjectionMvMaxVerticalOffset = 0;
-const int kProjectionMvMaxHorizontalOffset = 8;
-const int kInvalidMvValue = -32768;
+constexpr int kMvBorder = 128;
+constexpr int kProjectionMvClamp = 16383;
+constexpr int kProjectionMvMaxVerticalOffset = 0;
+constexpr int kProjectionMvMaxHorizontalOffset = 8;
+constexpr int kInvalidMvValue = -32768;
 
 // Applies the sign of |sign_value| to |value| (and does so without a branch).
 int ApplySign(int value, int sign_value) {
@@ -121,12 +122,12 @@
   }
 }
 
-inline bool HasNewMv(PredictionMode mode) {
-  return mode == kPredictionModeNewMv || mode == kPredictionModeNewNewMv ||
-         mode == kPredictionModeNearNewMv || mode == kPredictionModeNewNearMv ||
-         mode == kPredictionModeNearestNewMv ||
-         mode == kPredictionModeNewNearestMv;
-}
+constexpr BitMaskSet kPredictionModeNewMvMask(kPredictionModeNewMv,
+                                              kPredictionModeNewNewMv,
+                                              kPredictionModeNearNewMv,
+                                              kPredictionModeNewNearMv,
+                                              kPredictionModeNearestNewMv,
+                                              kPredictionModeNewNearestMv);
 
 // 7.10.2.8.
 void SearchStackSingle(const Tile::Block& block, int row, int column, int index,
@@ -148,7 +149,7 @@
     candidate_mv = bp.mv[index];
   }
   LowerMvPrecision(block, candidate_mv.mv);
-  *found_new_mv |= HasNewMv(candidate_mode);
+  *found_new_mv |= kPredictionModeNewMvMask.Contains(candidate_mode);
   *found_match = true;
   for (int i = 0; i < *num_mv_found; ++i) {
     if (ref_mv_stack[i].mv[0] == candidate_mv) {
@@ -183,7 +184,7 @@
     }
     LowerMvPrecision(block, candidate_mv[i].mv);
   }
-  *found_new_mv |= HasNewMv(candidate_mode);
+  *found_new_mv |= kPredictionModeNewMvMask.Contains(candidate_mode);
   *found_match = true;
   for (int i = 0; i < *num_mv_found; ++i) {
     if (ref_mv_stack[i].mv[0] == candidate_mv[0] &&
@@ -717,6 +718,21 @@
           sequence_header.enable_order_hint, sequence_header.order_hint_bits) *
       dst_sign;
   if (std::abs(reference_to_current_with_sign) > kMaxFrameDistance) return true;
+  // Index 0 of these two arrays are never used.
+  int reference_offsets[kNumReferenceFrameTypes];
+  bool skip_reference[kNumReferenceFrameTypes];
+  for (int source_reference_type = kReferenceFrameLast;
+       source_reference_type <= kNumInterReferenceFrameTypes;
+       ++source_reference_type) {
+    const int reference_offset = GetRelativeDistance(
+        current_frame.order_hint(source),
+        source_frame->order_hint(
+            static_cast<ReferenceFrameType>(source_reference_type)),
+        sequence_header.enable_order_hint, sequence_header.order_hint_bits);
+    skip_reference[source_reference_type] =
+        std::abs(reference_offset) > kMaxFrameDistance || reference_offset <= 0;
+    reference_offsets[source_reference_type] = reference_offset;
+  }
   // The column range has to be offset by kProjectionMvMaxHorizontalOffset since
   // coordinates in that range could end up being position_x8 because of
   // projection.
@@ -729,15 +745,11 @@
     for (int x8 = adjusted_x8_start; x8 < adjusted_x8_end; ++x8) {
       const ReferenceFrameType source_reference =
           *source_frame->motion_field_reference_frame(y8, x8);
-      if (source_reference <= kReferenceFrameIntra) continue;
-      const int reference_offset = GetRelativeDistance(
-          current_frame.order_hint(source),
-          source_frame->order_hint(source_reference),
-          sequence_header.enable_order_hint, sequence_header.order_hint_bits);
-      if (std::abs(reference_offset) > kMaxFrameDistance ||
-          reference_offset <= 0) {
+      if (source_reference <= kReferenceFrameIntra ||
+          skip_reference[source_reference]) {
         continue;
       }
+      const int reference_offset = reference_offsets[source_reference];
       const MotionVector& mv = *source_frame->motion_field_mv(y8, x8);
       MotionVector projection_mv;
       GetMvProjectionNoClamp(mv, reference_to_current_with_sign,
diff --git a/libgav1/src/obu_parser.cc b/libgav1/src/obu_parser.cc
index d7873c2..644f1a5 100644
--- a/libgav1/src/obu_parser.cc
+++ b/libgav1/src/obu_parser.cc
@@ -10,6 +10,7 @@
 #include "src/buffer_pool.h"
 #include "src/decoder_impl.h"
 #include "src/motion_vector.h"
+#include "src/utils/common.h"
 #include "src/utils/logging.h"
 
 namespace libgav1 {
@@ -22,15 +23,20 @@
 // 5.9.16.
 // Find the smallest value of k such that block_size << k is greater than or
 // equal to target.
-inline int TileLog2(int block_size, int target) {
+//
+// NOTE: TileLog2(block_size, target) is equal to
+//   CeilLog2(ceil((double)target / block_size))
+// where the division is a floating-point number division. (This equality holds
+// even when |target| is equal to 0.) In the special case of block_size == 1,
+// TileLog2(1, target) is equal to CeilLog2(target).
+int TileLog2(int block_size, int target) {
   int k = 0;
   for (; (block_size << k) < target; ++k) {
   }
   return k;
 }
 
-inline void ParseBitStreamLevel(BitStreamLevel* const level,
-                                uint8_t level_bits) {
+void ParseBitStreamLevel(BitStreamLevel* const level, uint8_t level_bits) {
   level->major = kMinimumMajorBitstreamLevel + (level_bits >> 2);
   level->minor = level_bits & 3;
 }
@@ -396,6 +402,7 @@
   sequence_header.film_grain_params_present = static_cast<bool>(scratch);
   // TODO(wtc): Compare new sequence header with old sequence header.
   sequence_header_ = sequence_header;
+  has_sequence_header_ = true;
   // Section 6.4.1: It is a requirement of bitstream conformance that if
   // OperatingPointIdc is equal to 0, then obu_extension_flag is equal to 0 for
   // all OBUs that follow this sequence header until the next sequence header.
@@ -489,6 +496,15 @@
   } else {
     frame_header_.superres_scale_denominator = kSuperResScaleNumerator;
   }
+  assert(frame_header_.width != 0);
+  assert(frame_header_.height != 0);
+  // Check if multiplying upscaled_width by height would overflow.
+  assert(frame_header_.upscaled_width >= frame_header_.width);
+  if (frame_header_.upscaled_width > INT32_MAX / frame_header_.height) {
+    LIBGAV1_DLOG(ERROR, "Frame dimensions too big: width=%d height=%d.",
+                 frame_header_.width, frame_header_.height);
+    return false;
+  }
   frame_header_.columns4x4 = ((frame_header_.width + 7) >> 3) << 1;
   frame_header_.rows4x4 = ((frame_header_.height + 7) >> 3) << 1;
   return true;
@@ -1510,9 +1526,9 @@
   const int sb_max_tile_area = kMaxTileArea >> MultiplyBy2(sb_size);
   const int minlog2_tile_columns = TileLog2(sb_max_tile_width, sb_columns);
   const int maxlog2_tile_columns =
-      TileLog2(1, std::min(sb_columns, static_cast<int>(kMaxTileColumns)));
+      CeilLog2(std::min(sb_columns, static_cast<int>(kMaxTileColumns)));
   const int maxlog2_tile_rows =
-      TileLog2(1, std::min(sb_rows, static_cast<int>(kMaxTileRows)));
+      CeilLog2(std::min(sb_rows, static_cast<int>(kMaxTileRows)));
   const int min_log2_tiles = std::max(
       minlog2_tile_columns, TileLog2(sb_max_tile_area, sb_rows * sb_columns));
   int64_t scratch;
@@ -1534,7 +1550,11 @@
     if (sb_tile_width <= 0) return false;
     int i = 0;
     for (int sb_start = 0; sb_start < sb_columns; sb_start += sb_tile_width) {
-      if (i > kMaxTileColumns) return false;
+      if (i >= kMaxTileColumns) {
+        LIBGAV1_DLOG(ERROR,
+                     "tile_columns would be greater than kMaxTileColumns.");
+        return false;
+      }
       tile_info->tile_column_start[i++] = sb_start << sb_shift;
     }
     tile_info->tile_column_start[i] = frame_header_.columns4x4;
@@ -1557,7 +1577,10 @@
     if (sb_tile_height <= 0) return false;
     i = 0;
     for (int sb_start = 0; sb_start < sb_rows; sb_start += sb_tile_height) {
-      if (i > kMaxTileRows) return false;
+      if (i >= kMaxTileRows) {
+        LIBGAV1_DLOG(ERROR, "tile_rows would be greater than kMaxTileRows.");
+        return false;
+      }
       tile_info->tile_row_start[i++] = sb_start << sb_shift;
     }
     tile_info->tile_row_start[i] = frame_header_.rows4x4;
@@ -1565,13 +1588,19 @@
   } else {
     int widest_tile_sb = 1;
     int i = 0;
-    for (int sb_start = 0; sb_start < sb_columns && i < kMaxTileColumns; ++i) {
+    for (int sb_start = 0; sb_start < sb_columns; ++i) {
+      if (i >= kMaxTileColumns) {
+        LIBGAV1_DLOG(ERROR,
+                     "tile_columns would be greater than kMaxTileColumns.");
+        return false;
+      }
       tile_info->tile_column_start[i] = sb_start << sb_shift;
       const int max_width =
           std::min(sb_columns - sb_start, static_cast<int>(sb_max_tile_width));
       int sb_size;
       if (!bit_reader_->DecodeUniform(max_width, &sb_size)) {
         LIBGAV1_DLOG(ERROR, "Not enough bits.");
+        return false;
       }
       ++sb_size;
       widest_tile_sb = std::max(sb_size, widest_tile_sb);
@@ -1579,7 +1608,7 @@
     }
     tile_info->tile_column_start[i] = frame_header_.columns4x4;
     tile_info->tile_columns = i;
-    tile_info->tile_columns_log2 = TileLog2(1, tile_info->tile_columns);
+    tile_info->tile_columns_log2 = CeilLog2(tile_info->tile_columns);
 
     int max_tile_area_sb = sb_rows * sb_columns;
     if (min_log2_tiles > 0) max_tile_area_sb >>= min_log2_tiles + 1;
@@ -1587,26 +1616,32 @@
         std::max(max_tile_area_sb / widest_tile_sb, 1);
 
     i = 0;
-    for (int sb_start = 0; sb_start < sb_rows && i < kMaxTileRows; ++i) {
+    for (int sb_start = 0; sb_start < sb_rows; ++i) {
+      if (i >= kMaxTileRows) {
+        LIBGAV1_DLOG(ERROR, "tile_rows would be greater than kMaxTileRows.");
+        return false;
+      }
       tile_info->tile_row_start[i] = sb_start << sb_shift;
       const int max_height = std::min(sb_rows - sb_start, max_tile_height_sb);
       int sb_size;
       if (!bit_reader_->DecodeUniform(max_height, &sb_size)) {
         LIBGAV1_DLOG(ERROR, "Not enough bits.");
+        return false;
       }
       ++sb_size;
       sb_start += sb_size;
     }
     tile_info->tile_row_start[i] = frame_header_.rows4x4;
     tile_info->tile_rows = i;
-    tile_info->tile_rows_log2 = TileLog2(1, tile_info->tile_rows);
+    tile_info->tile_rows_log2 = CeilLog2(tile_info->tile_rows);
   }
   tile_info->tile_count = tile_info->tile_rows * tile_info->tile_columns;
   tile_info->context_update_id = 0;
-  if ((tile_info->tile_columns | tile_info->tile_rows) > 1) {
-    OBU_READ_LITERAL_OR_FAIL(tile_info->tile_columns_log2 +
-                             tile_info->tile_rows_log2);
-    tile_info->context_update_id = scratch;
+  const int tile_bits =
+      tile_info->tile_columns_log2 + tile_info->tile_rows_log2;
+  if (tile_bits != 0) {
+    OBU_READ_LITERAL_OR_FAIL(tile_bits);
+    tile_info->context_update_id = static_cast<int16_t>(scratch);
     if (tile_info->context_update_id >= tile_info->tile_count) {
       LIBGAV1_DLOG(ERROR, "Invalid context_update_tile_id (%d) >= %d.",
                    tile_info->context_update_id, tile_info->tile_count);
@@ -1814,8 +1849,9 @@
           continue;
         }
         const int index = sequence_header_.operating_point_idc[i];
-        if (index == 0 || (InTemporalLayer(index, temporal_ids_.back()) &&
-                           InSpatialLayer(index, spatial_ids_.back()))) {
+        if (index == 0 ||
+            (InTemporalLayer(index, obu_headers_.back().temporal_id) &&
+             InSpatialLayer(index, obu_headers_.back().spatial_id))) {
           OBU_READ_LITERAL_OR_FAIL(
               sequence_header_.decoder_model_info.buffer_removal_time_length);
           frame_header_.buffer_removal_time[i] = static_cast<uint32_t>(scratch);
@@ -1991,6 +2027,9 @@
 }
 
 bool ObuParser::ParseFrameHeader() {
+  // Section 6.8.1: It is a requirement of bitstream conformance that a
+  // sequence header OBU has been received before a frame header OBU.
+  if (!has_sequence_header_) return false;
   if (!ParseFrameParameters()) return false;
   if (frame_header_.show_existing_frame) return true;
   bool status = ParseTileInfoSyntax() && ParseQuantizerParameters() &&
@@ -2000,6 +2039,7 @@
       frame_header_.segmentation);
   status =
       ParseQuantizerIndexDeltaParameters() && ParseLoopFilterDeltaParameters();
+  if (!status) return false;
   ComputeSegmentLosslessAndQIndex();
   status = ParseLoopFilterParameters();
   if (!status) return false;
@@ -2029,12 +2069,6 @@
   size -= 2;
   const auto type = static_cast<MetadataType>(scratch);
   switch (type) {
-    case kMetadataTypePrivateData:
-      for (size_t i = 0; i < size; ++i) {
-        OBU_READ_LITERAL_OR_FAIL(8);
-        metadata_.private_data.push_back(scratch);
-      }
-      break;
     case kMetadataTypeHdrContentLightLevel:
       OBU_READ_LITERAL_OR_FAIL(16);
       metadata_.max_cll = scratch;
@@ -2069,11 +2103,11 @@
   if (tile_group.start != next_tile_group_start_ ||
       tile_group.start > tile_group.end ||
       tile_group.end >= frame_header_.tile_info.tile_count) {
-    LIBGAV1_DLOG(
-        ERROR,
-        "Invalid tile group start %d (expected %d), end %d, tile_count %d.",
-        tile_group.start, next_tile_group_start_, tile_group.end,
-        frame_header_.tile_info.tile_count);
+    LIBGAV1_DLOG(ERROR,
+                 "Invalid tile group start %d or end %d: expected tile group "
+                 "start %d, tile_count %d.",
+                 tile_group.start, tile_group.end, next_tile_group_start_,
+                 frame_header_.tile_info.tile_count);
     return false;
   }
   next_tile_group_start_ = tile_group.end + 1;
@@ -2093,7 +2127,10 @@
   const size_t start_offset = bit_reader_->byte_offset();
   const int tile_bits =
       tile_info->tile_columns_log2 + tile_info->tile_rows_log2;
-  tile_groups_.emplace_back();
+  if (!tile_groups_.emplace_back()) {
+    LIBGAV1_DLOG(ERROR, "Could not add an element to tile_groups_.");
+    return false;
+  }
   auto& tile_group = tile_groups_.back();
   if (tile_bits == 0) {
     tile_group.start = 0;
@@ -2116,7 +2153,7 @@
     SetTileDataOffset(size, 1, bytes_consumed_so_far);
     return true;
   }
-  if (types_.back() == kObuFrame) {
+  if (obu_headers_.back().type == kObuFrame) {
     // 6.10.1: If obu_type is equal to OBU_FRAME, it is a requirement of
     // bitstream conformance that the value of tile_start_and_end_present_flag
     // is equal to 0.
@@ -2139,13 +2176,14 @@
 }
 
 bool ObuParser::ParseHeader() {
+  ObuHeader obu_header;
   int64_t scratch = bit_reader_->ReadBit();
   if (scratch != 0) {
     LIBGAV1_DLOG(ERROR, "forbidden_bit is not zero.");
     return false;
   }
   OBU_READ_LITERAL_OR_FAIL(4);
-  types_.push_back(static_cast<libgav1::ObuType>(scratch));
+  obu_header.type = static_cast<libgav1::ObuType>(scratch);
   OBU_READ_BIT_OR_FAIL;
   const auto extension_flag = static_cast<bool>(scratch);
   OBU_READ_BIT_OR_FAIL;
@@ -2161,7 +2199,7 @@
     LIBGAV1_DLOG(ERROR, "obu_reserved_1bit is not zero.");
     return false;
   }
-  has_extension_.push_back(extension_flag);
+  obu_header.has_extension = extension_flag;
   if (extension_flag) {
     if (extension_disallowed_) {
       LIBGAV1_DLOG(ERROR,
@@ -2169,19 +2207,19 @@
       return false;
     }
     OBU_READ_LITERAL_OR_FAIL(3);
-    temporal_ids_.push_back(scratch);
+    obu_header.temporal_id = scratch;
     OBU_READ_LITERAL_OR_FAIL(2);
-    spatial_ids_.push_back(scratch);
+    obu_header.spatial_id = scratch;
     OBU_READ_LITERAL_OR_FAIL(3);  // reserved.
     if (scratch != 0) {
       LIBGAV1_DLOG(ERROR, "extension_header_reserved_3bits is not zero.");
       return false;
     }
   } else {
-    temporal_ids_.push_back(0);
-    spatial_ids_.push_back(0);
+    obu_header.temporal_id = 0;
+    obu_header.spatial_id = 0;
   }
-  return true;
+  return obu_headers_.push_back(obu_header);
 }
 
 #undef OBU_READ_UVLC_OR_FAIL
@@ -2203,10 +2241,7 @@
   size_t size = size_;
 
   // Clear everything except the sequence header.
-  types_.clear();
-  has_extension_.clear();
-  temporal_ids_.clear();
-  spatial_ids_.clear();
+  obu_headers_.clear();
   frame_header_ = {};
   tile_groups_.clear();
   next_tile_group_start_ = 0;
@@ -2238,18 +2273,17 @@
       return false;
     }
 
-    if (types_.back() != kObuSequenceHeader &&
-        types_.back() != kObuTemporalDelimiter &&
+    const ObuHeader& obu_header = obu_headers_.back();
+    const ObuType obu_type = obu_header.type;
+    if (obu_type != kObuSequenceHeader && obu_type != kObuTemporalDelimiter &&
+        has_sequence_header_ &&
         sequence_header_.operating_point_idc[kOperatingPoint] != 0 &&
-        has_extension_.back() &&
+        obu_header.has_extension &&
         (!InTemporalLayer(sequence_header_.operating_point_idc[kOperatingPoint],
-                          temporal_ids_.back()) ||
+                          obu_header.temporal_id) ||
          !InSpatialLayer(sequence_header_.operating_point_idc[kOperatingPoint],
-                         spatial_ids_.back()))) {
-      types_.pop_back();
-      spatial_ids_.pop_back();
-      temporal_ids_.pop_back();
-      has_extension_.pop_back();
+                         obu_header.spatial_id))) {
+      obu_headers_.pop_back();
       bit_reader_->SkipBytes(obu_size);
       data += bit_reader_->byte_offset();
       size -= bit_reader_->byte_offset();
@@ -2258,7 +2292,7 @@
 
     const size_t obu_start_position = bit_reader_->bit_offset();
     bool obu_skipped = false;
-    switch (types_.back()) {
+    switch (obu_type) {
       case kObuTemporalDelimiter:
         break;
       case kObuSequenceHeader:
@@ -2348,6 +2382,9 @@
         parsed_one_full_frame =
             (tile_groups_.back().end == frame_header_.tile_info.tile_count - 1);
         break;
+      case kObuTileList:
+        LIBGAV1_DLOG(ERROR, "Decoding of tile list OBUs is not supported.");
+        return false;
       case kObuPadding:
       // TODO(b/120903866): Fix ParseMetadata() and then invoke that for the
       // kObuMetadata case.
@@ -2356,11 +2393,14 @@
         obu_skipped = true;
         break;
       default:
-        LIBGAV1_DLOG(ERROR, "Unknown OBU type: %d.", types_.back());
-        return false;
+        // Skip reserved OBUs. Section 6.2.2: Reserved units are for future use
+        // and shall be ignored by AV1 decoder.
+        bit_reader_->SkipBytes(obu_size);
+        obu_skipped = true;
+        break;
     }
-    if (obu_size > 0 && !obu_skipped && types_.back() != kObuFrame &&
-        types_.back() != kObuTileGroup) {
+    if (obu_size > 0 && !obu_skipped && obu_type != kObuFrame &&
+        obu_type != kObuTileGroup) {
       const size_t parsed_obu_size_in_bits =
           bit_reader_->bit_offset() - obu_start_position;
       if (obu_size * 8 < parsed_obu_size_in_bits) {
@@ -2368,14 +2408,14 @@
             ERROR,
             "Parsed OBU size (%zu bits) is greater than expected OBU size "
             "(%zu bytes) obu_type: %d.",
-            parsed_obu_size_in_bits, obu_size, types_.back());
+            parsed_obu_size_in_bits, obu_size, obu_type);
         return false;
       }
       if (!bit_reader_->VerifyAndSkipTrailingBits(obu_size * 8 -
                                                   parsed_obu_size_in_bits)) {
         LIBGAV1_DLOG(ERROR,
                      "Error when verifying trailing bits for obu type: %d",
-                     types_.back());
+                     obu_type);
         return false;
       }
     }
@@ -2386,12 +2426,16 @@
       LIBGAV1_DLOG(ERROR,
                    "OBU size (%zu) and consumed size (%zu) does not match for "
                    "obu_type: %d.",
-                   obu_size, consumed_obu_size, types_.back());
+                   obu_size, consumed_obu_size, obu_type);
       return false;
     }
     data += bytes_consumed;
     size -= bytes_consumed;
   }
+  if (!parsed_one_full_frame && seen_frame_header) {
+    LIBGAV1_DLOG(ERROR, "The last tile group in the frame was not received.");
+    return false;
+  }
   data_ = data;
   size_ = size;
   return true;
diff --git a/libgav1/src/obu_parser.h b/libgav1/src/obu_parser.h
index 0e1c0fa..9d972b1 100644
--- a/libgav1/src/obu_parser.h
+++ b/libgav1/src/obu_parser.h
@@ -5,7 +5,6 @@
 #include <cstddef>
 #include <cstdint>
 #include <memory>
-#include <vector>
 
 #include "src/decoder_buffer.h"
 #include "src/dsp/common.h"
@@ -15,6 +14,7 @@
 #include "src/utils/constants.h"
 #include "src/utils/raw_bit_reader.h"
 #include "src/utils/segmentation.h"
+#include "src/utils/vector.h"
 
 namespace libgav1 {
 
@@ -43,6 +43,13 @@
   kPrimaryReferenceNone = 7
 };  // anonymous enum
 
+struct ObuHeader {
+  ObuType type;
+  bool has_extension;
+  int8_t temporal_id;
+  int8_t spatial_id;
+};
+
 enum BitstreamProfile : uint8_t {
   kProfile0,
   kProfile1,
@@ -50,9 +57,17 @@
   kMaxProfiles
 };
 
+// In the bitstream the level is encoded in five bits: the first three bits
+// encode |major| - 2 and the last two bits encode |minor|.
+//
+// If the mapped level (major.minor) is in the tables in Annex A.3, there are
+// bitstream conformance requirements on the maximum or minimum values of
+// several variables. The encoded value of 31 (which corresponds to the mapped
+// level 9.3) is the "maximum parameters" level and imposes no level-based
+// constraints on the bitstream.
 struct BitStreamLevel {
-  uint8_t major;
-  uint8_t minor;
+  uint8_t major;  // Range: 2-9.
+  uint8_t minor;  // Range: 0-3.
 };
 
 enum ColorPrimaries : uint8_t {
@@ -262,6 +277,21 @@
   int32_t params[6];
 
   // Represent two shearing operations. Computed from |params| by SetupShear().
+  //
+  // The least significant six (= kWarpParamRoundingBits) bits are all zeros.
+  // (This means alpha, beta, gamma, and delta could be represented by a 10-bit
+  // signed integer.) The minimum value is INT16_MIN (= -32768) and the maximum
+  // value is 32704 = 0x7fc0, the largest int16_t value whose least significant
+  // six bits are all zeros.
+  //
+  // Valid warp parameters (as validated by SetupShear()) have smaller ranges.
+  // Their absolute values are less than 2^14 (= 16384). (This follows from
+  // the warpValid check at the end of Section 7.11.3.6.)
+  //
+  // NOTE: Section 7.11.3.6 of the spec allows a maximum value of 32768, which
+  // is outside the range of int16_t. When cast to int16_t, 32768 becomes
+  // -32768. This potential int16_t overflow does not matter because either
+  // 32768 or -32768 causes SetupShear() to return false,
   int16_t alpha;
   int16_t beta;
   int16_t gamma;
@@ -279,7 +309,7 @@
   int tile_rows_log2;
   int tile_rows;
   int tile_row_start[kMaxTileRows + 1];
-  uint8_t context_update_id;
+  int16_t context_update_id;
   uint8_t tile_size_bytes;
 };
 
@@ -368,13 +398,12 @@
 };
 
 enum MetadataType : uint8_t {
-  kMetadataTypePrivateData,
-  kMetadataTypeHdrContentLightLevel,
+  // 0 is reserved for AOM use.
+  kMetadataTypeHdrContentLightLevel = 1,
   kMetadataTypeHdrMasteringDisplayColorVolume
 };
 
 struct ObuMetadata {
-  std::vector<uint8_t> private_data;
   // Maximum content light level.
   uint16_t max_cll;
   // Maximum frame-average light level.
@@ -424,17 +453,16 @@
   bool ParseOneFrame();
 
   // Getters. Only valid if ParseOneFrame() completes successfully.
-  const std::vector<ObuType>& types() const { return types_; }
-  const std::vector<int8_t>& temporal_ids() const { return temporal_ids_; }
-  const std::vector<int8_t>& spatial_ids() const { return spatial_ids_; }
+  const Vector<ObuHeader>& obu_headers() const { return obu_headers_; }
   const ObuSequenceHeader& sequence_header() const { return sequence_header_; }
   const ObuFrameHeader& frame_header() const { return frame_header_; }
   const ObuMetadata& metadata() const { return metadata_; }
-  const std::vector<ObuTileGroup>& tile_groups() const { return tile_groups_; }
+  const Vector<ObuTileGroup>& tile_groups() const { return tile_groups_; }
 
   // Setters.
   void set_sequence_header(const ObuSequenceHeader& sequence_header) {
     sequence_header_ = sequence_header;
+    has_sequence_header_ = true;
   }
 
  private:
@@ -512,16 +540,15 @@
   size_t size_;
 
   // OBU elements. Only valid if ParseOneFrame() completes successfully.
-  std::vector<ObuType> types_;
-  std::vector<bool> has_extension_;
-  std::vector<int8_t> temporal_ids_;
-  std::vector<int8_t> spatial_ids_;
+  Vector<ObuHeader> obu_headers_;
   ObuSequenceHeader sequence_header_ = {};
   ObuFrameHeader frame_header_ = {};
   ObuMetadata metadata_ = {};
-  std::vector<ObuTileGroup> tile_groups_;
+  Vector<ObuTileGroup> tile_groups_;
   // The expected |start| value of the next ObuTileGroup.
   int next_tile_group_start_ = 0;
+  // If true, the sequence_header_ field is valid.
+  bool has_sequence_header_ = false;
   // If true, the obu_extension_flag syntax element in the OBU header must be
   // 0. Set to true when parsing a sequence header if OperatingPointIdc is 0.
   bool extension_disallowed_ = false;
diff --git a/libgav1/src/post_filter.cc b/libgav1/src/post_filter.cc
index bb29520..107a04c 100644
--- a/libgav1/src/post_filter.cc
+++ b/libgav1/src/post_filter.cc
@@ -1,15 +1,16 @@
 #include "src/post_filter.h"
 
 #include <algorithm>
+#include <atomic>
 #include <cassert>
-#include <condition_variable>  // NOLINT (unapproved c++11 header)
 #include <cstddef>
 #include <cstdint>
 #include <cstring>
-#include <mutex>  // NOLINT (unapproved c++11 header)
+#include <memory>
 
 #include "src/dsp/constants.h"
 #include "src/utils/array_2d.h"
+#include "src/utils/blocking_counter.h"
 #include "src/utils/logging.h"
 #include "src/utils/memory.h"
 #include "src/utils/types.h"
@@ -22,6 +23,37 @@
     {{7, 0, 2, 4, 5, 6, 6, 6}, {0, 1, 2, 3, 4, 5, 6, 7}}};
 
 template <typename Pixel>
+void ExtendFrame(uint8_t* const frame_start, const int width, const int height,
+                 ptrdiff_t stride, const int left, const int right,
+                 const int top, const int bottom) {
+  auto* const start = reinterpret_cast<Pixel*>(frame_start);
+  const Pixel* src = start;
+  Pixel* dst = start - left;
+  stride /= sizeof(Pixel);
+  // Copy to left and right borders.
+  for (int y = 0; y < height; ++y) {
+    Memset(dst, src[0], left);
+    Memset(dst + (left + width), src[width - 1], right);
+    src += stride;
+    dst += stride;
+  }
+  // Copy to top borders.
+  src = start - left;
+  dst = start - left - top * stride;
+  for (int y = 0; y < top; ++y) {
+    memcpy(dst, src, sizeof(Pixel) * stride);
+    dst += stride;
+  }
+  // Copy to bottom borders.
+  dst = start - left + height * stride;
+  src = dst - stride;
+  for (int y = 0; y < bottom; ++y) {
+    memcpy(dst, src, sizeof(Pixel) * stride);
+    dst += stride;
+  }
+}
+
+template <typename Pixel>
 void CopyPlane(const uint8_t* source, int source_stride, const int width,
                const int height, uint8_t* dest, int dest_stride) {
   auto* dst = reinterpret_cast<Pixel*>(dest);
@@ -35,29 +67,6 @@
   }
 }
 
-void CopyYuvBufferToSource(const YuvBuffer* const filtered_frame,
-                           YuvBuffer* const source_frame) {
-  const int num_planes =
-      filtered_frame->is_monochrome() ? kMaxPlanesMonochrome : kMaxPlanes;
-  for (int plane = kPlaneY; plane < num_planes; ++plane) {
-    if (filtered_frame->bitdepth() == 8) {
-      CopyPlane<uint8_t>(
-          filtered_frame->data(plane), filtered_frame->stride(plane),
-          filtered_frame->displayed_width(plane),
-          filtered_frame->displayed_height(plane), source_frame->data(plane),
-          source_frame->stride(plane));
-#if LIBGAV1_MAX_BITDEPTH >= 10
-    } else {
-      CopyPlane<uint16_t>(
-          filtered_frame->data(plane), filtered_frame->stride(plane),
-          filtered_frame->displayed_width(plane),
-          filtered_frame->displayed_height(plane), source_frame->data(plane),
-          source_frame->stride(plane));
-#endif
-    }
-  }
-}
-
 template <int bitdepth, typename Pixel>
 void ComputeSuperRes(const uint8_t* source, uint32_t source_stride,
                      const int upscaled_width, const int height,
@@ -89,49 +98,14 @@
 
 }  // namespace
 
-template <typename Pixel>
-void ExtendFrame(uint8_t* const frame_start, const int width, const int height,
-                 ptrdiff_t stride, const int left, const int right,
-                 const int top, const int bottom) {
-  auto* const start = reinterpret_cast<Pixel*>(frame_start);
-  const Pixel* src = start;
-  Pixel* dst = start - left;
-  stride /= sizeof(Pixel);
-  // Copy to left and right borders.
-  for (int y = 0; y < height; ++y) {
-    Memset(dst, src[0], left);
-    Memset(dst + (left + width), src[width - 1], right);
-    src += stride;
-    dst += stride;
-  }
-  // Copy to top borders.
-  src = start - left;
-  dst = start - left - top * stride;
-  for (int y = 0; y < top; ++y) {
-    memcpy(dst, src, sizeof(Pixel) * stride);
-    dst += stride;
-  }
-  // Copy to bottom borders.
-  dst = start - left + height * stride;
-  src = dst - stride;
-  for (int y = 0; y < bottom; ++y) {
-    memcpy(dst, src, sizeof(Pixel) * stride);
-    dst += stride;
-  }
-}
-
-// Matching definition of static data members.
-constexpr int PostFilter::kRestorationWindowWidth;
+// Static data member definitions.
+constexpr int PostFilter::kCdefLargeValue;
 
 bool PostFilter::ApplyFiltering() {
   if (DoDeblock() && !ApplyDeblockFilter()) return false;
   if (DoCdef() && !ApplyCdef()) return false;
   if (DoSuperRes() && !ApplySuperRes()) return false;
   if (DoRestoration() && !ApplyLoopRestoration()) return false;
-  if (DoCdef() && !DoRestoration()) {
-    // Copy (upscaled) cdef filtered frame to output source buffer.
-    CopyYuvBufferToSource(cdef_buffer_, source_buffer_);
-  }
   // Extend frame boundary for inter frame convolution, referencing.
   for (int plane = kPlaneY; plane < planes_; ++plane) {
     const int8_t subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
@@ -177,18 +151,39 @@
   }
 }
 
+void PostFilter::DeblockFilterWorker(const DeblockFilterJob* jobs, int num_jobs,
+                                     std::atomic<int>* job_counter,
+                                     DeblockFilter deblock_filter) {
+  int job_index;
+  while ((job_index = job_counter->fetch_add(1, std::memory_order_relaxed)) <
+         num_jobs) {
+    const DeblockFilterJob& job = jobs[job_index];
+    for (int column4x4 = 0, column_unit = 0;
+         column4x4 < frame_header_.columns4x4;
+         column4x4 += kNum4x4InLoopFilterMaskUnit, ++column_unit) {
+      const int unit_id = GetDeblockUnitId(job.row_unit, column_unit);
+      (this->*deblock_filter)(static_cast<Plane>(job.plane), job.row4x4,
+                              column4x4, unit_id);
+    }
+  }
+}
+
 bool PostFilter::ApplyDeblockFilterThreaded() {
-  std::mutex mutex;
-  int pending_jobs = 0;  // Guarded by |mutex|.
-  std::condition_variable pending_jobs_zero_condvar;
   const int jobs_per_plane = DivideBy16(frame_header_.rows4x4 + 15);
   const int num_workers = thread_pool_->num_threads();
-  const int jobs_for_threadpool_per_plane =
-      jobs_per_plane * num_workers / (num_workers + 1);
-  assert(jobs_for_threadpool_per_plane < jobs_per_plane);
-  std::vector<std::function<void()>> current_thread_jobs;
-  current_thread_jobs.reserve(planes_ *
-                              (jobs_per_plane - jobs_for_threadpool_per_plane));
+  int planes[kMaxPlanes];
+  planes[0] = kPlaneY;
+  int num_planes = 1;
+  for (int plane = kPlaneU; plane < planes_; ++plane) {
+    if (frame_header_.loop_filter.level[plane + 1] != 0) {
+      planes[num_planes++] = plane;
+    }
+  }
+  const int num_jobs = num_planes * jobs_per_plane;
+  std::unique_ptr<DeblockFilterJob[]> jobs_unique_ptr(
+      new (std::nothrow) DeblockFilterJob[num_jobs]);
+  if (jobs_unique_ptr == nullptr) return false;
+  DeblockFilterJob* jobs = jobs_unique_ptr.get();
   // The vertical filters are not dependent on each other. So simply schedule
   // them for all possible rows.
   //
@@ -200,62 +195,34 @@
   //
   // The only synchronization involved is to know when the each directional
   // filter is complete for the entire frame.
-  for (auto deblock_filter : {&PostFilter::VerticalDeblockFilter,
-                              &PostFilter::HorizontalDeblockFilter}) {
-    for (int plane = kPlaneY; plane < planes_; ++plane) {
-      if (plane != kPlaneY && frame_header_.loop_filter.level[plane + 1] == 0) {
-        continue;
-      }
-      {
-        std::lock_guard<std::mutex> lock(mutex);
-        pending_jobs += jobs_for_threadpool_per_plane;
-      }
+  for (DeblockFilter deblock_filter : {&PostFilter::VerticalDeblockFilter,
+                                       &PostFilter::HorizontalDeblockFilter}) {
+    int job_index = 0;
+    for (int i = 0; i < num_planes; ++i) {
+      const int plane = planes[i];
       for (int row4x4 = 0, row_unit = 0; row4x4 < frame_header_.rows4x4;
            row4x4 += kNum4x4InLoopFilterMaskUnit, ++row_unit) {
-        if (row_unit < jobs_for_threadpool_per_plane) {
-          thread_pool_->Schedule([this, plane, row4x4, row_unit, deblock_filter,
-                                  &mutex, &pending_jobs,
-                                  &pending_jobs_zero_condvar]() {
-            for (int column4x4 = 0, column_unit = 0;
-                 column4x4 < frame_header_.columns4x4;
-                 column4x4 += kNum4x4InLoopFilterMaskUnit, ++column_unit) {
-              const int unit_id = GetDeblockUnitId(row_unit, column_unit);
-              (this->*deblock_filter)(static_cast<Plane>(plane), row4x4,
-                                      column4x4, unit_id);
-            }
-            std::lock_guard<std::mutex> lock(mutex);
-            if (--pending_jobs == 0) {
-              // TODO(jzern): the mutex doesn't need to be locked to signal the
-              // condition.
-              pending_jobs_zero_condvar.notify_one();
-            }
-          });
-        } else {
-          current_thread_jobs.push_back(
-              [this, plane, row4x4, row_unit, deblock_filter]() {
-                for (int column4x4 = 0, column_unit = 0;
-                     column4x4 < frame_header_.columns4x4;
-                     column4x4 += kNum4x4InLoopFilterMaskUnit, ++column_unit) {
-                  const int unit_id = GetDeblockUnitId(row_unit, column_unit);
-                  (this->*deblock_filter)(static_cast<Plane>(plane), row4x4,
-                                          column4x4, unit_id);
-                }
-              });
-        }
+        assert(job_index < num_jobs);
+        DeblockFilterJob& job = jobs[job_index++];
+        job.plane = plane;
+        job.row4x4 = row4x4;
+        job.row_unit = row_unit;
       }
     }
-    // Run the jobs for current thread.
-    for (const auto& job : current_thread_jobs) {
-      job();
+    assert(job_index == num_jobs);
+    std::atomic<int> job_counter(0);
+    BlockingCounter pending_workers(num_workers);
+    for (int i = 0; i < num_workers; ++i) {
+      thread_pool_->Schedule([this, jobs, num_jobs, &job_counter,
+                              deblock_filter, &pending_workers]() {
+        DeblockFilterWorker(jobs, num_jobs, &job_counter, deblock_filter);
+        pending_workers.Decrement();
+      });
     }
-    current_thread_jobs.clear();
+    // Run the jobs on the current thread.
+    DeblockFilterWorker(jobs, num_jobs, &job_counter, deblock_filter);
     // Wait for the threadpool jobs to finish.
-    {
-      std::unique_lock<std::mutex> lock(mutex);
-      while (pending_jobs != 0) {
-        pending_jobs_zero_condvar.wait(lock);
-      }
-    }
+    pending_workers.Wait();
   }
   return true;
 }
@@ -300,132 +267,352 @@
   return true;
 }
 
-bool PostFilter::ApplyCdef() {
-  if (!cdef_filtered_buffer_.Realloc(
-          bitdepth_, planes_ == kMaxPlanesMonochrome, upscaled_width_, height_,
-          subsampling_x_, subsampling_y_, kBorderPixels,
-          /*byte_alignment=*/0, nullptr, nullptr, nullptr)) {
-    return false;
+void PostFilter::ComputeDeblockFilterLevels(
+    const int8_t delta_lf[kFrameLfCount],
+    uint8_t deblock_filter_levels[kMaxSegments][kFrameLfCount]
+                                 [kNumReferenceFrameTypes][2]) const {
+  if (!DoDeblock()) return;
+  for (int segment_id = 0;
+       segment_id < (frame_header_.segmentation.enabled ? kMaxSegments : 1);
+       ++segment_id) {
+    int level_index = 0;
+    for (; level_index < 2; ++level_index) {
+      LoopFilterMask::ComputeDeblockFilterLevels(
+          frame_header_, segment_id, level_index, delta_lf,
+          deblock_filter_levels[segment_id][level_index]);
+    }
+    for (; level_index < kFrameLfCount; ++level_index) {
+      if (frame_header_.loop_filter.level[level_index] != 0) {
+        LoopFilterMask::ComputeDeblockFilterLevels(
+            frame_header_, segment_id, level_index, delta_lf,
+            deblock_filter_levels[segment_id][level_index]);
+      }
+    }
   }
-  cdef_buffer_ = &cdef_filtered_buffer_;
+}
 
-  const size_t pixel_size =
-      (bitdepth_ == 8) ? sizeof(uint8_t) : sizeof(uint16_t);
+uint8_t* PostFilter::GetCdefBufferAndStride(
+    const int start_x, const int start_y, const int plane,
+    const int subsampling_x, const int subsampling_y,
+    const int window_buffer_plane_size, const int vertical_shift,
+    const int horizontal_shift, int* cdef_stride) {
+  if (!DoRestoration() && thread_pool_ != nullptr) {
+    // write output to threaded_window_buffer.
+    *cdef_stride = window_buffer_width_ * pixel_size_;
+    const int column_window = start_x % (window_buffer_width_ >> subsampling_x);
+    const int row_window = start_y % (window_buffer_height_ >> subsampling_y);
+    return threaded_window_buffer_ + plane * window_buffer_plane_size +
+           row_window * (*cdef_stride) + column_window * pixel_size_;
+  }
+  // write output to cdef_buffer_.
+  *cdef_stride = cdef_buffer_->stride(plane);
+  // In-place cdef is applied by writing the output to the top-left
+  // corner, if restoration is not present. In this case,
+  // cdef_buffer_ == source_buffer_.
+  const ptrdiff_t buffer_offset =
+      DoRestoration()
+          ? 0
+          : vertical_shift * (*cdef_stride) + horizontal_shift * pixel_size_;
+  return cdef_buffer_->data(plane) + start_y * (*cdef_stride) +
+         start_x * pixel_size_ + buffer_offset;
+}
+
+template <typename Pixel>
+void PostFilter::ApplyCdefForOneUnit(uint16_t* cdef_block, const int index,
+                                     const int block_width4x4,
+                                     const int block_height4x4,
+                                     const int row4x4_start,
+                                     const int column4x4_start) {
   const int coeff_shift = bitdepth_ - 8;
-  const int step_64x64 = kNum4x4BlocksWide[kBlock64x64];
   const int step = kNum4x4BlocksWide[kBlock8x8];
-  // Apply cdef on each 8x8 Y block and
-  // (8 >> subsampling_x)x(8 >> subsampling_y) UV block.
-  for (int row = 0; row < frame_header_.rows4x4; row += step_64x64) {
-    for (int column = 0; column < frame_header_.columns4x4;
-         column += step_64x64) {
-      const int index = cdef_index_[DivideBy16(row)][DivideBy16(column)];
-      const int block_width4x4 =
-          std::min(step_64x64, frame_header_.columns4x4 - column);
-      const int block_height4x4 =
-          std::min(step_64x64, frame_header_.rows4x4 - row);
-      if (index == -1) {
-        for (int plane = kPlaneY; plane < planes_; ++plane) {
-          const int subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
-          const int subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
-          const int start_x = MultiplyBy4(column) >> subsampling_x;
-          const int start_y = MultiplyBy4(row) >> subsampling_y;
-          const ptrdiff_t cdef_stride = cdef_buffer_->stride(plane);
-          uint8_t* const cdef_buffer = cdef_buffer_->data(plane) +
-                                       start_y * cdef_stride +
-                                       start_x * pixel_size;
-          const int src_stride = source_buffer_->stride(plane);
-          uint8_t* const src_buffer = source_buffer_->data(plane) +
-                                      start_y * src_stride +
-                                      start_x * pixel_size;
-          const int block_width = MultiplyBy4(block_width4x4) >> subsampling_x;
-          const int block_height =
-              MultiplyBy4(block_height4x4) >> subsampling_y;
+  const int horizontal_shift = -source_buffer_->alignment() / pixel_size_;
+  const int vertical_shift = -kCdefBorder;
+  const int window_buffer_plane_size =
+      window_buffer_width_ * window_buffer_height_ * pixel_size_;
+
+  if (index == -1) {
+    for (int plane = kPlaneY; plane < planes_; ++plane) {
+      const int subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
+      const int subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
+      const int start_x = MultiplyBy4(column4x4_start) >> subsampling_x;
+      const int start_y = MultiplyBy4(row4x4_start) >> subsampling_y;
+      int cdef_stride;
+      uint8_t* const cdef_buffer = GetCdefBufferAndStride(
+          start_x, start_y, plane, subsampling_x, subsampling_y,
+          window_buffer_plane_size, vertical_shift, horizontal_shift,
+          &cdef_stride);
+      const int src_stride = source_buffer_->stride(plane);
+      uint8_t* const src_buffer = source_buffer_->data(plane) +
+                                  start_y * src_stride + start_x * pixel_size_;
+      const int block_width = MultiplyBy4(block_width4x4) >> subsampling_x;
+      const int block_height = MultiplyBy4(block_height4x4) >> subsampling_y;
+      for (int y = 0; y < block_height; ++y) {
+        memcpy(cdef_buffer + y * cdef_stride, src_buffer + y * src_stride,
+               block_width * pixel_size_);
+      }
+    }
+    return;
+  }
+
+  PrepareCdefBlock<Pixel>(source_buffer_, planes_, subsampling_x_,
+                          subsampling_y_, frame_header_.width,
+                          frame_header_.height, block_width4x4, block_height4x4,
+                          row4x4_start, column4x4_start, cdef_block,
+                          kRestorationProcessingUnitSizeWithBorders);
+
+  for (int row4x4 = row4x4_start; row4x4 < row4x4_start + block_height4x4;
+       row4x4 += step) {
+    for (int column4x4 = column4x4_start;
+         column4x4 < column4x4_start + block_width4x4; column4x4 += step) {
+      const bool skip =
+          block_parameters_.Find(row4x4, column4x4) != nullptr &&
+          block_parameters_.Find(row4x4 + 1, column4x4) != nullptr &&
+          block_parameters_.Find(row4x4, column4x4 + 1) != nullptr &&
+          block_parameters_.Find(row4x4 + 1, column4x4 + 1) != nullptr &&
+          block_parameters_.Find(row4x4, column4x4)->skip &&
+          block_parameters_.Find(row4x4 + 1, column4x4)->skip &&
+          block_parameters_.Find(row4x4, column4x4 + 1)->skip &&
+          block_parameters_.Find(row4x4 + 1, column4x4 + 1)->skip;
+      int damping = frame_header_.cdef.damping + coeff_shift;
+      int direction_y;
+      int direction;
+      int variance;
+      uint8_t primary_strength;
+      uint8_t secondary_strength;
+
+      for (int plane = kPlaneY; plane < planes_; ++plane) {
+        const int subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
+        const int subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
+        const int start_x = MultiplyBy4(column4x4) >> subsampling_x;
+        const int start_y = MultiplyBy4(row4x4) >> subsampling_y;
+        const int block_width = 8 >> subsampling_x;
+        const int block_height = 8 >> subsampling_y;
+        int cdef_stride;
+        uint8_t* const cdef_buffer = GetCdefBufferAndStride(
+            start_x, start_y, plane, subsampling_x, subsampling_y,
+            window_buffer_plane_size, vertical_shift, horizontal_shift,
+            &cdef_stride);
+        const int src_stride = source_buffer_->stride(plane);
+        uint8_t* const src_buffer = source_buffer_->data(plane) +
+                                    start_y * src_stride +
+                                    start_x * pixel_size_;
+
+        if (skip) {  // No cdef filtering.
           for (int y = 0; y < block_height; ++y) {
             memcpy(cdef_buffer + y * cdef_stride, src_buffer + y * src_stride,
-                   block_width * pixel_size);
+                   block_width * pixel_size_);
           }
+          continue;
         }
+
+        if (plane == kPlaneY) {
+          dsp_.cdef_direction(src_buffer, src_stride, &direction_y, &variance);
+          primary_strength = frame_header_.cdef.y_primary_strength[index]
+                             << coeff_shift;
+          secondary_strength = frame_header_.cdef.y_secondary_strength[index]
+                               << coeff_shift;
+          direction = (primary_strength == 0) ? 0 : direction_y;
+          const int variance_strength =
+              ((variance >> 6) != 0) ? std::min(FloorLog2(variance >> 6), 12)
+                                     : 0;
+          primary_strength =
+              (variance != 0)
+                  ? (primary_strength * (4 + variance_strength) + 8) >> 4
+                  : 0;
+        } else {
+          primary_strength = frame_header_.cdef.uv_primary_strength[index]
+                             << coeff_shift;
+          secondary_strength = frame_header_.cdef.uv_secondary_strength[index]
+                               << coeff_shift;
+          direction = (primary_strength == 0)
+                          ? 0
+                          : kCdefUvDirection[subsampling_x_][subsampling_y_]
+                                            [direction_y];
+          damping = frame_header_.cdef.damping + coeff_shift - 1;
+        }
+
+        if ((primary_strength | secondary_strength) == 0) {
+          for (int y = 0; y < block_height; ++y) {
+            memcpy(cdef_buffer + y * cdef_stride, src_buffer + y * src_stride,
+                   block_width * pixel_size_);
+          }
+          continue;
+        }
+        uint16_t* cdef_src =
+            cdef_block + plane * kRestorationProcessingUnitSizeWithBorders *
+                             kRestorationProcessingUnitSizeWithBorders;
+        cdef_src += kCdefBorder * kRestorationProcessingUnitSizeWithBorders +
+                    kCdefBorder;
+        cdef_src += (MultiplyBy4(row4x4 - row4x4_start) >> subsampling_y) *
+                        kRestorationProcessingUnitSizeWithBorders +
+                    (MultiplyBy4(column4x4 - column4x4_start) >> subsampling_x);
+        dsp_.cdef_filter(cdef_src, kRestorationProcessingUnitSizeWithBorders,
+                         frame_header_.rows4x4, frame_header_.columns4x4,
+                         start_x, start_y, subsampling_x, subsampling_y,
+                         primary_strength, secondary_strength, damping,
+                         direction, cdef_buffer, cdef_stride);
+      }
+    }
+  }
+}
+
+template <typename Pixel>
+void PostFilter::ApplyCdefForOneRowInWindow(const int row4x4,
+                                            const int column4x4_start) {
+  const int step_64x64 = 16;  // = 64/4.
+  uint16_t cdef_block[kRestorationProcessingUnitSizeWithBorders *
+                      kRestorationProcessingUnitSizeWithBorders * 3];
+
+  for (int column4x4_64x64 = 0;
+       column4x4_64x64 < std::min(DivideBy4(window_buffer_width_),
+                                  frame_header_.columns4x4 - column4x4_start);
+       column4x4_64x64 += step_64x64) {
+    const int column4x4 = column4x4_start + column4x4_64x64;
+    const int index = cdef_index_[DivideBy16(row4x4)][DivideBy16(column4x4)];
+    const int block_width4x4 =
+        std::min(step_64x64, frame_header_.columns4x4 - column4x4);
+    const int block_height4x4 =
+        std::min(step_64x64, frame_header_.rows4x4 - row4x4);
+
+    ApplyCdefForOneUnit<Pixel>(cdef_block, index, block_width4x4,
+                               block_height4x4, row4x4, column4x4);
+  }
+}
+
+// Each thread processes one row inside the window.
+// Y, U, V planes are processed together inside one thread.
+template <typename Pixel>
+bool PostFilter::ApplyCdefThreaded() {
+  assert((window_buffer_height_ & 63) == 0);
+  const int num_workers = thread_pool_->num_threads();
+  const int horizontal_shift = -source_buffer_->alignment() / pixel_size_;
+  const int vertical_shift = -kCdefBorder;
+  const int window_buffer_plane_size =
+      window_buffer_width_ * window_buffer_height_ * pixel_size_;
+  const int window_buffer_height4x4 = DivideBy4(window_buffer_height_);
+  const int step_64x64 = 16;  // = 64/4.
+  for (int row4x4 = 0; row4x4 < frame_header_.rows4x4;
+       row4x4 += window_buffer_height4x4) {
+    const int actual_window_height4x4 =
+        std::min(window_buffer_height4x4, frame_header_.rows4x4 - row4x4);
+    const int vertical_units_per_window =
+        DivideBy16(actual_window_height4x4 + 15);
+    for (int column4x4 = 0; column4x4 < frame_header_.columns4x4;
+         column4x4 += DivideBy4(window_buffer_width_)) {
+      const int jobs_for_threadpool =
+          vertical_units_per_window * num_workers / (num_workers + 1);
+      BlockingCounter pending_jobs(jobs_for_threadpool);
+      int job_count = 0;
+      for (int row64x64 = 0; row64x64 < actual_window_height4x4;
+           row64x64 += step_64x64) {
+        if (job_count < jobs_for_threadpool) {
+          thread_pool_->Schedule(
+              [this, row4x4, column4x4, row64x64, &pending_jobs]() {
+                ApplyCdefForOneRowInWindow<Pixel>(row4x4 + row64x64, column4x4);
+                pending_jobs.Decrement();
+              });
+        } else {
+          ApplyCdefForOneRowInWindow<Pixel>(row4x4 + row64x64, column4x4);
+        }
+        ++job_count;
+      }
+      pending_jobs.Wait();
+      if (DoRestoration()) continue;
+
+      // Copy |threaded_window_buffer_| to cdef_buffer_ (== source_buffer_).
+      assert(cdef_buffer_ == source_buffer_);
+      for (int plane = kPlaneY; plane < planes_; ++plane) {
+        const int cdef_stride = cdef_buffer_->stride(plane);
+        const ptrdiff_t buffer_offset =
+            vertical_shift * cdef_stride + horizontal_shift * pixel_size_;
+        const int8_t subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
+        const int8_t subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
+        const int plane_row = MultiplyBy4(row4x4) >> subsampling_y;
+        const int plane_column = MultiplyBy4(column4x4) >> subsampling_x;
+        int copy_width = std::min(frame_header_.columns4x4 - column4x4,
+                                  DivideBy4(window_buffer_width_));
+        copy_width = MultiplyBy4(copy_width) >> subsampling_x;
+        int copy_height =
+            std::min(frame_header_.rows4x4 - row4x4, window_buffer_height4x4);
+        copy_height = MultiplyBy4(copy_height) >> subsampling_y;
+        CopyPlane<Pixel>(
+            threaded_window_buffer_ + plane * window_buffer_plane_size,
+            window_buffer_width_ * pixel_size_, copy_width, copy_height,
+            cdef_buffer_->data(plane) + plane_row * cdef_stride +
+                plane_column * pixel_size_ + buffer_offset,
+            cdef_stride);
+      }
+    }
+  }
+  if (!DoRestoration()) {
+    for (int plane = kPlaneY; plane < planes_; ++plane) {
+      if (!cdef_buffer_->ShiftBuffer(plane, horizontal_shift, vertical_shift)) {
+        LIBGAV1_DLOG(ERROR,
+                     "Error shifting frame buffer head pointer at plane: %d",
+                     plane);
+        return false;
+      }
+    }
+  }
+
+  return true;
+}
+
+bool PostFilter::ApplyCdef() {
+  if (!DoRestoration()) {
+    cdef_buffer_ = source_buffer_;
+  } else {
+    if (!cdef_filtered_buffer_.Realloc(
+            bitdepth_, planes_ == kMaxPlanesMonochrome, upscaled_width_,
+            height_, subsampling_x_, subsampling_y_, kBorderPixels,
+            /*byte_alignment=*/0, nullptr, nullptr, nullptr)) {
+      return false;
+    }
+    cdef_buffer_ = &cdef_filtered_buffer_;
+  }
+
+  if (thread_pool_ != nullptr) {
+#if LIBGAV1_MAX_BITDEPTH >= 10
+    if (bitdepth_ >= 10) {
+      return ApplyCdefThreaded<uint16_t>();
+    }
+#endif
+    return ApplyCdefThreaded<uint8_t>();
+  }
+
+  const int step_64x64 = 16;  // = 64/4.
+  // Apply cdef on each 8x8 Y block and
+  // (8 >> subsampling_x)x(8 >> subsampling_y) UV block.
+  for (int row4x4 = 0; row4x4 < frame_header_.rows4x4; row4x4 += step_64x64) {
+    for (int column4x4 = 0; column4x4 < frame_header_.columns4x4;
+         column4x4 += step_64x64) {
+      const int index = cdef_index_[DivideBy16(row4x4)][DivideBy16(column4x4)];
+      const int block_width4x4 =
+          std::min(step_64x64, frame_header_.columns4x4 - column4x4);
+      const int block_height4x4 =
+          std::min(step_64x64, frame_header_.rows4x4 - row4x4);
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+      if (bitdepth_ >= 10) {
+        ApplyCdefForOneUnit<uint16_t>(cdef_block_, index, block_width4x4,
+                                      block_height4x4, row4x4, column4x4);
         continue;
       }
-
-      for (int row4x4 = row; row4x4 < row + block_height4x4; row4x4 += step) {
-        for (int column4x4 = column; column4x4 < column + block_width4x4;
-             column4x4 += step) {
-          const bool skip =
-              block_parameters_->Find(row4x4, column4x4) != nullptr &&
-              block_parameters_->Find(row4x4 + 1, column4x4) != nullptr &&
-              block_parameters_->Find(row4x4, column4x4 + 1) != nullptr &&
-              block_parameters_->Find(row4x4 + 1, column4x4 + 1) != nullptr &&
-              block_parameters_->Find(row4x4, column4x4)->skip &&
-              block_parameters_->Find(row4x4 + 1, column4x4)->skip &&
-              block_parameters_->Find(row4x4, column4x4 + 1)->skip &&
-              block_parameters_->Find(row4x4 + 1, column4x4 + 1)->skip;
-          int damping = frame_header_.cdef.damping + coeff_shift;
-          int direction_y;
-          int direction;
-          int variance;
-          uint8_t primary_strength;
-          uint8_t secondary_strength;
-
-          for (int plane = kPlaneY; plane < planes_; ++plane) {
-            const int subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
-            const int subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
-            const int start_x = MultiplyBy4(column4x4) >> subsampling_x;
-            const int start_y = MultiplyBy4(row4x4) >> subsampling_y;
-            const int block_width = 8 >> subsampling_x;
-            const int block_height = 8 >> subsampling_y;
-            const ptrdiff_t cdef_stride = cdef_buffer_->stride(plane);
-            uint8_t* const cdef_buffer = cdef_buffer_->data(plane) +
-                                         start_y * cdef_stride +
-                                         start_x * pixel_size;
-            const int src_stride = source_buffer_->stride(plane);
-            uint8_t* const src_buffer = source_buffer_->data(plane) +
-                                        start_y * src_stride +
-                                        start_x * pixel_size;
-            for (int y = 0; y < block_height; ++y) {
-              memcpy(cdef_buffer + y * cdef_stride, src_buffer + y * src_stride,
-                     block_width * pixel_size);
-            }
-            if (skip) continue;  // No cdef filtering.
-
-            if (plane == kPlaneY) {
-              dsp_.cdef_direction(cdef_buffer, cdef_stride, &direction_y,
-                                  &variance);
-              primary_strength = frame_header_.cdef.y_primary_strength[index]
-                                 << coeff_shift;
-              secondary_strength =
-                  frame_header_.cdef.y_secondary_strength[index] << coeff_shift;
-              direction = (primary_strength == 0) ? 0 : direction_y;
-              const int variance_strength =
-                  ((variance >> 6) != 0)
-                      ? std::min(FloorLog2(variance >> 6), 12)
-                      : 0;
-              primary_strength =
-                  (variance != 0)
-                      ? (primary_strength * (4 + variance_strength) + 8) >> 4
-                      : 0;
-            } else {
-              primary_strength = frame_header_.cdef.uv_primary_strength[index]
-                                 << coeff_shift;
-              secondary_strength =
-                  frame_header_.cdef.uv_secondary_strength[index]
-                  << coeff_shift;
-              direction = (primary_strength == 0)
-                              ? 0
-                              : kCdefUvDirection[subsampling_x_][subsampling_y_]
-                                                [direction_y];
-              damping = frame_header_.cdef.damping + coeff_shift - 1;
-            }
-            // TODO(chengchen): Possible early termination if all parameters are
-            // 0.
-            dsp_.cdef_filter(src_buffer, src_stride, frame_header_.rows4x4,
-                             frame_header_.columns4x4, start_x, start_y,
-                             subsampling_x, subsampling_y, primary_strength,
-                             secondary_strength, damping, direction,
-                             cdef_buffer, cdef_stride);
-          }
-        }
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+      ApplyCdefForOneUnit<uint8_t>(cdef_block_, index, block_width4x4,
+                                   block_height4x4, row4x4, column4x4);
+    }
+  }
+  if (!DoRestoration()) {
+    const int horizontal_shift = -source_buffer_->alignment() / pixel_size_;
+    const int vertical_shift = -kCdefBorder;
+    for (int plane = kPlaneY; plane < planes_; ++plane) {
+      if (!source_buffer_->ShiftBuffer(plane, horizontal_shift,
+                                       vertical_shift)) {
+        LIBGAV1_DLOG(ERROR,
+                     "Error shifting frame buffer head pointer at plane: %d",
+                     plane);
+        return false;
       }
     }
   }
@@ -492,16 +679,14 @@
   // Extend original frame, copy to borders.
   for (int plane = kPlaneY; plane < planes_; ++plane) {
     const int8_t subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
-    const int8_t subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
     uint8_t* const frame_start = input_buffer->data(plane);
     const int plane_width =
         RightShiftWithRounding(upscaled_width_, subsampling_x);
-    const int border_height = kBorderPixels >> subsampling_y;
-    const int border_width = kBorderPixels >> subsampling_x;
-    ExtendFrameBoundary(frame_start, plane_width,
-                        input_buffer->displayed_height(plane),
-                        input_buffer->stride(plane), border_width, border_width,
-                        border_height, border_height);
+    ExtendFrameBoundary(
+        frame_start, plane_width, input_buffer->displayed_height(plane),
+        input_buffer->stride(plane), input_buffer->left_border(plane),
+        input_buffer->right_border(plane), input_buffer->top_border(plane),
+        input_buffer->bottom_border(plane));
   }
 }
 
@@ -536,9 +721,7 @@
     const int x, const int y, const int row, const int unit_row,
     const int current_process_unit_height, const int process_unit_width,
     const int window_width, const int plane_unit_size,
-    const int num_horizontal_units, std::mutex* const mutex,
-    int* const pending_jobs,
-    std::condition_variable* const pending_jobs_zero_condvar) {
+    const int num_horizontal_units) {
   for (int column = 0; column < window_width; column += process_unit_width) {
     const int unit_x = x + column;
     const int unit_column =
@@ -555,17 +738,7 @@
         cdef_buffer, cdef_buffer_stride, deblock_buffer, deblock_buffer_stride,
         plane, plane_height, unit_id, type, x, y, row, column,
         current_process_unit_width, current_process_unit_height,
-        process_unit_width, kRestorationWindowWidth);
-  }
-  if (mutex != nullptr) {
-    assert(pending_jobs != nullptr);
-    assert(pending_jobs_zero_condvar != nullptr);
-    std::lock_guard<std::mutex> lock(*mutex);
-    if (--*pending_jobs == 0) {
-      // TODO(jzern): the mutex doesn't need to be locked to signal the
-      // condition.
-      pending_jobs_zero_condvar->notify_one();
-    }
+        process_unit_width, window_buffer_width_);
   }
 }
 
@@ -581,15 +754,14 @@
   const int unit_x = x + column;
   const int unit_y = y + row;
   uint8_t* cdef_unit_buffer =
-      cdef_buffer + unit_y * cdef_buffer_stride + unit_x * sizeof(Pixel);
+      cdef_buffer + unit_y * cdef_buffer_stride + unit_x * pixel_size_;
   Array2DView<Pixel> loop_restored_window(
-      restoration_window_height_, kRestorationWindowWidth,
-      reinterpret_cast<Pixel*>(threaded_loop_restoration_buffer_));
+      window_buffer_height_, window_buffer_width_,
+      reinterpret_cast<Pixel*>(threaded_window_buffer_));
   if (type == kLoopRestorationTypeNone) {
     Pixel* dest = &loop_restored_window[row][column];
     for (int k = 0; k < current_process_unit_height; ++k) {
-      memcpy(dest, cdef_unit_buffer,
-             current_process_unit_width * sizeof(Pixel));
+      memcpy(dest, cdef_unit_buffer, current_process_unit_width * pixel_size_);
       dest += window_width;
       cdef_unit_buffer += cdef_buffer_stride;
     }
@@ -605,7 +777,7 @@
                                sizeof(Pixel) +
                            ((sizeof(Pixel) == 1) ? 6 : 0)];
   const ptrdiff_t block_buffer_stride =
-      kRestorationProcessingUnitSizeWithBorders * sizeof(Pixel);
+      kRestorationProcessingUnitSizeWithBorders * pixel_size_;
   IntermediateBuffers intermediate_buffers;
 
   RestorationBuffer restoration_buffer = {
@@ -616,10 +788,9 @@
        intermediate_buffers.box_filter.intermediate_b},
       kRestorationProcessingUnitSizeWithBorders + kRestorationPadding,
       intermediate_buffers.wiener,
-      kMaxSuperBlockSizeInPixels,
-      {(bitdepth_ == 12) ? 5 : 3, (bitdepth_ == 12) ? 9 : 11}};
+      kMaxSuperBlockSizeInPixels};
   uint8_t* deblock_unit_buffer =
-      deblock_buffer + unit_y * deblock_buffer_stride + unit_x * sizeof(Pixel);
+      deblock_buffer + unit_y * deblock_buffer_stride + unit_x * pixel_size_;
   assert(type == kLoopRestorationTypeSgrProj ||
          type == kLoopRestorationTypeWiener);
   const dsp::LoopRestorationFunc restoration_func =
@@ -631,23 +802,23 @@
       unit_y + current_process_unit_height >= plane_height);
   restoration_func(reinterpret_cast<const uint8_t*>(
                        block_buffer + kRestorationBorder * block_buffer_stride +
-                       kRestorationBorder * sizeof(Pixel)),
+                       kRestorationBorder * pixel_size_),
                    &loop_restored_window[row][column],
                    restoration_info_->loop_restoration_info(
                        static_cast<Plane>(plane), unit_id),
-                   block_buffer_stride, window_width * sizeof(Pixel),
+                   block_buffer_stride, window_width * pixel_size_,
                    current_process_unit_width, current_process_unit_height,
                    &restoration_buffer);
 }
 
 // Multi-thread version of loop restoration, based on a moving window of size
-// |kRestorationWindowWidth|x|restoration_window_height_|. Inside the moving
-// window, we create a filtering job for each row and each filtering job is
-// submitted to the thread pool. Each free thread takes one job from the thread
-// pool and completes filtering until all jobs are finished. This approach
-// requires an extra buffer (|threaded_loop_restoration_buffer_|) to hold the
-// filtering output, whose size is the size of the window. It also needs block
-// buffers (i.e., |block_buffer| and |intermediate_buffers| in
+// |window_buffer_width_|x|window_buffer_height_|. Inside the moving window, we
+// create a filtering job for each row and each filtering job is submitted to
+// the thread pool. Each free thread takes one job from the thread pool and
+// completes filtering until all jobs are finished. This approach requires an
+// extra buffer (|threaded_window_buffer_|) to hold the filtering output, whose
+// size is the size of the window. It also needs block buffers (i.e.,
+// |block_buffer| and |intermediate_buffers| in
 // ApplyLoopRestorationForOneUnit()) to store intermediate results in loop
 // restoration for each thread. After all units inside the window are filtered,
 // the output is written to the frame buffer.
@@ -663,9 +834,7 @@
       kRestorationProcessingUnitSize >> subsampling_y_,
       kRestorationProcessingUnitSize >> subsampling_y_};
 
-  std::mutex mutex;
-  std::condition_variable pending_jobs_zero_condvar;
-  const int horizontal_shift = -source_buffer_->alignment() / sizeof(Pixel);
+  const int horizontal_shift = -source_buffer_->alignment() / pixel_size_;
   const int vertical_shift = -kRestorationBorder;
   for (int plane = kPlaneY; plane < planes_; ++plane) {
     if (loop_restoration_.type[plane] == kLoopRestorationTypeNone) {
@@ -696,7 +865,7 @@
         RightShiftWithRounding(upscaled_width_, subsampling_x);
     const int plane_height = RightShiftWithRounding(height_, subsampling_y);
     const ptrdiff_t src_unit_buffer_offset =
-        vertical_shift * src_stride + horizontal_shift * sizeof(Pixel);
+        vertical_shift * src_stride + horizontal_shift * pixel_size_;
     ExtendFrameBoundary(cdef_buffer, plane_width, plane_height,
                         cdef_buffer_stride, kRestorationBorder,
                         kRestorationBorder, kRestorationBorder,
@@ -709,10 +878,10 @@
     }
 
     const int num_workers = thread_pool_->num_threads();
-    for (int y = 0; y < plane_height; y += restoration_window_height_) {
-      const int actual_window_height = std::min(
-          restoration_window_height_ - ((y == 0) ? unit_height_offset : 0),
-          plane_height - y);
+    for (int y = 0; y < plane_height; y += window_buffer_height_) {
+      const int actual_window_height =
+          std::min(window_buffer_height_ - ((y == 0) ? unit_height_offset : 0),
+                   plane_height - y);
       int vertical_units_per_window =
           (actual_window_height + plane_process_unit_height[plane] - 1) /
           plane_process_unit_height[plane];
@@ -731,14 +900,13 @@
                 plane_process_unit_height[plane] +
             1;
       }
-      for (int x = 0; x < plane_width; x += kRestorationWindowWidth) {
+      for (int x = 0; x < plane_width; x += window_buffer_width_) {
         const int actual_window_width =
-            std::min(kRestorationWindowWidth, plane_width - x);
-        int pending_jobs = vertical_units_per_window;  // Guarded by |mutex|.
+            std::min(window_buffer_width_, plane_width - x);
         const int jobs_for_threadpool =
-            pending_jobs * num_workers / (num_workers + 1);
-        assert(jobs_for_threadpool < pending_jobs);
-        pending_jobs = jobs_for_threadpool;
+            vertical_units_per_window * num_workers / (num_workers + 1);
+        assert(jobs_for_threadpool < vertical_units_per_window);
+        BlockingCounter pending_jobs(jobs_for_threadpool);
         int job_count = 0;
         int current_process_unit_height;
         for (int row = 0; row < actual_window_height;
@@ -757,20 +925,19 @@
 
           if (job_count < jobs_for_threadpool) {
             thread_pool_->Schedule(
-                [this, &mutex, cdef_buffer, cdef_buffer_stride, deblock_buffer,
+                [this, cdef_buffer, cdef_buffer_stride, deblock_buffer,
                  deblock_buffer_stride, process_unit_width,
                  current_process_unit_height, actual_window_width,
                  plane_unit_size, num_horizontal_units, x, y, row, unit_row,
-                 plane_height, plane_width, plane, &pending_jobs,
-                 &pending_jobs_zero_condvar]() {
+                 plane_height, plane_width, plane, &pending_jobs]() {
                   ApplyLoopRestorationForOneRowInWindow<Pixel>(
                       cdef_buffer, cdef_buffer_stride, deblock_buffer,
                       deblock_buffer_stride, static_cast<Plane>(plane),
                       plane_height, plane_width, x, y, row, unit_row,
                       current_process_unit_height, process_unit_width,
                       actual_window_width, plane_unit_size,
-                      num_horizontal_units, &mutex, &pending_jobs,
-                      &pending_jobs_zero_condvar);
+                      num_horizontal_units);
+                  pending_jobs.Decrement();
                 });
           } else {
             ApplyLoopRestorationForOneRowInWindow<Pixel>(
@@ -778,22 +945,17 @@
                 deblock_buffer_stride, static_cast<Plane>(plane), plane_height,
                 plane_width, x, y, row, unit_row, current_process_unit_height,
                 process_unit_width, actual_window_width, plane_unit_size,
-                num_horizontal_units, nullptr, nullptr, nullptr);
+                num_horizontal_units);
           }
           ++job_count;
         }
-        {
-          // Wait for all jobs of current window to finish.
-          std::unique_lock<std::mutex> lock(mutex);
-          while (pending_jobs > 0) {
-            pending_jobs_zero_condvar.wait(lock);
-          }
-        }
-        // Copy |threaded_loop_restoration_buffer_| to output frame.
-        CopyPlane<Pixel>(threaded_loop_restoration_buffer_,
-                         kRestorationWindowWidth * sizeof(Pixel),
+        // Wait for all jobs of current window to finish.
+        pending_jobs.Wait();
+        // Copy |threaded_window_buffer_| to output frame.
+        CopyPlane<Pixel>(threaded_window_buffer_,
+                         window_buffer_width_ * pixel_size_,
                          actual_window_width, actual_window_height,
-                         src_buffer + y * src_stride + x * sizeof(Pixel) +
+                         src_buffer + y * src_stride + x * pixel_size_ +
                              src_unit_buffer_offset,
                          src_stride);
       }
@@ -811,7 +973,7 @@
 
 bool PostFilter::ApplyLoopRestoration() {
   if (thread_pool_ != nullptr) {
-    assert(threaded_loop_restoration_buffer_ != nullptr);
+    assert(threaded_window_buffer_ != nullptr);
 #if LIBGAV1_MAX_BITDEPTH >= 10
     if (bitdepth_ >= 10) {
       return ApplyLoopRestorationThreaded<uint16_t>();
@@ -821,10 +983,8 @@
   }
 
   if (!DoCdef()) cdef_buffer_ = source_buffer_;
-  const size_t pixel_size =
-      (bitdepth_ == 8) ? sizeof(uint8_t) : sizeof(uint16_t);
   const ptrdiff_t block_buffer_stride =
-      kRestorationProcessingUnitSizeWithBorders * pixel_size;
+      kRestorationProcessingUnitSizeWithBorders * pixel_size_;
   const int plane_process_unit_width[kMaxPlanes] = {
       kRestorationProcessingUnitSize,
       kRestorationProcessingUnitSize >> subsampling_x_,
@@ -833,8 +993,16 @@
       kRestorationProcessingUnitSize,
       kRestorationProcessingUnitSize >> subsampling_y_,
       kRestorationProcessingUnitSize >> subsampling_y_};
-  RestorationBuffer restoration_buffer;
-  PrepareRestorationBuffer(&restoration_buffer);
+  IntermediateBuffers intermediate_buffers;
+  RestorationBuffer restoration_buffer = {
+      {intermediate_buffers.box_filter.output[0],
+       intermediate_buffers.box_filter.output[1]},
+      plane_process_unit_width[kPlaneY],
+      {intermediate_buffers.box_filter.intermediate_a,
+       intermediate_buffers.box_filter.intermediate_b},
+      kRestorationProcessingUnitSizeWithBorders + kRestorationPadding,
+      intermediate_buffers.wiener,
+      kMaxSuperBlockSizeInPixels};
 
   for (int plane = kPlaneY; plane < planes_; ++plane) {
     if (loop_restoration_.type[plane] == kLoopRestorationTypeNone) {
@@ -887,10 +1055,10 @@
     }
 
     int loop_restored_rows = 0;
-    const int horizontal_shift = -source_buffer_->alignment() / pixel_size;
+    const int horizontal_shift = -source_buffer_->alignment() / pixel_size_;
     const int vertical_shift = -kRestorationBorder;
     const ptrdiff_t src_unit_buffer_offset =
-        vertical_shift * src_stride + horizontal_shift * pixel_size;
+        vertical_shift * src_stride + horizontal_shift * pixel_size_;
     for (int unit_row = 0; unit_row < num_vertical_units; ++unit_row) {
       int current_unit_height = plane_unit_size;
       // Note [1]: we need to identify the entire restoration area. So the
@@ -914,11 +1082,11 @@
                 ->loop_restoration_info(static_cast<Plane>(plane), unit_id)
                 .type;
         uint8_t* src_unit_buffer =
-            src_buffer + unit_column * plane_unit_size * pixel_size;
+            src_buffer + unit_column * plane_unit_size * pixel_size_;
         uint8_t* cdef_unit_buffer =
-            cdef_buffer + unit_column * plane_unit_size * pixel_size;
+            cdef_buffer + unit_column * plane_unit_size * pixel_size_;
         uint8_t* deblock_unit_buffer =
-            deblock_buffer + unit_column * plane_unit_size * pixel_size;
+            deblock_buffer + unit_column * plane_unit_size * pixel_size_;
 
         // Take care of the last column. The max width of last column unit
         // could be 3/2 unit_size.
@@ -930,7 +1098,7 @@
         if (type == kLoopRestorationTypeNone) {
           for (int y = 0; y < current_unit_height; ++y) {
             memcpy(src_unit_buffer + src_unit_buffer_offset, cdef_unit_buffer,
-                   current_unit_width * pixel_size);
+                   current_unit_width * pixel_size_);
             src_unit_buffer += src_stride;
             cdef_unit_buffer += cdef_buffer_stride;
           }
@@ -971,14 +1139,14 @@
             // To address this, we store the restored pixels not onto the start
             // of current block on the source frame buffer, say point A,
             // but to its top by three pixels and to the left by
-            // alignment/pixel_size pixels, say point B, such that
+            // alignment/pixel_size_ pixels, say point B, such that
             // next processing unit can fetch 3 pixel border of unrestored
             // values. And we need to adjust the input frame buffer pointer to
             // its left and top corner, point B.
             uint8_t* const cdef_process_unit_buffer =
-                cdef_unit_buffer + column * pixel_size;
+                cdef_unit_buffer + column * pixel_size_;
             uint8_t* const deblock_process_unit_buffer =
-                deblock_unit_buffer + column * pixel_size;
+                deblock_unit_buffer + column * pixel_size_;
             const bool frame_top_border = unit_row + row == 0;
             const bool frame_bottom_border =
                 (unit_row == num_vertical_units - 1) &&
@@ -1001,8 +1169,8 @@
             restoration_func(
                 reinterpret_cast<const uint8_t*>(
                     block_buffer_ + kRestorationBorder * block_buffer_stride +
-                    kRestorationBorder * pixel_size),
-                src_unit_buffer + column * pixel_size + src_unit_buffer_offset,
+                    kRestorationBorder * pixel_size_),
+                src_unit_buffer + column * pixel_size_ + src_unit_buffer_offset,
                 restoration_info_->loop_restoration_info(
                     static_cast<Plane>(plane), unit_id),
                 block_buffer_stride, src_stride, processing_unit_width,
@@ -1043,9 +1211,7 @@
   const int8_t subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
   const int row_step = 1 << subsampling_y;
   const int column_step = 1 << subsampling_x;
-  const size_t pixel_size =
-      (bitdepth_ == 8) ? sizeof(uint8_t) : sizeof(uint16_t);
-  const size_t src_step = 4 * pixel_size;
+  const size_t src_step = 4 * pixel_size_;
   const ptrdiff_t row_stride = MultiplyBy4(source_buffer_->stride(plane));
   const ptrdiff_t src_stride = source_buffer_->stride(plane);
   uint8_t* src = SetBufferOffset(source_buffer_, plane, row4x4_start,
@@ -1083,18 +1249,21 @@
     const int index = GetIndex(row);
     const int shift = GetShift(row, column);
     const int level_offset = LoopFilterMask::GetLevelOffset(row, column);
-    // TODO(chengchen): replace 0, 1, 2 to meaningful enum names.
-    // mask of current row. mask4x4 represents the vertical filter length for
+    // Mask of current row. mask4x4 represents the vertical filter length for
     // the current horizontal edge is 4, and we needs to apply 3-tap filtering.
     // Similarly, mask8x8 and mask16x16 represent filter lengths are 8 and 16.
     uint64_t mask4x4 =
-        (masks_->GetTop(unit_id, plane, 0 /*Tx4x4*/, index) >> shift) &
+        (masks_->GetTop(unit_id, plane, kLoopFilterTransformSizeId4x4, index) >>
+         shift) &
         single_row_mask;
     uint64_t mask8x8 =
-        (masks_->GetTop(unit_id, plane, 1 /*Tx8x8*/, index) >> shift) &
+        (masks_->GetTop(unit_id, plane, kLoopFilterTransformSizeId8x8, index) >>
+         shift) &
         single_row_mask;
     uint64_t mask16x16 =
-        (masks_->GetTop(unit_id, plane, 2 /*Tx16x16*/, index) >> shift) &
+        (masks_->GetTop(unit_id, plane, kLoopFilterTransformSizeId16x16,
+                        index) >>
+         shift) &
         single_row_mask;
     // mask4x4, mask8x8, mask16x16 are mutually exclusive.
     assert((mask4x4 & mask8x8) == 0 && (mask4x4 & mask16x16) == 0 &&
@@ -1133,7 +1302,6 @@
           if ((mask16x16 & two_block_mask) == two_block_mask) {
             edge_count = 2;
             // Apply filtering for two edges.
-            // TODO(chengchen): actual implementation to come.
             filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
                         hev_thresh_0);
             filter_func(src_row + src_step, src_stride, outer_thresh_1,
@@ -1152,7 +1320,6 @@
           if ((mask8x8 & two_block_mask) == two_block_mask) {
             edge_count = 2;
             // Apply filtering for two edges.
-            // TODO(chengchen): actual implementation to come.
             filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
                         hev_thresh_0);
             filter_func(src_row + src_step, src_stride, outer_thresh_1,
@@ -1170,7 +1337,6 @@
           if ((mask4x4 & two_block_mask) == two_block_mask) {
             edge_count = 2;
             // Apply filtering for two edges.
-            // TODO(chengchen): actual implementation to come.
             filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
                         hev_thresh_0);
             filter_func(src_row + src_step, src_stride, outer_thresh_1,
@@ -1189,7 +1355,7 @@
       mask16x16 >>= step;
       mask >>= step;
       column_offset += step;
-      src_row += MultiplyBy4(edge_count) * pixel_size;
+      src_row += MultiplyBy4(edge_count) * pixel_size_;
     }
     src += row_stride;
   }
@@ -1245,13 +1411,19 @@
     // the current vertical edge is 4, and we needs to apply 3-tap filtering.
     // Similarly, mask8x8 and mask16x16 represent filter lengths are 8 and 16.
     uint64_t mask4x4_0 =
-        (masks_->GetLeft(unit_id, plane, 0 /*Tx4x4*/, index) >> shift) &
+        (masks_->GetLeft(unit_id, plane, kLoopFilterTransformSizeId4x4,
+                         index) >>
+         shift) &
         single_row_mask;
     uint64_t mask8x8_0 =
-        (masks_->GetLeft(unit_id, plane, 1 /*Tx8x8*/, index) >> shift) &
+        (masks_->GetLeft(unit_id, plane, kLoopFilterTransformSizeId8x8,
+                         index) >>
+         shift) &
         single_row_mask;
     uint64_t mask16x16_0 =
-        (masks_->GetLeft(unit_id, plane, 2 /*Tx16x16*/, index) >> shift) &
+        (masks_->GetLeft(unit_id, plane, kLoopFilterTransformSizeId16x16,
+                         index) >>
+         shift) &
         single_row_mask;
     // mask4x4, mask8x8, mask16x16 are mutually exclusive.
     assert((mask4x4_0 & mask8x8_0) == 0 && (mask4x4_0 & mask16x16_0) == 0 &&
@@ -1260,15 +1432,18 @@
     // the corresponding SIMD function to apply filtering for two vertical
     // edges together.
     uint64_t mask4x4_1 =
-        (masks_->GetLeft(unit_id, plane, 0 /*Tx4x4*/, index_next) >>
+        (masks_->GetLeft(unit_id, plane, kLoopFilterTransformSizeId4x4,
+                         index_next) >>
          shift_next_row) &
         single_row_mask;
     uint64_t mask8x8_1 =
-        (masks_->GetLeft(unit_id, plane, 1 /*Tx8x8*/, index_next) >>
+        (masks_->GetLeft(unit_id, plane, kLoopFilterTransformSizeId8x8,
+                         index_next) >>
          shift_next_row) &
         single_row_mask;
     uint64_t mask16x16_1 =
-        (masks_->GetLeft(unit_id, plane, 2 /*Tx16x16*/, index_next) >>
+        (masks_->GetLeft(unit_id, plane, kLoopFilterTransformSizeId16x16,
+                         index_next) >>
          shift_next_row) &
         single_row_mask;
     // mask4x4, mask8x8, mask16x16 are mutually exclusive.
@@ -1309,8 +1484,6 @@
           const dsp::LoopFilterFunc filter_func = dsp_.loop_filters[size][type];
           if ((mask16x16_0 & mask16x16_1 & 1) != 0) {
             // Apply dual vertical edge filtering.
-            // TODO(chengchen): actual implementation to come, probably with
-            // functions to support filtering 4 edges.
             filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
                         hev_thresh_0);
             filter_func(src_row_next, src_stride, outer_thresh_1,
@@ -1330,8 +1503,6 @@
                                                : dsp::kLoopFilterSize6;
           const dsp::LoopFilterFunc filter_func = dsp_.loop_filters[size][type];
           if ((mask8x8_0 & mask8x8_1 & 1) != 0) {
-            // TODO(chengchen): actual implementation to come, probably with
-            // functions to support filtering 4 edges.
             filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
                         hev_thresh_0);
             filter_func(src_row_next, src_stride, outer_thresh_1,
@@ -1349,8 +1520,6 @@
           const dsp::LoopFilterSize size = dsp::kLoopFilterSize4;
           const dsp::LoopFilterFunc filter_func = dsp_.loop_filters[size][type];
           if ((mask4x4_0 & mask4x4_1 & 1) != 0) {
-            // TODO(chengchen): actual implementation to come, probably with
-            // functions to support filtering 4 edges.
             filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
                         hev_thresh_0);
             filter_func(src_row_next, src_stride, outer_thresh_1,
@@ -1404,22 +1573,4 @@
   *hev_thresh = hev_thresh_[level];
 }
 
-void PostFilter::PrepareRestorationBuffer(
-    RestorationBuffer* restoration_buffer) {
-  restoration_buffer->box_filter_process_output[0] =
-      intermediate_buffers_.box_filter.output[0];
-  restoration_buffer->box_filter_process_output[1] =
-      intermediate_buffers_.box_filter.output[1];
-  restoration_buffer->box_filter_process_intermediate[0] =
-      intermediate_buffers_.box_filter.intermediate_a;
-  restoration_buffer->box_filter_process_intermediate[1] =
-      intermediate_buffers_.box_filter.intermediate_b;
-  restoration_buffer->box_filter_process_intermediate_stride =
-      kRestorationProcessingUnitSizeWithBorders + kRestorationPadding;
-  restoration_buffer->wiener_buffer = intermediate_buffers_.wiener;
-  restoration_buffer->wiener_buffer_stride = kMaxSuperBlockSizeInPixels;
-  restoration_buffer->inter_round_bits[0] = bitdepth_ == 12 ? 5 : 3;
-  restoration_buffer->inter_round_bits[1] = bitdepth_ == 12 ? 9 : 11;
-}
-
 }  // namespace libgav1
diff --git a/libgav1/src/post_filter.h b/libgav1/src/post_filter.h
index 7120405..b6c9518 100644
--- a/libgav1/src/post_filter.h
+++ b/libgav1/src/post_filter.h
@@ -3,11 +3,11 @@
 
 #include <algorithm>
 #include <array>
-#include <condition_variable>  // NOLINT (unapproved c++11 header)
+#include <atomic>
 #include <cstddef>
 #include <cstdint>
 #include <cstring>
-#include <mutex>  // NOLINT (unapproved c++11 header)
+#include <type_traits>
 
 #include "src/dsp/common.h"
 #include "src/dsp/dsp.h"
@@ -18,6 +18,7 @@
 #include "src/utils/block_parameters_holder.h"
 #include "src/utils/common.h"
 #include "src/utils/constants.h"
+#include "src/utils/memory.h"
 #include "src/utils/threadpool.h"
 #include "src/yuv_buffer.h"
 
@@ -37,8 +38,7 @@
 // and restoration together to only filter frame buffer once.
 class PostFilter {
  public:
-  // Window width used for multi-threaded loop restoration.
-  static constexpr int kRestorationWindowWidth = 512;
+  static constexpr int kCdefLargeValue = 30000;
 
   // This class does not take ownership of the masks/restoration_info, but it
   // may change their values.
@@ -49,8 +49,7 @@
              BlockParametersHolder* block_parameters,
              YuvBuffer* const source_buffer, const dsp::Dsp* dsp,
              ThreadPool* const thread_pool,
-             uint8_t* const threaded_loop_restoration_buffer,
-             int do_post_filter_mask)
+             uint8_t* const threaded_window_buffer, int do_post_filter_mask)
       : frame_header_(frame_header),
         loop_restoration_(frame_header.loop_restoration),
         dsp_(*dsp),
@@ -65,16 +64,21 @@
         planes_(sequence_header.color_config.is_monochrome
                     ? kMaxPlanesMonochrome
                     : kMaxPlanes),
+        pixel_size_(static_cast<int>((bitdepth_ == 8) ? sizeof(uint8_t)
+                                                      : sizeof(uint16_t))),
         masks_(masks),
         cdef_index_(cdef_index),
-        threaded_loop_restoration_buffer_(threaded_loop_restoration_buffer),
+        threaded_window_buffer_(threaded_window_buffer),
         restoration_info_(restoration_info),
-        restoration_window_height_(
-            GetRestorationWindowHeight(thread_pool, frame_header)),
-        block_parameters_(block_parameters),
+        window_buffer_width_(GetWindowBufferWidth(thread_pool, frame_header)),
+        window_buffer_height_(GetWindowBufferHeight(thread_pool, frame_header)),
+        block_parameters_(*block_parameters),
         source_buffer_(source_buffer),
         do_post_filter_mask_(do_post_filter_mask),
-        thread_pool_(thread_pool) {}
+        thread_pool_(thread_pool) {
+    const int8_t zero_delta_lf[kFrameLfCount] = {};
+    ComputeDeblockFilterLevels(zero_delta_lf, deblock_filter_levels_);
+  }
 
   // non copyable/movable.
   PostFilter(const PostFilter&) = delete;
@@ -118,13 +122,15 @@
   //                Inputs are source_buffer_ and cdef_buffer_.
   //                Ouput is source_buffer_.
   bool ApplyFiltering();
-  bool DoCdef() const {
-    return (do_post_filter_mask_ & 0x02) != 0 &&
-           (frame_header_.cdef.bits > 0 ||
-            frame_header_.cdef.y_primary_strength[0] > 0 ||
-            frame_header_.cdef.y_secondary_strength[0] > 0 ||
-            frame_header_.cdef.uv_primary_strength[0] > 0 ||
-            frame_header_.cdef.uv_secondary_strength[0] > 0);
+  bool DoCdef() const { return DoCdef(frame_header_, do_post_filter_mask_); }
+  static bool DoCdef(const ObuFrameHeader& frame_header,
+                     int do_post_filter_mask) {
+    return (do_post_filter_mask & 0x02) != 0 &&
+           (frame_header.cdef.bits > 0 ||
+            frame_header.cdef.y_primary_strength[0] > 0 ||
+            frame_header.cdef.y_secondary_strength[0] > 0 ||
+            frame_header.cdef.uv_primary_strength[0] > 0 ||
+            frame_header.cdef.uv_secondary_strength[0] > 0);
   }
   // If filter levels for Y plane (0 for vertical, 1 for horizontal),
   // are all zero, deblock filter will not be applied.
@@ -137,6 +143,17 @@
   bool DoDeblock() const {
     return DoDeblock(frame_header_, do_post_filter_mask_);
   }
+  uint8_t GetZeroDeltaDeblockFilterLevel(int segment_id, int level_index,
+                                         ReferenceFrameType type,
+                                         int mode_id) const {
+    return deblock_filter_levels_[segment_id][level_index][type][mode_id];
+  }
+  // Computes the deblock filter levels using |delta_lf| and stores them in
+  // |deblock_filter_levels|.
+  void ComputeDeblockFilterLevels(
+      const int8_t delta_lf[kFrameLfCount],
+      uint8_t deblock_filter_levels[kMaxSegments][kFrameLfCount]
+                                   [kNumReferenceFrameTypes][2]) const;
   bool DoRestoration() const;
   // Returns true if loop restoration will be performed for the given parameters
   // and mask.
@@ -172,45 +189,82 @@
                            ptrdiff_t stride, int left, int right, int top,
                            int bottom);
 
-  // For multi-threaded loop restoration, window height is the minimum of the
-  // following two quantities:
+  static int GetWindowBufferWidth(const ThreadPool* const thread_pool,
+                                  const ObuFrameHeader& frame_header) {
+    return (thread_pool == nullptr) ? 0
+                                    : Align(frame_header.upscaled_width, 64);
+  }
+
+  // For multi-threaded cdef and loop restoration, window height is the minimum
+  // of the following two quantities:
   //  1) thread_count * 64
   //  2) frame_height rounded up to the nearest power of 64
-  // Where 64 is the block size for loop restoration.
-  static int GetRestorationWindowHeight(const ThreadPool* const thread_pool,
-                                        const ObuFrameHeader& frame_header) {
+  // Where 64 is the block size for cdef and loop restoration.
+  static int GetWindowBufferHeight(const ThreadPool* const thread_pool,
+                                   const ObuFrameHeader& frame_header) {
     if (thread_pool == nullptr) return 0;
     const int thread_count = 1 + thread_pool->num_threads();
     const int window_height = MultiplyBy64(thread_count);
-    const int adjusted_frame_height = (frame_header.height + 63) & ~63;
+    const int adjusted_frame_height = Align(frame_header.height, 64);
     return std::min(adjusted_frame_height, window_height);
   }
 
  private:
+  // The type of the HorizontalDeblockFilter and VerticalDeblockFilter member
+  // functions.
+  using DeblockFilter = void (PostFilter::*)(Plane plane, int row4x4_start,
+                                             int column4x4_start, int unit_id);
+  // Represents a job for a worker thread to apply the deblock filter.
+  struct DeblockFilterJob : public Allocable {
+    int plane;
+    int row4x4;
+    int row_unit;
+  };
   // Buffers for loop restoration intermediate results. Depending on the filter
   // type, only one member of the union is used.
   union IntermediateBuffers {
     // For Wiener filter.
     // The array |intermediate| in Section 7.17.4, the intermediate results
     // between the horizontal and vertical filters.
-    alignas(
-        32) uint16_t wiener[(kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1) *
-                            kMaxSuperBlockSizeInPixels];
+    alignas(kMaxAlignment)
+        uint16_t wiener[(kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1) *
+                        kMaxSuperBlockSizeInPixels];
     // For self-guided filter.
     struct {
       // The arrays flt0 and flt1 in Section 7.17.2, the outputs of the box
       // filter process in pass 0 and pass 1.
-      alignas(32) int32_t output[2][kMaxBoxFilterProcessOutputPixels];
+      alignas(
+          kMaxAlignment) int32_t output[2][kMaxBoxFilterProcessOutputPixels];
       // The 2d arrays A and B in Section 7.17.3, the intermediate results in
       // the box filter process. Reused for pass 0 and pass 1.
-      alignas(32) uint32_t intermediate_a[kBoxFilterProcessIntermediatePixels];
-      alignas(32) uint32_t intermediate_b[kBoxFilterProcessIntermediatePixels];
+      alignas(kMaxAlignment) uint32_t
+          intermediate_a[kBoxFilterProcessIntermediatePixels];
+      alignas(kMaxAlignment) uint32_t
+          intermediate_b[kBoxFilterProcessIntermediatePixels];
     } box_filter;
   };
 
   bool ApplyDeblockFilter();
+  void DeblockFilterWorker(const DeblockFilterJob* jobs, int num_jobs,
+                           std::atomic<int>* job_counter,
+                           DeblockFilter deblock_filter);
   bool ApplyDeblockFilterThreaded();
+
+  uint8_t* GetCdefBufferAndStride(int start_x, int start_y, int plane,
+                                  int subsampling_x, int subsampling_y,
+                                  int window_buffer_plane_size,
+                                  int vertical_shift, int horizontal_shift,
+                                  int* cdef_stride);
+  template <typename Pixel>
+  void ApplyCdefForOneUnit(uint16_t* cdef_block, int index, int block_width4x4,
+                           int block_height4x4, int row4x4_start,
+                           int column4x4_start);
+  template <typename Pixel>
+  void ApplyCdefForOneRowInWindow(int row, int column);
+  template <typename Pixel>
+  bool ApplyCdefThreaded();
   bool ApplyCdef();  // Sections 7.15 and 7.15.1.
+
   bool ApplySuperRes();
   // Note for ApplyLoopRestoration():
   // First, we must differentiate loop restoration processing unit from loop
@@ -252,8 +306,7 @@
       uint8_t* deblock_buffer, ptrdiff_t deblock_buffer_stride, Plane plane,
       int plane_height, int plane_width, int x, int y, int row, int unit_row,
       int current_process_unit_height, int process_unit_width, int window_width,
-      int plane_unit_size, int num_horizontal_units, std::mutex* mutex,
-      int* pending_jobs, std::condition_variable* pending_jobs_zero_condvar);
+      int plane_unit_size, int num_horizontal_units);
   template <typename Pixel>
   void ApplyLoopRestorationForOneUnit(
       uint8_t* cdef_buffer, ptrdiff_t cdef_buffer_stride,
@@ -273,10 +326,17 @@
                                int column4x4_start, int unit_id);
   void VerticalDeblockFilter(Plane plane, int row4x4_start, int column4x4_start,
                              int unit_id);
+  // HorizontalDeblockFilter and VerticalDeblockFilter must have the correct
+  // signature.
+  static_assert(std::is_same<decltype(&PostFilter::HorizontalDeblockFilter),
+                             DeblockFilter>::value,
+                "");
+  static_assert(std::is_same<decltype(&PostFilter::VerticalDeblockFilter),
+                             DeblockFilter>::value,
+                "");
   void InitDeblockFilterParams();  // Part of 7.14.4.
   void GetDeblockFilterParams(uint8_t level, int* outer_thresh,
                               int* inner_thresh, int* hev_thresh) const;
-  void PrepareRestorationBuffer(RestorationBuffer* restoration_buffer);
   // Applies super resolution and writes result to input_buffer.
   void FrameSuperRes(YuvBuffer* input_buffer);  // Section 7.16.
 
@@ -291,21 +351,31 @@
   const int8_t subsampling_x_;
   const int8_t subsampling_y_;
   const int8_t planes_;
+  const int pixel_size_;
   // This class does not take ownership of the masks/restoration_info, but it
   // could change their values.
   LoopFilterMask* const masks_;
   uint8_t inner_thresh_[kMaxLoopFilterValue + 1] = {};
   uint8_t outer_thresh_[kMaxLoopFilterValue + 1] = {};
   uint8_t hev_thresh_[kMaxLoopFilterValue + 1] = {};
+  // This stores the deblocking filter levels assuming that the delta is zero.
+  // This will be used by all superblocks whose delta is zero (without having to
+  // recompute them). The dimensions (in order) are: segment_id, level_index
+  // (based on plane and direction), reference_frame and mode_id.
+  uint8_t deblock_filter_levels_[kMaxSegments][kFrameLfCount]
+                                [kNumReferenceFrameTypes][2];
   const Array2D<int16_t>& cdef_index_;
-  // Pointer to the data buffer used for multi-threaded loop restoration. The
-  // size of this buffer must be at least kRestorationWindowWidth *
-  // |restoration_window_height_| * sizeof(Pixel). If |thread_pool_| is nullptr,
-  // then this buffer is not used and can be nullptr as well.
-  uint8_t* const threaded_loop_restoration_buffer_;
+  // Pointer to the data buffer used for multi-threaded cdef or loop
+  // restoration. The size of this buffer must be at least
+  // |window_buffer_width_| * |window_buffer_height_| * |pixel_size_|.
+  // Or |planes_| times that for multi-threaded cdef.
+  // If |thread_pool_| is nullptr, then this buffer is not used and can be
+  // nullptr as well.
+  uint8_t* const threaded_window_buffer_;
   LoopRestorationInfo* const restoration_info_;
-  const int restoration_window_height_;
-  BlockParametersHolder* const block_parameters_;
+  const int window_buffer_width_;
+  const int window_buffer_height_;
+  const BlockParametersHolder& block_parameters_;
   // Frame buffer to hold cdef filtered frame.
   YuvBuffer cdef_filtered_buffer_;
   // Frame buffer to hold the copy of the buffer to be upscaled,
@@ -323,7 +393,6 @@
 
   ThreadPool* const thread_pool_;
 
-  IntermediateBuffers intermediate_buffers_;
   // A small buffer to hold input source image block for loop restoration.
   // Its size is one processing unit size + borders.
   // Self-guided filter needs an extra one-pixel border.
@@ -333,6 +402,10 @@
       block_buffer_[kRestorationProcessingUnitSizeWithBorders *
                     kRestorationProcessingUnitSizeWithBorders *
                     sizeof(uint16_t)];
+  // A block buffer to hold the input that is converted to uint16_t before
+  // cdef filtering. Only used in single threaded case.
+  uint16_t cdef_block_[kRestorationProcessingUnitSizeWithBorders *
+                       kRestorationProcessingUnitSizeWithBorders * 3];
 
   template <int bitdepth, typename Pixel>
   friend class PostFilterSuperResTest;
@@ -409,6 +482,104 @@
   }
 }
 
+template <typename Pixel>
+void CopyRows(const Pixel* src, const ptrdiff_t src_stride,
+              const int block_width, const int unit_width,
+              const bool is_frame_top, const bool is_frame_bottom,
+              const bool is_frame_left, const bool is_frame_right,
+              const bool copy_top, const int num_rows, uint16_t* dst,
+              const ptrdiff_t dst_stride) {
+  if (is_frame_top || is_frame_bottom) {
+    if (is_frame_bottom) dst -= kCdefBorder;
+    for (int y = 0; y < num_rows; ++y) {
+      Memset(dst, PostFilter::kCdefLargeValue, unit_width + 2 * kCdefBorder);
+      dst += dst_stride;
+    }
+  } else {
+    if (copy_top) {
+      src -= kCdefBorder * src_stride;
+      dst += kCdefBorder;
+    }
+    for (int y = 0; y < num_rows; ++y) {
+      for (int x = -kCdefBorder; x < 0; ++x) {
+        dst[x] = is_frame_left ? PostFilter::kCdefLargeValue : src[x];
+      }
+      for (int x = 0; x < block_width; ++x) {
+        dst[x] = src[x];
+      }
+      for (int x = block_width; x < unit_width + kCdefBorder; ++x) {
+        dst[x] = is_frame_right ? PostFilter::kCdefLargeValue : src[x];
+      }
+      dst += dst_stride;
+      src += src_stride;
+    }
+  }
+}
+
+// This function prepares the input source block for cdef filtering.
+// The input source block contains a 12x12 block, with the inner 8x8 as the
+// desired filter region.
+// It pads the block if the 12x12 block includes out of frame pixels with
+// a large value.
+// This achieves the required behavior defined in section 5.11.52 of the spec.
+template <typename Pixel>
+void PrepareCdefBlock(const YuvBuffer* const source_buffer, const int planes,
+                      const int subsampling_x, const int subsampling_y,
+                      const int frame_width, const int frame_height,
+                      const int block_width4x4, const int block_height4x4,
+                      const int row_64x64, const int column_64x64,
+                      uint16_t* cdef_source, const ptrdiff_t cdef_stride) {
+  for (int plane = kPlaneY; plane < planes; ++plane) {
+    uint16_t* cdef_src =
+        cdef_source + plane * kRestorationProcessingUnitSizeWithBorders *
+                          kRestorationProcessingUnitSizeWithBorders;
+    const int plane_subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x;
+    const int plane_subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y;
+    const int start_x = MultiplyBy4(column_64x64) >> plane_subsampling_x;
+    const int start_y = MultiplyBy4(row_64x64) >> plane_subsampling_y;
+    const int plane_width =
+        RightShiftWithRounding(frame_width, plane_subsampling_x);
+    const int plane_height =
+        RightShiftWithRounding(frame_height, plane_subsampling_y);
+    const int block_width = MultiplyBy4(block_width4x4) >> plane_subsampling_x;
+    const int block_height =
+        MultiplyBy4(block_height4x4) >> plane_subsampling_y;
+    // unit_width, unit_height are the same as block_width, block_height unless
+    // it reaches the frame boundary, where block_width < 64 or
+    // block_height < 64. unit_width, unit_height guarantee we build blocks on
+    // a multiple of 8.
+    const int unit_width =
+        Align(block_width, (plane_subsampling_x > 0) ? 4 : 8);
+    const int unit_height =
+        Align(block_height, (plane_subsampling_y > 0) ? 4 : 8);
+    const bool is_frame_left = column_64x64 == 0;
+    const bool is_frame_right = start_x + block_width >= plane_width;
+    const bool is_frame_top = row_64x64 == 0;
+    const bool is_frame_bottom = start_y + block_height >= plane_height;
+    const int src_stride = source_buffer->stride(plane) / sizeof(Pixel);
+    const Pixel* src_buffer =
+        reinterpret_cast<const Pixel*>(source_buffer->data(plane)) +
+        start_y * src_stride + start_x;
+    // Copy to the top 2 rows.
+    CopyRows(src_buffer, src_stride, block_width, unit_width, is_frame_top,
+             false, is_frame_left, is_frame_right, true, kCdefBorder, cdef_src,
+             cdef_stride);
+    cdef_src += kCdefBorder * cdef_stride + kCdefBorder;
+
+    // Copy the body.
+    CopyRows(src_buffer, src_stride, block_width, unit_width, false, false,
+             is_frame_left, is_frame_right, false, block_height, cdef_src,
+             cdef_stride);
+    src_buffer += block_height * src_stride;
+    cdef_src += block_height * cdef_stride;
+
+    // Copy to bottom rows.
+    CopyRows(src_buffer, src_stride, block_width, unit_width, false,
+             is_frame_bottom, is_frame_left, is_frame_right, false,
+             kCdefBorder + unit_height - block_height, cdef_src, cdef_stride);
+  }
+}
+
 }  // namespace libgav1
 
 #endif  // LIBGAV1_SRC_POST_FILTER_H_
diff --git a/libgav1/src/prediction_mask.cc b/libgav1/src/prediction_mask.cc
index b77ba0c..8a2a63f 100644
--- a/libgav1/src/prediction_mask.cc
+++ b/libgav1/src/prediction_mask.cc
@@ -17,7 +17,7 @@
 namespace libgav1 {
 namespace {
 
-const int kWedgeDirectionTypes = 16;
+constexpr int kWedgeDirectionTypes = 16;
 
 enum kWedgeDirection : uint8_t {
   kWedgeHorizontal,
@@ -77,7 +77,7 @@
                                                {kWedgeOblique117, 2, 4},
                                                {kWedgeOblique117, 6, 4}}};
 
-const uint8_t kWedgeFlipSignLookup[9][16] = {
+constexpr uint8_t kWedgeFlipSignLookup[9][16] = {
     {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1},  // kBlock8x8
     {1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1},  // kBlock8x16
     {1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1},  // kBlock8x32
@@ -89,19 +89,19 @@
     {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1},  // kBlock32x32
 };
 
-const uint8_t kWedgeMasterObliqueOdd[kWedgeMaskMasterSize] = {
+constexpr uint8_t kWedgeMasterObliqueOdd[kWedgeMaskMasterSize] = {
     0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
     0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  2,  6,  18,
     37, 53, 60, 63, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
     64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64};
 
-const uint8_t kWedgeMasterObliqueEven[kWedgeMaskMasterSize] = {
+constexpr uint8_t kWedgeMasterObliqueEven[kWedgeMaskMasterSize] = {
     0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
     0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  4,  11, 27,
     46, 58, 62, 63, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
     64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64};
 
-const uint8_t kWedgeMasterVertical[kWedgeMaskMasterSize] = {
+constexpr uint8_t kWedgeMasterVertical[kWedgeMaskMasterSize] = {
     0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
     0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  2,  7,  21,
     43, 57, 62, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
@@ -140,10 +140,14 @@
                           const ptrdiff_t stride_1,
                           const uint16_t* prediction_2,
                           const ptrdiff_t stride_2, const int bitdepth,
-                          const bool mask_is_inverse,
-                          const int inter_post_round_bits, const int width,
+                          const bool mask_is_inverse, const int width,
                           const int height, uint8_t* mask,
                           const ptrdiff_t mask_stride) {
+#if LIBGAV1_MAX_BITDEPTH == 12
+  const int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
+#else
+  constexpr int inter_post_round_bits = 4;
+#endif
   for (int y = 0; y < height; ++y) {
     for (int x = 0; x < width; ++x) {
       const int rounding_bits = bitdepth - 8 + inter_post_round_bits;
@@ -232,10 +236,14 @@
   int block_size_index = 0;
   int wedge_masks_offset = 0;
   for (int size = kBlock8x8; size <= kBlock32x32; ++size) {
+    if (!kIsWedgeCompoundModeAllowed.Contains(size)) continue;
+
     const int width = kBlockWidthPixels[size];
     const int height = kBlockHeightPixels[size];
-    assert(width >= 8 && width <= 32);
-    if (height < 8 || height > 32) continue;
+    assert(width >= 8);
+    assert(width <= 32);
+    assert(height >= 8);
+    assert(height <= 32);
 
     const auto block_size = static_cast<BlockSize>(size);
     for (int index = 0; index < kWedgeDirectionTypes; ++index) {
@@ -280,12 +288,10 @@
                         const ptrdiff_t stride_1,
                         const uint16_t* const prediction_2,
                         const ptrdiff_t stride_2, const bool mask_is_inverse,
-                        const int inter_post_round_bits, const int width,
-                        const int height, const int bitdepth,
+                        const int width, const int height, const int bitdepth,
                         uint8_t* const mask, const ptrdiff_t mask_stride) {
   DifferenceWeightMask(prediction_1, stride_1, prediction_2, stride_2, bitdepth,
-                       mask_is_inverse, inter_post_round_bits, width, height,
-                       mask, mask_stride);
+                       mask_is_inverse, width, height, mask, mask_stride);
 }
 
 void GenerateInterIntraMask(const int mode, const int width, const int height,
diff --git a/libgav1/src/prediction_mask.h b/libgav1/src/prediction_mask.h
index 86b083d..ce6c5cc 100644
--- a/libgav1/src/prediction_mask.h
+++ b/libgav1/src/prediction_mask.h
@@ -4,10 +4,17 @@
 #include <cstddef>
 #include <cstdint>
 
+#include "src/utils/bit_mask_set.h"
 #include "src/utils/types.h"
 
 namespace libgav1 {
 
+constexpr BitMaskSet kIsWedgeCompoundModeAllowed(kBlock8x8, kBlock8x16,
+                                                 kBlock8x32, kBlock16x8,
+                                                 kBlock16x16, kBlock16x32,
+                                                 kBlock32x8, kBlock32x16,
+                                                 kBlock32x32);
+
 // This function generates wedge masks. It should be called only once for the
 // decoder. If the video is key frame only, we don't have to call this
 // function.
@@ -18,9 +25,8 @@
 // 7.11.3.12.
 void GenerateWeightMask(const uint16_t* prediction_1, ptrdiff_t stride_1,
                         const uint16_t* prediction_2, ptrdiff_t stride_2,
-                        bool mask_is_inverse, int inter_post_round_bits,
-                        int width, int height, int bitdepth, uint8_t* mask,
-                        ptrdiff_t mask_stride);
+                        bool mask_is_inverse, int width, int height,
+                        int bitdepth, uint8_t* mask, ptrdiff_t mask_stride);
 
 // 7.11.3.13.
 void GenerateInterIntraMask(int mode, int width, int height, uint8_t* mask,
diff --git a/libgav1/src/quantizer.h b/libgav1/src/quantizer.h
index 5478167..8e87ab0 100644
--- a/libgav1/src/quantizer.h
+++ b/libgav1/src/quantizer.h
@@ -10,7 +10,8 @@
 
 // Stores the quantization parameters of Section 5.9.12.
 struct QuantizerParameters {
-  int16_t base_index;
+  // base_index is in the range [0, 255].
+  uint8_t base_index;
   int8_t delta_dc[kMaxPlanes];
   // delta_ac[kPlaneY] is always 0.
   int8_t delta_ac[kMaxPlanes];
diff --git a/libgav1/src/reconstruction.cc b/libgav1/src/reconstruction.cc
index a113557..83968ed 100644
--- a/libgav1/src/reconstruction.cc
+++ b/libgav1/src/reconstruction.cc
@@ -38,9 +38,9 @@
 
 template <typename Residual, typename Pixel>
 void Reconstruct(const dsp::Dsp& dsp, TransformType tx_type,
-                 TransformSize tx_size, int8_t bitdepth, bool lossless,
-                 Residual* const buffer, int start_x, int start_y,
-                 Array2DView<Pixel>* frame, int16_t non_zero_coeff_count) {
+                 TransformSize tx_size, bool lossless, Residual* const buffer,
+                 int start_x, int start_y, Array2DView<Pixel>* frame,
+                 int16_t non_zero_coeff_count) {
   static_assert(sizeof(Residual) == 2 || sizeof(Residual) == 4, "");
   const int tx_width_log2 = kTransformWidthLog2[tx_size];
   const int tx_height_log2 = kTransformHeightLog2[tx_size];
@@ -54,8 +54,8 @@
       dsp.inverse_transforms[row_transform_size][row_transform];
   assert(row_transform_func != nullptr);
 
-  row_transform_func(tx_type, tx_size, bitdepth, buffer, start_x, start_y,
-                     frame, /*is_row=*/true, non_zero_coeff_count);
+  row_transform_func(tx_type, tx_size, buffer, start_x, start_y, frame,
+                     /*is_row=*/true, non_zero_coeff_count);
 
   // Column transform.
   const dsp::TransformSize1D column_transform_size =
@@ -66,19 +66,18 @@
       dsp.inverse_transforms[column_transform_size][column_transform];
   assert(column_transform_func != nullptr);
 
-  column_transform_func(tx_type, tx_size, bitdepth, buffer, start_x, start_y,
-                        frame, /*is_row=*/false, non_zero_coeff_count);
+  column_transform_func(tx_type, tx_size, buffer, start_x, start_y, frame,
+                        /*is_row=*/false, non_zero_coeff_count);
 }
 
 template void Reconstruct(const dsp::Dsp& dsp, TransformType tx_type,
-                          TransformSize tx_size, int8_t bitdepth, bool lossless,
-                          int16_t* buffer, int start_x, int start_y,
-                          Array2DView<uint8_t>* frame,
+                          TransformSize tx_size, bool lossless, int16_t* buffer,
+                          int start_x, int start_y, Array2DView<uint8_t>* frame,
                           int16_t non_zero_coeff_count);
 #if LIBGAV1_MAX_BITDEPTH >= 10
 template void Reconstruct(const dsp::Dsp& dsp, TransformType tx_type,
-                          TransformSize tx_size, int8_t bitdepth, bool lossless,
-                          int32_t* buffer, int start_x, int start_y,
+                          TransformSize tx_size, bool lossless, int32_t* buffer,
+                          int start_x, int start_y,
                           Array2DView<uint16_t>* frame,
                           int16_t non_zero_coeff_count);
 #endif
diff --git a/libgav1/src/reconstruction.h b/libgav1/src/reconstruction.h
index 8b4849e..e73f567 100644
--- a/libgav1/src/reconstruction.h
+++ b/libgav1/src/reconstruction.h
@@ -17,20 +17,20 @@
 // transform block size |tx_size| starting at position |start_x| and |start_y|.
 template <typename Residual, typename Pixel>
 void Reconstruct(const dsp::Dsp& dsp, TransformType tx_type,
-                 TransformSize tx_size, int8_t bitdepth, bool lossless,
-                 Residual* buffer, int start_x, int start_y,
-                 Array2DView<Pixel>* frame, int16_t non_zero_coeff_count);
+                 TransformSize tx_size, bool lossless, Residual* buffer,
+                 int start_x, int start_y, Array2DView<Pixel>* frame,
+                 int16_t non_zero_coeff_count);
 
 extern template void Reconstruct(const dsp::Dsp& dsp, TransformType tx_type,
-                                 TransformSize tx_size, int8_t bitdepth,
-                                 bool lossless, int16_t* buffer, int start_x,
-                                 int start_y, Array2DView<uint8_t>* frame,
+                                 TransformSize tx_size, bool lossless,
+                                 int16_t* buffer, int start_x, int start_y,
+                                 Array2DView<uint8_t>* frame,
                                  int16_t non_zero_coeff_count);
 #if LIBGAV1_MAX_BITDEPTH >= 10
 extern template void Reconstruct(const dsp::Dsp& dsp, TransformType tx_type,
-                                 TransformSize tx_size, int8_t bitdepth,
-                                 bool lossless, int32_t* buffer, int start_x,
-                                 int start_y, Array2DView<uint16_t>* frame,
+                                 TransformSize tx_size, bool lossless,
+                                 int32_t* buffer, int start_x, int start_y,
+                                 Array2DView<uint16_t>* frame,
                                  int16_t non_zero_coeff_count);
 #endif
 
diff --git a/libgav1/src/residual_buffer_pool.cc b/libgav1/src/residual_buffer_pool.cc
index 3e16a4b..c21bb9a 100644
--- a/libgav1/src/residual_buffer_pool.cc
+++ b/libgav1/src/residual_buffer_pool.cc
@@ -1,7 +1,6 @@
 #include "src/residual_buffer_pool.h"
 
-#include <algorithm>
-#include <new>
+#include <mutex>  // NOLINT (unapproved c++11 header)
 #include <utility>
 
 namespace libgav1 {
@@ -37,6 +36,36 @@
 
 }  // namespace
 
+ResidualBufferStack::~ResidualBufferStack() {
+  while (top_ != nullptr) {
+    ResidualBuffer* top = top_;
+    top_ = top_->next_;
+    delete top;
+  }
+}
+
+void ResidualBufferStack::Push(std::unique_ptr<ResidualBuffer> buffer) {
+  buffer->next_ = top_;
+  top_ = buffer.release();
+  ++num_buffers_;
+}
+
+std::unique_ptr<ResidualBuffer> ResidualBufferStack::Pop() {
+  std::unique_ptr<ResidualBuffer> top;
+  if (top_ != nullptr) {
+    top.reset(top_);
+    top_ = top_->next_;
+    top->next_ = nullptr;
+    --num_buffers_;
+  }
+  return top;
+}
+
+void ResidualBufferStack::Swap(ResidualBufferStack* other) {
+  std::swap(top_, other->top_);
+  std::swap(num_buffers_, other->num_buffers_);
+}
+
 ResidualBufferPool::ResidualBufferPool(bool use_128x128_superblock,
                                        int subsampling_x, int subsampling_y,
                                        size_t residual_size)
@@ -61,36 +90,39 @@
   queue_size_ = queue_size;
   // The existing buffers (if any) are no longer valid since the buffer size or
   // the queue size has changed. Clear the stack.
-  std::lock_guard<std::mutex> lock(mutex_);
-  while (!buffers_.empty()) {
-    buffers_.pop();
+  ResidualBufferStack buffers;
+  {
+    std::lock_guard<std::mutex> lock(mutex_);
+    // Move the buffers in the stack to the local variable |buffers| and clear
+    // the stack.
+    buffers.Swap(&buffers_);
+    // Release mutex_ before freeing the buffers.
   }
+  // As the local variable |buffers| goes out of scope, its destructor frees
+  // the buffers that were in the stack.
 }
 
 std::unique_ptr<ResidualBuffer> ResidualBufferPool::Get() {
   std::unique_ptr<ResidualBuffer> buffer = nullptr;
   {
     std::lock_guard<std::mutex> lock(mutex_);
-    if (!buffers_.empty()) {
-      buffer = std::move(buffers_.top());
-      buffers_.pop();
-    }
+    buffer = buffers_.Pop();
   }
   if (buffer == nullptr) {
-    buffer.reset(new (std::nothrow) ResidualBuffer(buffer_size_, queue_size_));
+    buffer = ResidualBuffer::Create(buffer_size_, queue_size_);
   }
   return buffer;
 }
 
 void ResidualBufferPool::Release(std::unique_ptr<ResidualBuffer> buffer) {
-  buffer->transform_parameters.Reset();
+  buffer->transform_parameters()->Reset();
   std::lock_guard<std::mutex> lock(mutex_);
-  buffers_.push(std::move(buffer));
+  buffers_.Push(std::move(buffer));
 }
 
 size_t ResidualBufferPool::Size() const {
   std::lock_guard<std::mutex> lock(mutex_);
-  return buffers_.size();
+  return buffers_.Size();
 }
 
 }  // namespace libgav1
diff --git a/libgav1/src/residual_buffer_pool.h b/libgav1/src/residual_buffer_pool.h
index a2751bd..ab9c1e6 100644
--- a/libgav1/src/residual_buffer_pool.h
+++ b/libgav1/src/residual_buffer_pool.h
@@ -6,7 +6,6 @@
 #include <memory>
 #include <mutex>  // NOLINT (unapproved c++11 header)
 #include <new>
-#include <stack>
 
 #include "src/utils/common.h"
 #include "src/utils/compiler_attributes.h"
@@ -21,17 +20,21 @@
 // boundary checks since we always push data into the queue before accessing it.
 class TransformParameterQueue {
  public:
-  explicit TransformParameterQueue(int max_size) : max_size_(max_size) {
-    // No initialization is necessary since the data will be always written to
-    // before being read.
-    non_zero_coeff_count_.reset(new (std::nothrow) int16_t[max_size_]);
-    tx_type_.reset(new (std::nothrow) TransformType[max_size_]);
-  }
+  TransformParameterQueue() = default;
 
   // Move only.
   TransformParameterQueue(TransformParameterQueue&& other) = default;
   TransformParameterQueue& operator=(TransformParameterQueue&& other) = default;
 
+  LIBGAV1_MUST_USE_RESULT bool Init(int max_size) {
+    max_size_ = max_size;
+    // No initialization is necessary since the data will be always written to
+    // before being read.
+    non_zero_coeff_count_.reset(new (std::nothrow) int16_t[max_size_]);
+    tx_type_.reset(new (std::nothrow) TransformType[max_size_]);
+    return non_zero_coeff_count_ != nullptr && tx_type_ != nullptr;
+  }
+
   // Adds the |non_zero_coeff_count| and the |tx_type| to the back of the queue.
   void Push(int16_t non_zero_coeff_count, TransformType tx_type) {
     assert(back_ < max_size_);
@@ -65,22 +68,31 @@
   }
 
   // Used only in the tests. Returns the number of elements in the queue.
-  int Size() { return back_ - front_; }
+  int Size() const { return back_ - front_; }
 
  private:
-  const int max_size_;
+  int max_size_ = 0;
   std::unique_ptr<int16_t[]> non_zero_coeff_count_;
   std::unique_ptr<TransformType[]> tx_type_;
   int front_ = 0;
   int back_ = 0;
 };
 
-// This struct is used for parsing and decoding a superblock. Members of this
-// struct are populated in the "parse" step and consumed in the "decode" step.
-struct ResidualBuffer : public Allocable {
-  ResidualBuffer(size_t buffer_size, int queue_size)
-      : transform_parameters(queue_size) {
-    buffer = MakeAlignedUniquePtr<uint8_t>(32, buffer_size);
+// This class is used for parsing and decoding a superblock. Members of this
+// class are populated in the "parse" step and consumed in the "decode" step.
+class ResidualBuffer : public Allocable {
+ public:
+  static std::unique_ptr<ResidualBuffer> Create(size_t buffer_size,
+                                                int queue_size) {
+    std::unique_ptr<ResidualBuffer> buffer(new (std::nothrow) ResidualBuffer);
+    if (buffer != nullptr) {
+      buffer->buffer_ = MakeAlignedUniquePtr<uint8_t>(32, buffer_size);
+      if (buffer->buffer_ == nullptr ||
+          !buffer->transform_parameters_.Init(queue_size)) {
+        buffer = nullptr;
+      }
+    }
+    return buffer;
   }
 
   // Move only.
@@ -88,9 +100,52 @@
   ResidualBuffer& operator=(ResidualBuffer&& other) = default;
 
   // Buffer used to store the residual values.
-  AlignedUniquePtr<uint8_t> buffer;
+  uint8_t* buffer() { return buffer_.get(); }
   // Queue used to store the transform parameters.
-  TransformParameterQueue transform_parameters;
+  TransformParameterQueue* transform_parameters() {
+    return &transform_parameters_;
+  }
+
+ private:
+  friend class ResidualBufferStack;
+
+  ResidualBuffer() = default;
+
+  AlignedUniquePtr<uint8_t> buffer_;
+  TransformParameterQueue transform_parameters_;
+  // Used by ResidualBufferStack to form a chain of ResidualBuffers.
+  ResidualBuffer* next_ = nullptr;
+};
+
+// A LIFO stack of ResidualBuffers. Owns the buffers in the stack.
+class ResidualBufferStack {
+ public:
+  ResidualBufferStack() = default;
+
+  // Not copyable or movable
+  ResidualBufferStack(const ResidualBufferStack&) = delete;
+  ResidualBufferStack& operator=(const ResidualBufferStack&) = delete;
+
+  ~ResidualBufferStack();
+
+  // Pushes |buffer| to the top of the stack.
+  void Push(std::unique_ptr<ResidualBuffer> buffer);
+
+  // If the stack is non-empty, returns the buffer at the top of the stack and
+  // removes it from the stack. If the stack is empty, returns nullptr.
+  std::unique_ptr<ResidualBuffer> Pop();
+
+  // Swaps the contents of this stack and |other|.
+  void Swap(ResidualBufferStack* other);
+
+  // Returns the number of buffers in the stack.
+  size_t Size() const { return num_buffers_; }
+
+ private:
+  // A singly-linked list of ResidualBuffers, chained together using the next_
+  // field of ResidualBuffer.
+  ResidualBuffer* top_ = nullptr;
+  size_t num_buffers_ = 0;
 };
 
 // Utility class used to manage the residual buffers (and the transform
@@ -122,8 +177,7 @@
 
  private:
   mutable std::mutex mutex_;
-  std::stack<std::unique_ptr<ResidualBuffer>> buffers_
-      LIBGAV1_GUARDED_BY(mutex_);
+  ResidualBufferStack buffers_ LIBGAV1_GUARDED_BY(mutex_);
   size_t buffer_size_;
   int queue_size_;
 };
diff --git a/libgav1/src/utils/scan.cc b/libgav1/src/scan_tables.inc
similarity index 92%
rename from libgav1/src/utils/scan.cc
rename to libgav1/src/scan_tables.inc
index 3cc096d..3caacde 100644
--- a/libgav1/src/utils/scan.cc
+++ b/libgav1/src/scan_tables.inc
@@ -1,10 +1,4 @@
-#include "src/utils/scan.h"
-
-#include "src/utils/common.h"
-#include "src/utils/constants.h"
-
-namespace libgav1 {
-namespace {
+// This file contains all the scan order tables.
 
 const uint16_t kDefaultScan4x4[16] = {0, 1,  4,  8,  5, 2,  3,  6,
                                       9, 12, 13, 10, 7, 11, 14, 15};
@@ -409,38 +403,24 @@
     95,  251, 220, 189, 158, 127, 252, 221, 190, 159, 253, 222, 191, 254, 223,
     255};
 
-const uint16_t* kDefaultScan[kNumTransformSizes] = {
-    kDefaultScan4x4,   kDefaultScan4x8,   kDefaultScan4x16,  kDefaultScan8x4,
-    kDefaultScan8x8,   kDefaultScan8x16,  kDefaultScan8x32,  kDefaultScan16x4,
-    kDefaultScan16x8,  kDefaultScan16x16, kDefaultScan16x32, kDefaultScan32x32,
-    kDefaultScan32x8,  kDefaultScan32x16, kDefaultScan32x32, kDefaultScan32x32,
-    kDefaultScan32x32, kDefaultScan32x32, kDefaultScan32x32};
-
-const uint16_t* kColumnScan[kNumTransformSizes] = {
-    kColumnScan4x4,  kColumnScan4x8,   kColumnScan4x16, kColumnScan8x4,
-    kColumnScan8x8,  kColumnScan8x16,  kColumnScan16x4, kColumnScan16x4,
-    kColumnScan16x8, kColumnScan16x16, kColumnScan16x4, kColumnScan16x4,
-    kColumnScan16x4, kColumnScan16x4,  kColumnScan16x4, kColumnScan16x4,
-    kColumnScan16x4, kColumnScan16x4,  kColumnScan16x4};
-
-const uint16_t* kRowScan[kNumTransformSizes] = {
-    kRowScan4x4,  kRowScan4x8,  kRowScan4x16, kRowScan8x4,  kRowScan8x8,
-    kRowScan8x16, kRowScan16x4, kRowScan16x4, kRowScan16x8, kRowScan16x16,
-    kRowScan16x4, kRowScan16x4, kRowScan16x4, kRowScan16x4, kRowScan16x4,
-    kRowScan16x4, kRowScan16x4, kRowScan16x4, kRowScan16x4};
-
-}  // namespace
-
-const uint16_t* GetScan(TransformSize tx_size, TransformType tx_type) {
-  if (tx_size == kTransformSize16x64) return kDefaultScan16x32;
-  if (tx_size == kTransformSize64x16) return kDefaultScan32x16;
-  if (kTransformSizeSquareMax[tx_size] == kTransformSize64x64) {
-    return kDefaultScan32x32;
-  }
-  const TransformClass tx_class = GetTransformClass(tx_type);
-  if (tx_class == kTransformClassVertical) return kRowScan[tx_size];
-  if (tx_class == kTransformClassHorizontal) return kColumnScan[tx_size];
-  return kDefaultScan[tx_size];
-}
-
-}  // namespace libgav1
+// 5.11.41 (implemented as a simple look up of transform class and transform
+// size).
+const uint16_t* kScan[3][kNumTransformSizes] = {
+    // kTransformClass2D
+    {kDefaultScan4x4, kDefaultScan4x8, kDefaultScan4x16, kDefaultScan8x4,
+     kDefaultScan8x8, kDefaultScan8x16, kDefaultScan8x32, kDefaultScan16x4,
+     kDefaultScan16x8, kDefaultScan16x16, kDefaultScan16x32, kDefaultScan16x32,
+     kDefaultScan32x8, kDefaultScan32x16, kDefaultScan32x32, kDefaultScan32x32,
+     kDefaultScan32x16, kDefaultScan32x32, kDefaultScan32x32},
+    // kTransformClassHorizontal
+    {kColumnScan4x4, kColumnScan4x8, kColumnScan4x16, kColumnScan8x4,
+     kColumnScan8x8, kColumnScan8x16, kColumnScan16x4, kColumnScan16x4,
+     kColumnScan16x8, kColumnScan16x16, kColumnScan16x4, kDefaultScan16x32,
+     kColumnScan16x4, kColumnScan16x4, kColumnScan16x4, kDefaultScan32x32,
+     kDefaultScan32x16, kDefaultScan32x32, kDefaultScan32x32},
+    // kTransformClassVertical
+    {kRowScan4x4, kRowScan4x8, kRowScan4x16, kRowScan8x4, kRowScan8x8,
+     kRowScan8x16, kRowScan16x4, kRowScan16x4, kRowScan16x8, kRowScan16x16,
+     kRowScan16x4, kDefaultScan16x32, kRowScan16x4, kRowScan16x4, kRowScan16x4,
+     kDefaultScan32x32, kDefaultScan32x16, kDefaultScan32x32,
+     kDefaultScan32x32}};
diff --git a/libgav1/src/symbol_decoder_context.cc b/libgav1/src/symbol_decoder_context.cc
index aa65be9..c7bef4a 100644
--- a/libgav1/src/symbol_decoder_context.cc
+++ b/libgav1/src/symbol_decoder_context.cc
@@ -10,7 +10,7 @@
 // Import all the constants in the anonymous namespace.
 #include "src/symbol_decoder_context_cdfs.inc"
 
-inline uint8_t GetQuantizerContext(int base_quantizer_index) {
+uint8_t GetQuantizerContext(int base_quantizer_index) {
   if (base_quantizer_index <= 20) return 0;
   if (base_quantizer_index <= 60) return 1;
   if (base_quantizer_index <= 120) return 2;
@@ -84,11 +84,6 @@
 
 }  // namespace
 
-const int kEobPtSymbolCount[7] = {kEobPt16SymbolCount,  kEobPt32SymbolCount,
-                                  kEobPt64SymbolCount,  kEobPt128SymbolCount,
-                                  kEobPt256SymbolCount, kEobPt512SymbolCount,
-                                  kEobPt1024SymbolCount};
-
 #define CDF_COPY(source, destination) \
   memcpy(destination, source, sizeof(source))
 
diff --git a/libgav1/src/symbol_decoder_context.h b/libgav1/src/symbol_decoder_context.h
index 6a44a3f..63f37ce 100644
--- a/libgav1/src/symbol_decoder_context.h
+++ b/libgav1/src/symbol_decoder_context.h
@@ -15,8 +15,7 @@
   kSkipContexts = 3,
   kSkipModeContexts = 3,
   kBooleanFieldCdfSize = 3,
-  kDeltaQSymbolCount = 4,
-  kDeltaLfSymbolCount = 4,
+  kDeltaSymbolCount = 4,  // Used for both delta_q and delta_lf.
   kIntraModeContexts = 5,
   kYModeContexts = 4,
   kAngleDeltaSymbolCount = 2 * kMaxAngleDelta + 1,
@@ -73,8 +72,6 @@
   kNumMvComponents = 2,
 };  // anonymous enum
 
-extern const int kEobPtSymbolCount[7];
-
 struct SymbolDecoderContext {
   SymbolDecoderContext() = default;
   explicit SymbolDecoderContext(int base_quantizer_index) {
@@ -104,9 +101,9 @@
                                        [kBooleanFieldCdfSize];
   uint16_t skip_cdf[kSkipContexts][kBooleanFieldCdfSize];
   uint16_t skip_mode_cdf[kSkipModeContexts][kBooleanFieldCdfSize];
-  uint16_t delta_q_cdf[kDeltaQSymbolCount + 1];
-  uint16_t delta_lf_cdf[kDeltaLfSymbolCount + 1];
-  uint16_t delta_lf_multi_cdf[kFrameLfCount][kDeltaLfSymbolCount + 1];
+  uint16_t delta_q_cdf[kDeltaSymbolCount + 1];
+  uint16_t delta_lf_cdf[kDeltaSymbolCount + 1];
+  uint16_t delta_lf_multi_cdf[kFrameLfCount][kDeltaSymbolCount + 1];
   uint16_t intra_block_copy_cdf[kBooleanFieldCdfSize];
   uint16_t intra_frame_y_mode_cdf[kIntraModeContexts][kIntraModeContexts]
                                  [kIntraPredictionModesY + 1];
diff --git a/libgav1/src/symbol_decoder_context_cdfs.inc b/libgav1/src/symbol_decoder_context_cdfs.inc
index 6c57ce8..9b720ff 100644
--- a/libgav1/src/symbol_decoder_context_cdfs.inc
+++ b/libgav1/src/symbol_decoder_context_cdfs.inc
@@ -47,7 +47,7 @@
     {147, 0, 0}, {12060, 0, 0}, {24641, 0, 0}};
 
 // This constant is also used for DeltaLf and DeltaLfMulti.
-const uint16_t kDefaultDeltaQCdf[kDeltaQSymbolCount + 1] = {4608, 648, 91, 0,
+const uint16_t kDefaultDeltaQCdf[kDeltaSymbolCount + 1] = {4608, 648, 91, 0,
                                                             0};
 
 const uint16_t kDefaultIntraBlockCopyCdf[kBooleanFieldCdfSize] = {2237, 0, 0};
diff --git a/libgav1/src/symbol_visibility.h b/libgav1/src/symbol_visibility.h
new file mode 100644
index 0000000..5c79687
--- /dev/null
+++ b/libgav1/src/symbol_visibility.h
@@ -0,0 +1,72 @@
+#ifndef LIBGAV1_SRC_SYMBOL_VISIBILITY_H_
+#define LIBGAV1_SRC_SYMBOL_VISIBILITY_H_
+
+// This module defines the LIBGAV1_PUBLIC macro. LIBGAV1_PUBLIC, when combined
+// with the flags -fvisibility=hidden and -fvisibility-inlines-hidden, restricts
+// symbol availability when users use the shared object form of libgav1. The
+// intent is to prevent exposure of libgav1 internals to users of the library,
+// and to avoid ABI compatibility problems that changes to libgav1 internals
+// would cause for users of the libgav1 shared object.
+//
+// Examples:
+//
+// This form makes a class and all of its members part of the public API:
+//
+// class LIBGAV1_PUBLIC A {
+//  public:
+//   A();
+//   ~A();
+//   void Foo();
+//   int Bar();
+// };
+//
+// A::A(), A::~A(), A::Foo(), and A::Bar() are all available to code linking to
+// the shared object when this form is used.
+//
+// This form exposes a single class method as part of the public API:
+//
+// class B {
+//  public:
+//   B();
+//   ~B();
+//   LIBGAV1_PUBLIC int Foo();
+// };
+//
+// In this examples only B::Foo() is available to the user of the shared object.
+//
+// Non-class member functions can also be exposed individually:
+//
+// LIBGAV1_PUBLIC void Bar();
+//
+// In this example Bar() would be available to users of the shared object.
+//
+// Much of the above information and more can be found at
+// https://gcc.gnu.org/wiki/Visibility
+
+#if !defined(LIBGAV1_PUBLIC)
+#if defined(_WIN32)
+#if defined(LIBGAV1_BUILDING_DLL) && LIBGAV1_BUILDING_DLL
+#if defined(__GNUC__)
+#define LIBGAV1_PUBLIC __attribute__((dllexport))
+#else
+#define LIBGAV1_PUBLIC __declspec(dllexport)
+#endif  // defined(__GNUC__)
+#elif defined(LIBGAV1_BUILDING_DLL)
+#ifdef __GNUC__
+#define LIBGAV1_PUBLIC __attribute__((dllimport))
+#else
+#define LIBGAV1_PUBLIC __declspec(dllimport)
+#endif  // defined(__GNUC__)
+#else
+#define LIBGAV1_PUBLIC
+#endif  // defined(LIBGAV1_BUILDING_DLL) && LIBGAV1_BUILDING_DLL
+#else
+#if defined(__GNUC__) && __GNUC__ >= 4
+#define LIBGAV1_PUBLIC __attribute__((visibility("default")))
+#else
+#define LIBGAV1_PUBLIC
+#endif
+#endif  // defined(_WIN32)
+#endif  // defined(LIBGAV1_PUBLIC)
+
+#endif  // LIBGAV1_SRC_SYMBOL_VISIBILITY_H_
diff --git a/libgav1/src/threading_strategy.cc b/libgav1/src/threading_strategy.cc
index 36d9802..440cad0 100644
--- a/libgav1/src/threading_strategy.cc
+++ b/libgav1/src/threading_strategy.cc
@@ -1,6 +1,7 @@
 #include "src/threading_strategy.h"
 
 #include <algorithm>
+#include <cassert>
 
 #include "src/utils/logging.h"
 
@@ -14,14 +15,15 @@
 
 bool ThreadingStrategy::Reset(const ObuFrameHeader& frame_header,
                               int thread_count) {
-  if (thread_count <= 1) {
+  assert(thread_count > 0);
+  if (thread_count == 1) {
     thread_pool_.reset(nullptr);
-    use_tile_threads_ = false;
+    tile_thread_count_ = 0;
     max_tile_index_for_row_threads_ = 0;
     return true;
   }
 
-  // We do work in the main thread, so it is sufficient to create
+  // We do work in the current thread, so it is sufficient to create
   // |thread_count|-1 threads in the threadpool.
   thread_count = std::min(thread_count - 1, kMaxThreads);
 
@@ -30,7 +32,7 @@
     if (thread_pool_ == nullptr) {
       LIBGAV1_DLOG(ERROR, "Failed to create a thread pool with %d threads.",
                    thread_count);
-      use_tile_threads_ = false;
+      tile_thread_count_ = 0;
       max_tile_index_for_row_threads_ = 0;
       return false;
     }
@@ -39,16 +41,44 @@
   // Prefer tile threads first (but only if there is more than one tile).
   const int tile_count = frame_header.tile_info.tile_count;
   if (tile_count > 1) {
-    use_tile_threads_ = true;
-    thread_count -= tile_count;
-    if (thread_count <= 0) {
+    // We want 1 + tile_thread_count_ <= tile_count because the current thread
+    // is also used to decode tiles. This is equivalent to
+    // tile_thread_count_ <= tile_count - 1.
+    tile_thread_count_ = std::min(thread_count, tile_count - 1);
+    thread_count -= tile_thread_count_;
+    if (thread_count == 0) {
       max_tile_index_for_row_threads_ = 0;
       return true;
     }
   } else {
-    use_tile_threads_ = false;
+    tile_thread_count_ = 0;
   }
 
+#if defined(__ANDROID__)
+  // Assign the remaining threads for each Tile. The heuristic used here is that
+  // we will assign two threads for each Tile. So for example, if |thread_count|
+  // is 2, for a stream with 2 tiles the first tile would get both the threads
+  // and the second tile would have row multi-threading turned off. This
+  // heuristic is based on the fact that row multi-threading is fast enough only
+  // when there are at least two threads to do the decoding (since one thread
+  // always does the parsing).
+  //
+  // This heuristic might stop working when SIMD optimizations make the decoding
+  // much faster and the parsing thread is only as fast as the decoding threads.
+  // So we will have to revisit this later to make sure that this is still
+  // optimal.
+  //
+  // Note that while this heuristic significantly improves performance on high
+  // end devices (like the Pixel 3), there are some performance regressions in
+  // some lower end devices (in some cases) and that needs to be revisited as we
+  // bring in more optimizations. Overall, the gains because of this heuristic
+  // seems to be much larger than the regressions.
+  for (int i = 0; i < tile_count; ++i) {
+    max_tile_index_for_row_threads_ = i + 1;
+    thread_count -= 2;
+    if (thread_count <= 0) break;
+  }
+#else   // !defined(__ANDROID__)
   // Assign the remaining threads to each Tile.
   for (int i = 0; i < tile_count; ++i) {
     const int count = thread_count / tile_count +
@@ -60,6 +90,7 @@
     }
     max_tile_index_for_row_threads_ = i + 1;
   }
+#endif  // defined(__ANDROID__)
   return true;
 }
 
diff --git a/libgav1/src/threading_strategy.h b/libgav1/src/threading_strategy.h
index eebd8b5..2745195 100644
--- a/libgav1/src/threading_strategy.h
+++ b/libgav1/src/threading_strategy.h
@@ -35,9 +35,11 @@
   // Returns a pointer to the ThreadPool that is to be used for Tile
   // multi-threading.
   ThreadPool* tile_thread_pool() const {
-    return use_tile_threads_ ? thread_pool_.get() : nullptr;
+    return (tile_thread_count_ != 0) ? thread_pool_.get() : nullptr;
   }
 
+  int tile_thread_count() const { return tile_thread_count_; }
+
   // Returns a pointer to the ThreadPool that is to be used within the Tile at
   // index |tile_index| for superblock row multi-threading.
   ThreadPool* row_thread_pool(int tile_index) const {
@@ -51,7 +53,7 @@
 
  private:
   std::unique_ptr<ThreadPool> thread_pool_;
-  bool use_tile_threads_;
+  int tile_thread_count_;
   int max_tile_index_for_row_threads_;
 };
 
diff --git a/libgav1/src/tile.h b/libgav1/src/tile.h
index 524229b..3895f5e 100644
--- a/libgav1/src/tile.h
+++ b/libgav1/src/tile.h
@@ -24,6 +24,7 @@
 #include "src/symbol_decoder_context.h"
 #include "src/utils/array_2d.h"
 #include "src/utils/block_parameters_holder.h"
+#include "src/utils/blocking_counter.h"
 #include "src/utils/common.h"
 #include "src/utils/compiler_attributes.h"
 #include "src/utils/constants.h"
@@ -64,7 +65,8 @@
        const SegmentationMap* prev_segment_ids, PostFilter* post_filter,
        BlockParametersHolder* block_parameters, Array2D<int16_t>* cdef_index,
        Array2D<TransformSize>* inter_transform_sizes, const dsp::Dsp* dsp,
-       ThreadPool* thread_pool, ResidualBufferPool* residual_buffer_pool);
+       ThreadPool* thread_pool, ResidualBufferPool* residual_buffer_pool,
+       BlockingCounterWithStatus* pending_tiles);
 
   // Move only.
   Tile(Tile&& tile) noexcept;
@@ -74,7 +76,7 @@
 
   struct Block;  // Defined after this class.
 
-  bool Decode();  // 5.11.2.
+  bool Decode(bool is_main_thread);  // 5.11.2.
   const ObuSequenceHeader& sequence_header() const { return sequence_header_; }
   const ObuFrameHeader& frame_header() const { return frame_header_; }
   const RefCountedBuffer& current_frame() const { return current_frame_; }
@@ -92,51 +94,42 @@
   int superblock_columns() const { return superblock_columns_; }
 
  private:
-  struct EntropyContext : public Allocable {
-    static const int kLeft = 0;
-    static const int kTop = 1;
-
-    EntropyContext() = default;
-    ~EntropyContext() = default;
-    EntropyContext(const EntropyContext&) = default;
-    EntropyContext& operator=(const EntropyContext&) = default;
-
-    uint8_t coefficient_level = 0;
-    uint8_t dc_category = 0;
-  };
-
-  // Stores the transform tree state when reading variable size transform trees.
+  // Stores the transform tree state when reading variable size transform trees
+  // and when applying the transform tree. When applying the transform tree,
+  // |depth| is not used.
   struct TransformTreeNode {
-    TransformTreeNode(int row4x4, int column4x4, TransformSize tx_size,
-                      int depth)
-        : row4x4(row4x4),
-          column4x4(column4x4),
-          tx_size(tx_size),
-          depth(depth) {}
+    // The default constructor is invoked by the Stack<TransformTreeNode, n>
+    // constructor. Stack<> does not use the default-constructed elements, so it
+    // is safe for the default constructor to not initialize the members.
+    TransformTreeNode() = default;
+    TransformTreeNode(int x, int y, TransformSize tx_size, int depth = -1)
+        : x(x), y(y), tx_size(tx_size), depth(depth) {}
 
-    int row4x4;
-    int column4x4;
+    int x;
+    int y;
     TransformSize tx_size;
     int depth;
   };
 
+  static constexpr int kBlockDecodedStride = 34;
   // Buffer to facilitate decoding a superblock. When |split_parse_and_decode_|
   // is true, each superblock that is being decoded will get its own instance of
   // this buffer.
   struct SuperBlockBuffer {
-    uint8_t prediction_mask[kMaxSuperBlockSizeInPixels *
-                            kMaxSuperBlockSizeInPixels];
-    // This stores the decoded state of every 4x4 block in a superblock. It has
-    // 1 row/column border on all 4 sides. The left and top borders are handled
-    // by the |BlockDecoded()| function. The bottom and right borders are
-    // included in the array itself (hence the 33x33 dimension instead of
-    // 32x32).
-    bool block_decoded[kMaxPlanes][33][33];
-    // Stores the thresholds for determining if top-right and bottom-left pixels
-    // are available. Equivalent to the sbWidth4 and sbHeight4 variables in
-    // section 5.11.3 of the spec.
-    int block_decoded_width_threshold;
-    int block_decoded_height_threshold;
+    // The size of mask is 128x128.
+    AlignedUniquePtr<uint8_t> prediction_mask;
+    // Buffer used for inter prediction process. The buffers have an alignment
+    // of 8 bytes when allocated.
+    AlignedUniquePtr<uint16_t> prediction_buffer[2];
+    size_t prediction_buffer_size[2] = {};
+    // Equivalent to BlockDecoded array in the spec. This stores the decoded
+    // state of every 4x4 block in a superblock. It has 1 row/column border on
+    // all 4 sides (hence the 34x34 dimension instead of 32x32). Note that the
+    // spec uses "-1" as an index to access the left and top borders. In the
+    // code, we treat the index (1, 1) as equivalent to the spec's (0, 0). So
+    // all accesses into this array will be offset by +1 when compared with the
+    // spec.
+    bool block_decoded[kMaxPlanes][kBlockDecodedStride][kBlockDecodedStride];
     // Buffer used for storing subsampled luma samples needed for CFL
     // prediction. This buffer is used to avoid repetition of the subsampling
     // for the V plane when it is already done for the U plane.
@@ -176,27 +169,28 @@
     std::condition_variable pending_jobs_zero_condvar;
   };
 
+  // Performs member initializations that may fail. Called by Decode().
+  LIBGAV1_MUST_USE_RESULT bool Init();
+
   // Entry point for multi-threaded decoding. This function performs the same
   // functionality as Decode(). The current thread does the "parse" step while
   // the worker threads do the "decode" step.
   bool ThreadedDecode();
 
   // Returns whether or not the prerequisites for decoding the superblock at
-  // |row_index| and |column_index| are satisfied. |threading_parameters.mutex|
-  // must be held when calling this function.
-  bool CanDecode(int row_index, int column_index,
-                 const Array2D<SuperBlockState>& sb_state);
+  // |row_index| and |column_index| are satisfied. |threading_.mutex| must be
+  // held when calling this function.
+  bool CanDecode(int row_index, int column_index) const;
 
   // This function is run by the worker threads when multi-threaded decoding is
   // enabled. Once a superblock is decoded, this function will set the
-  // corresponding |threading->sb_state| entry to kSuperBlockStateDecoded. On
-  // failure, |threading->abort| will be set to true. If at any point
-  // |threading->abort| becomes true, this function will return as early as it
+  // corresponding |threading_.sb_state| entry to kSuperBlockStateDecoded. On
+  // failure, |threading_.abort| will be set to true. If at any point
+  // |threading_.abort| becomes true, this function will return as early as it
   // can. If the decoding succeeds, this function will also schedule the
   // decoding jobs for the superblock to the bottom-left and the superblock to
   // the right of this superblock (if it is allowed).
-  void DecodeSuperBlock(int row_index, int column_index, int block_width4x4,
-                        ThreadingParameters* threading);
+  void DecodeSuperBlock(int row_index, int column_index, int block_width4x4);
 
   uint16_t* GetPartitionCdf(int row4x4, int column4x4, BlockSize block_size);
   bool ReadPartition(int row4x4, int column4x4, BlockSize block_size,
@@ -219,6 +213,8 @@
   // function in the spec is equivalent to ProcessBlock() in the code.
   bool DecodeBlock(ParameterTree* tree, SuperBlockBuffer* sb_buffer);
 
+  void ClearBlockDecoded(SuperBlockBuffer* sb_buffer, int row4x4,
+                         int column4x4);  // 5.11.3.
   bool ProcessSuperBlock(int row4x4, int column4x4, int block_width4x4,
                          SuperBlockBuffer* sb_buffer, ProcessingMode mode);
   void ResetLoopRestorationParams();
@@ -239,12 +235,14 @@
   void ReadSkip(const Block& block);            // 5.11.11.
   void ReadSkipMode(const Block& block);        // 5.11.10.
   void ReadCdef(const Block& block);            // 5.11.56.
-  // Returns the new value.
-  int ReadAndClipDelta(uint16_t* cdf, int symbol_count, int delta_small,
-                       int scale, int min_value, int max_value, int value);
+  // Returns the new value. |cdf| is an array of size kDeltaSymbolCount + 1.
+  int ReadAndClipDelta(uint16_t* cdf, int delta_small, int scale, int min_value,
+                       int max_value, int value);
   void ReadQuantizerIndexDelta(const Block& block);  // 5.11.12.
   void ReadLoopFilterDelta(const Block& block);      // 5.11.13.
-  void ComputeDeblockFilterLevel(const Block& block);
+  // Populates |BlockParameters::deblock_filter_level| for the given |block|
+  // using |deblock_filter_levels_|.
+  void PopulateDeblockFilterLevel(const Block& block);
   void ReadPredictionModeY(const Block& block, bool intra_y_mode);
   void ReadIntraAngleInfo(const Block& block,
                           PlaneType plane_type);  // 5.11.42 and 5.11.43.
@@ -267,18 +265,26 @@
                               bool intra_y_mode);  // 5.11.22.
   int GetUseCompoundReferenceContext(const Block& block);
   CompoundReferenceType ReadCompoundReferenceType(const Block& block);
-  int GetReferenceContext(const Block& block,
-                          const std::vector<ReferenceFrameType>& types1,
-                          const std::vector<ReferenceFrameType>& types2) const;
-  uint16_t* GetReferenceCdf(
-      const Block& block, bool is_single, bool is_backward, int index,
-      CompoundReferenceType type = kNumCompoundReferenceTypes);
+  // Calculates count0 by calling block.CountReferences() on the frame types
+  // from type0_start to type0_end, inclusive, and summing the results.
+  // Calculates count1 by calling block.CountReferences() on the frame types
+  // from type1_start to type1_end, inclusive, and summing the results.
+  // Compares count0 with count1 and returns 0, 1 or 2.
+  //
+  // See count_refs and ref_count_ctx in 8.3.2.
+  int GetReferenceContext(const Block& block, ReferenceFrameType type0_start,
+                          ReferenceFrameType type0_end,
+                          ReferenceFrameType type1_start,
+                          ReferenceFrameType type1_end) const;
+  template <bool is_single, bool is_backward, int index>
+  uint16_t* GetReferenceCdf(const Block& block, CompoundReferenceType type =
+                                                    kNumCompoundReferenceTypes);
   void ReadReferenceFrames(const Block& block);  // 5.11.25.
   void ReadInterPredictionModeY(const Block& block,
                                 const MvContexts& mode_contexts);
   void ReadRefMvIndex(const Block& block);
   void ReadInterIntraMode(const Block& block, bool is_compound);  // 5.11.28.
-  bool IsScaled(ReferenceFrameType type);  // Part of 5.11.27.
+  bool IsScaled(ReferenceFrameType type) const;  // Part of 5.11.27.
   void ReadMotionMode(const Block& block, bool is_compound);  // 5.11.27.
   uint16_t* GetIsExplicitCompoundTypeCdf(const Block& block);
   uint16_t* GetIsCompoundTypeAverageCdf(const Block& block);
@@ -300,7 +306,7 @@
                                  TransformSize tx_size);
   void DecodeTransformSize(const Block& block);  // 5.11.16.
   bool ComputePrediction(const Block& block);    // 5.11.33.
-  // |x4| and |y4| are the row and column positions of the 4x4 block. |w4| and
+  // |x4| and |y4| are the column and row positions of the 4x4 block. |w4| and
   // |h4| are the width and height in 4x4 units of |tx_size|.
   int GetTransformAllZeroContext(const Block& block, Plane plane,
                                  TransformSize tx_size, int x4, int y4, int w4,
@@ -324,21 +330,35 @@
   int GetCoeffBaseRangeContextVertical(int adjusted_tx_width_log2, int pos);
   int GetDcSignContext(int x4, int y4, int w4, int h4, Plane plane);
   void SetEntropyContexts(int x4, int y4, int w4, int h4, Plane plane,
-                          uint8_t coefficient_level, uint8_t dc_category);
-  void InterIntraPrediction(
+                          uint8_t coefficient_level, int8_t dc_category);
+  bool InterIntraPrediction(
       uint16_t* prediction[2], ptrdiff_t prediction_stride,
       const uint8_t* prediction_mask, ptrdiff_t prediction_mask_stride,
       const PredictionParameters& prediction_parameters, int prediction_width,
       int prediction_height, int subsampling_x, int subsampling_y,
-      uint8_t post_round_bits, uint8_t* dest,
+      uint8_t* dest,
       ptrdiff_t dest_stride);  // Part of section 7.11.3.1 in the spec.
-  void CompoundInterPrediction(
-      const Block& block, uint16_t* prediction[2], ptrdiff_t prediction_stride,
+  // Several prediction modes need a prediction mask:
+  // kCompoundPredictionTypeDiffWeighted, kCompoundPredictionTypeWedge,
+  // kCompoundPredictionTypeIntra. They are mutually exclusive. So the mask is
+  // allocated in each case. The mask only needs to be allocated for kPlaneY
+  // and then used for other planes.
+  LIBGAV1_MUST_USE_RESULT static bool AllocatePredictionMask(
+      SuperBlockBuffer* sb_buffer);
+  bool CompoundInterPrediction(
+      const Block& block, ptrdiff_t prediction_stride,
       ptrdiff_t prediction_mask_stride, int prediction_width,
       int prediction_height, Plane plane, int subsampling_x, int subsampling_y,
       int bitdepth, int candidate_row, int candidate_column, uint8_t* dest,
-      ptrdiff_t dest_stride,
-      uint8_t post_round_bits);  // Part of section 7.11.3.1 in the spec.
+      ptrdiff_t dest_stride);  // Part of section 7.11.3.1 in the spec.
+  GlobalMotion* GetWarpParams(const Block& block, Plane plane,
+                              int prediction_width, int prediction_height,
+                              const PredictionParameters& prediction_parameters,
+                              ReferenceFrameType reference_type,
+                              bool* is_local_valid,
+                              GlobalMotion* global_motion_params,
+                              GlobalMotion* local_warp_params)
+      const;  // Part of section 7.11.3.1 in the spec.
   bool InterPrediction(const Block& block, Plane plane, int x, int y,
                        int prediction_width, int prediction_height,
                        int candidate_row, int candidate_column,
@@ -369,24 +389,42 @@
                             const uint8_t* round_bits, bool is_compound,
                             bool is_inter_intra, uint8_t* dest,
                             ptrdiff_t dest_stride);  // 7.11.3.4.
-  bool BlockWarpProcess(const Block& block, Plane plane, int index, int width,
-                        int height, uint16_t* prediction,
-                        ptrdiff_t prediction_stride, GlobalMotion* warp_params,
-                        const uint8_t* round_bits, bool is_compound,
-                        bool is_inter_intra, uint8_t* dest,
+  bool BlockWarpProcess(const Block& block, Plane plane, int index,
+                        int block_start_x, int block_start_y, int width,
+                        int height, ptrdiff_t prediction_stride,
+                        GlobalMotion* warp_params, const uint8_t* round_bits,
+                        bool is_compound, bool is_inter_intra, uint8_t* dest,
                         ptrdiff_t dest_stride);  // 7.11.3.5.
   void ObmcBlockPrediction(const MotionVector& mv, Plane plane,
                            int reference_frame_index, int width, int height,
                            int x, int y, int candidate_row,
-                           int candidate_column, const uint8_t* mask,
-                           int blending_direction, const uint8_t* round_bits);
+                           int candidate_column,
+                           ObmcDirection blending_direction,
+                           const uint8_t* round_bits);
   void ObmcPrediction(const Block& block, Plane plane, int width, int height,
                       const uint8_t* round_bits);  // 7.11.3.9.
-  void DistanceWeightedPrediction(
-      uint16_t* prediction_0, ptrdiff_t prediction_stride_0,
-      uint16_t* prediction_1, ptrdiff_t prediction_stride_1, int width,
-      int height, int candidate_row, int candidate_column, uint8_t* dest,
-      ptrdiff_t dest_stride, uint8_t post_round_bits);  // 7.11.3.15.
+  void DistanceWeightedPrediction(uint16_t* prediction_0,
+                                  ptrdiff_t prediction_stride_0,
+                                  uint16_t* prediction_1,
+                                  ptrdiff_t prediction_stride_1, int width,
+                                  int height, int candidate_row,
+                                  int candidate_column, uint8_t* dest,
+                                  ptrdiff_t dest_stride);  // 7.11.3.15.
+  // This function specializes the parsing of DC coefficient by removing some of
+  // the branches when i == 0 (since scan[0] is always 0 and scan[i] is always
+  // non-zero for all other possible values of i). |dc_category| is an output
+  // parameter that is populated when |is_dc_coefficient| is true.
+  // |coefficient_level| is an output parameter which accumulates the
+  // coefficient level.
+  template <bool is_dc_coefficient>
+  bool ReadSignAndApplyDequantization(
+      const Block& block, const uint16_t* scan, int i,
+      int adjusted_tx_width_log2, int tx_width, int q_value,
+      const uint8_t* quantizer_matrix, int shift, int min_value, int max_value,
+      uint16_t* dc_sign_cdf, int8_t* dc_category,
+      int* coefficient_level);  // Part of 5.11.39.
+  int ReadCoeffBaseRange(int clamped_tx_size_context, int cdf_context,
+                         int plane_type);  // Part of 5.11.39.
   // Returns the number of non-zero coefficients that were read. |tx_type| is an
   // output parameter that stores the computed transform type for the plane
   // whose coefficients were read. Returns -1 on failure.
@@ -397,8 +435,9 @@
   bool TransformBlock(const Block& block, Plane plane, int base_x, int base_y,
                       TransformSize tx_size, int x, int y,
                       ProcessingMode mode);  // 5.11.35.
-  bool TransformTree(const Block& block, int start_x, int start_y, int width,
-                     int height, ProcessingMode mode);  // 5.11.36.
+  // Iterative implementation of 5.11.36.
+  bool TransformTree(const Block& block, int start_x, int start_y,
+                     BlockSize plane_size, ProcessingMode mode);
   void ReconstructBlock(const Block& block, Plane plane, int start_x,
                         int start_y, TransformSize tx_size,
                         TransformType tx_type,
@@ -406,21 +445,19 @@
   bool Residual(const Block& block, ProcessingMode mode);  // 5.11.34.
   // part of 5.11.5 (reset_block_context() in the spec).
   void ResetEntropyContext(const Block& block);
-  int GetPaletteColorContext(const Block& block, PlaneType plane_type, int row,
-                             int column, int palette_size,
-                             uint8_t color_order[kMaxPaletteSize]);  // 5.11.50.
-  void ReadPaletteTokens(const Block& block);                        // 5.11.49.
-  // Helper function for handling the border cases in 5.11.3 of the spec.
-  // Early return if has_top_or_left is false, since has_bottom_left
-  // (has_top_right) must be false if has_left(has_top) is false.
-  bool BlockDecoded(const Block& block, Plane plane, int row4x4, int column4x4,
-                    bool has_top_or_left) const;
+  // Populates the |color_context| and |color_order| for the |i|th iteration
+  // with entries counting down from |start| to |end| (|start| > |end|).
+  void PopulatePaletteColorContexts(
+      const Block& block, PlaneType plane_type, int i, int start, int end,
+      uint8_t color_order[kMaxPaletteSquare][kMaxPaletteSize],
+      uint8_t color_context[kMaxPaletteSquare]);                 // 5.11.50.
+  bool ReadPaletteTokens(const Block& block);                    // 5.11.49.
   template <typename Pixel>
   void IntraPrediction(const Block& block, Plane plane, int x, int y,
                        bool has_left, bool has_top, bool has_top_right,
                        bool has_bottom_left, PredictionMode mode,
                        TransformSize tx_size);
-  bool UsesSmoothPrediction(int row, int column, Plane plane) const;
+  bool IsSmoothPrediction(int row, int column, Plane plane) const;
   int GetIntraEdgeFilterType(const Block& block,
                              Plane plane) const;  // 7.11.2.8.
   template <typename Pixel>
@@ -449,12 +486,6 @@
     return sequence_header_.color_config.is_monochrome ? kMaxPlanesMonochrome
                                                        : kMaxPlanes;
   }
-  int SubsamplingX(Plane plane) const {
-    return (plane > kPlaneY) ? sequence_header_.color_config.subsampling_x : 0;
-  }
-  int SubsamplingY(Plane plane) const {
-    return (plane > kPlaneY) ? sequence_header_.color_config.subsampling_y : 0;
-  }
 
   const int number_;
   int row_;
@@ -468,11 +499,46 @@
   int superblock_rows_;
   int superblock_columns_;
   bool read_deltas_;
+  const int8_t subsampling_x_[kMaxPlanes];
+  const int8_t subsampling_y_[kMaxPlanes];
+
+  // The dimensions (in order) are: segment_id, level_index (based on plane and
+  // direction), reference_frame and mode_id.
+  uint8_t deblock_filter_levels_[kMaxSegments][kFrameLfCount]
+                                [kNumReferenceFrameTypes][2];
+
   // current_quantizer_index_ is in the range [0, 255].
-  int current_quantizer_index_;
-  // First dimension: left/top; Second dimension: plane; Third dimension:
-  // row4x4/column4x4.
-  std::array<Array2D<EntropyContext>, 2> entropy_contexts_;
+  uint8_t current_quantizer_index_;
+  // These two arrays (|coefficient_levels_| and |dc_categories_|) are used to
+  // store the entropy context. Their dimensions are as follows: First -
+  // left/top; Second - plane; Third - row4x4 (if first dimension is
+  // left)/column4x4 (if first dimension is top).
+  //
+  // This is equivalent to the LeftLevelContext and AboveLevelContext arrays in
+  // the spec. In the spec, it stores values from 0 through 63 (inclusive). The
+  // stored values are used to compute the left and top contexts in
+  // GetTransformAllZeroContext. In that function, we only care about the
+  // following values: 0, 1, 2, 3 and >= 4. So instead of clamping to 63, we
+  // clamp to 4 (i.e.) all the values greater than 4 are stored as 4.
+  std::array<Array2D<uint8_t>, 2> coefficient_levels_;
+  // This is equivalent to the LeftDcContext and AboveDcContext arrays in the
+  // spec. In the spec, it can store 3 possible values: 0, 1 and 2 (where 1
+  // means the value is < 0, 2 means the value is > 0 and 0 means the value is
+  // equal to 0).
+  //
+  // The stored values are used in two places:
+  //  * GetTransformAllZeroContext: Here, we only care about whether the
+  //  value is 0 or not (whether it is 1 or 2 is irrelevant).
+  //  * GetDcSignContext: Here, we do the following computation: if the
+  //  stored value is 1, we decrement a counter. If the stored value is 2
+  //  we increment a counter.
+  //
+  // Based on this usage, we can simply replace 1 with -1 and 2 with 1 and
+  // use that value to compute the counter.
+  //
+  // The usage on GetTransformAllZeroContext is unaffected since there we
+  // only care about whether it is 0 or not.
+  std::array<Array2D<int8_t>, 2> dc_categories_;
   const ObuSequenceHeader& sequence_header_;
   const ObuFrameHeader& frame_header_;
   RefCountedBuffer& current_frame_;
@@ -496,6 +562,18 @@
   // array will be used by each call to ReadTransformCoefficients() depending on
   // the transform size.
   int32_t quantized_[kQuantizedCoefficientBufferSize];
+  // Stores the "color order" for a block for each iteration in
+  // ReadPaletteTokens(). The "color order" is used to populate the
+  // |color_index_map| used for palette prediction. This is declared as a class
+  // variable because only the first few values in each dimension are used by
+  // each call depending on the block size and the palette size.
+  uint8_t color_order_[kMaxPaletteSquare][kMaxPaletteSize];
+  // Stores the "color context" for a block for each iteration in
+  // ReadPaletteTokens(). The "color context" is the cdf context index used to
+  // read the |palette_color_idx_y| variable in the spec. This is declared as a
+  // class variable because only the first few values in each dimension are used
+  // by each call depending on the block size and the palette size.
+  uint8_t color_context_[kMaxPaletteSquare];
   // When there is no multi-threading within the Tile, |residual_buffer_| is
   // used. When there is multi-threading within the Tile,
   // |residual_buffer_threaded_| is used. In the following comment,
@@ -542,7 +620,9 @@
   // thread will do the parsing while the thread pool workers will do the
   // decoding.
   ThreadPool* const thread_pool_;
+  ThreadingParameters threading_;
   ResidualBufferPool* const residual_buffer_pool_;
+  BlockingCounterWithStatus* const pending_tiles_;
   bool split_parse_and_decode_;
   // This is used only when |split_parse_and_decode_| is false.
   std::unique_ptr<PredictionParameters> prediction_parameters_ = nullptr;
@@ -552,6 +632,9 @@
   TransformType transform_types_[32][32];
   // delta_lf_[i] is in the range [-63, 63].
   int8_t delta_lf_[kFrameLfCount];
+  // True if all the values in |delta_lf_| are zero. False otherwise.
+  bool delta_lf_all_zero_;
+  bool build_bit_mask_when_parsing_;
 };
 
 struct Tile::Block {
@@ -590,7 +673,7 @@
 
   bool TopAvailableChroma() const {
     if (!HasChroma()) return false;
-    if ((tile.sequence_header_.color_config.subsampling_y |
+    if ((tile.sequence_header_.color_config.subsampling_y &
          kNum4x4BlocksHigh[size]) == 1) {
       return tile.IsInside(row4x4 - 2, column4x4);
     }
@@ -599,7 +682,7 @@
 
   bool LeftAvailableChroma() const {
     if (!HasChroma()) return false;
-    if ((tile.sequence_header_.color_config.subsampling_x |
+    if ((tile.sequence_header_.color_config.subsampling_x &
          kNum4x4BlocksWide[size]) == 1) {
       return tile.IsInside(row4x4, column4x4 - 2);
     }
diff --git a/libgav1/src/tile/bitstream/mode_info.cc b/libgav1/src/tile/bitstream/mode_info.cc
index be71f86..b2e6039 100644
--- a/libgav1/src/tile/bitstream/mode_info.cc
+++ b/libgav1/src/tile/bitstream/mode_info.cc
@@ -11,9 +11,11 @@
 #include "src/dsp/constants.h"
 #include "src/motion_vector.h"
 #include "src/obu_parser.h"
+#include "src/prediction_mask.h"
 #include "src/symbol_decoder_context.h"
 #include "src/tile.h"
 #include "src/utils/array_2d.h"
+#include "src/utils/bit_mask_set.h"
 #include "src/utils/block_parameters_holder.h"
 #include "src/utils/common.h"
 #include "src/utils/constants.h"
@@ -30,41 +32,42 @@
 constexpr int kDeltaLfSmall = 3;
 constexpr int kNoScale = 1 << kReferenceFrameScalePrecision;
 
-const uint8_t kIntraYModeContext[kIntraPredictionModesY] = {0, 1, 2, 3, 4, 4, 4,
-                                                            4, 3, 0, 1, 2, 0};
+constexpr uint8_t kIntraYModeContext[kIntraPredictionModesY] = {
+    0, 1, 2, 3, 4, 4, 4, 4, 3, 0, 1, 2, 0};
 
-const uint8_t kSizeGroup[kMaxBlockSizes] = {0, 0, 0, 0, 1, 1, 1, 0, 1, 2, 2,
-                                            2, 1, 2, 3, 3, 2, 3, 3, 3, 3, 3};
+constexpr uint8_t kSizeGroup[kMaxBlockSizes] = {
+    0, 0, 0, 0, 1, 1, 1, 0, 1, 2, 2, 2, 1, 2, 3, 3, 2, 3, 3, 3, 3, 3};
 
-const int kCompoundModeNewMvContexts = 5;
-const uint8_t kCompoundModeContextMap[3][kCompoundModeNewMvContexts] = {
+constexpr int kCompoundModeNewMvContexts = 5;
+constexpr uint8_t kCompoundModeContextMap[3][kCompoundModeNewMvContexts] = {
     {0, 1, 1, 1, 1}, {1, 2, 3, 4, 4}, {4, 4, 5, 6, 7}};
 
-const uint8_t kWedgeBits[kMaxBlockSizes] = {0, 0, 0, 0, 4, 4, 4, 0, 4, 4, 4,
-                                            0, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0};
-
 enum CflSign : uint8_t {
   kCflSignZero = 0,
   kCflSignNegative = 1,
   kCflSignPositive = 2
 };
 
-inline bool IsBackwardReference(ReferenceFrameType type) {
+constexpr BitMaskSet kPredictionModeHasNearMvMask(kPredictionModeNearMv,
+                                                  kPredictionModeNearNearMv,
+                                                  kPredictionModeNearNewMv,
+                                                  kPredictionModeNewNearMv);
+
+constexpr BitMaskSet kIsInterIntraModeAllowedMask(kBlock8x8, kBlock8x16,
+                                                  kBlock16x8, kBlock16x16,
+                                                  kBlock16x32, kBlock32x16,
+                                                  kBlock32x32);
+
+bool IsBackwardReference(ReferenceFrameType type) {
   return type >= kReferenceFrameBackward && type <= kReferenceFrameAlternate;
 }
 
-inline bool IsSameDirectionReferencePair(ReferenceFrameType type1,
-                                         ReferenceFrameType type2) {
+bool IsSameDirectionReferencePair(ReferenceFrameType type1,
+                                  ReferenceFrameType type2) {
   return (type1 >= kReferenceFrameBackward) ==
          (type2 >= kReferenceFrameBackward);
 }
 
-inline bool IsInterIntraModeAllowed(BlockSize size) {
-  return size == kBlock8x8 || size == kBlock8x16 || size == kBlock16x8 ||
-         size == kBlock16x16 || size == kBlock16x32 || size == kBlock32x16 ||
-         size == kBlock32x32;
-}
-
 // This is called neg_deinterleave() in the spec.
 int DecodeSegmentId(int diff, int reference, int max) {
   if (reference == 0) return diff;
@@ -98,6 +101,17 @@
   return 0;
 }
 
+// Returns true if the either the width or the height of the block is equal to
+// four.
+bool IsBlockDimension4(BlockSize size) {
+  return size < kBlock8x8 || size == kBlock16x4;
+}
+
+// Returns true if both the width and height of the block is less than 64.
+bool IsBlockDimensionLessThan64(BlockSize size) {
+  return size <= kBlock32x32 && size != kBlock16x64;
+}
+
 }  // namespace
 
 bool Tile::ReadSegmentId(const Block& block) {
@@ -139,7 +153,7 @@
   uint16_t* const segment_id_cdf =
       symbol_decoder_context_.segment_id_cdf[context];
   const int encoded_segment_id =
-      reader_.ReadSymbol(segment_id_cdf, kMaxSegments);
+      reader_.ReadSymbol<kMaxSegments>(segment_id_cdf);
   bp.segment_id =
       DecodeSegmentId(encoded_segment_id, pred,
                       frame_header_.segmentation.last_active_segment_id + 1);
@@ -181,7 +195,7 @@
     ++context;
   }
   uint16_t* const skip_cdf = symbol_decoder_context_.skip_cdf[context];
-  bp.skip = static_cast<bool>(reader_.ReadSymbol(skip_cdf, 2));
+  bp.skip = reader_.ReadSymbol(skip_cdf);
 }
 
 void Tile::ReadSkipMode(const Block& block) {
@@ -193,7 +207,7 @@
                                                kSegmentFeatureReferenceFrame) ||
       frame_header_.segmentation.FeatureActive(bp.segment_id,
                                                kSegmentFeatureGlobalMv) ||
-      kBlockWidthPixels[block.size] < 8 || kBlockHeightPixels[block.size] < 8) {
+      IsBlockDimension4(block.size)) {
     bp.skip_mode = false;
     return;
   }
@@ -229,10 +243,9 @@
   }
 }
 
-int Tile::ReadAndClipDelta(uint16_t* const cdf, int symbol_count,
-                           int delta_small, int scale, int min_value,
-                           int max_value, int value) {
-  int abs = reader_.ReadSymbol(cdf, symbol_count);
+int Tile::ReadAndClipDelta(uint16_t* const cdf, int delta_small, int scale,
+                           int min_value, int max_value, int value) {
+  int abs = reader_.ReadSymbol<kDeltaSymbolCount>(cdf);
   if (abs == delta_small) {
     const int remaining_bit_count =
         static_cast<int>(reader_.ReadLiteral(3)) + 1;
@@ -251,19 +264,21 @@
 }
 
 void Tile::ReadQuantizerIndexDelta(const Block& block) {
+  assert(read_deltas_);
   BlockParameters& bp = *block.bp;
-  if (!read_deltas_ || (block.size == SuperBlockSize() && bp.skip)) {
+  if ((block.size == SuperBlockSize() && bp.skip)) {
     return;
   }
-  current_quantizer_index_ = ReadAndClipDelta(
-      symbol_decoder_context_.delta_q_cdf, kDeltaQSymbolCount, kDeltaQSmall,
-      frame_header_.delta_q.scale, kMinLossyQuantizer, kMaxQuantizer,
-      current_quantizer_index_);
+  current_quantizer_index_ =
+      ReadAndClipDelta(symbol_decoder_context_.delta_q_cdf, kDeltaQSmall,
+                       frame_header_.delta_q.scale, kMinLossyQuantizer,
+                       kMaxQuantizer, current_quantizer_index_);
 }
 
 void Tile::ReadLoopFilterDelta(const Block& block) {
+  assert(read_deltas_);
   BlockParameters& bp = *block.bp;
-  if (!read_deltas_ || !frame_header_.delta_lf.present ||
+  if (!frame_header_.delta_lf.present ||
       (block.size == SuperBlockSize() && bp.skip)) {
     return;
   }
@@ -271,15 +286,23 @@
   if (frame_header_.delta_lf.multi) {
     frame_lf_count = kFrameLfCount - (PlaneCount() > 1 ? 0 : 2);
   }
+  bool recompute_deblock_filter_levels = false;
   for (int i = 0; i < frame_lf_count; ++i) {
     uint16_t* const delta_lf_abs_cdf =
         frame_header_.delta_lf.multi
             ? symbol_decoder_context_.delta_lf_multi_cdf[i]
             : symbol_decoder_context_.delta_lf_cdf;
-    delta_lf_[i] =
-        ReadAndClipDelta(delta_lf_abs_cdf, kDeltaLfSymbolCount, kDeltaLfSmall,
-                         frame_header_.delta_lf.scale, -kMaxLoopFilterValue,
-                         kMaxLoopFilterValue, delta_lf_[i]);
+    const int8_t old_delta_lf = delta_lf_[i];
+    delta_lf_[i] = ReadAndClipDelta(
+        delta_lf_abs_cdf, kDeltaLfSmall, frame_header_.delta_lf.scale,
+        -kMaxLoopFilterValue, kMaxLoopFilterValue, delta_lf_[i]);
+    recompute_deblock_filter_levels =
+        recompute_deblock_filter_levels || (old_delta_lf != delta_lf_[i]);
+  }
+  delta_lf_all_zero_ =
+      (delta_lf_[0] | delta_lf_[1] | delta_lf_[2] | delta_lf_[3]) == 0;
+  if (!delta_lf_all_zero_ && recompute_deblock_filter_levels) {
+    post_filter_.ComputeDeblockFilterLevels(delta_lf_, deblock_filter_levels_);
   }
 }
 
@@ -298,7 +321,7 @@
     cdf = symbol_decoder_context_.y_mode_cdf[kSizeGroup[block.size]];
   }
   block.bp->y_mode = static_cast<PredictionMode>(
-      reader_.ReadSymbol(cdf, static_cast<int>(kIntraPredictionModesY)));
+      reader_.ReadSymbol<kIntraPredictionModesY>(cdf));
 }
 
 void Tile::ReadIntraAngleInfo(const Block& block, PlaneType plane_type) {
@@ -312,13 +335,13 @@
   uint16_t* const cdf =
       symbol_decoder_context_.angle_delta_cdf[mode - kPredictionModeVertical];
   prediction_parameters.angle_delta[plane_type] =
-      reader_.ReadSymbol(cdf, kAngleDeltaSymbolCount);
+      reader_.ReadSymbol<kAngleDeltaSymbolCount>(cdf);
   prediction_parameters.angle_delta[plane_type] -= kMaxAngleDelta;
 }
 
 void Tile::ReadCflAlpha(const Block& block) {
-  const int signs = reader_.ReadSymbol(
-      symbol_decoder_context_.cfl_alpha_signs_cdf, kCflAlphaSignsSymbolCount);
+  const int signs = reader_.ReadSymbol<kCflAlphaSignsSymbolCount>(
+      symbol_decoder_context_.cfl_alpha_signs_cdf);
   const auto sign_u = static_cast<CflSign>((signs + 1) / 3);
   const auto sign_v = static_cast<CflSign>((signs + 1) % 3);
   PredictionParameters& prediction_parameters =
@@ -326,8 +349,8 @@
   prediction_parameters.cfl_alpha_u = 0;
   if (sign_u != kCflSignZero) {
     prediction_parameters.cfl_alpha_u =
-        reader_.ReadSymbol(symbol_decoder_context_.cfl_alpha_cdf[signs - 2],
-                           kCflAlphaSymbolCount) +
+        reader_.ReadSymbol<kCflAlphaSymbolCount>(
+            symbol_decoder_context_.cfl_alpha_cdf[signs - 2]) +
         1;
     if (sign_u == kCflSignNegative) prediction_parameters.cfl_alpha_u *= -1;
   }
@@ -335,8 +358,8 @@
   if (sign_v != kCflSignZero) {
     const int context = (sign_v - 1) * 3 + sign_u;
     prediction_parameters.cfl_alpha_v =
-        reader_.ReadSymbol(symbol_decoder_context_.cfl_alpha_cdf[context],
-                           kCflAlphaSymbolCount) +
+        reader_.ReadSymbol<kCflAlphaSymbolCount>(
+            symbol_decoder_context_.cfl_alpha_cdf[context]) +
         1;
     if (sign_v == kCflSignNegative) prediction_parameters.cfl_alpha_v *= -1;
   }
@@ -347,11 +370,10 @@
   bool chroma_from_luma_allowed;
   if (frame_header_.segmentation.lossless[bp.segment_id]) {
     chroma_from_luma_allowed =
-        kPlaneResidualSize[block.size][SubsamplingX(kPlaneU)]
-                          [SubsamplingY(kPlaneU)] == kBlock4x4;
+        kPlaneResidualSize[block.size][subsampling_x_[kPlaneU]]
+                          [subsampling_y_[kPlaneU]] == kBlock4x4;
   } else {
-    chroma_from_luma_allowed = kBlockWidthPixels[block.size] <= 32 &&
-                               kBlockHeightPixels[block.size] <= 32;
+    chroma_from_luma_allowed = IsBlockDimensionLessThan64(block.size);
   }
   uint16_t* const cdf =
       symbol_decoder_context_
@@ -367,9 +389,8 @@
       static_cast<int>(block.bp->prediction_parameters->use_intra_block_copy);
   const bool sign = reader_.ReadSymbol(
       symbol_decoder_context_.mv_sign_cdf[component][context]);
-  const int mv_class = reader_.ReadSymbol(
-      symbol_decoder_context_.mv_class_cdf[component][context],
-      kMvClassSymbolCount);
+  const int mv_class = reader_.ReadSymbol<kMvClassSymbolCount>(
+      symbol_decoder_context_.mv_class_cdf[component][context]);
   int magnitude = 1;
   int value;
   uint16_t* fraction_cdf;
@@ -396,7 +417,7 @@
   }
   const int fraction =
       (frame_header_.force_integer_mv == 0)
-          ? reader_.ReadSymbol(fraction_cdf, kMvFractionSymbolCount)
+          ? reader_.ReadSymbol<kMvFractionSymbolCount>(fraction_cdf)
           : 3;
   const int precision =
       frame_header_.allow_high_precision_mv
@@ -430,16 +451,15 @@
   prediction_parameters.use_filter_intra = false;
   if (!sequence_header_.enable_filter_intra || bp.y_mode != kPredictionModeDc ||
       bp.palette_mode_info.size[kPlaneTypeY] != 0 ||
-      kBlockWidthPixels[block.size] > 32 ||
-      kBlockHeightPixels[block.size] > 32) {
+      !IsBlockDimensionLessThan64(block.size)) {
     return;
   }
   prediction_parameters.use_filter_intra = reader_.ReadSymbol(
       symbol_decoder_context_.use_filter_intra_cdf[block.size]);
   if (prediction_parameters.use_filter_intra) {
     prediction_parameters.filter_intra_mode = static_cast<FilterIntraPredictor>(
-        reader_.ReadSymbol(symbol_decoder_context_.filter_intra_mode_cdf,
-                           static_cast<int>(kNumFilterIntraPredictors)));
+        reader_.ReadSymbol<kNumFilterIntraPredictors>(
+            symbol_decoder_context_.filter_intra_mode_cdf));
   }
 }
 
@@ -457,9 +477,11 @@
     return false;
   }
   ReadCdef(block);
-  ReadQuantizerIndexDelta(block);
-  ReadLoopFilterDelta(block);
-  read_deltas_ = false;
+  if (read_deltas_) {
+    ReadQuantizerIndexDelta(block);
+    ReadLoopFilterDelta(block);
+    read_deltas_ = false;
+  }
   PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
   prediction_parameters.use_intra_block_copy = false;
@@ -551,29 +573,32 @@
   BlockParameters& bp = *block.bp;
   if (bp.skip_mode) {
     bp.is_inter = true;
-  } else if (frame_header_.segmentation.FeatureActive(
-                 bp.segment_id, kSegmentFeatureReferenceFrame)) {
+    return;
+  }
+  if (frame_header_.segmentation.FeatureActive(bp.segment_id,
+                                               kSegmentFeatureReferenceFrame)) {
     bp.is_inter =
         frame_header_.segmentation
             .feature_data[bp.segment_id][kSegmentFeatureReferenceFrame] !=
         kReferenceFrameIntra;
-  } else if (frame_header_.segmentation.FeatureActive(
-                 bp.segment_id, kSegmentFeatureGlobalMv)) {
-    bp.is_inter = true;
-  } else {
-    int context = 0;
-    if (block.top_available && block.left_available) {
-      context =
-          (block.IsTopIntra() && block.IsLeftIntra())
-              ? 3
-              : static_cast<int>(block.IsTopIntra() || block.IsLeftIntra());
-    } else if (block.top_available || block.left_available) {
-      context = 2 * static_cast<int>(block.top_available ? block.IsTopIntra()
-                                                         : block.IsLeftIntra());
-    }
-    bp.is_inter =
-        reader_.ReadSymbol(symbol_decoder_context_.is_inter_cdf[context]);
+    return;
   }
+  if (frame_header_.segmentation.FeatureActive(bp.segment_id,
+                                               kSegmentFeatureGlobalMv)) {
+    bp.is_inter = true;
+    return;
+  }
+  int context = 0;
+  if (block.top_available && block.left_available) {
+    context = (block.IsTopIntra() && block.IsLeftIntra())
+                  ? 3
+                  : static_cast<int>(block.IsTopIntra() || block.IsLeftIntra());
+  } else if (block.top_available || block.left_available) {
+    context = 2 * static_cast<int>(block.top_available ? block.IsTopIntra()
+                                                       : block.IsLeftIntra());
+  }
+  bp.is_inter =
+      reader_.ReadSymbol(symbol_decoder_context_.is_inter_cdf[context]);
 }
 
 bool Tile::ReadIntraBlockModeInfo(const Block& block, bool intra_y_mode) {
@@ -677,62 +702,67 @@
       symbol_decoder_context_.compound_reference_type_cdf[context]));
 }
 
-int Tile::GetReferenceContext(
-    const Block& block, const std::vector<ReferenceFrameType>& types1,
-    const std::vector<ReferenceFrameType>& types2) const {
-  int count[2] = {};
-  for (int i = 0; i < 2; ++i) {
-    for (const auto& type : (i == 0) ? types1 : types2) {
-      count[i] += block.CountReferences(type);
-    }
+int Tile::GetReferenceContext(const Block& block,
+                              ReferenceFrameType type0_start,
+                              ReferenceFrameType type0_end,
+                              ReferenceFrameType type1_start,
+                              ReferenceFrameType type1_end) const {
+  int count0 = 0;
+  int count1 = 0;
+  for (int type = type0_start; type <= type0_end; ++type) {
+    count0 += block.CountReferences(static_cast<ReferenceFrameType>(type));
   }
-  return (count[0] < count[1]) ? 0 : (count[0] == count[1] ? 1 : 2);
+  for (int type = type1_start; type <= type1_end; ++type) {
+    count1 += block.CountReferences(static_cast<ReferenceFrameType>(type));
+  }
+  return (count0 < count1) ? 0 : (count0 == count1 ? 1 : 2);
 }
 
+template <bool is_single, bool is_backward, int index>
 uint16_t* Tile::GetReferenceCdf(
-    const Block& block, bool is_single, bool is_backward, int index,
+    const Block& block,
     CompoundReferenceType type /*= kNumCompoundReferenceTypes*/) {
   int context = 0;
   if ((type == kCompoundReferenceUnidirectional && index == 0) ||
       (is_single && index == 1)) {
     // uni_comp_ref and single_ref_p1.
     context =
-        GetReferenceContext(block,
-                            {kReferenceFrameLast, kReferenceFrameLast2,
-                             kReferenceFrameLast3, kReferenceFrameGolden},
-                            {kReferenceFrameBackward, kReferenceFrameAlternate2,
-                             kReferenceFrameAlternate});
+        GetReferenceContext(block, kReferenceFrameLast, kReferenceFrameGolden,
+                            kReferenceFrameBackward, kReferenceFrameAlternate);
   } else if (type == kCompoundReferenceUnidirectional && index == 1) {
     // uni_comp_ref_p1.
     context =
-        GetReferenceContext(block, {kReferenceFrameLast2},
-                            {kReferenceFrameLast3, kReferenceFrameGolden});
+        GetReferenceContext(block, kReferenceFrameLast2, kReferenceFrameLast2,
+                            kReferenceFrameLast3, kReferenceFrameGolden);
   } else if ((type == kCompoundReferenceUnidirectional && index == 2) ||
              (type == kCompoundReferenceBidirectional && index == 2) ||
              (is_single && index == 5)) {
     // uni_comp_ref_p2, comp_ref_p2 and single_ref_p5.
-    context = GetReferenceContext(block, {kReferenceFrameLast3},
-                                  {kReferenceFrameGolden});
+    context =
+        GetReferenceContext(block, kReferenceFrameLast3, kReferenceFrameLast3,
+                            kReferenceFrameGolden, kReferenceFrameGolden);
   } else if ((type == kCompoundReferenceBidirectional && index == 0) ||
              (is_single && index == 3)) {
     // comp_ref and single_ref_p3.
     context =
-        GetReferenceContext(block, {kReferenceFrameLast, kReferenceFrameLast2},
-                            {kReferenceFrameLast3, kReferenceFrameGolden});
+        GetReferenceContext(block, kReferenceFrameLast, kReferenceFrameLast2,
+                            kReferenceFrameLast3, kReferenceFrameGolden);
   } else if ((type == kCompoundReferenceBidirectional && index == 1) ||
              (is_single && index == 4)) {
     // comp_ref_p1 and single_ref_p4.
-    context = GetReferenceContext(block, {kReferenceFrameLast},
-                                  {kReferenceFrameLast2});
+    context =
+        GetReferenceContext(block, kReferenceFrameLast, kReferenceFrameLast,
+                            kReferenceFrameLast2, kReferenceFrameLast2);
   } else if ((is_single && index == 2) || (is_backward && index == 0)) {
     // single_ref_p2 and comp_bwdref.
     context = GetReferenceContext(
-        block, {kReferenceFrameBackward, kReferenceFrameAlternate2},
-        {kReferenceFrameAlternate});
+        block, kReferenceFrameBackward, kReferenceFrameAlternate2,
+        kReferenceFrameAlternate, kReferenceFrameAlternate);
   } else if ((is_single && index == 6) || (is_backward && index == 1)) {
     // single_ref_p6 and comp_bwdref_p1.
-    context = GetReferenceContext(block, {kReferenceFrameBackward},
-                                  {kReferenceFrameAlternate2});
+    context = GetReferenceContext(
+        block, kReferenceFrameBackward, kReferenceFrameBackward,
+        kReferenceFrameAlternate2, kReferenceFrameAlternate2);
   }
   if (is_single) {
     // The index parameter for single references is offset by one since the spec
@@ -781,21 +811,21 @@
     if (reference_type == kCompoundReferenceUnidirectional) {
       // uni_comp_ref.
       if (reader_.ReadSymbol(
-              GetReferenceCdf(block, false, false, 0, reference_type))) {
+              GetReferenceCdf<false, false, 0>(block, reference_type))) {
         bp.reference_frame[0] = kReferenceFrameBackward;
         bp.reference_frame[1] = kReferenceFrameAlternate;
         return;
       }
       // uni_comp_ref_p1.
       if (!reader_.ReadSymbol(
-              GetReferenceCdf(block, false, false, 1, reference_type))) {
+              GetReferenceCdf<false, false, 1>(block, reference_type))) {
         bp.reference_frame[0] = kReferenceFrameLast;
         bp.reference_frame[1] = kReferenceFrameLast2;
         return;
       }
       // uni_comp_ref_p2.
       if (reader_.ReadSymbol(
-              GetReferenceCdf(block, false, false, 2, reference_type))) {
+              GetReferenceCdf<false, false, 2>(block, reference_type))) {
         bp.reference_frame[0] = kReferenceFrameLast;
         bp.reference_frame[1] = kReferenceFrameGolden;
         return;
@@ -807,26 +837,28 @@
     assert(reference_type == kCompoundReferenceBidirectional);
     // comp_ref.
     if (reader_.ReadSymbol(
-            GetReferenceCdf(block, false, false, 0, reference_type))) {
+            GetReferenceCdf<false, false, 0>(block, reference_type))) {
       // comp_ref_p2.
-      bp.reference_frame[0] = reader_.ReadSymbol(GetReferenceCdf(
-                                  block, false, false, 2, reference_type))
-                                  ? kReferenceFrameGolden
-                                  : kReferenceFrameLast3;
+      bp.reference_frame[0] =
+          reader_.ReadSymbol(
+              GetReferenceCdf<false, false, 2>(block, reference_type))
+              ? kReferenceFrameGolden
+              : kReferenceFrameLast3;
     } else {
       // comp_ref_p1.
-      bp.reference_frame[0] = reader_.ReadSymbol(GetReferenceCdf(
-                                  block, false, false, 1, reference_type))
-                                  ? kReferenceFrameLast2
-                                  : kReferenceFrameLast;
+      bp.reference_frame[0] =
+          reader_.ReadSymbol(
+              GetReferenceCdf<false, false, 1>(block, reference_type))
+              ? kReferenceFrameLast2
+              : kReferenceFrameLast;
     }
     // comp_bwdref.
-    if (reader_.ReadSymbol(GetReferenceCdf(block, false, true, 0))) {
+    if (reader_.ReadSymbol(GetReferenceCdf<false, true, 0>(block))) {
       bp.reference_frame[1] = kReferenceFrameAlternate;
     } else {
       // comp_bwdref_p1.
       bp.reference_frame[1] =
-          reader_.ReadSymbol(GetReferenceCdf(block, false, true, 1))
+          reader_.ReadSymbol(GetReferenceCdf<false, true, 1>(block))
               ? kReferenceFrameAlternate2
               : kReferenceFrameBackward;
     }
@@ -835,31 +867,31 @@
   assert(!use_compound_reference);
   bp.reference_frame[1] = kReferenceFrameNone;
   // single_ref_p1.
-  if (reader_.ReadSymbol(GetReferenceCdf(block, true, false, 1))) {
+  if (reader_.ReadSymbol(GetReferenceCdf<true, false, 1>(block))) {
     // single_ref_p2.
-    if (reader_.ReadSymbol(GetReferenceCdf(block, true, false, 2))) {
+    if (reader_.ReadSymbol(GetReferenceCdf<true, false, 2>(block))) {
       bp.reference_frame[0] = kReferenceFrameAlternate;
       return;
     }
     // single_ref_p6.
     bp.reference_frame[0] =
-        reader_.ReadSymbol(GetReferenceCdf(block, true, false, 6))
+        reader_.ReadSymbol(GetReferenceCdf<true, false, 6>(block))
             ? kReferenceFrameAlternate2
             : kReferenceFrameBackward;
     return;
   }
   // single_ref_p3.
-  if (reader_.ReadSymbol(GetReferenceCdf(block, true, false, 3))) {
+  if (reader_.ReadSymbol(GetReferenceCdf<true, false, 3>(block))) {
     // single_ref_p5.
     bp.reference_frame[0] =
-        reader_.ReadSymbol(GetReferenceCdf(block, true, false, 5))
+        reader_.ReadSymbol(GetReferenceCdf<true, false, 5>(block))
             ? kReferenceFrameGolden
             : kReferenceFrameLast3;
     return;
   }
   // single_ref_p4.
   bp.reference_frame[0] =
-      reader_.ReadSymbol(GetReferenceCdf(block, true, false, 4))
+      reader_.ReadSymbol(GetReferenceCdf<true, false, 4>(block))
           ? kReferenceFrameLast2
           : kReferenceFrameLast;
 }
@@ -883,9 +915,8 @@
     const int idx1 =
         std::min(mode_contexts.new_mv, kCompoundModeNewMvContexts - 1);
     const int context = kCompoundModeContextMap[idx0][idx1];
-    const int offset = reader_.ReadSymbol(
-        symbol_decoder_context_.compound_prediction_mode_cdf[context],
-        static_cast<int>(kNumCompoundInterPredictionModes));
+    const int offset = reader_.ReadSymbol<kNumCompoundInterPredictionModes>(
+        symbol_decoder_context_.compound_prediction_mode_cdf[context]);
     bp.y_mode =
         static_cast<PredictionMode>(kPredictionModeNearestNearestMv + offset);
     return;
@@ -916,10 +947,12 @@
       *block.bp->prediction_parameters;
   prediction_parameters.ref_mv_index = 0;
   if (bp.y_mode != kPredictionModeNewMv &&
-      bp.y_mode != kPredictionModeNewNewMv && !HasNearMv(bp.y_mode)) {
+      bp.y_mode != kPredictionModeNewNewMv &&
+      !kPredictionModeHasNearMvMask.Contains(bp.y_mode)) {
     return;
   }
-  const int start = static_cast<int>(HasNearMv(bp.y_mode));
+  const int start =
+      static_cast<int>(kPredictionModeHasNearMvMask.Contains(bp.y_mode));
   prediction_parameters.ref_mv_index = start;
   for (int i = start; i < start + 2; ++i) {
     if (prediction_parameters.ref_mv_count <= i + 1) continue;
@@ -940,7 +973,7 @@
   prediction_parameters.inter_intra_mode = kNumInterIntraModes;
   prediction_parameters.is_wedge_inter_intra = false;
   if (bp.skip_mode || !sequence_header_.enable_interintra_compound ||
-      is_compound || !IsInterIntraModeAllowed(block.size)) {
+      is_compound || !kIsInterIntraModeAllowedMask.Contains(block.size)) {
     return;
   }
   // kSizeGroup[block.size] is guaranteed to be non-zero because of the block
@@ -952,10 +985,10 @@
     prediction_parameters.inter_intra_mode = kNumInterIntraModes;
     return;
   }
-  prediction_parameters
-      .inter_intra_mode = static_cast<InterIntraMode>(reader_.ReadSymbol(
-      symbol_decoder_context_.inter_intra_mode_cdf[kSizeGroup[block.size] - 1],
-      static_cast<int>(kNumInterIntraModes)));
+  prediction_parameters.inter_intra_mode =
+      static_cast<InterIntraMode>(reader_.ReadSymbol<kNumInterIntraModes>(
+          symbol_decoder_context_
+              .inter_intra_mode_cdf[kSizeGroup[block.size] - 1]));
   bp.reference_frame[1] = kReferenceFrameIntra;
   prediction_parameters.angle_delta[kPlaneTypeY] = 0;
   prediction_parameters.angle_delta[kPlaneTypeUV] = 0;
@@ -964,12 +997,12 @@
       symbol_decoder_context_.is_wedge_inter_intra_cdf[block.size]);
   if (!prediction_parameters.is_wedge_inter_intra) return;
   prediction_parameters.wedge_index =
-      reader_.ReadSymbol(symbol_decoder_context_.wedge_index_cdf[block.size],
-                         kWedgeIndexSymbolCount);
+      reader_.ReadSymbol<kWedgeIndexSymbolCount>(
+          symbol_decoder_context_.wedge_index_cdf[block.size]);
   prediction_parameters.wedge_sign = 0;
 }
 
-bool Tile::IsScaled(ReferenceFrameType type) {
+bool Tile::IsScaled(ReferenceFrameType type) const {
   const int index =
       frame_header_.reference_frame_index[type - kReferenceFrameLast];
   const int x_scale = ((reference_frames_[index]->upscaled_width()
@@ -991,7 +1024,7 @@
   const auto global_motion_type =
       frame_header_.global_motion[bp.reference_frame[0]].type;
   if (bp.skip_mode || !frame_header_.is_motion_mode_switchable ||
-      kBlockWidthPixels[block.size] < 8 || kBlockHeightPixels[block.size] < 8 ||
+      IsBlockDimension4(block.size) ||
       (frame_header_.force_integer_mv == 0 &&
        (bp.y_mode == kPredictionModeGlobalMv ||
         bp.y_mode == kPredictionModeGlobalGlobalMv) &&
@@ -1017,9 +1050,9 @@
             : kMotionModeSimple;
     return;
   }
-  prediction_parameters.motion_mode = static_cast<MotionMode>(
-      reader_.ReadSymbol(symbol_decoder_context_.motion_mode_cdf[block.size],
-                         static_cast<int>(kNumMotionModes)));
+  prediction_parameters.motion_mode =
+      static_cast<MotionMode>(reader_.ReadSymbol<kNumMotionModes>(
+          symbol_decoder_context_.motion_mode_cdf[block.size]));
 }
 
 uint16_t* Tile::GetIsExplicitCompoundTypeCdf(const Block& block) {
@@ -1087,16 +1120,15 @@
           reader_.ReadSymbol(GetIsExplicitCompoundTypeCdf(block));
     }
     if (bp.is_explicit_compound_type) {
-      if (kWedgeBits[block.size] == 0) {
-        prediction_parameters.compound_prediction_type =
-            kCompoundPredictionTypeDiffWeighted;
-      } else {
+      if (kIsWedgeCompoundModeAllowed.Contains(block.size)) {
         // Only kCompoundPredictionTypeWedge and
         // kCompoundPredictionTypeDiffWeighted are signaled explicitly.
         prediction_parameters.compound_prediction_type =
             static_cast<CompoundPredictionType>(reader_.ReadSymbol(
-                symbol_decoder_context_.compound_type_cdf[block.size],
-                static_cast<int>(kNumExplicitCompoundPredictionTypes)));
+                symbol_decoder_context_.compound_type_cdf[block.size]));
+      } else {
+        prediction_parameters.compound_prediction_type =
+            kCompoundPredictionTypeDiffWeighted;
       }
     } else {
       if (sequence_header_.enable_jnt_comp) {
@@ -1113,9 +1145,9 @@
     }
     if (prediction_parameters.compound_prediction_type ==
         kCompoundPredictionTypeWedge) {
-      prediction_parameters.wedge_index = reader_.ReadSymbol(
-          symbol_decoder_context_.wedge_index_cdf[block.size],
-          kWedgeIndexSymbolCount);
+      prediction_parameters.wedge_index =
+          reader_.ReadSymbol<kWedgeIndexSymbolCount>(
+              symbol_decoder_context_.wedge_index_cdf[block.size]);
       prediction_parameters.wedge_sign =
           static_cast<int>(reader_.ReadLiteral(1));
     } else if (prediction_parameters.compound_prediction_type ==
@@ -1179,18 +1211,16 @@
     }
     return;
   }
-  const bool is_block_larger_than_8x8 =
-      std::min(kBlockWidthPixels[block.size], kBlockHeightPixels[block.size]) >=
-      8;
   bool interpolation_filter_present = true;
   if (bp.skip_mode ||
       block.bp->prediction_parameters->motion_mode == kMotionModeLocalWarp) {
     interpolation_filter_present = false;
-  } else if (is_block_larger_than_8x8 && bp.y_mode == kPredictionModeGlobalMv) {
+  } else if (!IsBlockDimension4(block.size) &&
+             bp.y_mode == kPredictionModeGlobalMv) {
     interpolation_filter_present =
         frame_header_.global_motion[bp.reference_frame[0]].type ==
         kGlobalMotionTransformationTypeTranslation;
-  } else if (is_block_larger_than_8x8 &&
+  } else if (!IsBlockDimension4(block.size) &&
              bp.y_mode == kPredictionModeGlobalGlobalMv) {
     interpolation_filter_present =
         frame_header_.global_motion[bp.reference_frame[0]].type ==
@@ -1201,9 +1231,9 @@
   for (int i = 0; i < (sequence_header_.enable_dual_filter ? 2 : 1); ++i) {
     bp.interpolation_filter[i] =
         interpolation_filter_present
-            ? static_cast<InterpolationFilter>(reader_.ReadSymbol(
-                  GetInterpolationFilterCdf(block, i),
-                  static_cast<int>(kNumExplicitInterpolationFilters)))
+            ? static_cast<InterpolationFilter>(
+                  reader_.ReadSymbol<kNumExplicitInterpolationFilters>(
+                      GetInterpolationFilterCdf(block, i)))
             : kInterpolationFilterEightTap;
   }
   if (!sequence_header_.enable_dual_filter) {
@@ -1250,9 +1280,11 @@
     return false;
   }
   ReadCdef(block);
-  ReadQuantizerIndexDelta(block);
-  ReadLoopFilterDelta(block);
-  read_deltas_ = false;
+  if (read_deltas_) {
+    ReadQuantizerIndexDelta(block);
+    ReadLoopFilterDelta(block);
+    read_deltas_ = false;
+  }
   ReadIsInter(block);
   return bp.is_inter ? ReadInterBlockModeInfo(block)
                      : ReadIntraBlockModeInfo(block, /*intra_y_mode=*/false);
diff --git a/libgav1/src/tile/bitstream/palette.cc b/libgav1/src/tile/bitstream/palette.cc
index 9c9bd4e..133b016 100644
--- a/libgav1/src/tile/bitstream/palette.cc
+++ b/libgav1/src/tile/bitstream/palette.cc
@@ -8,6 +8,7 @@
 #include "src/obu_parser.h"
 #include "src/symbol_decoder_context.h"
 #include "src/tile.h"
+#include "src/utils/bit_mask_set.h"
 #include "src/utils/common.h"
 #include "src/utils/constants.h"
 #include "src/utils/entropy_decoder.h"
@@ -16,14 +17,9 @@
 namespace libgav1 {
 namespace {
 
-const int kNumPaletteNeighbors = 3;
-const uint8_t kPaletteColorHashMultiplier[kNumPaletteNeighbors] = {1, 2, 2};
-const int kPaletteColorIndexContext[kPaletteColorIndexSymbolCount + 1] = {
-    -1, -1, 0, -1, -1, 4, 3, 2, 1};
-
 // Add |value| to the |cache| if it doesn't already exist.
-inline void MaybeAddToPaletteCache(uint16_t value, uint16_t* const cache,
-                                   int* const n) {
+void MaybeAddToPaletteCache(uint16_t value, uint16_t* const cache,
+                            int* const n) {
   assert(cache != nullptr);
   assert(n != nullptr);
   if (*n > 0 && value == cache[*n - 1]) return;
@@ -170,8 +166,7 @@
 
 void Tile::ReadPaletteModeInfo(const Block& block) {
   BlockParameters& bp = *block.bp;
-  if (IsBlockSmallerThan8x8(block.size) || kBlockWidthPixels[block.size] > 64 ||
-      kBlockHeightPixels[block.size] > 64 ||
+  if (IsBlockSmallerThan8x8(block.size) || block.size > kBlock64x64 ||
       !frame_header_.allow_screen_content_tools) {
     bp.palette_mode_info.size[kPlaneTypeY] = 0;
     bp.palette_mode_info.size[kPlaneTypeUV] = 0;
@@ -186,9 +181,8 @@
     if (has_palette_y) {
       bp.palette_mode_info.size[kPlaneTypeY] =
           kMinPaletteSize +
-          reader_.ReadSymbol(
-              symbol_decoder_context_.palette_y_size_cdf[block_size_context],
-              kPaletteSizeSymbolCount);
+          reader_.ReadSymbol<kPaletteSizeSymbolCount>(
+              symbol_decoder_context_.palette_y_size_cdf[block_size_context]);
       ReadPaletteColors(block, kPlaneY);
     }
   }
@@ -201,60 +195,86 @@
     if (has_palette_uv) {
       bp.palette_mode_info.size[kPlaneTypeUV] =
           kMinPaletteSize +
-          reader_.ReadSymbol(
-              symbol_decoder_context_.palette_uv_size_cdf[block_size_context],
-              kPaletteSizeSymbolCount);
+          reader_.ReadSymbol<kPaletteSizeSymbolCount>(
+              symbol_decoder_context_.palette_uv_size_cdf[block_size_context]);
       ReadPaletteColors(block, kPlaneU);
     }
   }
 }
 
-int Tile::GetPaletteColorContext(const Block& block, PlaneType plane_type,
-                                 int row, int column, int palette_size,
-                                 uint8_t color_order[kMaxPaletteSize]) {
-  for (int i = 0; i < kMaxPaletteSize; ++i) {
-    color_order[i] = i;
-  }
-  int scores[kMaxPaletteSize] = {};
+void Tile::PopulatePaletteColorContexts(
+    const Block& block, PlaneType plane_type, int i, int start, int end,
+    uint8_t color_order[kMaxPaletteSquare][kMaxPaletteSize],
+    uint8_t color_context[kMaxPaletteSquare]) {
   const PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
-  if (row > 0 && column > 0) {
-    ++scores[prediction_parameters
-                 .color_index_map[plane_type][row - 1][column - 1]];
-  }
-  if (row > 0) {
-    scores[prediction_parameters
-               .color_index_map[plane_type][row - 1][column]] += 2;
-  }
-  if (column > 0) {
-    scores[prediction_parameters
-               .color_index_map[plane_type][row][column - 1]] += 2;
-  }
-  // Move the top 3 scores (largest first) and the corresponding color_order
-  // entry to the front of the array.
-  for (int i = 0; i < kNumPaletteNeighbors; ++i) {
-    const auto max_element =
-        std::max_element(scores + i, scores + palette_size);
-    const auto max_score = *max_element;
-    const auto max_index = static_cast<int>(std::distance(scores, max_element));
-    if (max_index != i) {
-      const uint8_t max_color_order = color_order[max_index];
-      for (int j = max_index; j > i; --j) {
-        scores[j] = scores[j - 1];
-        color_order[j] = color_order[j - 1];
+  for (int column = start, counter = 0; column >= end; --column, ++counter) {
+    const int row = i - column;
+    assert(row > 0 || column > 0);
+    const uint8_t top =
+        (row > 0)
+            ? prediction_parameters.color_index_map[plane_type][row - 1][column]
+            : 0;
+    const uint8_t left =
+        (column > 0)
+            ? prediction_parameters.color_index_map[plane_type][row][column - 1]
+            : 0;
+    uint8_t index_mask;
+    static_assert(kMaxPaletteSize <= 8, "");
+    int index;
+    if (column <= 0) {
+      color_context[counter] = 0;
+      color_order[counter][0] = top;
+      index_mask = 1 << top;
+      index = 1;
+    } else if (row <= 0) {
+      color_context[counter] = 0;
+      color_order[counter][0] = left;
+      index_mask = 1 << left;
+      index = 1;
+    } else {
+      const uint8_t top_left =
+          prediction_parameters
+              .color_index_map[plane_type][row - 1][column - 1];
+      index_mask = (1 << top) | (1 << left) | (1 << top_left);
+      if (top == left && top == top_left) {
+        color_context[counter] = 4;
+        color_order[counter][0] = top;
+        index = 1;
+      } else if (top == left) {
+        color_context[counter] = 3;
+        color_order[counter][0] = top;
+        color_order[counter][1] = top_left;
+        index = 2;
+      } else if (top == top_left) {
+        color_context[counter] = 2;
+        color_order[counter][0] = top_left;
+        color_order[counter][1] = left;
+        index = 2;
+      } else if (left == top_left) {
+        color_context[counter] = 2;
+        color_order[counter][0] = top_left;
+        color_order[counter][1] = top;
+        index = 2;
+      } else {
+        color_context[counter] = 1;
+        color_order[counter][0] = std::min(top, left);
+        color_order[counter][1] = std::max(top, left);
+        color_order[counter][2] = top_left;
+        index = 3;
       }
-      scores[i] = max_score;
-      color_order[i] = max_color_order;
+    }
+    // Even though only the first |palette_size| entries of this array are ever
+    // used, it is faster to populate all 8 because of the vectorization of the
+    // constant sized loop.
+    for (uint8_t j = 0; j < kMaxPaletteSize; ++j) {
+      if (BitMaskSet::MaskContainsValue(index_mask, j)) continue;
+      color_order[counter][index++] = j;
     }
   }
-  int context = 0;
-  for (int i = 0; i < kNumPaletteNeighbors; ++i) {
-    context += scores[i] * kPaletteColorHashMultiplier[i];
-  }
-  return kPaletteColorIndexContext[context];
 }
 
-void Tile::ReadPaletteTokens(const Block& block) {
+bool Tile::ReadPaletteTokens(const Block& block) {
   const PaletteModeInfo& palette_mode_info = block.bp->palette_mode_info;
   PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
@@ -283,24 +303,29 @@
         screen_width += 2;
       }
     }
-    uint8_t color_order[kMaxPaletteSize];
+    if (!prediction_parameters.color_index_map[plane_type].Reset(
+            block_height, block_width, /*zero_initialize=*/false)) {
+      return false;
+    }
     int first_value = 0;
     reader_.DecodeUniform(palette_size, &first_value);
     prediction_parameters.color_index_map[plane_type][0][0] = first_value;
     for (int i = 1; i < screen_height + screen_width - 1; ++i) {
-      for (int j = std::min(i, screen_width - 1);
-           j >= std::max(0, i - screen_height + 1); --j) {
-        const int context =
-            GetPaletteColorContext(block, static_cast<PlaneType>(plane_type),
-                                   i - j, j, palette_size, color_order);
-        assert(context >= 0);
+      const int start = std::min(i, screen_width - 1);
+      const int end = std::max(0, i - screen_height + 1);
+      uint8_t color_order[kMaxPaletteSquare][kMaxPaletteSize];
+      uint8_t color_context[kMaxPaletteSquare];
+      PopulatePaletteColorContexts(block, static_cast<PlaneType>(plane_type), i,
+                                   start, end, color_order, color_context);
+      for (int j = start, counter = 0; j >= end; --j, ++counter) {
         uint16_t* const cdf =
             symbol_decoder_context_
-                .palette_color_index_cdf[plane_type][palette_size -
-                                                     kMinPaletteSize][context];
+                .palette_color_index_cdf[plane_type]
+                                        [palette_size - kMinPaletteSize]
+                                        [color_context[counter]];
         const int color_order_index = reader_.ReadSymbol(cdf, palette_size);
         prediction_parameters.color_index_map[plane_type][i - j][j] =
-            color_order[color_order_index];
+            color_order[counter][color_order_index];
       }
     }
     if (screen_width < block_width) {
@@ -319,6 +344,7 @@
           block_width);
     }
   }
+  return true;
 }
 
 }  // namespace libgav1
diff --git a/libgav1/src/tile/bitstream/partition.cc b/libgav1/src/tile/bitstream/partition.cc
index 0709b2f..3a31095 100644
--- a/libgav1/src/tile/bitstream/partition.cc
+++ b/libgav1/src/tile/bitstream/partition.cc
@@ -12,12 +12,11 @@
 namespace libgav1 {
 namespace {
 
-inline uint16_t InverseCdfProbability(uint16_t probability) {
+uint16_t InverseCdfProbability(uint16_t probability) {
   return kCdfMaxProbability - probability;
 }
 
-inline uint16_t CdfElementProbability(const uint16_t* const cdf,
-                                      uint8_t element) {
+uint16_t CdfElementProbability(const uint16_t* const cdf, uint8_t element) {
   return (element > 0 ? cdf[element - 1] : uint16_t{kCdfMaxProbability}) -
          cdf[element];
 }
diff --git a/libgav1/src/tile/bitstream/transform_size.cc b/libgav1/src/tile/bitstream/transform_size.cc
index f11fb10..36867f1 100644
--- a/libgav1/src/tile/bitstream/transform_size.cc
+++ b/libgav1/src/tile/bitstream/transform_size.cc
@@ -1,7 +1,6 @@
 #include <algorithm>
 #include <cstdint>
 #include <cstring>
-#include <vector>
 
 #include "src/dsp/constants.h"
 #include "src/obu_parser.h"
@@ -13,14 +12,30 @@
 #include "src/utils/constants.h"
 #include "src/utils/entropy_decoder.h"
 #include "src/utils/segmentation.h"
+#include "src/utils/stack.h"
 #include "src/utils/types.h"
 
 namespace libgav1 {
 namespace {
 
-const uint8_t kMaxVariableTransformTreeDepth = 2;
+constexpr uint8_t kMaxVariableTransformTreeDepth = 2;
+// Max_Tx_Depth array from section 5.11.5 in the spec with the following
+// modification: If the element is not zero, it is subtracted by one. That is
+// the only way in which this array is being used.
+constexpr int kTxDepthCdfIndex[kMaxBlockSizes] = {
+    0, 0, 1, 0, 0, 1, 2, 1, 1, 1, 2, 3, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3};
 
-inline TransformSize GetSquareTransformSize(uint8_t pixels) {
+constexpr TransformSize kMaxTransformSizeRectangle[kMaxBlockSizes] = {
+    kTransformSize4x4,   kTransformSize4x8,   kTransformSize4x16,
+    kTransformSize8x4,   kTransformSize8x8,   kTransformSize8x16,
+    kTransformSize8x32,  kTransformSize16x4,  kTransformSize16x8,
+    kTransformSize16x16, kTransformSize16x32, kTransformSize16x64,
+    kTransformSize32x8,  kTransformSize32x16, kTransformSize32x32,
+    kTransformSize32x64, kTransformSize64x16, kTransformSize64x32,
+    kTransformSize64x64, kTransformSize64x64, kTransformSize64x64,
+    kTransformSize64x64};
+
+TransformSize GetSquareTransformSize(uint8_t pixels) {
   switch (pixels) {
     case 128:
     case 64:
@@ -71,7 +86,7 @@
   }
   const TransformSize max_rect_tx_size = kMaxTransformSizeRectangle[block.size];
   const bool allow_select = !bp.skip || !bp.is_inter;
-  if (block.size <= kBlock4x4 || !allow_select ||
+  if (block.size == kBlock4x4 || !allow_select ||
       frame_header_.tx_mode != kTxModeSelect) {
     return max_rect_tx_size;
   }
@@ -87,16 +102,16 @@
           : 0;
   const auto context = static_cast<int>(top_width >= max_tx_width) +
                        static_cast<int>(left_height >= max_tx_height);
-  const int max_tx_depth = kMaxTransformDepth[block.size];
-  const int cdf_index = (max_tx_depth > 0) ? max_tx_depth - 1 : 0;
-  const int symbol_count = (cdf_index == 0) ? 2 : 3;
+  const int cdf_index = kTxDepthCdfIndex[block.size];
+  const int symbol_count = 3 - static_cast<int>(cdf_index == 0);
   const int tx_depth = reader_.ReadSymbol(
       symbol_decoder_context_.tx_depth_cdf[cdf_index][context], symbol_count);
+  assert(tx_depth < 3);
   TransformSize tx_size = max_rect_tx_size;
-  for (int i = 0; i < tx_depth; ++i) {
-    tx_size = kSplitTransformSize[tx_size];
-  }
-  return tx_size;
+  if (tx_depth == 0) return tx_size;
+  tx_size = kSplitTransformSize[tx_size];
+  if (tx_depth == 1) return tx_size;
+  return kSplitTransformSize[tx_size];
 }
 
 void Tile::ReadVariableTransformTree(const Block& block, int row4x4,
@@ -108,24 +123,22 @@
                              TransformSizeToSquareTransformIndex(max_tx_size)) *
                             6;
 
-  std::vector<TransformTreeNode> stack;
   // Branching factor is 4 and maximum depth is 2. So the maximum stack size
-  // necessary is 8.
-  stack.reserve(8);
-  stack.emplace_back(row4x4, column4x4, tx_size, 0);
+  // necessary is (4 - 1) + 4 = 7.
+  Stack<TransformTreeNode, 7> stack;
+  stack.Push(TransformTreeNode(column4x4, row4x4, tx_size, 0));
 
-  while (!stack.empty()) {
-    TransformTreeNode node = stack.back();
-    stack.pop_back();
-    const int tx_width4x4 = DivideBy4(kTransformWidth[node.tx_size]);
-    const int tx_height4x4 = DivideBy4(kTransformHeight[node.tx_size]);
+  while (!stack.Empty()) {
+    TransformTreeNode node = stack.Pop();
+    const int tx_width4x4 = kTransformWidth4x4[node.tx_size];
+    const int tx_height4x4 = kTransformHeight4x4[node.tx_size];
     if (node.tx_size != kTransformSize4x4 &&
         node.depth != kMaxVariableTransformTreeDepth) {
-      const auto top = static_cast<int>(
-          GetTopTransformWidth(block, node.row4x4, node.column4x4, false) <
-          kTransformWidth[node.tx_size]);
+      const auto top =
+          static_cast<int>(GetTopTransformWidth(block, node.y, node.x, false) <
+                           kTransformWidth[node.tx_size]);
       const auto left = static_cast<int>(
-          GetLeftTransformHeight(block, node.row4x4, node.column4x4, false) <
+          GetLeftTransformHeight(block, node.y, node.x, false) <
           kTransformHeight[node.tx_size]);
       const int context =
           static_cast<int>(max_tx_size > kTransformSize8x8 &&
@@ -136,20 +149,20 @@
       // tx_split.
       if (reader_.ReadSymbol(symbol_decoder_context_.tx_split_cdf[context])) {
         const TransformSize sub_tx_size = kSplitTransformSize[node.tx_size];
-        const int step_width4x4 = DivideBy4(kTransformWidth[sub_tx_size]);
-        const int step_height4x4 = DivideBy4(kTransformHeight[sub_tx_size]);
+        const int step_width4x4 = kTransformWidth4x4[sub_tx_size];
+        const int step_height4x4 = kTransformHeight4x4[sub_tx_size];
         // The loops have to run in reverse order because we use a stack for
         // DFS.
         for (int i = tx_height4x4 - step_height4x4; i >= 0;
              i -= step_height4x4) {
           for (int j = tx_width4x4 - step_width4x4; j >= 0;
                j -= step_width4x4) {
-            if (node.row4x4 + i >= frame_header_.rows4x4 ||
-                node.column4x4 + j >= frame_header_.columns4x4) {
+            if (node.y + i >= frame_header_.rows4x4 ||
+                node.x + j >= frame_header_.columns4x4) {
               continue;
             }
-            stack.emplace_back(node.row4x4 + i, node.column4x4 + j, sub_tx_size,
-                               node.depth + 1);
+            stack.Push(TransformTreeNode(node.x + j, node.y + i, sub_tx_size,
+                                         node.depth + 1));
           }
         }
         continue;
@@ -158,10 +171,10 @@
     // tx_split is false.
     for (int i = 0; i < tx_height4x4; ++i) {
       static_assert(sizeof(TransformSize) == 1, "");
-      memset(&inter_transform_sizes_[node.row4x4 + i][node.column4x4],
-             node.tx_size, tx_width4x4);
+      memset(&inter_transform_sizes_[node.y + i][node.x], node.tx_size,
+             tx_width4x4);
     }
-    block_parameters_holder_.Find(node.row4x4, node.column4x4)->transform_size =
+    block_parameters_holder_.Find(node.y, node.x)->transform_size =
         node.tx_size;
   }
 }
@@ -174,8 +187,8 @@
       bp.is_inter && !bp.skip &&
       !frame_header_.segmentation.lossless[bp.segment_id]) {
     const TransformSize max_tx_size = kMaxTransformSizeRectangle[block.size];
-    const int tx_width4x4 = kTransformWidth[max_tx_size] / 4;
-    const int tx_height4x4 = kTransformHeight[max_tx_size] / 4;
+    const int tx_width4x4 = kTransformWidth4x4[max_tx_size];
+    const int tx_height4x4 = kTransformHeight4x4[max_tx_size];
     for (int row = block.row4x4; row < block.row4x4 + block_height4x4;
          row += tx_height4x4) {
       for (int column = block.column4x4;
diff --git a/libgav1/src/tile/prediction.cc b/libgav1/src/tile/prediction.cc
index b180516..6759838 100644
--- a/libgav1/src/tile/prediction.cc
+++ b/libgav1/src/tile/prediction.cc
@@ -15,6 +15,7 @@
 #include "src/prediction_mask.h"
 #include "src/tile.h"
 #include "src/utils/array_2d.h"
+#include "src/utils/bit_mask_set.h"
 #include "src/utils/block_parameters_holder.h"
 #include "src/utils/common.h"
 #include "src/utils/constants.h"
@@ -27,7 +28,8 @@
 namespace libgav1 {
 namespace {
 
-const int kAngleStep = 3;
+constexpr int kObmcBufferSize = 4096;  // 64x64
+constexpr int kAngleStep = 3;
 constexpr int kPredictionModeToAngle[kIntraPredictionModesUV] = {
     0, 90, 180, 45, 135, 113, 157, 203, 67, 0, 0, 0, 0};
 
@@ -39,33 +41,22 @@
 // The values for directional and dc modes are not used since the left/top
 // requirement for those modes depend on the prediction angle and the type of dc
 // mode.
-constexpr uint8_t kPredictionModeNeeds[kIntraPredictionModesY] = {
-    0,                       // kPredictionModeDc
-    kNeedsTop,               // kPredictionModeVertical
-    kNeedsLeft,              // kPredictionModeHorizontal
-    kNeedsTop,               // kPredictionModeD45
-    kNeedsLeft | kNeedsTop,  // kPredictionModeD135
-    kNeedsLeft | kNeedsTop,  // kPredictionModeD113
-    kNeedsLeft | kNeedsTop,  // kPredictionModeD157
-    kNeedsLeft,              // kPredictionModeD203
-    kNeedsTop,               // kPredictionModeD67
-    kNeedsLeft | kNeedsTop,  // kPredictionModeSmooth
-    kNeedsLeft | kNeedsTop,  // kPredictionModeSmoothVertical
-    kNeedsLeft | kNeedsTop,  // kPredictionModeSmoothHorizontal
-    kNeedsLeft | kNeedsTop   // kPredictionModePaeth
+constexpr BitMaskSet kPredictionModeNeedsMask[kIntraPredictionModesY] = {
+    BitMaskSet(0),                      // kPredictionModeDc
+    BitMaskSet(kNeedsTop),              // kPredictionModeVertical
+    BitMaskSet(kNeedsLeft),             // kPredictionModeHorizontal
+    BitMaskSet(kNeedsTop),              // kPredictionModeD45
+    BitMaskSet(kNeedsLeft, kNeedsTop),  // kPredictionModeD135
+    BitMaskSet(kNeedsLeft, kNeedsTop),  // kPredictionModeD113
+    BitMaskSet(kNeedsLeft, kNeedsTop),  // kPredictionModeD157
+    BitMaskSet(kNeedsLeft),             // kPredictionModeD203
+    BitMaskSet(kNeedsTop),              // kPredictionModeD67
+    BitMaskSet(kNeedsLeft, kNeedsTop),  // kPredictionModeSmooth
+    BitMaskSet(kNeedsLeft, kNeedsTop),  // kPredictionModeSmoothVertical
+    BitMaskSet(kNeedsLeft, kNeedsTop),  // kPredictionModeSmoothHorizontal
+    BitMaskSet(kNeedsLeft, kNeedsTop)   // kPredictionModePaeth
 };
 
-const int kBlendFromAbove = 0;
-const int kBlendFromLeft = 1;
-constexpr uint8_t kObmcMask2[2] = {45, 64};
-constexpr uint8_t kObmcMask4[4] = {39, 50, 59, 64};
-constexpr uint8_t kObmcMask8[8] = {36, 42, 48, 53, 57, 61, 64, 64};
-constexpr uint8_t kObmcMask16[16] = {34, 37, 40, 43, 46, 49, 52, 54,
-                                     56, 58, 60, 61, 64, 64, 64, 64};
-constexpr uint8_t kObmcMask32[32] = {33, 35, 36, 38, 40, 41, 43, 44, 45, 47, 48,
-                                     50, 51, 52, 53, 55, 56, 57, 58, 59, 60, 60,
-                                     61, 62, 64, 64, 64, 64, 64, 64, 64, 64};
-
 int16_t GetDirectionalIntraPredictorDerivative(const int angle) {
   assert(angle >= 3);
   assert(angle <= 87);
@@ -98,21 +89,6 @@
   }
 }
 
-const uint8_t* GetObmcMask(int length) {
-  switch (length) {
-    case 2:
-      return kObmcMask2;
-    case 4:
-      return kObmcMask4;
-    case 8:
-      return kObmcMask8;
-    case 16:
-      return kObmcMask16;
-    default:
-      return kObmcMask32;
-  }
-}
-
 // 7.11.2.9.
 int GetIntraEdgeFilterStrength(int width, int height, int filter_type,
                                int delta) {
@@ -236,16 +212,18 @@
 }
 
 // 7.11.3.2.
-void SetInterRoundingBits(const bool is_compound, const int bitdepth,
-                          uint8_t round_bits[2],
-                          uint8_t* const post_round_bits) {
+void GetInterRoundingBits(const bool is_compound, const int bitdepth,
+                          uint8_t round_bits[2]) {
   round_bits[0] = 3;
   round_bits[1] = is_compound ? 7 : 11;
+#if LIBGAV1_MAX_BITDEPTH == 12
   if (bitdepth == 12) {
     round_bits[0] += 2;
     if (!is_compound) round_bits[1] -= 2;
   }
-  *post_round_bits = 2 * kFilterBits - round_bits[0] - round_bits[1];
+#else
+  static_cast<void>(bitdepth);
+#endif
 }
 
 uint8_t* GetStartPoint(Array2DView<uint8_t>* const buffer, const int plane,
@@ -262,31 +240,36 @@
   return &buffer[plane][y][x];
 }
 
-inline int GetPixelPositionFromHighScale(int start, int step, int offset) {
+int GetPixelPositionFromHighScale(int start, int step, int offset) {
   return (start + step * offset) >> kScaleSubPixelBits;
 }
 
-}  // namespace
-
-bool Tile::BlockDecoded(const Block& block, Plane plane, int row4x4,
-                        int column4x4, bool has_top_or_left) const {
-  if (!has_top_or_left) return false;
-  if (row4x4 >= 0 && column4x4 >= 0) {
-    return block.sb_buffer->block_decoded[plane][row4x4][column4x4];
-  }
-  if (row4x4 < 0) {
-    return column4x4 < (block.sb_buffer->block_decoded_width_threshold >>
-                        SubsamplingX(plane));
-  }
-  assert(column4x4 < 0);
-  const int sb_size4x4 =
-      kNum4x4BlocksWide[sequence_header_.use_128x128_superblock ? kBlock128x128
-                                                                : kBlock64x64];
-  return row4x4 < (block.sb_buffer->block_decoded_height_threshold >>
-                   SubsamplingY(plane)) &&
-         row4x4 != (sb_size4x4 >> SubsamplingY(plane));
+dsp::MaskBlendFunc GetMaskBlendFunc(const dsp::Dsp& dsp,
+                                    InterIntraMode inter_intra_mode,
+                                    bool is_wedge_inter_intra,
+                                    int subsampling_x, int subsampling_y) {
+  const int is_inter_intra =
+      static_cast<int>(inter_intra_mode != kNumInterIntraModes);
+  return (is_inter_intra == 1 && !is_wedge_inter_intra)
+             ? dsp.mask_blend[0][is_inter_intra]
+             : dsp.mask_blend[subsampling_x + subsampling_y][is_inter_intra];
 }
 
+void PopulatePredictionMaskFromWedgeMask(const uint8_t* wedge_mask,
+                                         int wedge_mask_stride,
+                                         int prediction_width,
+                                         int prediction_height,
+                                         uint8_t* prediction_mask,
+                                         int prediction_mask_stride) {
+  for (int y = 0; y < prediction_height; ++y) {
+    memcpy(prediction_mask, wedge_mask, prediction_width);
+    prediction_mask += prediction_mask_stride;
+    wedge_mask += wedge_mask_stride;
+  }
+}
+
+}  // namespace
+
 template <typename Pixel>
 void Tile::IntraPrediction(const Block& block, Plane plane, int x, int y,
                            bool has_left, bool has_top, bool has_top_right,
@@ -294,14 +277,12 @@
                            TransformSize tx_size) {
   const int width = 1 << kTransformWidthLog2[tx_size];
   const int height = 1 << kTransformHeightLog2[tx_size];
-  const int x_shift =
-      (plane == kPlaneY) ? 0 : sequence_header_.color_config.subsampling_x;
-  const int y_shift =
-      (plane == kPlaneY) ? 0 : sequence_header_.color_config.subsampling_y;
+  const int x_shift = subsampling_x_[plane];
+  const int y_shift = subsampling_y_[plane];
   const int max_x = (MultiplyBy4(frame_header_.columns4x4) >> x_shift) - 1;
   const int max_y = (MultiplyBy4(frame_header_.rows4x4) >> y_shift) - 1;
-  alignas(16) Pixel top_row_data[160] = {};
-  alignas(16) Pixel left_column_data[160] = {};
+  alignas(kMaxAlignment) Pixel top_row_data[160] = {};
+  alignas(kMaxAlignment) Pixel left_column_data[160] = {};
   // Some predictors use |top_row_data| and |left_column_data| with a negative
   // offset to access pixels to the top-left of the current block. So have some
   // space before the arrays to allow populating those without having to move
@@ -322,7 +303,7 @@
                     kAngleStep
           : 0;
   const bool needs_top = use_filter_intra ||
-                         ((kPredictionModeNeeds[mode] & kNeedsTop) != 0) ||
+                         kPredictionModeNeedsMask[mode].Contains(kNeedsTop) ||
                          (is_directional_mode && prediction_angle < 180) ||
                          (mode == kPredictionModeDc && has_top);
   Array2DView<Pixel> buffer(buffer_[plane].rows(),
@@ -346,7 +327,7 @@
     }
   }
   const bool needs_left = use_filter_intra ||
-                          ((kPredictionModeNeeds[mode] & kNeedsLeft) != 0) ||
+                          kPredictionModeNeedsMask[mode].Contains(kNeedsLeft) ||
                           (is_directional_mode && prediction_angle > 90) ||
                           (mode == kPredictionModeDc && has_left);
   if (needs_left) {
@@ -404,7 +385,11 @@
                                               TransformSize tx_size);
 #endif
 
-bool Tile::UsesSmoothPrediction(int row, int column, Plane plane) const {
+constexpr BitMaskSet kPredictionModeSmoothMask(kPredictionModeSmooth,
+                                               kPredictionModeSmoothHorizontal,
+                                               kPredictionModeSmoothVertical);
+
+bool Tile::IsSmoothPrediction(int row, int column, Plane plane) const {
   const BlockParameters& bp = *block_parameters_holder_.Find(row, column);
   PredictionMode mode;
   if (plane == kPlaneY) {
@@ -413,14 +398,12 @@
     if (bp.reference_frame[0] > kReferenceFrameIntra) return false;
     mode = bp.uv_mode;
   }
-  return mode == kPredictionModeSmooth ||
-         mode == kPredictionModeSmoothHorizontal ||
-         mode == kPredictionModeSmoothVertical;
+  return kPredictionModeSmoothMask.Contains(mode);
 }
 
 int Tile::GetIntraEdgeFilterType(const Block& block, Plane plane) const {
-  const int subsampling_x = SubsamplingX(plane);
-  const int subsampling_y = SubsamplingY(plane);
+  const int subsampling_x = subsampling_x_[plane];
+  const int subsampling_y = subsampling_y_[plane];
   if ((plane == kPlaneY && block.top_available) ||
       (plane != kPlaneY && block.TopAvailableChroma())) {
     const int row =
@@ -429,7 +412,7 @@
     const int column =
         block.column4x4 +
         static_cast<int>(subsampling_x != 0 && (block.column4x4 & 1) == 0);
-    if (UsesSmoothPrediction(row, column, plane)) return 1;
+    if (IsSmoothPrediction(row, column, plane)) return 1;
   }
   if ((plane == kPlaneY && block.left_available) ||
       (plane != kPlaneY && block.LeftAvailableChroma())) {
@@ -438,7 +421,7 @@
     const int column =
         block.column4x4 - 1 -
         static_cast<int>(subsampling_x != 0 && (block.column4x4 & 1) != 0);
-    if (UsesSmoothPrediction(row, column, plane)) return 1;
+    if (IsSmoothPrediction(row, column, plane)) return 1;
   }
   return 0;
 }
@@ -544,6 +527,8 @@
                             buffer_[plane].columns() / sizeof(Pixel),
                             reinterpret_cast<Pixel*>(&buffer_[plane][0][0]));
   for (int row = 0; row < tx_height; ++row) {
+    assert(block.bp->prediction_parameters
+               ->color_index_map[plane_type][y4 + row] != nullptr);
     for (int column = 0; column < tx_width; ++column) {
       buffer[start_y + row][start_x + column] =
           palette[block.bp->prediction_parameters
@@ -565,20 +550,20 @@
 void Tile::ChromaFromLumaPrediction(const Block& block, const Plane plane,
                                     const int start_x, const int start_y,
                                     const TransformSize tx_size) {
-  const int subsampling_x = SubsamplingX(plane);
-  const int subsampling_y = SubsamplingY(plane);
+  const int subsampling_x = subsampling_x_[plane];
+  const int subsampling_y = subsampling_y_[plane];
   const PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
-  const int luma_y = start_y << subsampling_y;
-  const int luma_x = start_x << subsampling_x;
   Array2DView<Pixel> y_buffer(
       buffer_[kPlaneY].rows(), buffer_[kPlaneY].columns() / sizeof(Pixel),
       reinterpret_cast<Pixel*>(&buffer_[kPlaneY][0][0]));
   if (!block.sb_buffer->cfl_luma_buffer_valid) {
+    const int luma_x = start_x << subsampling_x;
+    const int luma_y = start_y << subsampling_y;
     dsp_.cfl_subsamplers[tx_size][subsampling_x + subsampling_y](
         block.sb_buffer->cfl_luma_buffer,
-        prediction_parameters.max_luma_width - (start_x << subsampling_x),
-        prediction_parameters.max_luma_height - (start_y << subsampling_y),
+        prediction_parameters.max_luma_width - luma_x,
+        prediction_parameters.max_luma_height - luma_y,
         reinterpret_cast<uint8_t*>(&y_buffer[luma_y][luma_x]),
         buffer_[kPlaneY].columns());
     block.sb_buffer->cfl_luma_buffer_valid = true;
@@ -602,15 +587,15 @@
     const TransformSize tx_size);
 #endif
 
-void Tile::InterIntraPrediction(
+bool Tile::InterIntraPrediction(
     uint16_t* prediction[2], const ptrdiff_t prediction_stride,
     const uint8_t* const prediction_mask,
     const ptrdiff_t prediction_mask_stride,
     const PredictionParameters& prediction_parameters,
     const int prediction_width, const int prediction_height,
-    const int subsampling_x, const int subsampling_y,
-    const uint8_t post_round_bits, uint8_t* const dest,
+    const int subsampling_x, const int subsampling_y, uint8_t* const dest,
     const ptrdiff_t dest_stride) {
+  assert(prediction_mask != nullptr);
   assert(prediction_parameters.compound_prediction_type ==
              kCompoundPredictionTypeIntra ||
          prediction_parameters.compound_prediction_type ==
@@ -628,7 +613,7 @@
     if (!intra_prediction.Reset(prediction_height, prediction_width)) {
       LIBGAV1_DLOG(ERROR,
                    "Can't allocate memory for the intra prediction block.");
-      return;
+      return false;
     }
     uint8_t* dest_ptr = dest;
     for (int r = 0; r < prediction_height; ++r) {
@@ -643,70 +628,127 @@
     prediction[1] = reinterpret_cast<uint16_t*>(dest);
     intra_stride = dest_stride / sizeof(uint16_t);
   }
-  dsp_.mask_blend(prediction[0], prediction_stride, prediction[1], intra_stride,
-                  prediction_mask, prediction_mask_stride, prediction_width,
-                  prediction_height, subsampling_x, subsampling_y,
-                  prediction_parameters.inter_intra_mode != kNumInterIntraModes,
-                  prediction_parameters.is_wedge_inter_intra, post_round_bits,
-                  dest, dest_stride);
+  GetMaskBlendFunc(dsp_, prediction_parameters.inter_intra_mode,
+                   prediction_parameters.is_wedge_inter_intra, subsampling_x,
+                   subsampling_y)(prediction[0], prediction_stride,
+                                  prediction[1], intra_stride, prediction_mask,
+                                  prediction_mask_stride, prediction_width,
+                                  prediction_height, dest, dest_stride);
+  return true;
 }
 
-void Tile::CompoundInterPrediction(
-    const Block& block, uint16_t* prediction[2],
-    const ptrdiff_t prediction_stride, const ptrdiff_t prediction_mask_stride,
-    const int prediction_width, const int prediction_height, const Plane plane,
-    const int subsampling_x, const int subsampling_y, const int bitdepth,
-    const int candidate_row, const int candidate_column, uint8_t* dest,
-    const ptrdiff_t dest_stride, const uint8_t post_round_bits) {
+bool Tile::AllocatePredictionMask(SuperBlockBuffer* sb_buffer) {
+  if (sb_buffer->prediction_mask != nullptr) {
+    return true;
+  }
+  sb_buffer->prediction_mask = MakeAlignedUniquePtr<uint8_t>(
+      16, kMaxSuperBlockSizeInPixels * kMaxSuperBlockSizeInPixels);
+  if (sb_buffer->prediction_mask == nullptr) {
+    LIBGAV1_DLOG(ERROR, "Allocation of prediction_mask failed.");
+    return false;
+  }
+  return true;
+}
+
+bool Tile::CompoundInterPrediction(
+    const Block& block, const ptrdiff_t prediction_stride,
+    const ptrdiff_t prediction_mask_stride, const int prediction_width,
+    const int prediction_height, const Plane plane, const int subsampling_x,
+    const int subsampling_y, const int bitdepth, const int candidate_row,
+    const int candidate_column, uint8_t* dest, const ptrdiff_t dest_stride) {
   const PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
+  uint16_t* prediction[2] = {block.sb_buffer->prediction_buffer[0].get(),
+                             block.sb_buffer->prediction_buffer[1].get()};
   switch (prediction_parameters.compound_prediction_type) {
     case kCompoundPredictionTypeWedge:
-      dsp_.mask_blend(
+      GetMaskBlendFunc(dsp_, prediction_parameters.inter_intra_mode,
+                       prediction_parameters.is_wedge_inter_intra,
+                       subsampling_x, subsampling_y)(
           prediction[0], prediction_stride, prediction[1], prediction_stride,
-          block.sb_buffer->prediction_mask, prediction_mask_stride,
-          prediction_width, prediction_height, subsampling_x, subsampling_y,
-          prediction_parameters.inter_intra_mode != kNumInterIntraModes,
-          prediction_parameters.is_wedge_inter_intra, post_round_bits, dest,
-          dest_stride);
+          block.sb_buffer->prediction_mask.get(), prediction_mask_stride,
+          prediction_width, prediction_height, dest, dest_stride);
       break;
     case kCompoundPredictionTypeDiffWeighted:
       if (plane == kPlaneY) {
+        if (!AllocatePredictionMask(block.sb_buffer)) {
+          return false;
+        }
         GenerateWeightMask(
             prediction[0], prediction_stride, prediction[1], prediction_stride,
-            prediction_parameters.mask_is_inverse, post_round_bits,
-            prediction_width, prediction_height, bitdepth,
-            block.sb_buffer->prediction_mask, prediction_mask_stride);
+            prediction_parameters.mask_is_inverse, prediction_width,
+            prediction_height, bitdepth, block.sb_buffer->prediction_mask.get(),
+            prediction_mask_stride);
       }
-      dsp_.mask_blend(
+      GetMaskBlendFunc(dsp_, prediction_parameters.inter_intra_mode,
+                       prediction_parameters.is_wedge_inter_intra,
+                       subsampling_x, subsampling_y)(
           prediction[0], prediction_stride, prediction[1], prediction_stride,
-          block.sb_buffer->prediction_mask, prediction_mask_stride,
-          prediction_width, prediction_height, subsampling_x, subsampling_y,
-          prediction_parameters.inter_intra_mode != kNumInterIntraModes,
-          prediction_parameters.is_wedge_inter_intra, post_round_bits, dest,
-          dest_stride);
+          block.sb_buffer->prediction_mask.get(), prediction_mask_stride,
+          prediction_width, prediction_height, dest, dest_stride);
       break;
     case kCompoundPredictionTypeDistance:
       DistanceWeightedPrediction(
           prediction[0], prediction_stride, prediction[1], prediction_stride,
           prediction_width, prediction_height, candidate_row, candidate_column,
-          dest, dest_stride, post_round_bits);
+          dest, dest_stride);
       break;
     case kCompoundPredictionTypeAverage:
       dsp_.average_blend(prediction[0], prediction_stride, prediction[1],
-                         prediction_stride, post_round_bits, prediction_width,
-                         prediction_height, dest, dest_stride);
+                         prediction_stride, prediction_width, prediction_height,
+                         dest, dest_stride);
       break;
     default:
       assert(false && "This is not a compound type.\n");
-      return;
+      return false;
   }
+  return true;
 }
 
-bool Tile::InterPrediction(const Block& block, Plane plane, int x, int y,
-                           int prediction_width, int prediction_height,
-                           int candidate_row, int candidate_column,
-                           bool* const is_local_valid,
+GlobalMotion* Tile::GetWarpParams(
+    const Block& block, const Plane plane, const int prediction_width,
+    const int prediction_height,
+    const PredictionParameters& prediction_parameters,
+    const ReferenceFrameType reference_type, bool* const is_local_valid,
+    GlobalMotion* const global_motion_params,
+    GlobalMotion* const local_warp_params) const {
+  if (prediction_width < 8 || prediction_height < 8 ||
+      frame_header_.force_integer_mv == 1) {
+    return nullptr;
+  }
+  if (plane == kPlaneY) {
+    *is_local_valid =
+        prediction_parameters.motion_mode == kMotionModeLocalWarp &&
+        WarpEstimation(
+            prediction_parameters.num_warp_samples, DivideBy4(prediction_width),
+            DivideBy4(prediction_height), block.row4x4, block.column4x4,
+            block.bp->mv[0], prediction_parameters.warp_estimate_candidates,
+            local_warp_params) &&
+        SetupShear(local_warp_params);
+  }
+  if (prediction_parameters.motion_mode == kMotionModeLocalWarp &&
+      *is_local_valid) {
+    return local_warp_params;
+  }
+  if (!IsScaled(reference_type)) {
+    GlobalMotionTransformationType global_motion_type =
+        (reference_type != kReferenceFrameIntra)
+            ? global_motion_params->type
+            : kNumGlobalMotionTransformationTypes;
+    const bool is_global_valid =
+        IsGlobalMvBlock(block.bp->y_mode, global_motion_type, block.size) &&
+        SetupShear(global_motion_params);
+    // Valid global motion type implies reference type can't be intra.
+    assert(!is_global_valid || reference_type != kReferenceFrameIntra);
+    if (is_global_valid) return global_motion_params;
+  }
+  return nullptr;
+}
+
+bool Tile::InterPrediction(const Block& block, const Plane plane, const int x,
+                           const int y, const int prediction_width,
+                           const int prediction_height, int candidate_row,
+                           int candidate_column, bool* const is_local_valid,
                            GlobalMotion* const local_warp_params) {
   const int bitdepth = sequence_header_.color_config.bitdepth;
   const BlockParameters& bp = *block.bp;
@@ -716,73 +758,54 @@
       bp_reference.reference_frame[1] > kReferenceFrameIntra;
   const bool is_inter_intra =
       bp.is_inter && bp.reference_frame[1] == kReferenceFrameIntra;
-  const ptrdiff_t prediction_stride = kMaxSuperBlockSizeInPixels;
-  AlignedUniquePtr<uint16_t> prediction[2] = {
-      MakeAlignedUniquePtr<uint16_t>(
-          8, kMaxSuperBlockSizeInPixels * prediction_stride),
-      AlignedUniquePtr<uint16_t>()};
-  if (prediction[0] == nullptr) {
-    LIBGAV1_DLOG(ERROR,
-                 "Can't allocate memory for the first prediction block.");
-    return false;
+  const ptrdiff_t prediction_stride = prediction_width;
+
+  const size_t prediction_buffer_size = prediction_height * prediction_stride;
+  if (block.sb_buffer->prediction_buffer_size[0] < prediction_buffer_size) {
+    block.sb_buffer->prediction_buffer[0] =
+        MakeAlignedUniquePtr<uint16_t>(8, prediction_buffer_size);
+    if (block.sb_buffer->prediction_buffer[0] == nullptr) {
+      block.sb_buffer->prediction_buffer_size[0] = 0;
+      LIBGAV1_DLOG(ERROR,
+                   "Can't allocate memory for the first prediction block.");
+      return false;
+    }
+    block.sb_buffer->prediction_buffer_size[0] = prediction_buffer_size;
   }
-  if (is_compound) {
-    prediction[1] = MakeAlignedUniquePtr<uint16_t>(
-        8, kMaxSuperBlockSizeInPixels * prediction_stride);
-    if (prediction[1] == nullptr) {
+  if (is_compound &&
+      block.sb_buffer->prediction_buffer_size[1] < prediction_buffer_size) {
+    block.sb_buffer->prediction_buffer[1] =
+        MakeAlignedUniquePtr<uint16_t>(8, prediction_buffer_size);
+    if (block.sb_buffer->prediction_buffer[1] == nullptr) {
+      block.sb_buffer->prediction_buffer_size[1] = 0;
       LIBGAV1_DLOG(ERROR,
                    "Can't allocate memory for the second prediction block.");
       return false;
     }
+    block.sb_buffer->prediction_buffer_size[1] = prediction_buffer_size;
   }
+
+  const PredictionParameters& prediction_parameters =
+      *block.bp->prediction_parameters;
   uint8_t* const dest = GetStartPoint(buffer_, plane, x, y, bitdepth);
   const ptrdiff_t dest_stride = buffer_[plane].columns();  // In bytes.
   uint8_t round_bits[2];
-  uint8_t post_round_bits;
-  SetInterRoundingBits(is_compound, sequence_header_.color_config.bitdepth,
-                       round_bits, &post_round_bits);
-  const PredictionParameters& prediction_parameters =
-      *block.bp->prediction_parameters;
-  if (plane == kPlaneY) {
-    *is_local_valid =
-        prediction_parameters.motion_mode == kMotionModeLocalWarp &&
-        WarpEstimation(
-            prediction_parameters.num_warp_samples, DivideBy4(prediction_width),
-            DivideBy4(prediction_height), block.row4x4, block.column4x4,
-            bp.mv[0], prediction_parameters.warp_estimate_candidates,
-            local_warp_params) &&
-        SetupShear(local_warp_params);
-  }
-
+  GetInterRoundingBits(is_compound, sequence_header_.color_config.bitdepth,
+                       round_bits);
   for (int index = 0; index < 1 + static_cast<int>(is_compound); ++index) {
     const ReferenceFrameType reference_type =
         bp_reference.reference_frame[index];
     GlobalMotion global_motion_params =
         frame_header_.global_motion[reference_type];
-    GlobalMotionTransformationType global_motion_type =
-        (reference_type != kReferenceFrameIntra)
-            ? global_motion_params.type
-            : kNumGlobalMotionTransformationTypes;
-    const bool is_global_valid =
-        IsGlobalMvBlock(bp.y_mode, global_motion_type, block.size) &&
-        SetupShear(&global_motion_params);
-    // Valid global motion type implies reference type can't be intra.
-    assert(!is_global_valid || reference_type != kReferenceFrameIntra);
-    GlobalMotion* warp_params = nullptr;
-    if (prediction_width < 8 || prediction_height < 8 ||
-        frame_header_.force_integer_mv == 1) {
-      warp_params = nullptr;
-    } else if (prediction_parameters.motion_mode == kMotionModeLocalWarp &&
-               *is_local_valid) {
-      warp_params = local_warp_params;
-    } else if (is_global_valid && !IsScaled(reference_type)) {
-      warp_params = &global_motion_params;
-    }
+    GlobalMotion* warp_params =
+        GetWarpParams(block, plane, prediction_width, prediction_height,
+                      prediction_parameters, reference_type, is_local_valid,
+                      &global_motion_params, local_warp_params);
     if (warp_params != nullptr) {
-      if (!BlockWarpProcess(block, plane, index, prediction_width,
-                            prediction_height, prediction[index].get(),
-                            prediction_stride, warp_params, round_bits,
-                            is_compound, is_inter_intra, dest, dest_stride)) {
+      if (!BlockWarpProcess(block, plane, index, x, y, prediction_width,
+                            prediction_height, prediction_stride, warp_params,
+                            round_bits, is_compound, is_inter_intra, dest,
+                            dest_stride)) {
         return false;
       }
     } else {
@@ -791,17 +814,17 @@
               ? -1
               : frame_header_.reference_frame_index[reference_type -
                                                     kReferenceFrameLast];
-      BlockInterPrediction(plane, reference_index, bp_reference.mv[index], x, y,
-                           prediction_width, prediction_height, candidate_row,
-                           candidate_column, prediction[index].get(),
-                           prediction_stride, round_bits, is_compound,
-                           is_inter_intra, dest, dest_stride);
+      BlockInterPrediction(
+          plane, reference_index, bp_reference.mv[index], x, y,
+          prediction_width, prediction_height, candidate_row, candidate_column,
+          block.sb_buffer->prediction_buffer[index].get(), prediction_stride,
+          round_bits, is_compound, is_inter_intra, dest, dest_stride);
     }
   }
 
   const ptrdiff_t prediction_mask_stride = kMaxSuperBlockSizeInPixels;
-  const int subsampling_x = SubsamplingX(plane);
-  const int subsampling_y = SubsamplingY(plane);
+  const int subsampling_x = subsampling_x_[plane];
+  const int subsampling_y = subsampling_y_[plane];
   if (prediction_parameters.compound_prediction_type ==
           kCompoundPredictionTypeWedge &&
       plane == kPlaneY) {
@@ -815,29 +838,32 @@
     const int offset = block_size_index * wedge_mask_stride_3 +
                        prediction_parameters.wedge_sign * wedge_mask_stride_2 +
                        prediction_parameters.wedge_index * wedge_mask_stride_1;
-    const int stride = kWedgeMaskMasterSize;
-    uint8_t* mask_ptr = block.sb_buffer->prediction_mask;
-    const uint8_t* wedge_mask_ptr = &wedge_masks_[offset];
-    for (int y = 0; y < prediction_height; ++y) {
-      memcpy(mask_ptr, wedge_mask_ptr, prediction_width);
-      mask_ptr += kMaxSuperBlockSizeInPixels;
-      wedge_mask_ptr += stride;
+    if (!AllocatePredictionMask(block.sb_buffer)) {
+      return false;
     }
+    PopulatePredictionMaskFromWedgeMask(
+        &wedge_masks_[offset], kWedgeMaskMasterSize, prediction_width,
+        prediction_height, block.sb_buffer->prediction_mask.get(),
+        kMaxSuperBlockSizeInPixels);
   } else if (prediction_parameters.compound_prediction_type ==
              kCompoundPredictionTypeIntra) {
+    if (plane == kPlaneY) {
+      if (!AllocatePredictionMask(block.sb_buffer)) {
+        return false;
+      }
+    }
     GenerateInterIntraMask(prediction_parameters.inter_intra_mode,
                            prediction_width, prediction_height,
-                           block.sb_buffer->prediction_mask,
+                           block.sb_buffer->prediction_mask.get(),
                            prediction_mask_stride);
   }
 
-  uint16_t* prediction_ptr[2] = {prediction[0].get(), prediction[1].get()};
+  bool ok = true;
   if (is_compound) {
-    CompoundInterPrediction(
-        block, prediction_ptr, prediction_stride, prediction_mask_stride,
-        prediction_width, prediction_height, plane, subsampling_x,
-        subsampling_y, bitdepth, candidate_row, candidate_column, dest,
-        dest_stride, post_round_bits);
+    ok = CompoundInterPrediction(
+        block, prediction_stride, prediction_mask_stride, prediction_width,
+        prediction_height, plane, subsampling_x, subsampling_y, bitdepth,
+        candidate_row, candidate_column, dest, dest_stride);
   } else {
     if (prediction_parameters.motion_mode == kMotionModeObmc) {
       // Obmc mode is allowed only for single reference (!is_compound).
@@ -845,14 +871,17 @@
                      round_bits);
     } else if (is_inter_intra) {
       // InterIntra and obmc must be mutually exclusive.
-      InterIntraPrediction(prediction_ptr, prediction_stride,
-                           block.sb_buffer->prediction_mask,
-                           prediction_mask_stride, prediction_parameters,
-                           prediction_width, prediction_height, subsampling_x,
-                           subsampling_y, post_round_bits, dest, dest_stride);
+      uint16_t* prediction_ptr[2] = {
+          block.sb_buffer->prediction_buffer[0].get(),
+          block.sb_buffer->prediction_buffer[1].get()};
+      ok = InterIntraPrediction(
+          prediction_ptr, prediction_stride,
+          block.sb_buffer->prediction_mask.get(), prediction_mask_stride,
+          prediction_parameters, prediction_width, prediction_height,
+          subsampling_x, subsampling_y, dest, dest_stride);
     }
   }
-  return true;
+  return ok;
 }
 
 void Tile::ObmcBlockPrediction(const MotionVector& mv, const Plane plane,
@@ -860,41 +889,37 @@
                                const int height, const int x, const int y,
                                const int candidate_row,
                                const int candidate_column,
-                               const uint8_t* const mask,
-                               const int blending_direction,
+                               const ObmcDirection blending_direction,
                                const uint8_t* const round_bits) {
-  uint16_t
-      obmc_prediction[kMaxSuperBlockSizeInPixels *
-                      (kMaxSuperBlockSizeInPixels + 2 * kRestorationBorder)];
-  const int obmc_prediction_stride = width + 2 * kRestorationBorder;
   const int bitdepth = sequence_header_.color_config.bitdepth;
   // Obmc's prediction needs to be clipped before blending with above/left
   // prediction blocks.
-  uint8_t obmc_clipped_prediction[2 * kMaxSuperBlockSizeInPixels *
-                                  (kMaxSuperBlockSizeInPixels +
-                                   2 * kRestorationBorder)];
+  uint8_t obmc_clipped_prediction[kObmcBufferSize
+#if LIBGAV1_MAX_BITDEPTH >= 10
+                                  * 2
+#endif
+  ];
   const ptrdiff_t obmc_clipped_prediction_stride =
-      (bitdepth == 8) ? obmc_prediction_stride
-                      : obmc_prediction_stride * sizeof(uint16_t);
+      (bitdepth == 8) ? width : width * sizeof(uint16_t);
   BlockInterPrediction(plane, reference_frame_index, mv, x, y, width, height,
-                       candidate_row, candidate_column, obmc_prediction,
-                       obmc_prediction_stride, round_bits, false, false,
-                       obmc_clipped_prediction, obmc_clipped_prediction_stride);
+                       candidate_row, candidate_column, nullptr, width,
+                       round_bits, false, false, obmc_clipped_prediction,
+                       obmc_clipped_prediction_stride);
 
   uint8_t* const prediction = GetStartPoint(buffer_, plane, x, y, bitdepth);
   const ptrdiff_t prediction_stride = buffer_[plane].columns();
-  dsp_.obmc_blend(prediction, prediction_stride, width, height,
-                  blending_direction, mask, obmc_clipped_prediction,
-                  obmc_clipped_prediction_stride);
+  dsp_.obmc_blend[blending_direction](prediction, prediction_stride, width,
+                                      height, obmc_clipped_prediction,
+                                      obmc_clipped_prediction_stride);
 }
 
 void Tile::ObmcPrediction(const Block& block, const Plane plane,
                           const int width, const int height,
                           const uint8_t* const round_bits) {
-  const int subsampling_x = SubsamplingX(plane);
-  const int subsampling_y = SubsamplingY(plane);
+  const int subsampling_x = subsampling_x_[plane];
+  const int subsampling_y = subsampling_y_[plane];
   const BlockSize plane_block_size =
-      kPlaneResidualSize[block.size][SubsamplingX(plane)][SubsamplingY(plane)];
+      kPlaneResidualSize[block.size][subsampling_x][subsampling_y];
   assert(plane_block_size != kBlockInvalid);
   const int num4x4_wide = kNum4x4BlocksWide[block.size];
   const int num4x4_high = kNum4x4BlocksHigh[block.size];
@@ -906,6 +931,7 @@
     const int candidate_row = block.row4x4 - 1;
     const int block_start_y = MultiplyBy4(block.row4x4) >> subsampling_y;
     int column4x4 = block.column4x4;
+    const int prediction_height = std::min(height >> 1, 32 >> subsampling_y);
     for (int i = 0, step; i < num_limit && column4x4 < column4x4_max;
          column4x4 += step) {
       const int candidate_column = column4x4 | 1;
@@ -920,14 +946,12 @@
                                                 kReferenceFrameLast];
         const int prediction_width =
             std::min(width, MultiplyBy4(step) >> subsampling_x);
-        const int prediction_height =
-            std::min(height >> 1, 32 >> subsampling_y);
-        const uint8_t* mask = GetObmcMask(prediction_height);
-        ObmcBlockPrediction(
-            bp_top.mv[0], plane, candidate_reference_frame_index,
-            prediction_width, prediction_height,
-            MultiplyBy4(column4x4) >> subsampling_x, block_start_y,
-            candidate_row, candidate_column, mask, kBlendFromAbove, round_bits);
+        ObmcBlockPrediction(bp_top.mv[0], plane,
+                            candidate_reference_frame_index, prediction_width,
+                            prediction_height,
+                            MultiplyBy4(column4x4) >> subsampling_x,
+                            block_start_y, candidate_row, candidate_column,
+                            kObmcDirectionVertical, round_bits);
       }
     }
   }
@@ -939,6 +963,7 @@
     const int candidate_column = block.column4x4 - 1;
     int row4x4 = block.row4x4;
     const int block_start_x = MultiplyBy4(block.column4x4) >> subsampling_x;
+    const int prediction_width = std::min(width >> 1, 32 >> subsampling_x);
     for (int i = 0, step; i < num_limit && row4x4 < row4x4_max;
          row4x4 += step) {
       const int candidate_row = row4x4 | 1;
@@ -951,15 +976,13 @@
         const int candidate_reference_frame_index =
             frame_header_.reference_frame_index[bp_left.reference_frame[0] -
                                                 kReferenceFrameLast];
-        const int prediction_width = std::min(width >> 1, 32 >> subsampling_x);
         const int prediction_height =
             std::min(height, MultiplyBy4(step) >> subsampling_y);
-        const uint8_t* mask = GetObmcMask(prediction_width);
-        ObmcBlockPrediction(bp_left.mv[0], plane,
-                            candidate_reference_frame_index, prediction_width,
-                            prediction_height, block_start_x,
-                            MultiplyBy4(row4x4) >> subsampling_y, candidate_row,
-                            candidate_column, mask, kBlendFromLeft, round_bits);
+        ObmcBlockPrediction(
+            bp_left.mv[0], plane, candidate_reference_frame_index,
+            prediction_width, prediction_height, block_start_x,
+            MultiplyBy4(row4x4) >> subsampling_y, candidate_row,
+            candidate_column, kObmcDirectionHorizontal, round_bits);
       }
     }
   }
@@ -969,7 +992,7 @@
     uint16_t* prediction_0, ptrdiff_t prediction_stride_0,
     uint16_t* prediction_1, ptrdiff_t prediction_stride_1, const int width,
     const int height, const int candidate_row, const int candidate_column,
-    uint8_t* dest, ptrdiff_t dest_stride, const uint8_t post_round_bits) {
+    uint8_t* dest, ptrdiff_t dest_stride) {
   int distance[2];
   int weight[2];
   for (int reference = 0; reference < 2; ++reference) {
@@ -987,9 +1010,9 @@
   }
   GetDistanceWeights(distance, weight);
 
-  dsp_.distance_weighted_blend(
-      prediction_0, prediction_stride_0, prediction_1, prediction_stride_1,
-      weight[0], weight[1], post_round_bits, width, height, dest, dest_stride);
+  dsp_.distance_weighted_blend(prediction_0, prediction_stride_0, prediction_1,
+                               prediction_stride_1, weight[0], weight[1], width,
+                               height, dest, dest_stride);
 }
 
 bool Tile::GetReferenceBlockPosition(
@@ -1011,12 +1034,11 @@
   *ref_block_end_y =
       GetPixelPositionFromHighScale(start_y, step_y, height - 1) +
       kConvolveBorderRightBottom;
-  int block_height =
-      height + kConvolveBorderLeftTop + kConvolveBorderRightBottom;
   if (is_scaled) {
-    block_height = (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
-                    kScaleSubPixelBits) +
-                   kSubPixelTaps;
+    const int block_height =
+        (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
+         kScaleSubPixelBits) +
+        kSubPixelTaps;
     *ref_block_end_y = *ref_block_start_y + block_height - 1;
   }
   const bool extend_left = *ref_block_start_x < ref_start_x;
@@ -1116,12 +1138,12 @@
   int step_y;
   ScaleMotionVector(mv, plane, reference_frame_index, x, y, &start_x, &start_y,
                     &step_x, &step_y);
-  // reference_frame_index equal to -1 indicates using current frame as
-  // reference.
   const int horizontal_filter_index = bp.interpolation_filter[1];
   const int vertical_filter_index = bp.interpolation_filter[0];
-  const int subsampling_x = SubsamplingX(plane);
-  const int subsampling_y = SubsamplingY(plane);
+  const int subsampling_x = subsampling_x_[plane];
+  const int subsampling_y = subsampling_y_[plane];
+  // reference_frame_index equal to -1 indicates using current frame as
+  // reference.
   const YuvBuffer* const reference_buffer =
       (reference_frame_index == -1)
           ? current_frame_.buffer()
@@ -1223,6 +1245,7 @@
       (is_compound || is_inter_intra) ? prediction : static_cast<void*>(dest);
   const ptrdiff_t output_stride =
       (is_compound || is_inter_intra) ? prediction_stride : dest_stride;
+  assert(output != nullptr);
   dsp::ConvolveFunc convolve_func =
       is_scaled ? dsp_.convolve_scale[is_compound || is_inter_intra]
                 : dsp_.convolve[reference_frame_index == -1][is_compound]
@@ -1233,17 +1256,19 @@
     convolve_func = dsp_.convolve[0][1][1][1];
   }
   convolve_func(block_start, block_stride, horizontal_filter_index,
-                vertical_filter_index, round_bits, start_x, start_y, step_x,
+                vertical_filter_index, round_bits[1], start_x, start_y, step_x,
                 step_y, width, height, output, output_stride);
 }
 
 bool Tile::BlockWarpProcess(const Block& block, const Plane plane,
-                            const int index, const int width, const int height,
-                            uint16_t* prediction, ptrdiff_t prediction_stride,
-                            GlobalMotion* warp_params,
+                            const int index, const int block_start_x,
+                            const int block_start_y, const int width,
+                            const int height, const ptrdiff_t prediction_stride,
+                            GlobalMotion* const warp_params,
                             const uint8_t* const round_bits,
                             const bool is_compound, const bool is_inter_intra,
                             uint8_t* const dest, const ptrdiff_t dest_stride) {
+  assert(width >= 8 && height >= 8);
   const BlockParameters& bp = *block.bp;
   const int reference_frame_index =
       frame_header_.reference_frame_index[bp.reference_frame[index] -
@@ -1258,16 +1283,10 @@
   const int source_height =
       reference_frames_[reference_frame_index]->buffer()->displayed_height(
           plane);
-  const int block_start_x = MultiplyBy4(block.column4x4) >> SubsamplingX(plane);
-  const int block_start_y = MultiplyBy4(block.row4x4) >> SubsamplingY(plane);
-  const bool warp_valid = SetupShear(warp_params);
-  if (!warp_valid) {
-    LIBGAV1_DLOG(ERROR, "Invalid warp parameters.");
-    return false;
-  }
+  uint16_t* const prediction = block.sb_buffer->prediction_buffer[index].get();
   dsp_.warp(source, source_stride, source_width, source_height,
-            warp_params->params, SubsamplingX(plane), SubsamplingY(plane),
-            round_bits, block_start_x, block_start_y, width, height,
+            warp_params->params, subsampling_x_[plane], subsampling_y_[plane],
+            round_bits[1], block_start_x, block_start_y, width, height,
             warp_params->alpha, warp_params->beta, warp_params->gamma,
             warp_params->delta, prediction, prediction_stride);
   if (!is_compound && !is_inter_intra) {
diff --git a/libgav1/src/tile/tile.cc b/libgav1/src/tile/tile.cc
index 9a5c812..378410d 100644
--- a/libgav1/src/tile/tile.cc
+++ b/libgav1/src/tile/tile.cc
@@ -7,20 +7,23 @@
 #include <cstring>
 #include <memory>
 #include <new>
+#include <numeric>
 #include <type_traits>
 #include <utility>
 
 #include "src/motion_vector.h"
 #include "src/reconstruction.h"
+#include "src/utils/bit_mask_set.h"
 #include "src/utils/logging.h"
-#include "src/utils/scan.h"
 #include "src/utils/segmentation.h"
+#include "src/utils/stack.h"
 
 namespace libgav1 {
 namespace {
 
 // Import all the constants in the anonymous namespace.
 #include "src/quantizer_tables.inc"
+#include "src/scan_tables.inc"
 
 // Precision bits when scaling reference frames.
 constexpr int kReferenceScaleShift = 14;
@@ -32,36 +35,63 @@
     kQuantizerCoefficientBaseRange + kNumQuantizerBaseLevels + 1;
 constexpr int kCoeffBaseRangeMaxIterations =
     kQuantizerCoefficientBaseRange / (kCoeffBaseRangeSymbolCount - 1);
+constexpr int kEntropyContextLeft = 0;
+constexpr int kEntropyContextTop = 1;
 
-const uint8_t kAllZeroContextsByMinMax[5][5] = {{1, 2, 2, 2, 3},
-                                                {1, 4, 4, 4, 5},
-                                                {1, 4, 4, 4, 5},
-                                                {1, 4, 4, 4, 5},
-                                                {1, 4, 4, 4, 6}};
+constexpr uint8_t kAllZeroContextsByTopLeft[5][5] = {{1, 2, 2, 2, 3},
+                                                     {2, 4, 4, 4, 5},
+                                                     {2, 4, 4, 4, 5},
+                                                     {2, 4, 4, 4, 5},
+                                                     {3, 5, 5, 5, 6}};
 
 // The space complexity of DFS is O(branching_factor * max_depth). For the
 // parameter tree, branching_factor = 4 (there could be up to 4 children for
-// every node) and max_depth = 6 (to go from a 128x128 block all the way to a
-// 4x4 block).
-constexpr int kDfsStackSize = 24;
+// every node) and max_depth (excluding the root) = 5 (to go from a 128x128
+// block all the way to a 4x4 block). The worse-case stack size is 16, by
+// counting the number of 'o' nodes in the diagram:
+//
+//   |                    128x128  The highest level (corresponding to the
+//   |                             root of the tree) has no node in the stack.
+//   |-----------------+
+//   |     |     |     |
+//   |     o     o     o  64x64
+//   |
+//   |-----------------+
+//   |     |     |     |
+//   |     o     o     o  32x32    Higher levels have three nodes in the stack,
+//   |                             because we pop one node off the stack before
+//   |-----------------+           pushing its four children onto the stack.
+//   |     |     |     |
+//   |     o     o     o  16x16
+//   |
+//   |-----------------+
+//   |     |     |     |
+//   |     o     o     o  8x8
+//   |
+//   |-----------------+
+//   |     |     |     |
+//   o     o     o     o  4x4      Only the lowest level has four nodes in the
+//                                 stack.
+constexpr int kDfsStackSize = 16;
 
 // Mask indicating whether the transform sets contain a particular transform
 // type. If |tx_type| is present in |tx_set|, then the |tx_type|th LSB is set.
-constexpr uint16_t kTransformTypeInSetMask[kNumTransformSets] = {
-    0x1, 0xE0F, 0x20F, 0xFFFF, 0xFFF, 0x201};
+constexpr BitMaskSet kTransformTypeInSetMask[kNumTransformSets] = {
+    BitMaskSet(0x1),    BitMaskSet(0xE0F), BitMaskSet(0x20F),
+    BitMaskSet(0xFFFF), BitMaskSet(0xFFF), BitMaskSet(0x201)};
 
-const PredictionMode
+constexpr PredictionMode
     kFilterIntraModeToIntraPredictor[kNumFilterIntraPredictors] = {
         kPredictionModeDc, kPredictionModeVertical, kPredictionModeHorizontal,
         kPredictionModeD157, kPredictionModeDc};
 
 // This is computed as:
 // min(transform_width_log2, 5) + min(transform_height_log2, 5) - 4.
-const uint8_t kEobMultiSizeLookup[kNumTransformSizes] = {
+constexpr uint8_t kEobMultiSizeLookup[kNumTransformSizes] = {
     0, 1, 2, 1, 2, 3, 4, 2, 3, 4, 5, 5, 4, 5, 6, 6, 5, 6, 6};
 
 /* clang-format off */
-const uint8_t kCoeffBaseContextOffset[kNumTransformSizes][5][5] = {
+constexpr uint8_t kCoeffBaseContextOffset[kNumTransformSizes][5][5] = {
     {{0, 1, 6, 6, 0}, {1, 6, 6, 21, 0}, {6, 6, 21, 21, 0}, {6, 21, 21, 21, 0},
      {0, 0, 0, 0, 0}},
     {{0, 11, 11, 11, 0}, {11, 11, 11, 11, 0}, {6, 6, 21, 21, 0},
@@ -102,9 +132,9 @@
      {6, 21, 21, 21, 21}, {21, 21, 21, 21, 21}}};
 /* clang-format on */
 
-const uint8_t kCoeffBasePositionContextOffset[3] = {26, 31, 36};
+constexpr uint8_t kCoeffBasePositionContextOffset[3] = {26, 31, 36};
 
-const PredictionMode kInterIntraToIntraMode[kNumInterIntraModes] = {
+constexpr PredictionMode kInterIntraToIntraMode[kNumInterIntraModes] = {
     kPredictionModeDc, kPredictionModeVertical, kPredictionModeHorizontal,
     kPredictionModeSmooth};
 
@@ -127,59 +157,100 @@
     {kNumTransformSizes, kNumTransformSizes, kTransformSize64x16,
      kTransformSize64x32, kTransformSize64x64}};
 
+// Defined in section 9.3 of the spec.
+constexpr TransformType kModeToTransformType[kIntraPredictionModesUV] = {
+    kTransformTypeDctDct,   kTransformTypeDctAdst,  kTransformTypeAdstDct,
+    kTransformTypeDctDct,   kTransformTypeAdstAdst, kTransformTypeDctAdst,
+    kTransformTypeAdstDct,  kTransformTypeAdstDct,  kTransformTypeDctAdst,
+    kTransformTypeAdstAdst, kTransformTypeDctAdst,  kTransformTypeAdstDct,
+    kTransformTypeAdstAdst, kTransformTypeDctDct};
+
+// Defined in section 5.11.47 of the spec. This array does not contain an entry
+// for kTransformSetDctOnly, so the first dimension needs to be
+// |kNumTransformSets| - 1.
+constexpr TransformType kInverseTransformTypeBySet[kNumTransformSets - 1][16] =
+    {{kTransformTypeIdentityIdentity, kTransformTypeDctDct,
+      kTransformTypeIdentityDct, kTransformTypeDctIdentity,
+      kTransformTypeAdstAdst, kTransformTypeDctAdst, kTransformTypeAdstDct},
+     {kTransformTypeIdentityIdentity, kTransformTypeDctDct,
+      kTransformTypeAdstAdst, kTransformTypeDctAdst, kTransformTypeAdstDct},
+     {kTransformTypeIdentityIdentity, kTransformTypeIdentityDct,
+      kTransformTypeDctIdentity, kTransformTypeIdentityAdst,
+      kTransformTypeAdstIdentity, kTransformTypeIdentityFlipadst,
+      kTransformTypeFlipadstIdentity, kTransformTypeDctDct,
+      kTransformTypeDctAdst, kTransformTypeAdstDct, kTransformTypeDctFlipadst,
+      kTransformTypeFlipadstDct, kTransformTypeAdstAdst,
+      kTransformTypeFlipadstFlipadst, kTransformTypeFlipadstAdst,
+      kTransformTypeAdstFlipadst},
+     {kTransformTypeIdentityIdentity, kTransformTypeIdentityDct,
+      kTransformTypeDctIdentity, kTransformTypeDctDct, kTransformTypeDctAdst,
+      kTransformTypeAdstDct, kTransformTypeDctFlipadst,
+      kTransformTypeFlipadstDct, kTransformTypeAdstAdst,
+      kTransformTypeFlipadstFlipadst, kTransformTypeFlipadstAdst,
+      kTransformTypeAdstFlipadst},
+     {kTransformTypeIdentityIdentity, kTransformTypeDctDct}};
+
+// Replaces all occurrences of 64x* and *x64 with 32x* and *x32 respectively.
+constexpr TransformSize kAdjustedTransformSize[kNumTransformSizes] = {
+    kTransformSize4x4,   kTransformSize4x8,   kTransformSize4x16,
+    kTransformSize8x4,   kTransformSize8x8,   kTransformSize8x16,
+    kTransformSize8x32,  kTransformSize16x4,  kTransformSize16x8,
+    kTransformSize16x16, kTransformSize16x32, kTransformSize16x32,
+    kTransformSize32x8,  kTransformSize32x16, kTransformSize32x32,
+    kTransformSize32x32, kTransformSize32x16, kTransformSize32x32,
+    kTransformSize32x32};
+
+// ith entry of this array is computed as:
+// DivideBy2(TransformSizeToSquareTransformIndex(kTransformSizeSquareMin[i]) +
+//           TransformSizeToSquareTransformIndex(kTransformSizeSquareMax[i]) +
+//           1)
+constexpr uint8_t kTransformSizeContext[kNumTransformSizes] = {
+    0, 1, 1, 1, 1, 2, 2, 1, 2, 2, 3, 3, 2, 3, 3, 4, 3, 4, 4};
+
+constexpr int8_t kSgrProjDefaultMultiplier[2] = {-32, 31};
+
+constexpr int8_t kWienerDefaultFilter[3] = {3, -7, 15};
+
 // Maps compound prediction modes into single modes. For e.g.
 // kPredictionModeNearestNewMv will map to kPredictionModeNearestMv for index 0
 // and kPredictionModeNewMv for index 1. It is used to simplify the logic in
 // AssignMv (and avoid duplicate code). This is section 5.11.30. in the spec.
-inline PredictionMode GetSinglePredictionMode(int index,
-                                              PredictionMode y_mode) {
-  if (index == 0) {
-    if (y_mode < kPredictionModeNearestNearestMv) {
-      return y_mode;
-    }
-    if (y_mode == kPredictionModeNewNewMv ||
-        y_mode == kPredictionModeNewNearestMv ||
-        y_mode == kPredictionModeNewNearMv) {
-      return kPredictionModeNewMv;
-    }
-    if (y_mode == kPredictionModeNearestNearestMv ||
-        y_mode == kPredictionModeNearestNewMv) {
-      return kPredictionModeNearestMv;
-    }
-    if (y_mode == kPredictionModeNearNearMv ||
-        y_mode == kPredictionModeNearNewMv) {
-      return kPredictionModeNearMv;
-    }
-    return kPredictionModeGlobalMv;
+constexpr PredictionMode
+    kCompoundToSinglePredictionMode[kNumCompoundInterPredictionModes][2] = {
+        {kPredictionModeNearestMv, kPredictionModeNearestMv},
+        {kPredictionModeNearMv, kPredictionModeNearMv},
+        {kPredictionModeNearestMv, kPredictionModeNewMv},
+        {kPredictionModeNewMv, kPredictionModeNearestMv},
+        {kPredictionModeNearMv, kPredictionModeNewMv},
+        {kPredictionModeNewMv, kPredictionModeNearMv},
+        {kPredictionModeGlobalMv, kPredictionModeGlobalMv},
+        {kPredictionModeNewMv, kPredictionModeNewMv},
+};
+PredictionMode GetSinglePredictionMode(int index, PredictionMode y_mode) {
+  if (y_mode < kPredictionModeNearestNearestMv) {
+    return y_mode;
   }
-  if (y_mode == kPredictionModeNewNewMv ||
-      y_mode == kPredictionModeNearestNewMv ||
-      y_mode == kPredictionModeNearNewMv) {
-    return kPredictionModeNewMv;
-  }
-  if (y_mode == kPredictionModeNearestNearestMv ||
-      y_mode == kPredictionModeNewNearestMv) {
-    return kPredictionModeNearestMv;
-  }
-  if (y_mode == kPredictionModeNearNearMv ||
-      y_mode == kPredictionModeNewNearMv) {
-    return kPredictionModeNearMv;
-  }
-  return kPredictionModeGlobalMv;
+  const int lookup_index = y_mode - kPredictionModeNearestNearestMv;
+  assert(lookup_index >= 0);
+  return kCompoundToSinglePredictionMode[lookup_index][index];
 }
 
 // log2(dqDenom) in section 7.12.3 of the spec. We use the log2 value because
 // dqDenom is always a power of two and hence right shift can be used instead of
 // division.
+constexpr BitMaskSet kQuantizationShift2Mask(kTransformSize32x64,
+                                             kTransformSize64x32,
+                                             kTransformSize64x64);
+constexpr BitMaskSet kQuantizationShift1Mask(kTransformSize16x32,
+                                             kTransformSize16x64,
+                                             kTransformSize32x16,
+                                             kTransformSize32x32,
+                                             kTransformSize64x16);
 int GetQuantizationShift(TransformSize tx_size) {
-  const int tx_width = kTransformWidth[tx_size];
-  const int tx_height = kTransformHeight[tx_size];
-  const int max_tx_dimension = std::max(tx_width, tx_height);
-  const int min_tx_dimension = std::min(tx_width, tx_height);
-  if (max_tx_dimension == 64 && min_tx_dimension >= 32) {
+  if (kQuantizationShift2Mask.Contains(tx_size)) {
     return 2;
   }
-  if (max_tx_dimension >= 32 && min_tx_dimension >= 16) {
+  if (kQuantizationShift1Mask.Contains(tx_size)) {
     return 1;
   }
   return 0;
@@ -193,6 +264,49 @@
   return index + MultiplyBy4(index >> tx_width_log2);
 }
 
+// Returns the minimum of |length| or |max|-|start|. This is used to clamp array
+// indices when accessing arrays whose bound is equal to |max|.
+int GetNumElements(int length, int start, int max) {
+  return std::min(length, max - start);
+}
+
+// This is the same as Max_Tx_Size_Rect array in the spec but with *x64 and 64*x
+// transforms replaced with *x32 and 32x* respectively.
+constexpr TransformSize kUVTransformSize[kMaxBlockSizes] = {
+    kTransformSize4x4,   kTransformSize4x8,   kTransformSize4x16,
+    kTransformSize8x4,   kTransformSize8x8,   kTransformSize8x16,
+    kTransformSize8x32,  kTransformSize16x4,  kTransformSize16x8,
+    kTransformSize16x16, kTransformSize16x32, kTransformSize16x32,
+    kTransformSize32x8,  kTransformSize32x16, kTransformSize32x32,
+    kTransformSize32x32, kTransformSize32x16, kTransformSize32x32,
+    kTransformSize32x32, kTransformSize32x32, kTransformSize32x32,
+    kTransformSize32x32};
+
+// 5.11.37.
+TransformSize GetTransformSize(bool lossless, BlockSize block_size, Plane plane,
+                               TransformSize tx_size, int subsampling_x,
+                               int subsampling_y) {
+  if (plane == kPlaneY) return tx_size;
+  // For Y Plane, |tx_size| is always kTransformSize4x4. So it is sufficient to
+  // have this special case for |lossless| only for U and V planes.
+  if (lossless) return kTransformSize4x4;
+  const BlockSize plane_size =
+      kPlaneResidualSize[block_size][subsampling_x][subsampling_y];
+  assert(plane_size != kBlockInvalid);
+  return kUVTransformSize[plane_size];
+}
+
+void SetTransformType(const Tile::Block& block, int x4, int y4, int w4, int h4,
+                      TransformType tx_type,
+                      TransformType transform_types[32][32]) {
+  const int y_offset = y4 - block.row4x4;
+  const int x_offset = x4 - block.column4x4;
+  static_assert(sizeof(transform_types[0][0]) == 1, "");
+  for (int i = 0; i < h4; ++i) {
+    memset(&transform_types[y_offset + i][x_offset], tx_type, w4);
+  }
+}
+
 }  // namespace
 
 Tile::Tile(
@@ -212,11 +326,16 @@
     Array2D<int16_t>* const cdef_index,
     Array2D<TransformSize>* const inter_transform_sizes,
     const dsp::Dsp* const dsp, ThreadPool* const thread_pool,
-    ResidualBufferPool* const residual_buffer_pool)
+    ResidualBufferPool* const residual_buffer_pool,
+    BlockingCounterWithStatus* const pending_tiles)
     : number_(tile_number),
       data_(data),
       size_(size),
       read_deltas_(false),
+      subsampling_x_{0, sequence_header.color_config.subsampling_x,
+                     sequence_header.color_config.subsampling_x},
+      subsampling_y_{0, sequence_header.color_config.subsampling_y,
+                     sequence_header.color_config.subsampling_y},
       current_quantizer_index_(frame_header.quantizer.base_index),
       sequence_header_(sequence_header),
       frame_header_(frame_header),
@@ -245,21 +364,15 @@
       cdef_index_(*cdef_index),
       inter_transform_sizes_(*inter_transform_sizes),
       thread_pool_(thread_pool),
-      residual_buffer_pool_(residual_buffer_pool) {
+      residual_buffer_pool_(residual_buffer_pool),
+      pending_tiles_(pending_tiles),
+      build_bit_mask_when_parsing_(false) {
   row_ = number_ / frame_header.tile_info.tile_columns;
   column_ = number_ % frame_header.tile_info.tile_columns;
   row4x4_start_ = frame_header.tile_info.tile_row_start[row_];
   row4x4_end_ = frame_header.tile_info.tile_row_start[row_ + 1];
   column4x4_start_ = frame_header.tile_info.tile_column_start[column_];
   column4x4_end_ = frame_header.tile_info.tile_column_start[column_ + 1];
-  for (size_t i = 0; i < entropy_contexts_.size(); ++i) {
-    const int contexts_per_plane = (i == EntropyContext::kLeft)
-                                       ? frame_header_.rows4x4
-                                       : frame_header_.columns4x4;
-    if (!entropy_contexts_[i].Reset(PlaneCount(), contexts_per_plane)) {
-      LIBGAV1_DLOG(ERROR, "entropy_contexts_[%zu].Reset() failed.", i);
-    }
-  }
   const int block_width4x4 = kNum4x4BlocksWide[SuperBlockSize()];
   const int block_width4x4_log2 = k4x4HeightLog2[SuperBlockSize()];
   superblock_rows_ =
@@ -271,17 +384,8 @@
   // superblock columns as |intra_block_copy_lag_|.
   split_parse_and_decode_ =
       thread_pool_ != nullptr && superblock_columns_ > intra_block_copy_lag_;
-  if (split_parse_and_decode_) {
-    assert(residual_buffer_pool != nullptr);
-    if (!residual_buffer_threaded_.Reset(superblock_rows_, superblock_columns_,
-                                         /*zero_initialize=*/false)) {
-      LIBGAV1_DLOG(ERROR, "residual_buffer_threaded_.Reset() failed.");
-    }
-  } else {
-    residual_buffer_ = MakeAlignedUniquePtr<uint8_t>(32, 4096 * residual_size_);
-    prediction_parameters_.reset(new (std::nothrow) PredictionParameters());
-  }
   memset(delta_lf_, 0, sizeof(delta_lf_));
+  delta_lf_all_zero_ = true;
   YuvBuffer* const buffer = current_frame->buffer();
   for (int plane = 0; plane < PlaneCount(); ++plane) {
     buffer_[plane].Reset(buffer->height(plane) + buffer->bottom_border(plane),
@@ -289,13 +393,58 @@
   }
 }
 
-bool Tile::Decode() {
+bool Tile::Init() {
+  assert(coefficient_levels_.size() == dc_categories_.size());
+  for (size_t i = 0; i < coefficient_levels_.size(); ++i) {
+    const int contexts_per_plane = (i == kEntropyContextLeft)
+                                       ? frame_header_.rows4x4
+                                       : frame_header_.columns4x4;
+    if (!coefficient_levels_[i].Reset(PlaneCount(), contexts_per_plane)) {
+      LIBGAV1_DLOG(ERROR, "coefficient_levels_[%zu].Reset() failed.", i);
+      return false;
+    }
+    if (!dc_categories_[i].Reset(PlaneCount(), contexts_per_plane)) {
+      LIBGAV1_DLOG(ERROR, "dc_categories_[%zu].Reset() failed.", i);
+      return false;
+    }
+  }
+  if (split_parse_and_decode_) {
+    assert(residual_buffer_pool_ != nullptr);
+    if (!residual_buffer_threaded_.Reset(superblock_rows_, superblock_columns_,
+                                         /*zero_initialize=*/false)) {
+      LIBGAV1_DLOG(ERROR, "residual_buffer_threaded_.Reset() failed.");
+      return false;
+    }
+  } else {
+    residual_buffer_ = MakeAlignedUniquePtr<uint8_t>(32, 4096 * residual_size_);
+    if (residual_buffer_ == nullptr) {
+      LIBGAV1_DLOG(ERROR, "Allocation of residual_buffer_ failed.");
+      return false;
+    }
+    prediction_parameters_.reset(new (std::nothrow) PredictionParameters());
+    if (prediction_parameters_ == nullptr) {
+      LIBGAV1_DLOG(ERROR, "Allocation of prediction_parameters_ failed.");
+      return false;
+    }
+  }
+  return true;
+}
+
+bool Tile::Decode(bool is_main_thread) {
+  if (!Init()) {
+    pending_tiles_->Decrement(false);
+    return false;
+  }
   if (frame_header_.use_ref_frame_mvs) {
     SetupMotionField(sequence_header_, frame_header_, current_frame_,
                      reference_frames_, motion_field_mv_, row4x4_start_,
                      row4x4_end_, column4x4_start_, column4x4_end_);
   }
   ResetLoopRestorationParams();
+  // If this is the main thread, we build the loop filter bit masks when parsing
+  // so that it happens in the current thread. This ensures that the main thread
+  // does as much work as possible.
+  build_bit_mask_when_parsing_ = is_main_thread;
   if (split_parse_and_decode_) {
     if (!ThreadedDecode()) return false;
   } else {
@@ -307,6 +456,7 @@
            column4x4 += block_width4x4) {
         if (!ProcessSuperBlock(row4x4, column4x4, block_width4x4, &sb_buffer,
                                kProcessingModeParseAndDecode)) {
+          pending_tiles_->Decrement(false);
           LIBGAV1_DLOG(ERROR, "Error decoding super block row: %d column: %d",
                        row4x4, column4x4);
           return false;
@@ -318,17 +468,22 @@
       number_ == frame_header_.tile_info.context_update_id) {
     *saved_symbol_decoder_context_ = symbol_decoder_context_;
   }
+  if (!split_parse_and_decode_) {
+    pending_tiles_->Decrement(true);
+  }
   return true;
 }
 
 bool Tile::ThreadedDecode() {
-  ThreadingParameters threading;
   {
-    std::lock_guard<std::mutex> lock(threading.mutex);
-    if (!threading.sb_state.Reset(superblock_rows_, superblock_columns_)) {
+    std::lock_guard<std::mutex> lock(threading_.mutex);
+    if (!threading_.sb_state.Reset(superblock_rows_, superblock_columns_)) {
+      pending_tiles_->Decrement(false);
       LIBGAV1_DLOG(ERROR, "threading.sb_state.Reset() failed.");
       return false;
     }
+    // Account for the parsing job.
+    ++threading_.pending_jobs;
   }
 
   const int block_width4x4 = kNum4x4BlocksWide[SuperBlockSize()];
@@ -342,47 +497,58 @@
          column4x4 += block_width4x4, ++column_index) {
       if (!ProcessSuperBlock(row4x4, column4x4, block_width4x4, &sb_buffer,
                              kProcessingModeParseOnly)) {
-        std::lock_guard<std::mutex> lock(threading.mutex);
-        threading.abort = true;
+        std::lock_guard<std::mutex> lock(threading_.mutex);
+        threading_.abort = true;
         break;
       }
-      std::lock_guard<std::mutex> lock(threading.mutex);
-      if (threading.abort) break;
-      threading.sb_state[row_index][column_index] = kSuperBlockStateParsed;
+      std::unique_lock<std::mutex> lock(threading_.mutex);
+      if (threading_.abort) break;
+      threading_.sb_state[row_index][column_index] = kSuperBlockStateParsed;
       // Schedule the decoding of this superblock if it is allowed.
-      if (CanDecode(row_index, column_index, threading.sb_state)) {
-        ++threading.pending_jobs;
-        threading.sb_state[row_index][column_index] = kSuperBlockStateScheduled;
-        thread_pool_->Schedule([this, row_index, column_index, block_width4x4,
-                                &threading]() {
-          DecodeSuperBlock(row_index, column_index, block_width4x4, &threading);
-        });
+      if (CanDecode(row_index, column_index)) {
+        ++threading_.pending_jobs;
+        threading_.sb_state[row_index][column_index] =
+            kSuperBlockStateScheduled;
+        lock.unlock();
+        thread_pool_->Schedule(
+            [this, row_index, column_index, block_width4x4]() {
+              DecodeSuperBlock(row_index, column_index, block_width4x4);
+            });
       }
     }
-    std::lock_guard<std::mutex> lock(threading.mutex);
-    if (threading.abort) break;
+    std::lock_guard<std::mutex> lock(threading_.mutex);
+    if (threading_.abort) break;
   }
 
-  // Wait for the decode jobs to finish.
-  std::unique_lock<std::mutex> lock(threading.mutex);
-  while (threading.pending_jobs != 0) {
-    threading.pending_jobs_zero_condvar.wait(lock);
+  // We are done parsing. We can return here since the calling thread will make
+  // sure that it waits for all the superblocks to be decoded.
+  //
+  // Finish using |threading_| before |pending_tiles_->Decrement()| because the
+  // Tile object could go out of scope as soon as |pending_tiles_->Decrement()|
+  // is called.
+  threading_.mutex.lock();
+  const bool no_pending_jobs = (--threading_.pending_jobs == 0);
+  const bool job_succeeded = !threading_.abort;
+  threading_.mutex.unlock();
+  if (no_pending_jobs) {
+    // We are done parsing and decoding this tile.
+    pending_tiles_->Decrement(job_succeeded);
   }
-
-  return !threading.abort;
+  return job_succeeded;
 }
 
-bool Tile::CanDecode(int row_index, int column_index,
-                     const Array2D<SuperBlockState>& sb_state) {
-  // If |sb_state| is not equal to kSuperBlockStateParsed, then return false.
-  // This is ok because if |sb_state| is equal to:
+bool Tile::CanDecode(int row_index, int column_index) const {
+  assert(row_index >= 0);
+  assert(column_index >= 0);
+  // If |threading_.sb_state[row_index][column_index]| is not equal to
+  // kSuperBlockStateParsed, then return false. This is ok because if
+  // |threading_.sb_state[row_index][column_index]| is equal to:
   //   kSuperBlockStateNone - then the superblock is not yet parsed.
   //   kSuperBlockStateScheduled - then the superblock is already scheduled for
   //                               decode.
   //   kSuperBlockStateDecoded - then the superblock has already been decoded.
-  if (row_index < 0 || column_index < 0 || row_index >= superblock_rows_ ||
-      column_index >= superblock_columns_ ||
-      sb_state[row_index][column_index] != kSuperBlockStateParsed) {
+  if (row_index >= superblock_rows_ || column_index >= superblock_columns_ ||
+      threading_.sb_state[row_index][column_index] != kSuperBlockStateParsed) {
     return false;
   }
   // First superblock has no dependencies.
@@ -392,29 +558,30 @@
   // Superblocks in the first row only depend on the superblock to the left of
   // it.
   if (row_index == 0) {
-    return sb_state[0][column_index - 1] == kSuperBlockStateDecoded;
+    return threading_.sb_state[0][column_index - 1] == kSuperBlockStateDecoded;
   }
   // All other superblocks depend on superblock to the left of it (if one
   // exists) and superblock to the top right with a lag of
   // |intra_block_copy_lag_| (if one exists).
   const int top_right_column_index =
       std::min(column_index + intra_block_copy_lag_, superblock_columns_ - 1);
-  return sb_state[row_index - 1][top_right_column_index] ==
+  return threading_.sb_state[row_index - 1][top_right_column_index] ==
              kSuperBlockStateDecoded &&
          (column_index == 0 ||
-          sb_state[row_index][column_index - 1] == kSuperBlockStateDecoded);
+          threading_.sb_state[row_index][column_index - 1] ==
+              kSuperBlockStateDecoded);
 }
 
-void Tile::DecodeSuperBlock(int row_index, int column_index, int block_width4x4,
-                            ThreadingParameters* const threading) {
+void Tile::DecodeSuperBlock(int row_index, int column_index,
+                            int block_width4x4) {
   SuperBlockBuffer sb_buffer;
   const int row4x4 = row4x4_start_ + (row_index * block_width4x4);
   const int column4x4 = column4x4_start_ + (column_index * block_width4x4);
   const bool ok = ProcessSuperBlock(row4x4, column4x4, block_width4x4,
                                     &sb_buffer, kProcessingModeDecodeOnly);
-  std::lock_guard<std::mutex> lock(threading->mutex);
+  std::unique_lock<std::mutex> lock(threading_.mutex);
   if (ok) {
-    threading->sb_state[row_index][column_index] = kSuperBlockStateDecoded;
+    threading_.sb_state[row_index][column_index] = kSuperBlockStateDecoded;
     // Candidate rows and columns that we could potentially begin the decoding
     // (if it is allowed to do so). The candidates are:
     //   1) The superblock to the bottom-left of the current superblock with a
@@ -429,26 +596,29 @@
          ++i) {
       const int candidate_row_index = candidate_row_indices[i];
       const int candidate_column_index = candidate_column_indices[i];
-      if (!CanDecode(candidate_row_index, candidate_column_index,
-                     threading->sb_state)) {
+      if (!CanDecode(candidate_row_index, candidate_column_index)) {
         continue;
       }
-      ++threading->pending_jobs;
-      threading->sb_state[candidate_row_index][candidate_column_index] =
+      ++threading_.pending_jobs;
+      threading_.sb_state[candidate_row_index][candidate_column_index] =
           kSuperBlockStateScheduled;
+      lock.unlock();
       thread_pool_->Schedule([this, candidate_row_index, candidate_column_index,
-                              block_width4x4, threading]() {
+                              block_width4x4]() {
         DecodeSuperBlock(candidate_row_index, candidate_column_index,
-                         block_width4x4, threading);
+                         block_width4x4);
       });
+      lock.lock();
     }
   } else {
-    threading->abort = true;
+    threading_.abort = true;
   }
-  if (--threading->pending_jobs == 0) {
-    // TODO(jzern): the mutex doesn't need to be locked to signal the
-    // condition.
-    threading->pending_jobs_zero_condvar.notify_one();
+  if (--threading_.pending_jobs == 0) {
+    // The lock needs to be unlocked here in this case because the Tile object
+    // could go out of scope as soon as |pending_tiles_->Decrement()| is called.
+    lock.unlock();
+    // We are done parsing and decoding this tile.
+    pending_tiles_->Decrement(!threading_.abort);
   }
 }
 
@@ -460,54 +630,55 @@
 int Tile::GetTransformAllZeroContext(const Block& block, Plane plane,
                                      TransformSize tx_size, int x4, int y4,
                                      int w4, int h4) {
-  const int max_x4x4 = frame_header_.columns4x4 >> SubsamplingX(plane);
-  const int max_y4x4 = frame_header_.rows4x4 >> SubsamplingY(plane);
+  const int max_x4x4 = frame_header_.columns4x4 >> subsampling_x_[plane];
+  const int max_y4x4 = frame_header_.rows4x4 >> subsampling_y_[plane];
 
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
   const BlockSize plane_block_size =
-      kPlaneResidualSize[block.size][SubsamplingX(plane)][SubsamplingY(plane)];
+      kPlaneResidualSize[block.size][subsampling_x_[plane]]
+                        [subsampling_y_[plane]];
   assert(plane_block_size != kBlockInvalid);
   const int block_width = kBlockWidthPixels[plane_block_size];
   const int block_height = kBlockHeightPixels[plane_block_size];
 
   int top = 0;
   int left = 0;
+  const int num_top_elements = GetNumElements(w4, x4, max_x4x4);
+  const int num_left_elements = GetNumElements(h4, y4, max_y4x4);
   if (plane == kPlaneY) {
     if (block_width == tx_width && block_height == tx_height) return 0;
-    for (int i = 0; i < w4 && (i + x4 < max_x4x4); ++i) {
-      top = std::max(top,
-                     static_cast<int>(
-                         entropy_contexts_[EntropyContext::kTop][plane][x4 + i]
-                             .coefficient_level));
+    const uint8_t* coefficient_levels =
+        &coefficient_levels_[kEntropyContextTop][plane][x4];
+    for (int i = 0; i < num_top_elements; ++i) {
+      top = std::max(top, static_cast<int>(coefficient_levels[i]));
     }
-    for (int i = 0; i < h4 && (i + y4 < max_y4x4); ++i) {
-      left = std::max(
-          left, static_cast<int>(
-                    entropy_contexts_[EntropyContext::kLeft][plane][y4 + i]
-                        .coefficient_level));
+    coefficient_levels = &coefficient_levels_[kEntropyContextLeft][plane][y4];
+    for (int i = 0; i < num_left_elements; ++i) {
+      left = std::max(left, static_cast<int>(coefficient_levels[i]));
     }
-    top = std::min(top, 255);
-    left = std::min(left, 255);
-    const int min = std::min({top, left, 4});
-    const int max = std::min(std::max(top, left), 4);
-    // kAllZeroContextsByMinMax is pre-computed based on the logic in the spec
+    assert(top <= 4);
+    assert(left <= 4);
+    // kAllZeroContextsByTopLeft is pre-computed based on the logic in the spec
     // for top and left.
-    return kAllZeroContextsByMinMax[min][max];
+    return kAllZeroContextsByTopLeft[top][left];
   }
-  for (int i = 0; i < w4 && (i + x4 < max_x4x4); ++i) {
-    top |= entropy_contexts_[EntropyContext::kTop][plane][x4 + i]
-               .coefficient_level;
-    top |= entropy_contexts_[EntropyContext::kTop][plane][x4 + i].dc_category;
+  const uint8_t* coefficient_levels =
+      &coefficient_levels_[kEntropyContextTop][plane][x4];
+  const int8_t* dc_categories = &dc_categories_[kEntropyContextTop][plane][x4];
+  for (int i = 0; i < num_top_elements; ++i) {
+    top |= coefficient_levels[i];
+    top |= dc_categories[i];
   }
-  for (int i = 0; i < h4 && (i + y4 < max_y4x4); ++i) {
-    left |= entropy_contexts_[EntropyContext::kLeft][plane][y4 + i]
-                .coefficient_level;
-    left |= entropy_contexts_[EntropyContext::kLeft][plane][y4 + i].dc_category;
+  coefficient_levels = &coefficient_levels_[kEntropyContextLeft][plane][y4];
+  dc_categories = &dc_categories_[kEntropyContextLeft][plane][y4];
+  for (int i = 0; i < num_left_elements; ++i) {
+    left |= coefficient_levels[i];
+    left |= dc_categories[i];
   }
-  int context = static_cast<int>(top != 0) + static_cast<int>(left != 0) + 7;
-  if (block_width * block_height > tx_width * tx_height) context += 3;
-  return context;
+  return static_cast<int>(top != 0) + static_cast<int>(left != 0) + 7 +
+         3 * static_cast<int>(block_width * block_height >
+                              tx_width * tx_height);
 }
 
 TransformSet Tile::GetTransformSet(TransformSize tx_size, bool is_inter) const {
@@ -546,15 +717,13 @@
   TransformType tx_type;
   if (bp.is_inter) {
     const int x4 =
-        std::max(block.column4x4,
-                 block_x << sequence_header_.color_config.subsampling_x);
-    const int y4 = std::max(
-        block.row4x4, block_y << sequence_header_.color_config.subsampling_y);
+        std::max(block.column4x4, block_x << subsampling_x_[kPlaneU]);
+    const int y4 = std::max(block.row4x4, block_y << subsampling_y_[kPlaneU]);
     tx_type = transform_types_[y4 - block.row4x4][x4 - block.column4x4];
   } else {
     tx_type = kModeToTransformType[bp.uv_mode];
   }
-  return static_cast<bool>((kTransformTypeInSetMask[tx_set] >> tx_type) & 1)
+  return kTransformTypeInSetMask[tx_set].Contains(tx_type)
              ? tx_type
              : kTransformTypeDctDct;
 }
@@ -590,13 +759,8 @@
     // first dimension needs to be offset by 1.
     tx_type = kInverseTransformTypeBySet[tx_set - 1][tx_type];
   }
-  transform_types_[y4 - block.row4x4][x4 - block.column4x4] = tx_type;
-  for (int i = 0; i < DivideBy4(kTransformWidth[tx_size]); ++i) {
-    for (int j = 0; j < DivideBy4(kTransformHeight[tx_size]); ++j) {
-      transform_types_[y4 + j - block.row4x4][x4 + i - block.column4x4] =
-          tx_type;
-    }
-  }
+  SetTransformType(block, x4, y4, kTransformWidth4x4[tx_size],
+                   kTransformHeight4x4[tx_size], tx_type, transform_types_);
 }
 
 // Section 8.3.2 in the spec, under coeff_base_eob.
@@ -734,48 +898,36 @@
 }
 
 int Tile::GetDcSignContext(int x4, int y4, int w4, int h4, Plane plane) {
-  const int max_x4x4 = frame_header_.columns4x4 >> SubsamplingX(plane);
-  const int max_y4x4 = frame_header_.rows4x4 >> SubsamplingY(plane);
-  int dc_sign = 0;
-  for (int i = 0; i < w4 && (i + x4 < max_x4x4); ++i) {
-    const int sign =
-        entropy_contexts_[EntropyContext::kTop][plane][x4 + i].dc_category;
-    if (sign == 1) {
-      dc_sign--;
-    } else if (sign == 2) {
-      dc_sign++;
-    }
-  }
-  for (int i = 0; i < h4 && (i + y4 < max_y4x4); ++i) {
-    const int sign =
-        entropy_contexts_[EntropyContext::kLeft][plane][y4 + i].dc_category;
-    if (sign == 1) {
-      dc_sign--;
-    } else if (sign == 2) {
-      dc_sign++;
-    }
-  }
-  if (dc_sign < 0) return 1;
-  if (dc_sign > 0) return 2;
-  return 0;
+  const int max_x4x4 = frame_header_.columns4x4 >> subsampling_x_[plane];
+  const int8_t* dc_categories = &dc_categories_[kEntropyContextTop][plane][x4];
+  int dc_sign = std::accumulate(
+      dc_categories, dc_categories + GetNumElements(w4, x4, max_x4x4), 0);
+  const int max_y4x4 = frame_header_.rows4x4 >> subsampling_y_[plane];
+  dc_categories = &dc_categories_[kEntropyContextLeft][plane][y4];
+  dc_sign = std::accumulate(
+      dc_categories, dc_categories + GetNumElements(h4, y4, max_y4x4), dc_sign);
+  // This return statement is equivalent to:
+  //   if (dc_sign < 0) return 1;
+  //   if (dc_sign > 0) return 2;
+  //   return 0;
+  return static_cast<int>(dc_sign < 0) +
+         MultiplyBy2(static_cast<int>(dc_sign > 0));
 }
 
 void Tile::SetEntropyContexts(int x4, int y4, int w4, int h4, Plane plane,
-                              uint8_t coefficient_level, uint8_t dc_category) {
-  const int max_x4x4 = frame_header_.columns4x4 >> SubsamplingX(plane);
-  const int max_y4x4 = frame_header_.rows4x4 >> SubsamplingY(plane);
-  for (int i = 0; i < w4 && (i + x4 < max_x4x4); ++i) {
-    entropy_contexts_[EntropyContext::kTop][plane][x4 + i].coefficient_level =
-        coefficient_level;
-    entropy_contexts_[EntropyContext::kTop][plane][x4 + i].dc_category =
-        dc_category;
-  }
-  for (int i = 0; i < h4 && (i + y4 < max_y4x4); ++i) {
-    entropy_contexts_[EntropyContext::kLeft][plane][y4 + i].coefficient_level =
-        coefficient_level;
-    entropy_contexts_[EntropyContext::kLeft][plane][y4 + i].dc_category =
-        dc_category;
-  }
+                              uint8_t coefficient_level, int8_t dc_category) {
+  const int max_x4x4 = frame_header_.columns4x4 >> subsampling_x_[plane];
+  const int num_top_elements = GetNumElements(w4, x4, max_x4x4);
+  memset(&coefficient_levels_[kEntropyContextTop][plane][x4], coefficient_level,
+         num_top_elements);
+  memset(&dc_categories_[kEntropyContextTop][plane][x4], dc_category,
+         num_top_elements);
+  const int max_y4x4 = frame_header_.rows4x4 >> subsampling_y_[plane];
+  const int num_left_elements = GetNumElements(h4, y4, max_y4x4);
+  memset(&coefficient_levels_[kEntropyContextLeft][plane][y4],
+         coefficient_level, num_left_elements);
+  memset(&dc_categories_[kEntropyContextLeft][plane][y4], dc_category,
+         num_left_elements);
 }
 
 void Tile::ScaleMotionVector(const MotionVector& mv, const Plane plane,
@@ -798,8 +950,8 @@
   const bool is_scaled_x = reference_upscaled_width != frame_header_.width;
   const bool is_scaled_y = reference_height != frame_header_.height;
   const int half_sample = 1 << (kSubPixelBits - 1);
-  int orig_x = (x << kSubPixelBits) + ((2 * mv.mv[1]) >> SubsamplingX(plane));
-  int orig_y = (y << kSubPixelBits) + ((2 * mv.mv[0]) >> SubsamplingY(plane));
+  int orig_x = (x << kSubPixelBits) + ((2 * mv.mv[1]) >> subsampling_x_[plane]);
+  int orig_y = (y << kSubPixelBits) + ((2 * mv.mv[0]) >> subsampling_y_[plane]);
   const int rounding_offset =
       DivideBy2(1 << (kScaleSubPixelBits - kSubPixelBits));
   if (is_scaled_x) {
@@ -840,33 +992,112 @@
   }
 }
 
+template <bool is_dc_coefficient>
+bool Tile::ReadSignAndApplyDequantization(
+    const Block& block, const uint16_t* const scan, int i,
+    int adjusted_tx_width_log2, int tx_width, int q_value,
+    const uint8_t* const quantizer_matrix, int shift, int min_value,
+    int max_value, uint16_t* const dc_sign_cdf, int8_t* const dc_category,
+    int* const coefficient_level) {
+  int pos = is_dc_coefficient ? 0 : scan[i];
+  const int pos_index =
+      is_dc_coefficient ? 0 : PaddedIndex(pos, adjusted_tx_width_log2);
+  const bool sign = quantized_[pos_index] != 0 &&
+                    (is_dc_coefficient ? reader_.ReadSymbol(dc_sign_cdf)
+                                       : static_cast<bool>(reader_.ReadBit()));
+  if (quantized_[pos_index] >
+      kNumQuantizerBaseLevels + kQuantizerCoefficientBaseRange) {
+    int length = 0;
+    bool golomb_length_bit = false;
+    do {
+      golomb_length_bit = static_cast<bool>(reader_.ReadBit());
+      ++length;
+      if (length > 20) {
+        LIBGAV1_DLOG(ERROR, "Invalid golomb_length %d", length);
+        return false;
+      }
+    } while (!golomb_length_bit);
+    int x = 1;
+    for (int i = length - 2; i >= 0; --i) {
+      x = (x << 1) | reader_.ReadBit();
+    }
+    quantized_[pos_index] += x - 1;
+  }
+  if (is_dc_coefficient && quantized_[0] > 0) {
+    *dc_category = sign ? -1 : 1;
+  }
+  quantized_[pos_index] &= 0xfffff;
+  *coefficient_level += quantized_[pos_index];
+  // Apply dequantization. Step 1 of section 7.12.3 in the spec.
+  int q = q_value;
+  if (quantizer_matrix != nullptr) {
+    q = RightShiftWithRounding(q * quantizer_matrix[pos], 5);
+  }
+  // The intermediate multiplication can exceed 32 bits, so it has to be
+  // performed by promoting one of the values to int64_t.
+  int32_t dequantized_value =
+      (static_cast<int64_t>(q) * quantized_[pos_index]) & 0xffffff;
+  dequantized_value >>= shift;
+  if (sign) {
+    dequantized_value = -dequantized_value;
+  }
+  // Inverse transform process assumes that the quantized coefficients are
+  // stored as a virtual 2d array of size |tx_width| x |tx_height|. If
+  // transform width is 64, then this assumption is broken because the scan
+  // order used for populating the coefficients for such transforms is the
+  // same as the one used for corresponding transform with width 32 (e.g. the
+  // scan order used for 64x16 is the same as the one used for 32x16). So we
+  // have to recompute the value of pos so that it reflects the index of the
+  // 2d array of size 64 x |tx_height|.
+  if (!is_dc_coefficient && tx_width == 64) {
+    const int row_index = DivideBy32(pos);
+    const int column_index = Mod32(pos);
+    pos = MultiplyBy64(row_index) + column_index;
+  }
+  if (sequence_header_.color_config.bitdepth == 8) {
+    auto* const residual_buffer =
+        reinterpret_cast<int16_t*>(block.sb_buffer->residual);
+    residual_buffer[pos] = Clip3(dequantized_value, min_value, max_value);
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  } else {
+    auto* const residual_buffer =
+        reinterpret_cast<int32_t*>(block.sb_buffer->residual);
+    residual_buffer[pos] = Clip3(dequantized_value, min_value, max_value);
+#endif
+  }
+  return true;
+}
+
+int Tile::ReadCoeffBaseRange(int clamped_tx_size_context, int cdf_context,
+                             int plane_type) {
+  int level = 0;
+  for (int j = 0; j < kCoeffBaseRangeMaxIterations; ++j) {
+    const int coeff_base_range = reader_.ReadSymbol<kCoeffBaseRangeSymbolCount>(
+        symbol_decoder_context_.coeff_base_range_cdf[clamped_tx_size_context]
+                                                    [plane_type][cdf_context]);
+    level += coeff_base_range;
+    if (coeff_base_range < (kCoeffBaseRangeSymbolCount - 1)) break;
+  }
+  return level;
+}
+
 int16_t Tile::ReadTransformCoefficients(const Block& block, Plane plane,
                                         int start_x, int start_y,
                                         TransformSize tx_size,
                                         TransformType* const tx_type) {
   const int x4 = DivideBy4(start_x);
   const int y4 = DivideBy4(start_y);
-  const int w4 = DivideBy4(kTransformWidth[tx_size]);
-  const int h4 = DivideBy4(kTransformHeight[tx_size]);
-
-  const int tx_size_square_min =
-      TransformSizeToSquareTransformIndex(kTransformSizeSquareMin[tx_size]);
-  const int tx_size_square_max =
-      TransformSizeToSquareTransformIndex(kTransformSizeSquareMax[tx_size]);
-  const int tx_size_context =
-      DivideBy2(tx_size_square_min + tx_size_square_max + 1);
+  const int w4 = kTransformWidth4x4[tx_size];
+  const int h4 = kTransformHeight4x4[tx_size];
+  const int tx_size_context = kTransformSizeContext[tx_size];
   int context =
       GetTransformAllZeroContext(block, plane, tx_size, x4, y4, w4, h4);
   const bool all_zero = reader_.ReadSymbol(
       symbol_decoder_context_.all_zero_cdf[tx_size_context][context]);
   if (all_zero) {
     if (plane == kPlaneY) {
-      for (int i = 0; i < w4; ++i) {
-        for (int j = 0; j < h4; ++j) {
-          transform_types_[y4 + j - block.row4x4][x4 + i - block.column4x4] =
-              kTransformTypeDctDct;
-        }
-      }
+      SetTransformType(block, x4, y4, w4, h4, kTransformTypeDctDct,
+                       transform_types_);
     }
     SetEntropyContexts(x4, y4, w4, h4, plane, 0, 0);
     // This is not used in this case, so it can be set to any value.
@@ -921,10 +1152,8 @@
       break;
   }
   const int16_t eob_pt =
-      1 + reader_.ReadSymbol(cdf, kEobPtSymbolCount[eob_multi_size]);
+      1 + reader_.ReadSymbol(cdf, kEobPt16SymbolCount + eob_multi_size);
   int16_t eob = (eob_pt < 2) ? eob_pt : ((1 << (eob_pt - 2)) + 1);
-  int coefficient_level = 0;
-  uint8_t dc_category = 0;
   if (eob_pt >= 3) {
     context = eob_pt - 3;
     const bool eob_extra = reader_.ReadSymbol(
@@ -933,16 +1162,16 @@
     if (eob_extra) eob += 1 << (eob_pt - 3);
     for (int i = 1; i < eob_pt - 2; ++i) {
       assert(eob_pt - i >= 3);
-      assert(eob_pt <= kEobPtSymbolCount[6]);
+      assert(eob_pt <= kEobPt1024SymbolCount);
       if (static_cast<bool>(reader_.ReadBit())) {
         eob += 1 << (eob_pt - i - 3);
       }
     }
   }
-  const uint16_t* scan = GetScan(tx_size, *tx_type);
+  const TransformClass tx_class = GetTransformClass(*tx_type);
+  const uint16_t* scan = kScan[tx_class][tx_size];
   const TransformSize adjusted_tx_size = kAdjustedTransformSize[tx_size];
   const int adjusted_tx_width_log2 = kTransformWidthLog2[adjusted_tx_size];
-  const TransformClass tx_class = GetTransformClass(*tx_type);
   // Lookup used to call the right variant of GetCoeffBaseContext*() based on
   // the transform class.
   static constexpr int (Tile::*kGetCoeffBaseContextFunc[])(
@@ -958,36 +1187,37 @@
       &Tile::GetCoeffBaseRangeContextVertical};
   auto get_coeff_base_range_context_func =
       kGetCoeffBaseRangeContextFunc[tx_class];
-  for (int i = eob - 1; i >= 0; --i) {
-    const uint16_t pos = scan[i];
-    int level;
-    int symbol_count;
-    if (i == eob - 1) {
-      level = 1;
-      context = GetCoeffBaseContextEob(tx_size, i);
-      cdf = symbol_decoder_context_
-                .coeff_base_eob_cdf[tx_size_context][plane_type][context];
-      symbol_count = kCoeffBaseEobSymbolCount;
-    } else {
-      level = 0;
-      context = (this->*get_coeff_base_context_func)(
-          tx_size, adjusted_tx_width_log2, pos);
-      cdf = symbol_decoder_context_
-                .coeff_base_cdf[tx_size_context][plane_type][context];
-      symbol_count = kCoeffBaseSymbolCount;
-    }
-    level += reader_.ReadSymbol(cdf, symbol_count);
+  const int clamped_tx_size_context = std::min(tx_size_context, 3);
+  // Read the last coefficient.
+  {
+    context = GetCoeffBaseContextEob(tx_size, eob - 1);
+    const uint16_t pos = scan[eob - 1];
+    int level =
+        1 + reader_.ReadSymbol(
+                symbol_decoder_context_
+                    .coeff_base_eob_cdf[tx_size_context][plane_type][context],
+                kCoeffBaseEobSymbolCount);
     if (level > kNumQuantizerBaseLevels) {
-      context = (this->*get_coeff_base_range_context_func)(
-          adjusted_tx_width_log2, pos);
-      for (int j = 0; j < kCoeffBaseRangeMaxIterations; ++j) {
-        const int coeff_base_range = reader_.ReadSymbol(
-            symbol_decoder_context_.coeff_base_range_cdf[std::min(
-                tx_size_context, 3)][plane_type][context],
-            kCoeffBaseRangeSymbolCount);
-        level += coeff_base_range;
-        if (coeff_base_range < (kCoeffBaseRangeSymbolCount - 1)) break;
-      }
+      level += ReadCoeffBaseRange(clamped_tx_size_context,
+                                  (this->*get_coeff_base_range_context_func)(
+                                      adjusted_tx_width_log2, pos),
+                                  plane_type);
+    }
+    quantized_[PaddedIndex(pos, adjusted_tx_width_log2)] = level;
+  }
+  // Read all the other coefficients.
+  for (int i = eob - 2; i >= 0; --i) {
+    const uint16_t pos = scan[i];
+    context = (this->*get_coeff_base_context_func)(tx_size,
+                                                   adjusted_tx_width_log2, pos);
+    int level = reader_.ReadSymbol<kCoeffBaseSymbolCount>(
+        symbol_decoder_context_
+            .coeff_base_cdf[tx_size_context][plane_type][context]);
+    if (level > kNumQuantizerBaseLevels) {
+      level += ReadCoeffBaseRange(clamped_tx_size_context,
+                                  (this->*get_coeff_base_range_context_func)(
+                                      adjusted_tx_width_log2, pos),
+                                  plane_type);
     }
     quantized_[PaddedIndex(pos, adjusted_tx_width_log2)] = level;
   }
@@ -998,86 +1228,37 @@
   const int dc_q_value = quantizer_.GetDcValue(plane, current_quantizer_index);
   const int ac_q_value = quantizer_.GetAcValue(plane, current_quantizer_index);
   const int shift = GetQuantizationShift(tx_size);
-  for (int i = 0; i < eob; ++i) {
-    int pos = scan[i];
-    const int pos_index = PaddedIndex(pos, adjusted_tx_width_log2);
-    bool sign = false;
-    if (quantized_[pos_index] != 0) {
-      if (i == 0) {
-        context = GetDcSignContext(x4, y4, w4, h4, plane);
-        sign = reader_.ReadSymbol(
-            symbol_decoder_context_.dc_sign_cdf[plane_type][context]);
-      } else {
-        sign = static_cast<bool>(reader_.ReadBit());
-      }
-    }
-    if (quantized_[pos_index] >
-        kNumQuantizerBaseLevels + kQuantizerCoefficientBaseRange) {
-      int length = 0;
-      bool golomb_length_bit = false;
-      do {
-        golomb_length_bit = static_cast<bool>(reader_.ReadBit());
-        ++length;
-        if (length > 20) {
-          LIBGAV1_DLOG(ERROR, "Invalid golomb_length %d", length);
-          return -1;
-        }
-      } while (!golomb_length_bit);
-      int x = 1;
-      for (int i = length - 2; i >= 0; --i) {
-        x = (x << 1) | reader_.ReadBit();
-      }
-      quantized_[pos_index] += x - 1;
-    }
-    if (pos == 0 && quantized_[pos_index] > 0) {
-      dc_category = sign ? 1 : 2;
-    }
-    quantized_[pos_index] &= 0xfffff;
-    coefficient_level += quantized_[pos_index];
-    // Apply dequantization. Step 1 of section 7.12.3 in the spec.
-    int q = (pos == 0) ? dc_q_value : ac_q_value;
-    if (frame_header_.quantizer.use_matrix &&
-        *tx_type < kTransformTypeIdentityIdentity &&
-        !frame_header_.segmentation.lossless[bp.segment_id] &&
-        frame_header_.quantizer.matrix_level[plane] < 15) {
-      q *= kQuantizerMatrix[frame_header_.quantizer.matrix_level[plane]]
-                           [plane_type][kQuantizerMatrixOffset[tx_size] + pos];
-      q = RightShiftWithRounding(q, 5);
-    }
-    // The intermediate multiplication can exceed 32 bits, so it has to be
-    // performed by promoting one of the values to int64_t.
-    int32_t dequantized_value =
-        (static_cast<int64_t>(q) * quantized_[pos_index]) & 0xffffff;
-    dequantized_value >>= shift;
-    if (sign) {
-      dequantized_value = -dequantized_value;
-    }
-    // Inverse transform process assumes that the quantized coefficients are
-    // stored as a virtual 2d array of size |tx_width| x |tx_height|. If
-    // transform width is 64, then this assumption is broken because the scan
-    // order used for populating the coefficients for such transforms is the
-    // same as the one used for corresponding transform with width 32 (e.g. the
-    // scan order used for 64x16 is the same as the one used for 32x16). So we
-    // have to recompute the value of pos so that it reflects the index of the
-    // 2d array of size 64 x |tx_height|.
-    if (tx_width == 64) {
-      const int row_index = DivideBy32(pos);
-      const int column_index = Mod32(pos);
-      pos = MultiplyBy64(row_index) + column_index;
-    }
-    if (sequence_header_.color_config.bitdepth == 8) {
-      auto* const residual_buffer =
-          reinterpret_cast<int16_t*>(block.sb_buffer->residual);
-      residual_buffer[pos] = Clip3(dequantized_value, min_value, max_value);
-#if LIBGAV1_MAX_BITDEPTH >= 10
-    } else {
-      auto* const residual_buffer =
-          reinterpret_cast<int32_t*>(block.sb_buffer->residual);
-      residual_buffer[pos] = Clip3(dequantized_value, min_value, max_value);
-#endif
+  const uint8_t* const quantizer_matrix =
+      (frame_header_.quantizer.use_matrix &&
+       *tx_type < kTransformTypeIdentityIdentity &&
+       !frame_header_.segmentation.lossless[bp.segment_id] &&
+       frame_header_.quantizer.matrix_level[plane] < 15)
+          ? &kQuantizerMatrix[frame_header_.quantizer.matrix_level[plane]]
+                             [plane_type][kQuantizerMatrixOffset[tx_size]]
+          : nullptr;
+  int coefficient_level = 0;
+  int8_t dc_category = 0;
+  uint16_t* const dc_sign_cdf =
+      (quantized_[0] != 0)
+          ? symbol_decoder_context_.dc_sign_cdf[plane_type][GetDcSignContext(
+                x4, y4, w4, h4, plane)]
+          : nullptr;
+  assert(scan[0] == 0);
+  if (!ReadSignAndApplyDequantization</*is_dc_coefficient=*/true>(
+          block, scan, 0, adjusted_tx_width_log2, tx_width, dc_q_value,
+          quantizer_matrix, shift, min_value, max_value, dc_sign_cdf,
+          &dc_category, &coefficient_level)) {
+    return -1;
+  }
+  for (int i = 1; i < eob; ++i) {
+    if (!ReadSignAndApplyDequantization</*is_dc_coefficient=*/false>(
+            block, scan, i, adjusted_tx_width_log2, tx_width, ac_q_value,
+            quantizer_matrix, shift, min_value, max_value, nullptr, nullptr,
+            &coefficient_level)) {
+      return -1;
     }
   }
-  SetEntropyContexts(x4, y4, w4, h4, plane, std::min(63, coefficient_level),
+  SetEntropyContexts(x4, y4, w4, h4, plane, std::min(4, coefficient_level),
                      dc_category);
   if (split_parse_and_decode_) {
     block.sb_buffer->residual += tx_width * tx_height * residual_size_;
@@ -1089,8 +1270,8 @@
                           int base_y, TransformSize tx_size, int x, int y,
                           ProcessingMode mode) {
   BlockParameters& bp = *block.bp;
-  const int subsampling_x = SubsamplingX(plane);
-  const int subsampling_y = SubsamplingY(plane);
+  const int subsampling_x = subsampling_x_[plane];
+  const int subsampling_y = subsampling_y_[plane];
   const int start_x = base_x + MultiplyBy4(x);
   const int start_y = base_y + MultiplyBy4(y);
   const int max_x = MultiplyBy4(frame_header_.columns4x4) >> subsampling_x;
@@ -1101,8 +1282,8 @@
   const int mask = sequence_header_.use_128x128_superblock ? 31 : 15;
   const int sub_block_row4x4 = row & mask;
   const int sub_block_column4x4 = column & mask;
-  const int step_x = DivideBy4(kTransformWidth[tx_size]);
-  const int step_y = DivideBy4(kTransformHeight[tx_size]);
+  const int step_x = kTransformWidth4x4[tx_size];
+  const int step_y = kTransformHeight4x4[tx_size];
   const bool do_decode = mode == kProcessingModeDecodeOnly ||
                          mode == kProcessingModeParseAndDecode;
   if (do_decode && !bp.is_inter) {
@@ -1122,10 +1303,11 @@
               ? bp.y_mode
               : (bp.uv_mode == kPredictionModeChromaFromLuma ? kPredictionModeDc
                                                              : bp.uv_mode);
-      const int tr_row4x4 = (sub_block_row4x4 >> subsampling_y) - 1;
-      const int tr_column4x4 = (sub_block_column4x4 >> subsampling_x) + step_x;
-      const int bl_row4x4 = (sub_block_row4x4 >> subsampling_y) + step_y;
-      const int bl_column4x4 = (sub_block_column4x4 >> subsampling_x) - 1;
+      const int tr_row4x4 = (sub_block_row4x4 >> subsampling_y);
+      const int tr_column4x4 =
+          (sub_block_column4x4 >> subsampling_x) + step_x + 1;
+      const int bl_row4x4 = (sub_block_row4x4 >> subsampling_y) + step_y + 1;
+      const int bl_column4x4 = (sub_block_column4x4 >> subsampling_x);
       const bool has_left =
           x > 0 || (plane == kPlaneY ? block.left_available
                                      : block.LeftAvailableChroma());
@@ -1135,16 +1317,16 @@
       if (sequence_header_.color_config.bitdepth == 8) {
         IntraPrediction<uint8_t>(
             block, plane, start_x, start_y, has_left, has_top,
-            BlockDecoded(block, plane, tr_row4x4, tr_column4x4, has_top),
-            BlockDecoded(block, plane, bl_row4x4, bl_column4x4, has_left), mode,
-            tx_size);
+            block.sb_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
+            block.sb_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
+            mode, tx_size);
 #if LIBGAV1_MAX_BITDEPTH >= 10
       } else {
         IntraPrediction<uint16_t>(
             block, plane, start_x, start_y, has_left, has_top,
-            BlockDecoded(block, plane, tr_row4x4, tr_column4x4, has_top),
-            BlockDecoded(block, plane, bl_row4x4, bl_column4x4, has_left), mode,
-            tx_size);
+            block.sb_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
+            block.sb_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
+            mode, tx_size);
 #endif
       }
       if (plane != kPlaneY && bp.uv_mode == kPredictionModeChromaFromLuma) {
@@ -1198,58 +1380,74 @@
     }
   }
   if (do_decode) {
+    bool* block_decoded =
+        &block.sb_buffer
+             ->block_decoded[plane][(sub_block_row4x4 >> subsampling_y) + 1]
+                            [(sub_block_column4x4 >> subsampling_x) + 1];
     for (int i = 0; i < step_y; ++i) {
-      for (int j = 0; j < step_x; ++j) {
-        block.sb_buffer
-            ->block_decoded[plane][(sub_block_row4x4 >> subsampling_y) + i]
-                           [(sub_block_column4x4 >> subsampling_x) + j] = true;
-      }
+      static_assert(sizeof(bool) == 1, "");
+      memset(block_decoded, 1, step_x);
+      block_decoded += kBlockDecodedStride;
     }
   }
   return true;
 }
 
 bool Tile::TransformTree(const Block& block, int start_x, int start_y,
-                         int width, int height, ProcessingMode mode) {
-  const int row = DivideBy4(start_y);
-  const int column = DivideBy4(start_x);
-  if (row >= frame_header_.rows4x4 || column >= frame_header_.columns4x4) {
-    return true;
-  }
-  const TransformSize inter_tx_size = inter_transform_sizes_[row][column];
-  if (width <= kTransformWidth[inter_tx_size] &&
-      height <= kTransformHeight[inter_tx_size]) {
-    TransformSize tx_size = kNumTransformSizes;
-    for (int i = 0; i < kNumTransformSizes; ++i) {
-      if (kTransformWidth[i] == width && kTransformHeight[i] == height) {
-        tx_size = static_cast<TransformSize>(i);
-        break;
-      }
+                         BlockSize plane_size, ProcessingMode mode) {
+  assert(plane_size <= kBlock64x64);
+  // Branching factor is 4; Maximum Depth is 4; So the maximum stack size
+  // required is (4 - 1) * 4 + 1 = 13.
+  Stack<TransformTreeNode, 13> stack;
+  // It is okay to cast BlockSize to TransformSize here since the enum are
+  // equivalent for all BlockSize values <= kBlock64x64.
+  stack.Push(TransformTreeNode(start_x, start_y,
+                               static_cast<TransformSize>(plane_size)));
+
+  while (!stack.Empty()) {
+    TransformTreeNode node = stack.Pop();
+    const int row = DivideBy4(node.y);
+    const int column = DivideBy4(node.x);
+    if (row >= frame_header_.rows4x4 || column >= frame_header_.columns4x4) {
+      continue;
     }
-    assert(tx_size < kNumTransformSizes);
-    return TransformBlock(block, kPlaneY, start_x, start_y, tx_size, 0, 0,
-                          mode);
+    const TransformSize inter_tx_size = inter_transform_sizes_[row][column];
+    const int width = kTransformWidth[node.tx_size];
+    const int height = kTransformHeight[node.tx_size];
+    if (width <= kTransformWidth[inter_tx_size] &&
+        height <= kTransformHeight[inter_tx_size]) {
+      if (!TransformBlock(block, kPlaneY, node.x, node.y, node.tx_size, 0, 0,
+                          mode)) {
+        return false;
+      }
+      continue;
+    }
+    // The split transform size look up gives the right transform size that we
+    // should push in the stack.
+    //   if (width > height) => transform size whose width is half.
+    //   if (width < height) => transform size whose height is half.
+    //   if (width == height) => transform size whose width and height are half.
+    const TransformSize split_tx_size = kSplitTransformSize[node.tx_size];
+    const int half_width = DivideBy2(width);
+    if (width > height) {
+      stack.Push(TransformTreeNode(node.x + half_width, node.y, split_tx_size));
+      stack.Push(TransformTreeNode(node.x, node.y, split_tx_size));
+      continue;
+    }
+    const int half_height = DivideBy2(height);
+    if (width < height) {
+      stack.Push(
+          TransformTreeNode(node.x, node.y + half_height, split_tx_size));
+      stack.Push(TransformTreeNode(node.x, node.y, split_tx_size));
+      continue;
+    }
+    stack.Push(TransformTreeNode(node.x + half_width, node.y + half_height,
+                                 split_tx_size));
+    stack.Push(TransformTreeNode(node.x, node.y + half_height, split_tx_size));
+    stack.Push(TransformTreeNode(node.x + half_width, node.y, split_tx_size));
+    stack.Push(TransformTreeNode(node.x, node.y, split_tx_size));
   }
-  const int half_width = DivideBy2(width);
-  const int half_height = DivideBy2(height);
-  if (width > height) {
-    return TransformTree(block, start_x, start_y, half_width, height, mode) &&
-           TransformTree(block, start_x + half_width, start_y, half_width,
-                         height, mode);
-  }
-  if (width < height) {
-    return TransformTree(block, start_x, start_y, width, half_height, mode) &&
-           TransformTree(block, start_x, start_y + half_height, width,
-                         half_height, mode);
-  }
-  return TransformTree(block, start_x, start_y, half_width, half_height,
-                       mode) &&
-         TransformTree(block, start_x + half_width, start_y, half_width,
-                       half_height, mode) &&
-         TransformTree(block, start_x, start_y + half_height, half_width,
-                       half_height, mode) &&
-         TransformTree(block, start_x + half_width, start_y + half_height,
-                       half_width, half_height, mode);
+  return true;
 }
 
 void Tile::ReconstructBlock(const Block& block, Plane plane, int start_x,
@@ -1260,7 +1458,7 @@
   if (non_zero_coeff_count == 0) return;
   // Reconstruction process. Steps 2 and 3 of Section 7.12.3 in the spec.
   if (sequence_header_.color_config.bitdepth == 8) {
-    Reconstruct(dsp_, tx_type, tx_size, sequence_header_.color_config.bitdepth,
+    Reconstruct(dsp_, tx_type, tx_size,
                 frame_header_.segmentation.lossless[block.bp->segment_id],
                 reinterpret_cast<int16_t*>(block.sb_buffer->residual), start_x,
                 start_y, &buffer_[plane], non_zero_coeff_count);
@@ -1269,7 +1467,7 @@
     Array2DView<uint16_t> buffer(
         buffer_[plane].rows(), buffer_[plane].columns() / sizeof(uint16_t),
         reinterpret_cast<uint16_t*>(&buffer_[plane][0][0]));
-    Reconstruct(dsp_, tx_type, tx_size, sequence_header_.color_config.bitdepth,
+    Reconstruct(dsp_, tx_type, tx_size,
                 frame_header_.segmentation.lossless[block.bp->segment_id],
                 reinterpret_cast<int32_t*>(block.sb_buffer->residual), start_x,
                 start_y, &buffer, non_zero_coeff_count);
@@ -1287,17 +1485,17 @@
   const BlockSize size_chunk4x4 =
       (width_chunks > 1 || height_chunks > 1) ? kBlock64x64 : block.size;
   const BlockParameters& bp = *block.bp;
-
+  const TransformSize y_tx_size =
+      GetTransformSize(frame_header_.segmentation.lossless[bp.segment_id],
+                       block.size, kPlaneY, bp.transform_size, 0, 0);
   for (int chunk_y = 0; chunk_y < height_chunks; ++chunk_y) {
     for (int chunk_x = 0; chunk_x < width_chunks; ++chunk_x) {
       for (int plane = 0; plane < (block.HasChroma() ? PlaneCount() : 1);
            ++plane) {
-        const int subsampling_x = SubsamplingX(static_cast<Plane>(plane));
-        const int subsampling_y = SubsamplingY(static_cast<Plane>(plane));
+        const int subsampling_x = subsampling_x_[plane];
+        const int subsampling_y = subsampling_y_[plane];
         const TransformSize tx_size =
-            GetTransformSize(frame_header_.segmentation.lossless[bp.segment_id],
-                             block.size, static_cast<Plane>(plane),
-                             bp.transform_size, subsampling_x, subsampling_y);
+            (plane == kPlaneY) ? y_tx_size : bp.uv_transform_size;
         const BlockSize plane_size =
             kPlaneResidualSize[size_chunk4x4][subsampling_x][subsampling_y];
         assert(plane_size != kBlockInvalid);
@@ -1308,16 +1506,14 @@
           const int column_chunk4x4 = block.column4x4 + MultiplyBy16(chunk_x);
           const int base_x = MultiplyBy4(column_chunk4x4 >> subsampling_x);
           const int base_y = MultiplyBy4(row_chunk4x4 >> subsampling_y);
-          if (!TransformTree(block, base_x, base_y,
-                             kBlockWidthPixels[plane_size],
-                             kBlockHeightPixels[plane_size], mode)) {
+          if (!TransformTree(block, base_x, base_y, plane_size, mode)) {
             return false;
           }
         } else {
           const int base_x = MultiplyBy4(block.column4x4 >> subsampling_x);
           const int base_y = MultiplyBy4(block.row4x4 >> subsampling_y);
-          const int step_x = DivideBy4(kTransformWidth[tx_size]);
-          const int step_y = DivideBy4(kTransformHeight[tx_size]);
+          const int step_x = kTransformWidth4x4[tx_size];
+          const int step_y = kTransformHeight4x4[tx_size];
           const int num4x4_wide = kNum4x4BlocksWide[plane_size];
           const int num4x4_high = kNum4x4BlocksHigh[plane_size];
           for (int y = 0; y < num4x4_high; y += step_y) {
@@ -1366,10 +1562,10 @@
   const int src_bottom_edge = src_top_edge + block_height;
   const int src_right_edge = src_left_edge + block_width;
   if (block.HasChroma()) {
-    if (block_width < 8 && sequence_header_.color_config.subsampling_x != 0) {
+    if (block_width < 8 && subsampling_x_[kPlaneU] != 0) {
       src_left_edge -= 4;
     }
-    if (block_height < 8 && sequence_header_.color_config.subsampling_y != 0) {
+    if (block_height < 8 && subsampling_y_[kPlaneU] != 0) {
       src_top_edge -= 4;
     }
   }
@@ -1455,22 +1651,24 @@
   const int block_width4x4 = kNum4x4BlocksWide[block.size];
   const int block_height4x4 = kNum4x4BlocksHigh[block.size];
   for (int plane = 0; plane < (block.HasChroma() ? PlaneCount() : 1); ++plane) {
-    const int subsampling_x = SubsamplingX(static_cast<Plane>(plane));
+    const int subsampling_x = subsampling_x_[plane];
     const int start_x = block.column4x4 >> subsampling_x;
     const int end_x =
         std::min((block.column4x4 + block_width4x4) >> subsampling_x,
                  frame_header_.columns4x4);
-    for (int x = start_x; x < end_x; ++x) {
-      entropy_contexts_[EntropyContext::kTop][plane][x] = {};
-    }
-    const int subsampling_y = SubsamplingY(static_cast<Plane>(plane));
+    memset(&coefficient_levels_[kEntropyContextTop][plane][start_x], 0,
+           end_x - start_x);
+    memset(&dc_categories_[kEntropyContextTop][plane][start_x], 0,
+           end_x - start_x);
+    const int subsampling_y = subsampling_y_[plane];
     const int start_y = block.row4x4 >> subsampling_y;
     const int end_y =
         std::min((block.row4x4 + block_height4x4) >> subsampling_y,
                  frame_header_.rows4x4);
-    for (int y = start_y; y < end_y; ++y) {
-      entropy_contexts_[EntropyContext::kLeft][plane][y] = {};
-    }
+    memset(&coefficient_levels_[kEntropyContextLeft][plane][start_y], 0,
+           end_y - start_y);
+    memset(&dc_categories_[kEntropyContextLeft][plane][start_y], 0,
+           end_y - start_y);
   }
 }
 
@@ -1488,25 +1686,25 @@
   // Local warping parameters, similar usage as is_local_valid.
   GlobalMotion local_warp_params;
   for (int plane = 0; plane < (block.HasChroma() ? PlaneCount() : 1); ++plane) {
-    const int8_t subsampling_x = SubsamplingX(static_cast<Plane>(plane));
-    const int8_t subsampling_y = SubsamplingY(static_cast<Plane>(plane));
+    const int8_t subsampling_x = subsampling_x_[plane];
+    const int8_t subsampling_y = subsampling_y_[plane];
     const BlockSize plane_size =
         kPlaneResidualSize[block.size][subsampling_x][subsampling_y];
     assert(plane_size != kBlockInvalid);
     const int block_width4x4 = kNum4x4BlocksWide[plane_size];
     const int block_height4x4 = kNum4x4BlocksHigh[plane_size];
-    const int block_width = kBlockWidthPixels[plane_size];
-    const int block_height = kBlockHeightPixels[plane_size];
+    const int block_width = MultiplyBy4(block_width4x4);
+    const int block_height = MultiplyBy4(block_height4x4);
     const int base_x = MultiplyBy4(block.column4x4 >> subsampling_x);
     const int base_y = MultiplyBy4(block.row4x4 >> subsampling_y);
     const BlockParameters& bp = *block.bp;
     if (bp.is_inter && bp.reference_frame[1] == kReferenceFrameIntra) {
-      const int tr_row4x4 = (sub_block_row4x4 >> subsampling_y) - 1;
+      const int tr_row4x4 = (sub_block_row4x4 >> subsampling_y);
       const int tr_column4x4 =
-          (sub_block_column4x4 >> subsampling_x) + block_width4x4;
+          (sub_block_column4x4 >> subsampling_x) + block_width4x4 + 1;
       const int bl_row4x4 =
           (sub_block_row4x4 >> subsampling_y) + block_height4x4;
-      const int bl_column4x4 = (sub_block_column4x4 >> subsampling_x) - 1;
+      const int bl_column4x4 = (sub_block_column4x4 >> subsampling_x) + 1;
       const TransformSize tx_size =
           k4x4SizeToTransformSize[k4x4WidthLog2[plane_size]]
                                  [k4x4HeightLog2[plane_size]];
@@ -1517,10 +1715,8 @@
       if (sequence_header_.color_config.bitdepth == 8) {
         IntraPrediction<uint8_t>(
             block, static_cast<Plane>(plane), base_x, base_y, has_left, has_top,
-            BlockDecoded(block, static_cast<Plane>(plane), tr_row4x4,
-                         tr_column4x4, has_top),
-            BlockDecoded(block, static_cast<Plane>(plane), bl_row4x4,
-                         bl_column4x4, has_left),
+            block.sb_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
+            block.sb_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
             kInterIntraToIntraMode[block.bp->prediction_parameters
                                        ->inter_intra_mode],
             tx_size);
@@ -1528,10 +1724,8 @@
       } else {
         IntraPrediction<uint16_t>(
             block, static_cast<Plane>(plane), base_x, base_y, has_left, has_top,
-            BlockDecoded(block, static_cast<Plane>(plane), tr_row4x4,
-                         tr_column4x4, has_top),
-            BlockDecoded(block, static_cast<Plane>(plane), bl_row4x4,
-                         bl_column4x4, has_left),
+            block.sb_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
+            block.sb_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
             kInterIntraToIntraMode[block.bp->prediction_parameters
                                        ->inter_intra_mode],
             tx_size);
@@ -1580,12 +1774,18 @@
   return true;
 }
 
-void Tile::ComputeDeblockFilterLevel(const Block& block) {
+void Tile::PopulateDeblockFilterLevel(const Block& block) {
+  if (!post_filter_.DoDeblock()) return;
   BlockParameters& bp = *block.bp;
-  for (int plane = kPlaneY; plane < PlaneCount(); ++plane) {
-    for (int i = kLoopFilterTypeVertical; i < kNumLoopFilterTypes; ++i) {
-      bp.deblock_filter_level[plane][i] = LoopFilterMask::GetDeblockFilterLevel(
-          frame_header_, bp, static_cast<Plane>(plane), i, delta_lf_);
+  for (int i = 0; i < kFrameLfCount; ++i) {
+    if (delta_lf_all_zero_) {
+      bp.deblock_filter_level[i] = post_filter_.GetZeroDeltaDeblockFilterLevel(
+          bp.segment_id, i, bp.reference_frame[0],
+          LoopFilterMask::GetModeId(bp.y_mode));
+    } else {
+      bp.deblock_filter_level[i] =
+          deblock_filter_levels_[bp.segment_id][i][bp.reference_frame[0]]
+                                [LoopFilterMask::GetModeId(bp.y_mode)];
     }
   }
 }
@@ -1612,10 +1812,13 @@
                               : std::move(prediction_parameters_);
   if (block.bp->prediction_parameters == nullptr) return false;
   if (!DecodeModeInfo(block)) return false;
-  ComputeDeblockFilterLevel(block);
-  ReadPaletteTokens(block);
+  PopulateDeblockFilterLevel(block);
+  if (!ReadPaletteTokens(block)) return false;
   DecodeTransformSize(block);
-  const BlockParameters& bp = *block.bp;
+  BlockParameters& bp = *block.bp;
+  bp.uv_transform_size = GetTransformSize(
+      frame_header_.segmentation.lossless[bp.segment_id], block.size, kPlaneU,
+      bp.transform_size, subsampling_x_[kPlaneU], subsampling_y_[kPlaneU]);
   if (bp.skip) ResetEntropyContext(block);
   const int block_width4x4 = kNum4x4BlocksWide[block_size];
   const int block_height4x4 = kNum4x4BlocksHigh[block_size];
@@ -1644,8 +1847,10 @@
     current_frame_.segmentation_map()->FillBlock(row4x4, column4x4, x_limit,
                                                  y_limit, bp.segment_id);
   }
-  if (!split_parse_and_decode_) {
+  if (build_bit_mask_when_parsing_ || !split_parse_and_decode_) {
     BuildBitMask(row4x4, column4x4, block_size);
+  }
+  if (!split_parse_and_decode_) {
     StoreMotionFieldMvsIntoCurrentFrame(block);
     prediction_parameters_ = std::move(block.bp->prediction_parameters);
   }
@@ -1667,7 +1872,9 @@
       !Residual(block, kProcessingModeDecodeOnly)) {
     return false;
   }
-  BuildBitMask(row4x4, column4x4, block_size);
+  if (!build_bit_mask_when_parsing_) {
+    BuildBitMask(row4x4, column4x4, block_size);
+  }
   StoreMotionFieldMvsIntoCurrentFrame(block);
   block.bp->prediction_parameters.reset(nullptr);
   return true;
@@ -1676,8 +1883,7 @@
 bool Tile::ProcessPartition(int row4x4_start, int column4x4_start,
                             ParameterTree* const root,
                             SuperBlockBuffer* const sb_buffer) {
-  std::vector<ParameterTree*> stack;
-  stack.reserve(kDfsStackSize);
+  Stack<ParameterTree*, kDfsStackSize> stack;
 
   // Set up the first iteration.
   ParameterTree* node = root;
@@ -1688,10 +1894,9 @@
   // DFS loop. If it sees a terminal node (leaf node), ProcessBlock is invoked.
   // Otherwise, the children are pushed into the stack for future processing.
   do {
-    if (!stack.empty()) {
+    if (!stack.Empty()) {
       // Set up subsequent iterations.
-      node = stack.back();
-      stack.pop_back();
+      node = stack.Pop();
       row4x4 = node->row4x4();
       column4x4 = node->column4x4();
       block_size = node->block_size();
@@ -1731,7 +1936,10 @@
           sequence_header_.color_config.subsampling_y);
       return false;
     }
-    node->SetPartitionType(partition);
+    if (!node->SetPartitionType(partition)) {
+      LIBGAV1_DLOG(ERROR, "node->SetPartitionType() failed.");
+      return false;
+    }
     switch (partition) {
       case kPartitionNone:
         if (!ProcessBlock(row4x4, column4x4, sub_size, node, sb_buffer)) {
@@ -1744,7 +1952,7 @@
         for (int i = 3; i >= 0; --i) {
           ParameterTree* const child = node->children(i);
           assert(child != nullptr);
-          stack.push_back(child);
+          stack.Push(child);
         }
         break;
       case kPartitionHorizontal:
@@ -1767,7 +1975,7 @@
         }
         break;
     }
-  } while (!stack.empty());
+  } while (!stack.Empty());
   return true;
 }
 
@@ -1799,6 +2007,40 @@
   }
 }
 
+void Tile::ClearBlockDecoded(SuperBlockBuffer* const sb_buffer, int row4x4,
+                             int column4x4) {
+  // Set everything to false.
+  memset(sb_buffer->block_decoded, 0, sizeof(sb_buffer->block_decoded));
+  // Set specific edge cases to true.
+  const int sb_size4 = sequence_header_.use_128x128_superblock ? 32 : 16;
+  for (int plane = 0; plane < PlaneCount(); ++plane) {
+    const int subsampling_x = subsampling_x_[plane];
+    const int subsampling_y = subsampling_y_[plane];
+    const int sb_width4 = (column4x4_end_ - column4x4) >> subsampling_x;
+    const int sb_height4 = (row4x4_end_ - row4x4) >> subsampling_y;
+    // The memset is equivalent to the following lines in the spec:
+    // for ( x = -1; x <= ( sbSize4 >> subX ); x++ ) {
+    //   if ( y < 0 && x < sbWidth4 ) {
+    //     BlockDecoded[plane][y][x] = 1
+    //   }
+    // }
+    const int num_elements =
+        std::min((sb_size4 >> subsampling_x_[plane]) + 1, sb_width4) + 1;
+    memset(&sb_buffer->block_decoded[plane][0][0], 1, num_elements);
+    // The for loop is equivalent to the following lines in the spec:
+    // for ( y = -1; y <= ( sbSize4 >> subY ); y++ )
+    //   if ( x < 0 && y < sbHeight4 )
+    //     BlockDecoded[plane][y][x] = 1
+    //   }
+    // }
+    // BlockDecoded[plane][sbSize4 >> subY][-1] = 0
+    for (int y = -1; y < std::min((sb_size4 >> subsampling_y), sb_height4);
+         ++y) {
+      sb_buffer->block_decoded[plane][y + 1][0] = true;
+    }
+  }
+}
+
 bool Tile::ProcessSuperBlock(int row4x4, int column4x4, int block_width4x4,
                              SuperBlockBuffer* const sb_buffer,
                              ProcessingMode mode) {
@@ -1811,10 +2053,7 @@
     ResetCdef(row4x4, column4x4);
   }
   if (decoding) {
-    memset(sb_buffer->block_decoded, 0,
-           sizeof(sb_buffer->block_decoded));  // Section 5.11.3.
-    sb_buffer->block_decoded_width_threshold = column4x4_end_ - column4x4;
-    sb_buffer->block_decoded_height_threshold = row4x4_end_ - row4x4;
+    ClearBlockDecoded(sb_buffer, row4x4, column4x4);
   }
   const BlockSize block_size = SuperBlockSize();
   if (parsing) {
@@ -1843,10 +2082,10 @@
       return false;
     }
     sb_buffer->residual =
-        residual_buffer_threaded_[sb_row_index][sb_column_index]->buffer.get();
+        residual_buffer_threaded_[sb_row_index][sb_column_index]->buffer();
     sb_buffer->transform_parameters =
-        &residual_buffer_threaded_[sb_row_index][sb_column_index]
-             ->transform_parameters;
+        residual_buffer_threaded_[sb_row_index][sb_column_index]
+            ->transform_parameters();
     if (!ProcessPartition(row4x4, column4x4,
                           block_parameters_holder_.Tree(row, column),
                           sb_buffer)) {
@@ -1856,10 +2095,10 @@
     }
   } else {
     sb_buffer->residual =
-        residual_buffer_threaded_[sb_row_index][sb_column_index]->buffer.get();
+        residual_buffer_threaded_[sb_row_index][sb_column_index]->buffer();
     sb_buffer->transform_parameters =
-        &residual_buffer_threaded_[sb_row_index][sb_column_index]
-             ->transform_parameters;
+        residual_buffer_threaded_[sb_row_index][sb_column_index]
+            ->transform_parameters();
     if (!DecodeSuperBlock(block_parameters_holder_.Tree(row, column),
                           sb_buffer)) {
       LIBGAV1_DLOG(ERROR, "Error decoding superblock row: %d column: %d",
@@ -1874,16 +2113,14 @@
 
 bool Tile::DecodeSuperBlock(ParameterTree* const tree,
                             SuperBlockBuffer* const sb_buffer) {
-  std::vector<ParameterTree*> stack;
-  stack.reserve(kDfsStackSize);
-  stack.push_back(tree);
-  while (!stack.empty()) {
-    ParameterTree* const node = stack.back();
-    stack.pop_back();
+  Stack<ParameterTree*, kDfsStackSize> stack;
+  stack.Push(tree);
+  while (!stack.Empty()) {
+    ParameterTree* const node = stack.Pop();
     if (node->partition() != kPartitionNone) {
       for (int i = 3; i >= 0; --i) {
         if (node->children(i) == nullptr) continue;
-        stack.push_back(node->children(i));
+        stack.Push(node->children(i));
       }
       continue;
     }
@@ -1926,14 +2163,13 @@
 
 void Tile::BuildBitMask(int row4x4, int column4x4, BlockSize block_size) {
   if (!post_filter_.DoDeblock()) return;
-  const int block_width4x4 = kNum4x4BlocksWide[block_size];
-  const int block_height4x4 = kNum4x4BlocksHigh[block_size];
-  if (block_width4x4 <= kNum4x4BlocksWide[kBlock64x64] &&
-      block_height4x4 <= kNum4x4BlocksHigh[kBlock64x64]) {
+  if (block_size <= kBlock64x64) {
     BuildBitMaskHelper(row4x4, column4x4, block_size, true, true);
   } else {
-    for (int y = 0; y < block_height4x4; y += kNum4x4BlocksHigh[kBlock64x64]) {
-      for (int x = 0; x < block_width4x4; x += kNum4x4BlocksWide[kBlock64x64]) {
+    const int block_width4x4 = kNum4x4BlocksWide[block_size];
+    const int block_height4x4 = kNum4x4BlocksHigh[block_size];
+    for (int y = 0; y < block_height4x4; y += 16) {
+      for (int x = 0; x < block_width4x4; x += 16) {
         BuildBitMaskHelper(row4x4 + y, column4x4 + x, kBlock64x64, x == 0,
                            y == 0);
       }
@@ -1951,12 +2187,6 @@
   LoopFilterMask* const masks = post_filter_.masks();
   const int unit_id = DivideBy16(row4x4) * masks->num_64x64_blocks_per_row() +
                       DivideBy16(column4x4);
-  const int row_limit = row4x4 + block_height4x4;
-  const int column_limit = column4x4 + block_width4x4;
-  const TransformSize current_block_uv_tx_size = GetTransformSize(
-      frame_header_.segmentation.lossless[bp.segment_id], block_size, kPlaneU,
-      kNumTransformSizes,  // This parameter is unused when plane != Y.
-      SubsamplingX(kPlaneU), SubsamplingY(kPlaneU));
 
   for (int plane = kPlaneY; plane < PlaneCount(); ++plane) {
     // For U and V planes, do not build bit masks if level == 0.
@@ -1964,38 +2194,39 @@
       continue;
     }
     // Build bit mask for vertical edges.
-    const int subsampling_x = SubsamplingX(static_cast<Plane>(plane));
-    const int subsampling_y = SubsamplingY(static_cast<Plane>(plane));
+    const int subsampling_x = subsampling_x_[plane];
+    const int subsampling_y = subsampling_y_[plane];
     const int plane_width =
         RightShiftWithRounding(frame_header_.width, subsampling_x);
+    const int column_limit =
+        std::min({column4x4 + block_width4x4, frame_header_.columns4x4,
+                  DivideBy4(plane_width + 3) << subsampling_x});
     const int plane_height =
         RightShiftWithRounding(frame_header_.height, subsampling_y);
-    const int vertical_step = 1 << subsampling_y;
-    const int horizontal_step = 1 << subsampling_x;
+    const int row_limit =
+        std::min({row4x4 + block_height4x4, frame_header_.rows4x4,
+                  DivideBy4(plane_height + 3) << subsampling_y});
     const int row_start = GetDeblockPosition(row4x4, subsampling_y);
     const int column_start = GetDeblockPosition(column4x4, subsampling_x);
-    if (row_start >= row4x4 + block_height4x4 ||
-        MultiplyBy4(row_start >> subsampling_y) >= plane_height ||
-        column_start >= column4x4 + block_width4x4 ||
-        MultiplyBy4(column_start >> subsampling_x) >= plane_width) {
+    if (row_start >= row_limit || column_start >= column_limit) {
       continue;
     }
+    const int vertical_step = 1 << subsampling_y;
+    const int horizontal_step = 1 << subsampling_x;
     const BlockParameters& bp =
         *block_parameters_holder_.Find(row_start, column_start);
+    const int horizontal_level_index =
+        kDeblockFilterLevelIndex[plane][kLoopFilterTypeHorizontal];
+    const int vertical_level_index =
+        kDeblockFilterLevelIndex[plane][kLoopFilterTypeVertical];
     const uint8_t vertical_level =
-        bp.deblock_filter_level[plane][kLoopFilterTypeVertical];
+        bp.deblock_filter_level[vertical_level_index];
 
-    for (int row = row_start;
-         row < row_limit && MultiplyBy4(row >> subsampling_y) < plane_height &&
-         row < frame_header_.rows4x4;
-         row += vertical_step) {
-      for (int column = column_start;
-           column < column_limit &&
-           MultiplyBy4(column >> subsampling_x) < plane_width &&
-           column < frame_header_.columns4x4;) {
+    for (int row = row_start; row < row_limit; row += vertical_step) {
+      for (int column = column_start; column < column_limit;) {
         const TransformSize tx_size = (plane == kPlaneY)
                                           ? inter_transform_sizes_[row][column]
-                                          : current_block_uv_tx_size;
+                                          : bp.uv_transform_size;
         // (1). Don't filter frame boundary.
         // (2). For tile boundary, we don't know whether the previous tile is
         // available or not, thus we handle it after all tiles are decoded.
@@ -2013,22 +2244,19 @@
             *block_parameters_holder_.Find(row, column - horizontal_step);
         const uint8_t left_level =
             is_vertical_border
-                ? bp_left.deblock_filter_level[plane][kLoopFilterTypeVertical]
+                ? bp_left.deblock_filter_level[vertical_level_index]
                 : vertical_level;
         // We don't have to check if the left block is skipped or not,
         // because if the current transform block is on the edge of the coding
         // block, is_vertical_border is true; if it's not on the edge,
         // left skip is equal to skip.
         if (vertical_level != 0 || left_level != 0) {
-          const TransformSize left_tx_size = GetTransformSize(
-              frame_header_.segmentation.lossless[bp_left.segment_id],
-              bp_left.size, static_cast<Plane>(plane),
-              inter_transform_sizes_[row][column - horizontal_step],
-              subsampling_x, subsampling_y);
-          // 0: 4x4, 1: 8x8, 2: 16x16.
-          const int transform_size_id =
-              std::min({kTransformWidthLog2[tx_size] - 2,
-                        kTransformWidthLog2[left_tx_size] - 2, 2});
+          const TransformSize left_tx_size =
+              (plane == kPlaneY)
+                  ? inter_transform_sizes_[row][column - horizontal_step]
+                  : bp_left.uv_transform_size;
+          const LoopFilterTransformSizeId transform_size_id =
+              GetTransformSizeIdWidth(tx_size, left_tx_size);
           const int r = row & (kNum4x4InLoopFilterMaskUnit - 1);
           const int c = column & (kNum4x4InLoopFilterMaskUnit - 1);
           const int shift = LoopFilterMask::GetShift(r, c);
@@ -2047,19 +2275,13 @@
 
     // Build bit mask for horizontal edges.
     const uint8_t horizontal_level =
-        bp.deblock_filter_level[plane][kLoopFilterTypeHorizontal];
-    for (int column = column_start;
-         column < column_limit &&
-         MultiplyBy4(column >> subsampling_x) < plane_width &&
-         column < frame_header_.columns4x4;
+        bp.deblock_filter_level[horizontal_level_index];
+    for (int column = column_start; column < column_limit;
          column += horizontal_step) {
-      for (int row = row_start;
-           row < row_limit &&
-           MultiplyBy4(row >> subsampling_y) < plane_height &&
-           row < frame_header_.rows4x4;) {
+      for (int row = row_start; row < row_limit;) {
         const TransformSize tx_size = (plane == kPlaneY)
                                           ? inter_transform_sizes_[row][column]
-                                          : current_block_uv_tx_size;
+                                          : bp.uv_transform_size;
 
         // (1). Don't filter frame boundary.
         // (2). For tile boundary, we don't know whether the previous tile is
@@ -2078,22 +2300,21 @@
             *block_parameters_holder_.Find(row - vertical_step, column);
         const uint8_t top_level =
             is_horizontal_border
-                ? bp_top.deblock_filter_level[plane][kLoopFilterTypeHorizontal]
+                ? bp_top.deblock_filter_level[horizontal_level_index]
                 : horizontal_level;
         // We don't have to check it the top block is skippped or not,
         // because if the current transform block is on the edge of the coding
         // block, is_horizontal_border is true; if it's not on the edge,
         // top skip is equal to skip.
         if (horizontal_level != 0 || top_level != 0) {
-          const TransformSize top_tx_size = GetTransformSize(
-              frame_header_.segmentation.lossless[bp_top.segment_id],
-              bp_top.size, static_cast<Plane>(plane),
-              inter_transform_sizes_[row - vertical_step][column],
-              subsampling_x, subsampling_y);
-          // 0: 4x4, 1: 8x8, 2: 16x16.
-          const int transform_size_id =
-              std::min({kTransformHeightLog2[tx_size] - 2,
-                        kTransformHeightLog2[top_tx_size] - 2, 2});
+          const TransformSize top_tx_size =
+              (plane == kPlaneY)
+                  ? inter_transform_sizes_[row - vertical_step][column]
+                  : bp_top.uv_transform_size;
+          const LoopFilterTransformSizeId transform_size_id =
+              static_cast<LoopFilterTransformSizeId>(
+                  std::min({kTransformHeightLog2[tx_size] - 2,
+                            kTransformHeightLog2[top_tx_size] - 2, 2}));
           const int r = row & (kNum4x4InLoopFilterMaskUnit - 1);
           const int c = column & (kNum4x4InLoopFilterMaskUnit - 1);
           const int shift = LoopFilterMask::GetShift(r, c);
diff --git a/libgav1/src/utils/allocator.cc b/libgav1/src/utils/allocator.cc
deleted file mode 100644
index 309aabc..0000000
--- a/libgav1/src/utils/allocator.cc
+++ /dev/null
@@ -1,25 +0,0 @@
-#include "src/utils/allocator.h"
-
-#include <cstdlib>
-
-namespace libgav1 {
-
-AllocatorBase::AllocatorBase(size_t element_size)
-    : element_size_(element_size), ok_(true) {}
-
-void* AllocatorBase::Allocate(size_t s) {
-  if (!ok_) return nullptr;
-  void* const ptr = malloc(s);
-  ok_ = !(s > 0 && ptr == nullptr);
-  return ptr;
-}
-
-bool AllocatorBase::Deallocate(void* p, size_t /*nb_elements*/) {
-  if (!ok_) return false;
-  free(p);
-  return true;
-}
-
-void AllocatorBase::Init(size_t /*element_size*/) {}
-
-}  // namespace libgav1
diff --git a/libgav1/src/utils/allocator.h b/libgav1/src/utils/allocator.h
deleted file mode 100644
index 79e5355..0000000
--- a/libgav1/src/utils/allocator.h
+++ /dev/null
@@ -1,94 +0,0 @@
-#ifndef LIBGAV1_SRC_UTILS_ALLOCATOR_H_
-#define LIBGAV1_SRC_UTILS_ALLOCATOR_H_
-
-#include <cstddef>
-#include <utility>
-
-namespace libgav1 {
-
-class AllocatorBase {
- public:
-  bool ok() const { return ok_; }
-
- protected:
-  explicit AllocatorBase(size_t element_size);
-
-  void* Allocate(size_t s);
-  bool Deallocate(void* p, size_t nb_elements);
-  void Init(size_t element_size);
-
-  size_t element_size_;
-  bool ok_;
-};
-
-// This allocator will NOT call the constructor, since the construct()
-// method has been voided.
-template <typename T>
-class AllocatorNoCtor : protected AllocatorBase {
- public:
-  using value_type = T;
-  AllocatorNoCtor() : AllocatorBase(sizeof(T)) {}
-  template <typename U>
-  explicit AllocatorNoCtor(const AllocatorNoCtor<U>& /*other*/)
-      : AllocatorBase(sizeof(T)) {}
-
-  T* allocate(size_t nb_elements) {
-    return static_cast<T*>(AllocatorBase::Allocate(nb_elements * sizeof(T)));
-  }
-  void deallocate(T* p, size_t nb_elements) {
-    ok_ = ok_ && AllocatorBase::Deallocate(p, nb_elements);
-  }
-  bool ok() const { return AllocatorBase::ok(); }
-
-  // The allocator disables any construction...
-  template <typename U, typename... Args>
-  void construct(U*, Args&&...) noexcept {}
-
-  // ...but copy and move constructions, which are called by the vector
-  // implementation itself.
-  void construct(T* p, const T& v) noexcept {
-    static_assert(noexcept(new ((void*)p) T(v)),
-                  "needs a noexcept copy constructor");
-    if (ok_) new ((void*)p) T(v);
-  }
-  void construct(T* p, T&& v) noexcept {
-    static_assert(noexcept(new ((void*)p) T(std::move(v))),
-                  "needs a noexcept move constructor");
-    if (ok_) new ((void*)p) T(std::move(v));
-  }
-
-  template <typename U>
-  void destroy(U* p) noexcept {
-    if (ok_) p->~U();
-  }
-};
-
-// This allocator calls the constructors
-template <typename T>
-class Allocator : public AllocatorNoCtor<T> {
- public:
-  using value_type = T;
-  Allocator() : AllocatorNoCtor<T>() {}
-  template <typename U>
-  explicit Allocator(const Allocator<U>& other) : AllocatorNoCtor<T>(other) {}
-  // Enable the constructor.
-  template <typename U, typename... Args>
-  void construct(U* p, Args&&... args) noexcept {
-    static_assert(noexcept(new ((void*)p) U(std::forward<Args>(args)...)),
-                  "needs a noexcept constructor");
-    if (AllocatorBase::ok_) new ((void*)p) U(std::forward<Args>(args)...);
-  }
-};
-
-template <typename U, typename V>
-bool operator==(const AllocatorNoCtor<U>&, const AllocatorNoCtor<V>&) {
-  return true;
-}
-template <typename U, typename V>
-bool operator!=(const AllocatorNoCtor<U>&, const AllocatorNoCtor<V>&) {
-  return false;
-}
-
-}  // namespace libgav1
-
-#endif  // LIBGAV1_SRC_UTILS_ALLOCATOR_H_
diff --git a/libgav1/src/utils/array_2d.h b/libgav1/src/utils/array_2d.h
index 0b093ff..c054178 100644
--- a/libgav1/src/utils/array_2d.h
+++ b/libgav1/src/utils/array_2d.h
@@ -63,8 +63,11 @@
     const size_t size = rows * columns;
     if (size_ < size) {
       // Note: This invokes the global operator new if T is a non-class type,
-      // such as integer or enum types, or a class that we don't own, such as
-      // std::unique_ptr.
+      // such as integer or enum types, or a class type that is not derived
+      // from libgav1::Allocable, such as std::unique_ptr. If we enforce a
+      // maximum allocation size or keep track of our own heap memory
+      // consumption, we will need to handle the allocations here that use the
+      // global operator new.
       if (zero_initialize) {
         data_.reset(new (std::nothrow) T[size]());
       } else {
diff --git a/libgav1/src/utils/bit_mask_set.h b/libgav1/src/utils/bit_mask_set.h
new file mode 100644
index 0000000..6e10315
--- /dev/null
+++ b/libgav1/src/utils/bit_mask_set.h
@@ -0,0 +1,63 @@
+#ifndef LIBGAV1_SRC_UTILS_BIT_MASK_SET_H_
+#define LIBGAV1_SRC_UTILS_BIT_MASK_SET_H_
+
+#include <cstdint>
+
+namespace libgav1 {
+
+// This class is used to check if a given value is equal to one of the several
+// predetermined values using a bit mask instead of a chain of comparisons and
+// ||s. This usually results in fewer instructions.
+//
+// Usage:
+//   constexpr BitMaskSet set(value1, value2);
+//   set.Contains(value1) => returns true.
+//   set.Contains(value3) => returns false.
+class BitMaskSet {
+ public:
+  explicit constexpr BitMaskSet(uint32_t mask) : mask_(mask) {}
+
+  constexpr BitMaskSet(int v1, int v2) : mask_((1U << v1) | (1U << v2)) {}
+
+  constexpr BitMaskSet(int v1, int v2, int v3)
+      : mask_((1U << v1) | (1U << v2) | (1U << v3)) {}
+
+  constexpr BitMaskSet(int v1, int v2, int v3, int v4)
+      : mask_((1U << v1) | (1U << v2) | (1U << v3) | (1U << v4)) {}
+
+  constexpr BitMaskSet(int v1, int v2, int v3, int v4, int v5)
+      : mask_((1U << v1) | (1U << v2) | (1U << v3) | (1U << v4) | (1U << v5)) {}
+
+  constexpr BitMaskSet(int v1, int v2, int v3, int v4, int v5, int v6)
+      : mask_((1U << v1) | (1U << v2) | (1U << v3) | (1U << v4) | (1U << v5) |
+              (1U << v6)) {}
+
+  constexpr BitMaskSet(int v1, int v2, int v3, int v4, int v5, int v6, int v7)
+      : mask_((1U << v1) | (1U << v2) | (1U << v3) | (1U << v4) | (1U << v5) |
+              (1U << v6) | (1U << v7)) {}
+
+  constexpr BitMaskSet(int v1, int v2, int v3, int v4, int v5, int v6, int v7,
+                       int v8, int v9)
+      : mask_((1U << v1) | (1U << v2) | (1U << v3) | (1U << v4) | (1U << v5) |
+              (1U << v6) | (1U << v7) | (1U << v8) | (1U << v9)) {}
+
+  constexpr BitMaskSet(int v1, int v2, int v3, int v4, int v5, int v6, int v7,
+                       int v8, int v9, int v10)
+      : mask_((1U << v1) | (1U << v2) | (1U << v3) | (1U << v4) | (1U << v5) |
+              (1U << v6) | (1U << v7) | (1U << v8) | (1U << v9) | (1U << v10)) {
+  }
+
+  constexpr bool Contains(uint8_t value) const {
+    return MaskContainsValue(mask_, value);
+  }
+
+  static constexpr bool MaskContainsValue(uint32_t mask, uint8_t value) {
+    return ((mask >> value) & 1) != 0;
+  }
+
+ private:
+  const uint32_t mask_;
+};
+
+}  // namespace libgav1
+#endif  // LIBGAV1_SRC_UTILS_BIT_MASK_SET_H_
diff --git a/libgav1/src/utils/bit_reader.cc b/libgav1/src/utils/bit_reader.cc
index 1f58fed..8975965 100644
--- a/libgav1/src/utils/bit_reader.cc
+++ b/libgav1/src/utils/bit_reader.cc
@@ -8,13 +8,13 @@
 namespace libgav1 {
 namespace {
 
-inline bool Assign(int* const value, int assignment, bool return_value) {
+bool Assign(int* const value, int assignment, bool return_value) {
   *value = assignment;
   return return_value;
 }
 
 // 5.9.29.
-inline int InverseRecenter(int r, int v) {
+int InverseRecenter(int r, int v) {
   if (v > (r << 1)) {
     return v;
   }
diff --git a/libgav1/src/utils/block_parameters_holder.cc b/libgav1/src/utils/block_parameters_holder.cc
index a7ee0c8..a207aa9 100644
--- a/libgav1/src/utils/block_parameters_holder.cc
+++ b/libgav1/src/utils/block_parameters_holder.cc
@@ -23,26 +23,37 @@
 
 BlockParametersHolder::BlockParametersHolder(int rows4x4, int columns4x4,
                                              bool use_128x128_superblock)
-    : rows4x4_(rows4x4), columns4x4_(columns4x4) {
-  if (!block_parameters_cache_.Reset(rows4x4, columns4x4)) {
+    : rows4x4_(rows4x4),
+      columns4x4_(columns4x4),
+      use_128x128_superblock_(use_128x128_superblock) {}
+
+bool BlockParametersHolder::Init() {
+  if (!block_parameters_cache_.Reset(rows4x4_, columns4x4_)) {
     LIBGAV1_DLOG(ERROR, "block_parameters_cache_.Reset() failed.");
+    return false;
   }
   const int rows =
-      RowsOrColumns4x4ToSuperBlocks(rows4x4, use_128x128_superblock);
+      RowsOrColumns4x4ToSuperBlocks(rows4x4_, use_128x128_superblock_);
   const int columns =
-      RowsOrColumns4x4ToSuperBlocks(columns4x4, use_128x128_superblock);
+      RowsOrColumns4x4ToSuperBlocks(columns4x4_, use_128x128_superblock_);
   const BlockSize sb_size =
-      use_128x128_superblock ? kBlock128x128 : kBlock64x64;
+      use_128x128_superblock_ ? kBlock128x128 : kBlock64x64;
   const int multiplier = kNum4x4BlocksWide[sb_size];
   if (!trees_.Reset(rows, columns)) {
     LIBGAV1_DLOG(ERROR, "trees_.Reset() failed.");
+    return false;
   }
   for (int i = 0; i < rows; ++i) {
     for (int j = 0; j < columns; ++j) {
-      trees_[i][j].reset(new (std::nothrow) ParameterTree(
-          i * multiplier, j * multiplier, sb_size));
+      trees_[i][j] =
+          ParameterTree::Create(i * multiplier, j * multiplier, sb_size);
+      if (trees_[i][j] == nullptr) {
+        LIBGAV1_DLOG(ERROR, "Allocation of trees_[%d][%d] failed.", i, j);
+        return false;
+      }
     }
   }
+  return true;
 }
 
 void BlockParametersHolder::FillCache(int row4x4, int column4x4,
diff --git a/libgav1/src/utils/block_parameters_holder.h b/libgav1/src/utils/block_parameters_holder.h
index a55d0d6..d92791b 100644
--- a/libgav1/src/utils/block_parameters_holder.h
+++ b/libgav1/src/utils/block_parameters_holder.h
@@ -4,6 +4,7 @@
 #include <memory>
 
 #include "src/utils/array_2d.h"
+#include "src/utils/compiler_attributes.h"
 #include "src/utils/constants.h"
 #include "src/utils/parameter_tree.h"
 #include "src/utils/types.h"
@@ -23,10 +24,13 @@
   BlockParametersHolder(const BlockParametersHolder&) = delete;
   BlockParametersHolder& operator=(const BlockParametersHolder&) = delete;
 
+  // Must be called first.
+  LIBGAV1_MUST_USE_RESULT bool Init();
+
   // Finds the BlockParameters corresponding to |row4x4| and |column4x4|. This
   // is done as a simple look up of the |block_parameters_cache_| matrix.
   // Returns nullptr if the BlockParameters cannot be found.
-  BlockParameters* Find(int row4x4, int column4x4) {
+  BlockParameters* Find(int row4x4, int column4x4) const {
     return block_parameters_cache_[row4x4][column4x4];
   }
 
@@ -40,8 +44,9 @@
                  BlockParameters* bp);
 
  private:
-  int rows4x4_;
-  int columns4x4_;
+  const int rows4x4_;
+  const int columns4x4_;
+  const bool use_128x128_superblock_;
   Array2D<std::unique_ptr<ParameterTree>> trees_;
 
   // This is a 2d array of size |rows4x4_| * |columns4x4_|.This is filled in by
diff --git a/libgav1/src/utils/blocking_counter.h b/libgav1/src/utils/blocking_counter.h
new file mode 100644
index 0000000..7fe3c97
--- /dev/null
+++ b/libgav1/src/utils/blocking_counter.h
@@ -0,0 +1,81 @@
+#ifndef LIBGAV1_SRC_UTILS_BLOCKING_COUNTER_H_
+#define LIBGAV1_SRC_UTILS_BLOCKING_COUNTER_H_
+
+#include <cassert>
+#include <condition_variable>  // NOLINT (unapproved c++11 header)
+#include <mutex>               // NOLINT (unapproved c++11 header)
+
+#include "src/utils/compiler_attributes.h"
+
+namespace libgav1 {
+
+// Implementation of a Blocking Counter that is used for the "fork-join"
+// use case. Typical usage would be as follows:
+//   BlockingCounter counter(num_jobs);
+//     - spawn the jobs.
+//     - call counter.Wait() on the master thread.
+//     - worker threads will call counter.Decrement().
+//     - master thread will return from counter.Wait() when all workers are
+//     complete.
+template <bool has_failure_status>
+class BlockingCounterImpl {
+ public:
+  explicit BlockingCounterImpl(int initial_count)
+      : count_(initial_count), job_failed_(false) {}
+
+  // Increment the counter by |count|. This must be called before Wait() is
+  // called. This must be called from the same thread that will call Wait().
+  void IncrementBy(int count) {
+    assert(count >= 0);
+    std::unique_lock<std::mutex> lock(mutex_);
+    count_ += count;
+  }
+
+  // Decrement the counter by 1. This function can be called only when
+  // |has_failure_status| is false (i.e.) when this class is being used with the
+  // |BlockingCounter| alias.
+  void Decrement() {
+    static_assert(!has_failure_status, "");
+    std::unique_lock<std::mutex> lock(mutex_);
+    if (--count_ == 0) {
+      condition_.notify_one();
+    }
+  }
+
+  // Decrement the counter by 1. This function can be called only when
+  // |has_failure_status| is true (i.e.) when this class is being used with the
+  // |BlockingCounterWithStatus| alias. |job_succeeded| is used to update the
+  // state of |job_failed_|.
+  void Decrement(bool job_succeeded) {
+    static_assert(has_failure_status, "");
+    std::unique_lock<std::mutex> lock(mutex_);
+    job_failed_ |= !job_succeeded;
+    if (--count_ == 0) {
+      condition_.notify_one();
+    }
+  }
+
+  // Block until the counter becomes 0. This function can be called only once
+  // per object. If |has_failure_status| is true, true is returned if all the
+  // jobs succeeded and false is returned if any of the jobs failed. If
+  // |has_failure_status| is false, this function always returns true.
+  bool Wait() {
+    std::unique_lock<std::mutex> lock(mutex_);
+    condition_.wait(lock, [this]() { return count_ == 0; });
+    // If |has_failure_status| is false, we simply return true.
+    return has_failure_status ? !job_failed_ : true;
+  }
+
+ private:
+  std::mutex mutex_;
+  std::condition_variable condition_;
+  int count_ LIBGAV1_GUARDED_BY(mutex_);
+  bool job_failed_ LIBGAV1_GUARDED_BY(mutex_);
+};
+
+using BlockingCounterWithStatus = BlockingCounterImpl<true>;
+using BlockingCounter = BlockingCounterImpl<false>;
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_UTILS_BLOCKING_COUNTER_H_
diff --git a/libgav1/src/utils/common.h b/libgav1/src/utils/common.h
index bfbdb8c..9803364 100644
--- a/libgav1/src/utils/common.h
+++ b/libgav1/src/utils/common.h
@@ -14,6 +14,7 @@
 #include <cstddef>
 #include <cstdint>
 
+#include "src/utils/bit_mask_set.h"
 #include "src/utils/constants.h"
 
 namespace libgav1 {
@@ -114,7 +115,11 @@
 }
 
 inline int CeilLog2(unsigned int n) {
-  return (n < 2) ? 0 : FloorLog2(n) + static_cast<int>((n & (n - 1)) != 0);
+  // The expression FloorLog2(n - 1) + 1 is undefined not only for n == 0 but
+  // also for n == 1, so this expression must be guarded by the n < 2 test. An
+  // alternative implementation is:
+  // return (n == 0) ? 0 : FloorLog2(n) + static_cast<int>((n & (n - 1)) != 0);
+  return (n < 2) ? 0 : FloorLog2(n - 1) + 1;
 }
 
 inline int32_t RightShiftWithRounding(int32_t value, int bits) {
@@ -146,13 +151,13 @@
                       : -RightShiftWithRounding(-value, bits);
 }
 
-inline int DivideBy2(int n) { return n >> 1; }
-inline int DivideBy4(int n) { return n >> 2; }
-inline int DivideBy8(int n) { return n >> 3; }
-inline int DivideBy16(int n) { return n >> 4; }
-inline int DivideBy32(int n) { return n >> 5; }
-inline int DivideBy64(int n) { return n >> 6; }
-inline int DivideBy128(int n) { return n >> 7; }
+constexpr int DivideBy2(int n) { return n >> 1; }
+constexpr int DivideBy4(int n) { return n >> 2; }
+constexpr int DivideBy8(int n) { return n >> 3; }
+constexpr int DivideBy16(int n) { return n >> 4; }
+constexpr int DivideBy32(int n) { return n >> 5; }
+constexpr int DivideBy64(int n) { return n >> 6; }
+constexpr int DivideBy128(int n) { return n >> 7; }
 
 // Convert |value| to unsigned before shifting to avoid undefined behavior with
 // negative values.
@@ -169,25 +174,27 @@
 inline int MultiplyBy32(int n) { return LeftShift(n, 5); }
 inline int MultiplyBy64(int n) { return LeftShift(n, 6); }
 
-inline int Mod32(int n) { return n & 0x1f; }
-inline int Mod64(int n) { return n & 0x3f; }
+constexpr int Mod32(int n) { return n & 0x1f; }
+constexpr int Mod64(int n) { return n & 0x3f; }
 
 //------------------------------------------------------------------------------
 // Bitstream functions
 
-inline bool IsIntraFrame(FrameType type) {
+constexpr bool IsIntraFrame(FrameType type) {
   return type == kFrameKey || type == kFrameIntraOnly;
 }
 
 inline TransformClass GetTransformClass(TransformType tx_type) {
-  if (tx_type == kTransformTypeIdentityDct ||
-      tx_type == kTransformTypeIdentityAdst ||
-      tx_type == kTransformTypeIdentityFlipadst) {
+  constexpr BitMaskSet kTransformClassVerticalMask(
+      kTransformTypeIdentityDct, kTransformTypeIdentityAdst,
+      kTransformTypeIdentityFlipadst);
+  if (kTransformClassVerticalMask.Contains(tx_type)) {
     return kTransformClassVertical;
   }
-  if (tx_type == kTransformTypeDctIdentity ||
-      tx_type == kTransformTypeAdstIdentity ||
-      tx_type == kTransformTypeFlipadstIdentity) {
+  constexpr BitMaskSet kTransformClassHorizontalMask(
+      kTransformTypeDctIdentity, kTransformTypeAdstIdentity,
+      kTransformTypeFlipadstIdentity);
+  if (kTransformClassHorizontalMask.Contains(tx_type)) {
     return kTransformClassHorizontal;
   }
   return kTransformClass2D;
@@ -198,12 +205,12 @@
   return MultiplyBy4(row_or_column4x4) >> (plane == kPlaneY ? 0 : subsampling);
 }
 
-inline PlaneType GetPlaneType(Plane plane) {
-  return (plane == kPlaneY) ? kPlaneTypeY : kPlaneTypeUV;
+constexpr PlaneType GetPlaneType(Plane plane) {
+  return static_cast<PlaneType>(plane != kPlaneY);
 }
 
 // 5.11.44.
-inline bool IsDirectionalMode(PredictionMode mode) {
+constexpr bool IsDirectionalMode(PredictionMode mode) {
   return mode >= kPredictionModeVertical && mode <= kPredictionModeD67;
 }
 
@@ -254,14 +261,8 @@
   return (diff & (m - 1)) - (diff & m);
 }
 
-// part of 5.11.23.
-inline bool HasNearMv(PredictionMode mode) {
-  return mode == kPredictionModeNearMv || mode == kPredictionModeNearNearMv ||
-         mode == kPredictionModeNearNewMv || mode == kPredictionModeNewNearMv;
-}
-
 inline bool IsBlockSmallerThan8x8(BlockSize size) {
-  return size == kBlock4x4 || size == kBlock4x8 || size == kBlock8x4;
+  return size < kBlock8x8 && size != kBlock4x16;
 }
 
 // Maps a square transform to an index between [0, 4]. kTransformSize4x4 maps
@@ -279,33 +280,12 @@
   return DivideBy4(tx_size);
 }
 
-// 5.11.37.
-inline TransformSize GetTransformSize(const bool lossless,
-                                      const BlockSize block_size,
-                                      const Plane plane,
-                                      const TransformSize tx_size,
-                                      const int subsampling_x,
-                                      const int subsampling_y) {
-  if (lossless) return kTransformSize4x4;
-  if (plane == kPlaneY) return tx_size;
-  const BlockSize plane_size =
-      kPlaneResidualSize[block_size][subsampling_x][subsampling_y];
-  assert(plane_size != kBlockInvalid);
-  const TransformSize uv_tx_size = kMaxTransformSizeRectangle[plane_size];
-  const uint32_t mask = 1U << uv_tx_size;
-  if ((mask & kTransformSize64Mask) == 0) {
-    return uv_tx_size;
-  }
-  if ((mask & kTransformWidth16Mask) != 0) return kTransformSize16x32;
-  if ((mask & kTransformHeight16Mask) != 0) return kTransformSize32x16;
-  return kTransformSize32x32;
-}
-
 // Gets the corresponding Y/U/V position, to set and get filter masks
 // in deblock filtering.
 // Returns luma_position if it's Y plane, whose subsampling must be 0.
 // Returns the odd position for U/V plane, if there is subsampling.
-inline int GetDeblockPosition(const int luma_position, const int subsampling) {
+constexpr int GetDeblockPosition(const int luma_position,
+                                 const int subsampling) {
   return luma_position | subsampling;
 }
 
@@ -328,6 +308,19 @@
   return (residual_size * rows * columns * subsampling_multiplier_num) >> 1;
 }
 
+// This function is equivalent to:
+// std::min({kTransformWidthLog2[tx_size] - 2,
+//           kTransformWidthLog2[left_tx_size] - 2,
+//           2});
+constexpr LoopFilterTransformSizeId GetTransformSizeIdWidth(
+    TransformSize tx_size, TransformSize left_tx_size) {
+  return static_cast<LoopFilterTransformSizeId>(
+      static_cast<int>(tx_size > kTransformSize4x16 &&
+                       left_tx_size > kTransformSize4x16) +
+      static_cast<int>(tx_size > kTransformSize8x32 &&
+                       left_tx_size > kTransformSize8x32));
+}
+
 }  // namespace libgav1
 
 #endif  // LIBGAV1_SRC_UTILS_COMMON_H_
diff --git a/libgav1/src/utils/constants.cc b/libgav1/src/utils/constants.cc
index 823f591..cc8a33f 100644
--- a/libgav1/src/utils/constants.cc
+++ b/libgav1/src/utils/constants.cc
@@ -117,25 +117,18 @@
 const uint8_t kTransformHeight[kNumTransformSizes] = {
     4, 8, 16, 4, 8, 16, 32, 4, 8, 16, 32, 64, 8, 16, 32, 64, 16, 32, 64};
 
+const uint8_t kTransformWidth4x4[kNumTransformSizes] = {
+    1, 1, 1, 2, 2, 2, 2, 4, 4, 4, 4, 4, 8, 8, 8, 8, 16, 16, 16};
+
+const uint8_t kTransformHeight4x4[kNumTransformSizes] = {
+    1, 2, 4, 1, 2, 4, 8, 1, 2, 4, 8, 16, 2, 4, 8, 16, 4, 8, 16};
+
 const uint8_t kTransformWidthLog2[kNumTransformSizes] = {
     2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6};
 
 const uint8_t kTransformHeightLog2[kNumTransformSizes] = {
     2, 3, 4, 2, 3, 4, 5, 2, 3, 4, 5, 6, 3, 4, 5, 6, 4, 5, 6};
 
-const TransformSize kMaxTransformSizeRectangle[kMaxBlockSizes] = {
-    kTransformSize4x4,   kTransformSize4x8,   kTransformSize4x16,
-    kTransformSize8x4,   kTransformSize8x8,   kTransformSize8x16,
-    kTransformSize8x32,  kTransformSize16x4,  kTransformSize16x8,
-    kTransformSize16x16, kTransformSize16x32, kTransformSize16x64,
-    kTransformSize32x8,  kTransformSize32x16, kTransformSize32x32,
-    kTransformSize32x64, kTransformSize64x16, kTransformSize64x32,
-    kTransformSize64x64, kTransformSize64x64, kTransformSize64x64,
-    kTransformSize64x64};
-
-const int kMaxTransformDepth[kMaxBlockSizes] = {
-    0, 1, 2, 1, 1, 2, 3, 2, 2, 2, 3, 4, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4};
-
 // 9.3 -- Split_Tx_Size[]
 const TransformSize kSplitTransformSize[kNumTransformSizes] = {
     kTransformSize4x4,   kTransformSize4x4,   kTransformSize4x8,
@@ -166,51 +159,8 @@
     kTransformSize64x64, kTransformSize64x64, kTransformSize64x64,
     kTransformSize64x64};
 
-// Defined in section 9.3 of the spec.
-const TransformType kModeToTransformType[kIntraPredictionModesUV] = {
-    kTransformTypeDctDct,   kTransformTypeDctAdst,  kTransformTypeAdstDct,
-    kTransformTypeDctDct,   kTransformTypeAdstAdst, kTransformTypeDctAdst,
-    kTransformTypeAdstDct,  kTransformTypeAdstDct,  kTransformTypeDctAdst,
-    kTransformTypeAdstAdst, kTransformTypeDctAdst,  kTransformTypeAdstDct,
-    kTransformTypeAdstAdst, kTransformTypeDctDct};
-
 const uint8_t kNumTransformTypesInSet[kNumTransformSets] = {1, 7, 5, 16, 12, 2};
 
-// Defined in section 5.11.47 of the spec. This array does not contain an entry
-// for kTransformSetDctOnly, so the first dimension needs to be
-// |kNumTransformSets| - 1.
-const TransformType kInverseTransformTypeBySet[kNumTransformSets - 1][16] = {
-    {kTransformTypeIdentityIdentity, kTransformTypeDctDct,
-     kTransformTypeIdentityDct, kTransformTypeDctIdentity,
-     kTransformTypeAdstAdst, kTransformTypeDctAdst, kTransformTypeAdstDct},
-    {kTransformTypeIdentityIdentity, kTransformTypeDctDct,
-     kTransformTypeAdstAdst, kTransformTypeDctAdst, kTransformTypeAdstDct},
-    {kTransformTypeIdentityIdentity, kTransformTypeIdentityDct,
-     kTransformTypeDctIdentity, kTransformTypeIdentityAdst,
-     kTransformTypeAdstIdentity, kTransformTypeIdentityFlipadst,
-     kTransformTypeFlipadstIdentity, kTransformTypeDctDct,
-     kTransformTypeDctAdst, kTransformTypeAdstDct, kTransformTypeDctFlipadst,
-     kTransformTypeFlipadstDct, kTransformTypeAdstAdst,
-     kTransformTypeFlipadstFlipadst, kTransformTypeFlipadstAdst,
-     kTransformTypeAdstFlipadst},
-    {kTransformTypeIdentityIdentity, kTransformTypeIdentityDct,
-     kTransformTypeDctIdentity, kTransformTypeDctDct, kTransformTypeDctAdst,
-     kTransformTypeAdstDct, kTransformTypeDctFlipadst,
-     kTransformTypeFlipadstDct, kTransformTypeAdstAdst,
-     kTransformTypeFlipadstFlipadst, kTransformTypeFlipadstAdst,
-     kTransformTypeAdstFlipadst},
-    {kTransformTypeIdentityIdentity, kTransformTypeDctDct}};
-
-// Replaces all occurrences of 64x* and *x64 with 32x* and *x32 respectively.
-const TransformSize kAdjustedTransformSize[kNumTransformSizes] = {
-    kTransformSize4x4,   kTransformSize4x8,   kTransformSize4x16,
-    kTransformSize8x4,   kTransformSize8x8,   kTransformSize8x16,
-    kTransformSize8x32,  kTransformSize16x4,  kTransformSize16x8,
-    kTransformSize16x16, kTransformSize16x32, kTransformSize16x32,
-    kTransformSize32x8,  kTransformSize32x16, kTransformSize32x32,
-    kTransformSize32x32, kTransformSize32x16, kTransformSize32x32,
-    kTransformSize32x32};
-
 const uint8_t kSgrProjParams[1 << kSgrProjParamsBits][4] = {
     {2, 12, 1, 4},  {2, 15, 1, 6},  {2, 18, 1, 8},  {2, 21, 1, 9},
     {2, 24, 1, 10}, {2, 29, 1, 11}, {2, 36, 1, 12}, {2, 45, 1, 13},
@@ -221,10 +171,6 @@
 
 const int8_t kSgrProjMultiplierMax[2] = {31, 95};
 
-const int8_t kSgrProjDefaultMultiplier[2] = {-32, 31};
-
-const int8_t kWienerDefaultFilter[3] = {3, -7, 15};
-
 const int8_t kWienerTapsMin[3] = {-5, -23, -17};
 
 const int8_t kWienerTapsMax[3] = {10, 8, 46};
@@ -595,11 +541,314 @@
     3,           // 87, ...
 };
 
-const uint8_t kPredictionModeDeltasLookup[kNumPredictionModes] = {
-    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  // Intra modes.
-    0,                      // ChromaFromLuma mode, not used in y mode.
-    1, 1, 0, 1,             // Inter modes.
-    1, 1, 1, 1, 1, 1, 0, 1  // Compound inter modes.
+const uint8_t kDeblockFilterLevelIndex[kMaxPlanes][kNumLoopFilterTypes] = {
+    {0, 1}, {2, 2}, {3, 3}};
+
+const int8_t kMaskIdLookup[4][kMaxBlockSizes] = {
+    // transform size 4x4.
+    {0,  1,  13, 2, 3,  4,  15, 14, 5,  6,  7,
+     17, 16, 8,  9, 10, 18, 11, 12, -1, -1, -1},
+    // transform size 8x8.
+    {-1, -1, -1, -1, 19, 20, 29, -1, 21, 22, 23,
+     31, 30, 24, 25, 26, 32, 27, 28, -1, -1, -1},
+    // transform size 16x16.
+    {-1, -1, -1, -1, -1, -1, -1, -1, -1, 33, 34,
+     40, -1, 35, 36, 37, 41, 38, 39, -1, -1, -1},
+    // transform size 32x32.
+    {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
+     -1, -1, -1, 42, 43, -1, 44, 45, -1, -1, -1},
+};
+
+const int8_t kVerticalBorderMaskIdLookup[kMaxBlockSizes] = {
+    0,  47, 61, 49, 19, 51, 63, 62, 53, 33, 55,
+    65, 64, 57, 42, 59, 66, 60, 46, -1, -1, -1};
+
+const uint64_t kTopMaskLookup[67][4] = {
+    // transform size 4X4
+    {0x0000000000000001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 4X4, transform size 4X4
+    {0x0000000000010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 4X8, transform size 4X4
+    {0x0000000000000003ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X4, transform size 4X4
+    {0x0000000000030003ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X8, transform size 4X4
+    {0x0003000300030003ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X16, transform size 4X4
+    {0x00000000000f000fULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X8, transform size 4X4
+    {0x000f000f000f000fULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X16, transform size 4X4
+    {0x000f000f000f000fULL, 0x000f000f000f000fULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X32, transform size 4X4
+    {0x00ff00ff00ff00ffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X16, transform size 4X4
+    {0x00ff00ff00ff00ffULL, 0x00ff00ff00ff00ffULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X32, transform size 4X4
+    {0x00ff00ff00ff00ffULL, 0x00ff00ff00ff00ffULL, 0x00ff00ff00ff00ffULL,
+     0x00ff00ff00ff00ffULL},  // block size 32X64, transform size 4X4
+    {0xffffffffffffffffULL, 0xffffffffffffffffULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X32, transform size 4X4
+    {0xffffffffffffffffULL, 0xffffffffffffffffULL, 0xffffffffffffffffULL,
+     0xffffffffffffffffULL},  // block size 64X64, transform size 4x4
+    {0x0001000100010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 4X16, transform size 4X4
+    {0x000000000000000fULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X4, transform size 4X4
+    {0x0003000300030003ULL, 0x0003000300030003ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X32, transform size 4X4
+    {0x0000000000ff00ffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X8, transform size 4X4
+    {0x000f000f000f000fULL, 0x000f000f000f000fULL, 0x000f000f000f000fULL,
+     0x000f000f000f000fULL},  // block size 16X64, transform size 4X4
+    {0xffffffffffffffffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X16, transform size 4X4
+    // transform size 8X8
+    {0x0000000000000003ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X8, transform size 8X8
+    {0x0000000300000003ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X16, transform size 8X8
+    {0x000000000000000fULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X8, transform size 8X8
+    {0x0000000f0000000fULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X16, transform size 8X8
+    {0x0000000f0000000fULL, 0x0000000f0000000fULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X32, transform size 8X8
+    {0x000000ff000000ffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X16, transform size 8X8
+    {0x000000ff000000ffULL, 0x000000ff000000ffULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X32, transform size 8X8
+    {0x000000ff000000ffULL, 0x000000ff000000ffULL, 0x000000ff000000ffULL,
+     0x000000ff000000ffULL},  // block size 32X64, transform size 8X8
+    {0x0000ffff0000ffffULL, 0x0000ffff0000ffffULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X32, transform size 8X8
+    {0x0000ffff0000ffffULL, 0x0000ffff0000ffffULL, 0x0000ffff0000ffffULL,
+     0x0000ffff0000ffffULL},  // block size 64X64, transform size 8X8
+    {0x0000000300000003ULL, 0x0000000300000003ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X32, transform size 8X8
+    {0x00000000000000ffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X8, transform size 8X8
+    {0x0000000f0000000fULL, 0x0000000f0000000fULL, 0x0000000f0000000fULL,
+     0x0000000f0000000fULL},  // block size 16X64, transform size 8X8
+    {0x0000ffff0000ffffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X16, transform size 8X8
+    // transform size 16X16
+    {0x000000000000000fULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X16, transform size 16X16
+    {0x000000000000000fULL, 0x000000000000000fULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X32, transform size 16X16
+    {0x00000000000000ffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X16, transform size 16X16
+    {0x00000000000000ffULL, 0x00000000000000ffULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X32, transform size 16X16
+    {0x00000000000000ffULL, 0x00000000000000ffULL, 0x00000000000000ffULL,
+     0x00000000000000ffULL},  // block size 32X64, transform size 16X16
+    {0x000000000000ffffULL, 0x000000000000ffffULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X32, transform size 16X16
+    {0x000000000000ffffULL, 0x000000000000ffffULL, 0x000000000000ffffULL,
+     0x000000000000ffffULL},  // block size 64X64, transform size 16X16
+    {0x000000000000000fULL, 0x000000000000000fULL, 0x000000000000000fULL,
+     0x000000000000000fULL},  // block size 16X64, transform size 16X16
+    {0x000000000000ffffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X16, transform size 16X16
+    // transform size 32X32
+    {0x00000000000000ffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X32, transform size 32X32
+    {0x00000000000000ffULL, 0x0000000000000000ULL, 0x00000000000000ffULL,
+     0x0000000000000000ULL},  // block size 32X64, transform size 32X32
+    {0x000000000000ffffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X32, transform size 32X32
+    {0x000000000000ffffULL, 0x0000000000000000ULL, 0x000000000000ffffULL,
+     0x0000000000000000ULL},  // block size 64X64, transform size 32X32
+    // transform size 64X64
+    {0x000000000000ffffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X64, transform size 64X64
+    // 2:1, 1:2 transform sizes.
+    {0x0000000000000001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 4X8, transform size 4X8
+    {0x0000000100000001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 4X16, transform size 4X8
+    {0x0000000000000003ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X4, transform size 8X4
+    {0x000000000000000fULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X4, transform size 8X4
+    {0x0000000000000003ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X16, transform size 8X16
+    {0x0000000000000003ULL, 0x0000000000000003ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X32, transform size 8X16
+    {0x000000000000000fULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X8, transform size 16X8
+    {0x00000000000000ffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X8, transform size 16X8
+    {0x000000000000000fULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X32, transform size 16X32
+    {0x000000000000000fULL, 0x0000000000000000ULL, 0x000000000000000fULL,
+     0x0000000000000000ULL},  // block size 16X64, transform size 16X32
+    {0x00000000000000ffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X16, transform size 32X16
+    {0x000000000000ffffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X16, transform size 32X16
+    {0x00000000000000ffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X64, transform size 32X64
+    {0x000000000000ffffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X32, transform size 64X32
+    // 4:1, 1:4 transform sizes.
+    {0x0000000000000001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 4X16, transform size 4X16
+    {0x000000000000000fULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X4, transform size 16X4
+    {0x0000000000000003ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X32, transform size 8X32
+    {0x00000000000000ffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X8, transform size 32X8
+    {0x000000000000000fULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X64, transform size 16X64
+    {0x000000000000ffffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X16, transform size 64X16
+};
+
+const uint64_t kLeftMaskLookup[67][4] = {
+    // transform size 4X4
+    {0x0000000000000001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 4X4, transform size 4X4
+    {0x0000000000010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 4X8, transform size 4X4
+    {0x0000000000000003ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X4, transform size 4X4
+    {0x0000000000030003ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X8, transform size 4X4
+    {0x0003000300030003ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X16, transform size 4X4
+    {0x00000000000f000fULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X8, transform size 4X4
+    {0x000f000f000f000fULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X16, transform size 4X4
+    {0x000f000f000f000fULL, 0x000f000f000f000fULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X32, transform size 4X4
+    {0x00ff00ff00ff00ffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X16, transform size 4X4
+    {0x00ff00ff00ff00ffULL, 0x00ff00ff00ff00ffULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X32, transform size 4X4
+    {0x00ff00ff00ff00ffULL, 0x00ff00ff00ff00ffULL, 0x00ff00ff00ff00ffULL,
+     0x00ff00ff00ff00ffULL},  // block size 32X64, transform size 4X4
+    {0xffffffffffffffffULL, 0xffffffffffffffffULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X32, transform size 4X4
+    {0xffffffffffffffffULL, 0xffffffffffffffffULL, 0xffffffffffffffffULL,
+     0xffffffffffffffffULL},  // block size 64X64, transform size 4X4
+    {0x0001000100010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 4X16, transform size 4X4
+    {0x000000000000000fULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X4, transform size 4X4
+    {0x0003000300030003ULL, 0x0003000300030003ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X32, transform size 4X4
+    {0x0000000000ff00ffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X8, transform size 4X4
+    {0x000f000f000f000fULL, 0x000f000f000f000fULL, 0x000f000f000f000fULL,
+     0x000f000f000f000fULL},  // block size 16X64, transform size 4X4
+    {0xffffffffffffffffULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X16, transform size 4X4
+    // transform size 8X8
+    {0x0000000000010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X8, transform size 8X8
+    {0x0001000100010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X16, transform size 8X8
+    {0x0000000000050005ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X8, transform size 8X8
+    {0x0005000500050005ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X16, transform size 8X8
+    {0x0005000500050005ULL, 0x0005000500050005ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X32, transform size 8X8
+    {0x0055005500550055ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X16, transform size 8X8
+    {0x0055005500550055ULL, 0x0055005500550055ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X32, transform size 8X8
+    {0x0055005500550055ULL, 0x0055005500550055ULL, 0x0055005500550055ULL,
+     0x0055005500550055ULL},  // block size 32X64, transform size 8X8
+    {0x5555555555555555ULL, 0x5555555555555555ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X32, transform size 8X8
+    {0x5555555555555555ULL, 0x5555555555555555ULL, 0x5555555555555555ULL,
+     0x5555555555555555ULL},  // block size 64X64, transform size 8X8
+    {0x0001000100010001ULL, 0x0001000100010001ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X32, transform size 8X8
+    {0x0000000000550055ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X8, transform size 8X8
+    {0x0005000500050005ULL, 0x0005000500050005ULL, 0x0005000500050005ULL,
+     0x0005000500050005ULL},  // block size 16X64, transform size 8X8
+    {0x5555555555555555ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X16, transform size 8X8
+    // transform size 16X16
+    {0x0001000100010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X16, transform size 16X16
+    {0x0001000100010001ULL, 0x0001000100010001ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X32, transform size 16X16
+    {0x0011001100110011ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X16, transform size 16X16
+    {0x0011001100110011ULL, 0x0011001100110011ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X32, transform size 16X16
+    {0x0011001100110011ULL, 0x0011001100110011ULL, 0x0011001100110011ULL,
+     0x0011001100110011ULL},  // block size 32X64, transform size 16X16
+    {0x1111111111111111ULL, 0x1111111111111111ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X32, transform size 16X16
+    {0x1111111111111111ULL, 0x1111111111111111ULL, 0x1111111111111111ULL,
+     0x1111111111111111ULL},  // block size 64X64, transform size 16X16
+    {0x0001000100010001ULL, 0x0001000100010001ULL, 0x0001000100010001ULL,
+     0x0001000100010001ULL},  // block size 16X64, transform size 16X16
+    {0x1111111111111111ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X16, transform size 16X16
+    // transform size 32X32
+    {0x0001000100010001ULL, 0x0001000100010001ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X32, transform size 32X32
+    {0x0101010101010101ULL, 0x0101010101010101ULL, 0x0101010101010101ULL,
+     0x0101010101010101ULL},  // block size 32X64, transform size 32X32
+    {0x0101010101010101ULL, 0x0101010101010101ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X32, transform size 32X32
+    {0x0101010101010101ULL, 0x0101010101010101ULL, 0x0101010101010101ULL,
+     0x0101010101010101ULL},  // block size 64X64, transform size 32X32
+    // transform size 64X64
+    {0x0001000100010001ULL, 0x0001000100010001ULL, 0x0001000100010001ULL,
+     0x0001000100010001ULL},  // block size 64X64, transform size 64X64
+    // 2:1, 1:2 transform sizes.
+    {0x0000000000010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 4X8, transform size 4X8
+    {0x0001000100010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 4X16, transform size 4X8
+    {0x0000000000000001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X4, transform size 8X4
+    {0x0000000000000005ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X4, transform size 8X4
+    {0x0001000100010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X16, transform size 8X16
+    {0x0001000100010001ULL, 0x0001000100010001ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X32, transform size 8X16
+    {0x0000000000010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X8, transform size 16X8
+    {0x0000000000110011ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X8, transform size 16X8
+    {0x0001000100010001ULL, 0x0001000100010001ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X32, transform size 16X32
+    {0x0001000100010001ULL, 0x0001000100010001ULL, 0x0001000100010001ULL,
+     0x0001000100010001ULL},  // block size 16X64, transform size 16X32
+    {0x0001000100010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X16, transform size 32X16
+    {0x0101010101010101ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X16, transform size 32X16
+    {0x0001000100010001ULL, 0x0001000100010001ULL, 0x0001000100010001ULL,
+     0x0001000100010001ULL},  // block size 32X64, transform size 32X64
+    {0x0001000100010001ULL, 0x0001000100010001ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X32, transform size 64X32
+    // 4:1, 1:4 transform sizes.
+    {0x0001000100010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 4X16, transform size 4X16
+    {0x0000000000000001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 16X4, transform size 16X4
+    {0x0001000100010001ULL, 0x0001000100010001ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 8X32, transform size 8X32
+    {0x0000000000010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 32X8, transform size 32X8
+    {0x0001000100010001ULL, 0x0001000100010001ULL, 0x0001000100010001ULL,
+     0x0001000100010001ULL},  // block size 16X64, transform size 16X64
+    {0x0001000100010001ULL, 0x0000000000000000ULL, 0x0000000000000000ULL,
+     0x0000000000000000ULL},  // block size 64X16, transform size 64X16
 };
 
 }  // namespace libgav1
diff --git a/libgav1/src/utils/constants.h b/libgav1/src/utils/constants.h
index 1f7c23d..5799a91 100644
--- a/libgav1/src/utils/constants.h
+++ b/libgav1/src/utils/constants.h
@@ -2,6 +2,9 @@
 #define LIBGAV1_SRC_UTILS_CONSTANTS_H_
 
 #include <cstdint>
+#include <cstdlib>
+
+#include "src/utils/bit_mask_set.h"
 
 namespace libgav1 {
 
@@ -18,7 +21,6 @@
   kFrameLfCount = 4,
   kMaxLoopFilterValue = 63,
   kNum4x4In64x64 = 256,
-  kNumTransformSizesLoopFilter = 3,  // 0: 4x4, 1: 8x8, 2: 16x16.
   kNumLoopFilterMasks = 4,
   kMaxAngleDelta = 3,
   kDirectionalIntraModes = 8,
@@ -30,9 +32,10 @@
   kRestorationTypeSymbolCount = 3,
   kSgrProjParamsBits = 4,
   kSgrProjPrecisionBits = 7,
-  kRestorationBorder = 3,  // Horizontal and vertical border are both 3.
-  kConvolveBorderLeftTop = 3,
-  kConvolveBorderRightBottom = 4,
+  kRestorationBorder = 3,      // Padding on each side of a restoration block.
+  kCdefBorder = 2,             // Padding on each side of a cdef block.
+  kConvolveBorderLeftTop = 3,  // Left/top padding of a convolve block.
+  kConvolveBorderRightBottom = 4,  // Right/bottom padding of a convolve block.
   kSubPixelTaps = 8,
   kWienerFilterBits = 7,
   kMaxPaletteSize = 8,
@@ -178,6 +181,29 @@
   kBlockInvalid
 };
 
+//  Partition types.  R: Recursive
+//
+//  None          Horizontal    Vertical      Split
+//  +-------+     +-------+     +---+---+     +---+---+
+//  |       |     |       |     |   |   |     | R | R |
+//  |       |     +-------+     |   |   |     +---+---+
+//  |       |     |       |     |   |   |     | R | R |
+//  +-------+     +-------+     +---+---+     +---+---+
+//
+//  Horizontal    Horizontal    Vertical      Vertical
+//  with top      with bottom   with left     with right
+//  split         split         split         split
+//  +---+---+     +-------+     +---+---+     +---+---+
+//  |   |   |     |       |     |   |   |     |   |   |
+//  +---+---+     +---+---+     +---+   |     |   +---+
+//  |       |     |   |   |     |   |   |     |   |   |
+//  +-------+     +---+---+     +---+---+     +---+---+
+//
+//  Horizontal4   Vertical4
+//  +-----+       +-+-+-+
+//  +-----+       | | | |
+//  +-----+       | | | |
+//  +-----+       +-+-+-+
 enum Partition : uint8_t {
   kPartitionNone,
   kPartitionHorizontal,
@@ -276,19 +302,14 @@
   kNumTransformTypes
 };
 
-// Allows checking whether a transform requires rows or columns to be flipped
-// with a single comparison rather than a chain of ||s. This should result in
-// fewer instructions overall.
-enum : uint32_t {
-  kTransformFlipColumnsMask = (1U << kTransformTypeFlipadstDct) |
-                              (1U << kTransformTypeFlipadstAdst) |
-                              (1U << kTransformTypeFlipadstIdentity) |
-                              (1U << kTransformTypeFlipadstFlipadst),
-  kTransformFlipRowsMask = (1U << kTransformTypeDctFlipadst) |
-                           (1U << kTransformTypeAdstFlipadst) |
-                           (1U << kTransformTypeIdentityFlipadst) |
-                           (1U << kTransformTypeFlipadstFlipadst)
-};
+constexpr BitMaskSet kTransformFlipColumnsMask(kTransformTypeFlipadstDct,
+                                               kTransformTypeFlipadstAdst,
+                                               kTransformTypeFlipadstIdentity,
+                                               kTransformTypeFlipadstFlipadst);
+constexpr BitMaskSet kTransformFlipRowsMask(kTransformTypeDctFlipadst,
+                                            kTransformTypeAdstFlipadst,
+                                            kTransformTypeIdentityFlipadst,
+                                            kTransformTypeFlipadstFlipadst);
 
 enum TransformSize : uint8_t {
   kTransformSize4x4,
@@ -313,24 +334,6 @@
   kNumTransformSizes
 };
 
-enum : uint32_t {
-  // Mask of all transform sizes with either dimension equal to 64.
-  kTransformSize64Mask =
-      (1U << kTransformSize64x16) | (1U << kTransformSize64x32) |
-      (1U << kTransformSize64x64) | (1U << kTransformSize16x64) |
-      (1U << kTransformSize32x64),
-  // Mask of all transform sizes with width equal to 16.
-  kTransformWidth16Mask =
-      (1U << kTransformSize16x4) | (1U << kTransformSize16x8) |
-      (1U << kTransformSize16x16) | (1U << kTransformSize16x32) |
-      (1U << kTransformSize16x64),
-  // Mask of all transform sizes with height equal to 16.
-  kTransformHeight16Mask =
-      (1U << kTransformSize4x16) | (1U << kTransformSize8x16) |
-      (1U << kTransformSize16x16) | (1U << kTransformSize32x16) |
-      (1U << kTransformSize64x16)
-};
-
 enum TransformSet : uint8_t {
   // DCT Only (1).
   kTransformSetDctOnly,
@@ -364,6 +367,12 @@
   kNumFilterIntraPredictors
 };
 
+enum ObmcDirection : uint8_t {
+  kObmcDirectionVertical,
+  kObmcDirectionHorizontal,
+  kNumObmcDirections
+};
+
 // In AV1 the name of the filter refers to the direction of filter application.
 // Horizontal refers to the column edge and vertical the row edge.
 enum LoopFilterType : uint8_t {
@@ -372,6 +381,13 @@
   kNumLoopFilterTypes
 };
 
+enum LoopFilterTransformSizeId : uint8_t {
+  kLoopFilterTransformSizeId4x4,
+  kLoopFilterTransformSizeId8x8,
+  kLoopFilterTransformSizeId16x16,
+  kNumLoopFilterTransformSizeIds
+};
+
 enum LoopRestorationType : uint8_t {
   kLoopRestorationTypeNone,
   kLoopRestorationTypeSwitchable,
@@ -433,6 +449,155 @@
   kObuPadding = 15,
 };
 
+//------------------------------------------------------------------------------
+// ToString()
+//
+// These functions are meant to be used only in debug logging and within tests.
+// They are defined inline to avoid including the strings in the release
+// library when logging is disabled; unreferenced functions will not be added to
+// any object file in that case.
+
+inline const char* ToString(const BlockSize size) {
+  switch (size) {
+    case kBlock4x4:
+      return "kBlock4x4";
+    case kBlock4x8:
+      return "kBlock4x8";
+    case kBlock4x16:
+      return "kBlock4x16";
+    case kBlock8x4:
+      return "kBlock8x4";
+    case kBlock8x8:
+      return "kBlock8x8";
+    case kBlock8x16:
+      return "kBlock8x16";
+    case kBlock8x32:
+      return "kBlock8x32";
+    case kBlock16x4:
+      return "kBlock16x4";
+    case kBlock16x8:
+      return "kBlock16x8";
+    case kBlock16x16:
+      return "kBlock16x16";
+    case kBlock16x32:
+      return "kBlock16x32";
+    case kBlock16x64:
+      return "kBlock16x64";
+    case kBlock32x8:
+      return "kBlock32x8";
+    case kBlock32x16:
+      return "kBlock32x16";
+    case kBlock32x32:
+      return "kBlock32x32";
+    case kBlock32x64:
+      return "kBlock32x64";
+    case kBlock64x16:
+      return "kBlock64x16";
+    case kBlock64x32:
+      return "kBlock64x32";
+    case kBlock64x64:
+      return "kBlock64x64";
+    case kBlock64x128:
+      return "kBlock64x128";
+    case kBlock128x64:
+      return "kBlock128x64";
+    case kBlock128x128:
+      return "kBlock128x128";
+    case kMaxBlockSizes:
+      return "kMaxBlockSizes";
+    case kBlockInvalid:
+      return "kBlockInvalid";
+  }
+  abort();
+}
+
+inline const char* ToString(const InterIntraMode mode) {
+  switch (mode) {
+    case kInterIntraModeDc:
+      return "kInterIntraModeDc";
+    case kInterIntraModeVertical:
+      return "kInterIntraModeVertical";
+    case kInterIntraModeHorizontal:
+      return "kInterIntraModeHorizontal";
+    case kInterIntraModeSmooth:
+      return "kInterIntraModeSmooth";
+    case kNumInterIntraModes:
+      return "kNumInterIntraModes";
+  }
+  abort();
+}
+
+inline const char* ToString(const ObmcDirection direction) {
+  switch (direction) {
+    case kObmcDirectionVertical:
+      return "kObmcDirectionVertical";
+    case kObmcDirectionHorizontal:
+      return "kObmcDirectionHorizontal";
+    case kNumObmcDirections:
+      return "kNumObmcDirections";
+  }
+  abort();
+}
+
+inline const char* ToString(const LoopRestorationType type) {
+  switch (type) {
+    case kLoopRestorationTypeNone:
+      return "kLoopRestorationTypeNone";
+    case kLoopRestorationTypeSwitchable:
+      return "kLoopRestorationTypeSwitchable";
+    case kLoopRestorationTypeWiener:
+      return "kLoopRestorationTypeWiener";
+    case kLoopRestorationTypeSgrProj:
+      return "kLoopRestorationTypeSgrProj";
+    case kNumLoopRestorationTypes:
+      return "kNumLoopRestorationTypes";
+  }
+  abort();
+}
+
+inline const char* ToString(const TransformType type) {
+  switch (type) {
+    case kTransformTypeDctDct:
+      return "kTransformTypeDctDct";
+    case kTransformTypeAdstDct:
+      return "kTransformTypeAdstDct";
+    case kTransformTypeDctAdst:
+      return "kTransformTypeDctAdst";
+    case kTransformTypeAdstAdst:
+      return "kTransformTypeAdstAdst";
+    case kTransformTypeFlipadstDct:
+      return "kTransformTypeFlipadstDct";
+    case kTransformTypeDctFlipadst:
+      return "kTransformTypeDctFlipadst";
+    case kTransformTypeFlipadstFlipadst:
+      return "kTransformTypeFlipadstFlipadst";
+    case kTransformTypeAdstFlipadst:
+      return "kTransformTypeAdstFlipadst";
+    case kTransformTypeFlipadstAdst:
+      return "kTransformTypeFlipadstAdst";
+    case kTransformTypeIdentityIdentity:
+      return "kTransformTypeIdentityIdentity";
+    case kTransformTypeIdentityDct:
+      return "kTransformTypeIdentityDct";
+    case kTransformTypeDctIdentity:
+      return "kTransformTypeDctIdentity";
+    case kTransformTypeIdentityAdst:
+      return "kTransformTypeIdentityAdst";
+    case kTransformTypeAdstIdentity:
+      return "kTransformTypeAdstIdentity";
+    case kTransformTypeIdentityFlipadst:
+      return "kTransformTypeIdentityFlipadst";
+    case kTransformTypeFlipadstIdentity:
+      return "kTransformTypeFlipadstIdentity";
+    // case to quiet compiler
+    case kNumTransformTypes:
+      return "kNumTransformTypes";
+  }
+  abort();
+}
+
+//------------------------------------------------------------------------------
+
 extern const uint8_t k4x4WidthLog2[kMaxBlockSizes];
 
 extern const uint8_t k4x4HeightLog2[kMaxBlockSizes];
@@ -453,14 +618,14 @@
 
 extern const uint8_t kTransformHeight[kNumTransformSizes];
 
+extern const uint8_t kTransformWidth4x4[kNumTransformSizes];
+
+extern const uint8_t kTransformHeight4x4[kNumTransformSizes];
+
 extern const uint8_t kTransformWidthLog2[kNumTransformSizes];
 
 extern const uint8_t kTransformHeightLog2[kNumTransformSizes];
 
-extern const TransformSize kMaxTransformSizeRectangle[kMaxBlockSizes];
-
-extern const int kMaxTransformDepth[kMaxBlockSizes];
-
 extern const TransformSize kSplitTransformSize[kNumTransformSizes];
 
 // Square transform of size min(w,h).
@@ -469,26 +634,14 @@
 // Square transform of size max(w,h).
 extern const TransformSize kTransformSizeSquareMax[kNumTransformSizes];
 
-extern const TransformType kModeToTransformType[kIntraPredictionModesUV];
-
 extern const uint8_t kNumTransformTypesInSet[kNumTransformSets];
 
-extern const TransformType kInverseTransformTypeBySet[kNumTransformSets - 1]
-                                                     [16];
-
-// Replaces all occurrences of 64x* and *x64 with 32x* and *x32 respectively.
-extern const TransformSize kAdjustedTransformSize[kNumTransformSizes];
-
 extern const uint8_t kSgrProjParams[1 << kSgrProjParamsBits][4];
 
 extern const int8_t kSgrProjMultiplierMin[2];
 
 extern const int8_t kSgrProjMultiplierMax[2];
 
-extern const int8_t kSgrProjDefaultMultiplier[2];
-
-extern const int8_t kWienerDefaultFilter[3];
-
 extern const int8_t kWienerTapsMin[3];
 
 extern const int8_t kWienerTapsMax[3];
@@ -501,7 +654,15 @@
 
 extern const int16_t kDirectionalIntraPredictorDerivative[44];
 
-extern const uint8_t kPredictionModeDeltasLookup[kNumPredictionModes];
+extern const uint8_t kDeblockFilterLevelIndex[kMaxPlanes][kNumLoopFilterTypes];
+
+extern const int8_t kMaskIdLookup[4][kMaxBlockSizes];
+
+extern const int8_t kVerticalBorderMaskIdLookup[kMaxBlockSizes];
+
+extern const uint64_t kTopMaskLookup[67][4];
+
+extern const uint64_t kLeftMaskLookup[67][4];
 
 }  // namespace libgav1
 
diff --git a/libgav1/src/utils/entropy_decoder.cc b/libgav1/src/utils/entropy_decoder.cc
index a925da4..ff6db39 100644
--- a/libgav1/src/utils/entropy_decoder.cc
+++ b/libgav1/src/utils/entropy_decoder.cc
@@ -7,98 +7,37 @@
 
 namespace {
 
-constexpr uint32_t kWindowSize = static_cast<uint32_t>(sizeof(uint32_t)) * 8;
-constexpr int kCdfPrecision = 6;
-constexpr int kMinimumProbabilityPerSymbol = 4;
 constexpr uint32_t kReadBitMask = ~255;
 // This constant is used to set the value of |bits_| so that bits can be read
 // after end of stream without trying to refill the buffer for a reasonably long
 // time.
 constexpr int kLargeBitCount = 0x4000;
+constexpr int kCdfPrecision = 6;
+constexpr int kMinimumProbabilityPerSymbol = 4;
 
-void UpdateCdf(uint16_t* cdf, int value, int symbol_count) {
-  const uint16_t count = cdf[symbol_count];
-  // rate is computed in the spec as:
-  //  3 + ( cdf[N] > 15 ) + ( cdf[N] > 31 ) + Min(FloorLog2(N), 2)
-  // In this case cdf[N] is |count|.
-  // Min(FloorLog2(N), 2) is 1 for symbol_count == {2, 3} and 2 for all
-  // symbol_count > 3. So the equation becomes:
-  //  4 + (count > 15) + (count > 31) + (symbol_count > 3).
-  // Note that the largest value for count is 32 (it is not incremented beyond
-  // 32). So using that information:
-  //  count >> 4 is 0 for count from 0 to 15.
-  //  count >> 4 is 1 for count from 16 to 31.
-  //  count >> 4 is 2 for count == 31.
-  // Now, the equation becomes:
-  //  4 + (count >> 4) + (symbol_count > 3).
-  // Since (count >> 4) can only be 0 or 1 or 2, the addition can be replaced
-  // with bitwise or. So the final equation is:
-  // (4 | (count >> 4)) + (symbol_count > 3).
-  const int rate = (4 | (count >> 4)) + static_cast<int>(symbol_count > 3);
-  for (int i = 0; i < symbol_count - 1; ++i) {
-    if (i < value) {
-      cdf[i] += (libgav1::kCdfMaxProbability - cdf[i]) >> rate;
-    } else {
-      cdf[i] -= cdf[i] >> rate;
-    }
-  }
-  cdf[symbol_count] += static_cast<uint16_t>(count < 32);
+// This function computes the "cur" variable as specified inside the do-while
+// loop in Section 8.2.6 of the spec. This function is monotonically
+// decreasing as the values of index increases (note that the |cdf| array is
+// sorted in decreasing order).
+uint32_t ScaleCdf(uint16_t values_in_range_shifted, const uint16_t* const cdf,
+                  int index, int symbol_count) {
+  return ((values_in_range_shifted * (cdf[index] >> kCdfPrecision)) >> 1) +
+         (kMinimumProbabilityPerSymbol * (symbol_count - index));
 }
 
 }  // namespace
 
 namespace libgav1 {
 
-#if defined(LIBGAV1_USE_LIBAOM_BIT_READER)
-const int kProbabilityHalf = 128;
+constexpr uint32_t DaalaBitReader::kWindowSize;  // static.
 
-DaalaBitReaderAom::DaalaBitReaderAom(const uint8_t* data, size_t size,
-                                     bool allow_update_cdf)
-    : allow_update_cdf_(allow_update_cdf) {
-  aom_daala_reader_init(&reader_, data, static_cast<int>(size));
-}
-
-int DaalaBitReaderAom::ReadBit() {
-  return aom_daala_read(&reader_, kProbabilityHalf);
-}
-
-int64_t DaalaBitReaderAom::ReadLiteral(int num_bits) {
-  if (num_bits > 32) return -1;
-  uint32_t literal = 0;
-  for (int bit = num_bits - 1; bit >= 0; bit--) {
-    literal |= static_cast<uint32_t>(ReadBit()) << bit;
-  }
-  return literal;
-}
-
-int DaalaBitReaderAom::ReadSymbol(uint16_t* cdf, int symbol_count) {
-  const int symbol = daala_read_symbol(&reader_, cdf, symbol_count);
-  if (allow_update_cdf_) {
-    UpdateCdf(cdf, symbol, symbol_count);
-  }
-  return symbol;
-}
-
-bool DaalaBitReaderAom::ReadSymbol(uint16_t* cdf) {
-  const int symbol = daala_read_symbol(&reader_, cdf, kBooleanSymbolCount);
-  if (allow_update_cdf_) {
-    UpdateCdf(cdf, symbol, kBooleanSymbolCount);
-  }
-  return symbol != 0;
-}
-
-bool DaalaBitReaderAom::ReadSymbolWithoutCdfUpdate(uint16_t* cdf) {
-  return daala_read_symbol(&reader_, cdf, kBooleanSymbolCount) != 0;
-}
-#endif  // defined(LIBGAV1_USE_LIBAOM_BIT_READER)
-
-DaalaBitReaderNative::DaalaBitReaderNative(const uint8_t* data, size_t size,
-                                           bool allow_update_cdf)
+DaalaBitReader::DaalaBitReader(const uint8_t* data, size_t size,
+                               bool allow_update_cdf)
     : data_(data),
       size_(size),
       data_index_(0),
       allow_update_cdf_(allow_update_cdf) {
-  window_diff_ = (uint32_t{1} << (kWindowSize - 1)) - 1;
+  window_diff_ = (WindowSize{1} << (kWindowSize - 1)) - 1;
   values_in_range_ = kCdfMaxProbability;
   bits_ = -15;
   PopulateBits();
@@ -109,10 +48,11 @@
 //   * The probability is fixed at half. So some multiplications can be replaced
 //     with bit operations.
 //   * Symbol count is fixed at 2.
-int DaalaBitReaderNative::ReadBit() {
+int DaalaBitReader::ReadBit() {
   const uint32_t curr =
       ((values_in_range_ & kReadBitMask) >> 1) + kMinimumProbabilityPerSymbol;
-  const uint32_t zero_threshold = curr << (kWindowSize - 16);
+  const WindowSize zero_threshold = static_cast<WindowSize>(curr)
+                                    << (kWindowSize - 16);
   int bit = 1;
   if (window_diff_ >= zero_threshold) {
     values_in_range_ -= curr;
@@ -125,7 +65,7 @@
   return bit;
 }
 
-int64_t DaalaBitReaderNative::ReadLiteral(int num_bits) {
+int64_t DaalaBitReader::ReadLiteral(int num_bits) {
   if (num_bits > 32) return -1;
   uint32_t literal = 0;
   for (int bit = num_bits - 1; bit >= 0; --bit) {
@@ -134,18 +74,16 @@
   return literal;
 }
 
-int DaalaBitReaderNative::ReadSymbol(uint16_t* const cdf, int symbol_count) {
+int DaalaBitReader::ReadSymbol(uint16_t* const cdf, int symbol_count) {
   const int symbol = ReadSymbolImpl(cdf, symbol_count);
   if (allow_update_cdf_) {
-    // TODO(vigneshv): This call can be replaced with the function contents
-    // inline once the DaalaBitReaderAom is removed.
-    UpdateCdf(cdf, symbol, symbol_count);
+    UpdateCdf(cdf, symbol_count, symbol);
   }
   return symbol;
 }
 
-bool DaalaBitReaderNative::ReadSymbol(uint16_t* cdf) {
-  const bool symbol = ReadSymbolImpl(cdf, kBooleanSymbolCount) != 0;
+bool DaalaBitReader::ReadSymbol(uint16_t* cdf) {
+  const bool symbol = ReadSymbolImpl(cdf) != 0;
   if (allow_update_cdf_) {
     const uint16_t count = cdf[2];
     // rate is computed in the spec as:
@@ -173,35 +111,98 @@
   return symbol;
 }
 
-bool DaalaBitReaderNative::ReadSymbolWithoutCdfUpdate(uint16_t* cdf) {
-  return ReadSymbolImpl(cdf, kBooleanSymbolCount) != 0;
+bool DaalaBitReader::ReadSymbolWithoutCdfUpdate(uint16_t* cdf) {
+  return ReadSymbolImpl(cdf) != 0;
 }
 
-int DaalaBitReaderNative::ReadSymbolImpl(const uint16_t* const cdf,
-                                         int symbol_count) {
+int DaalaBitReader::ReadSymbolImpl(const uint16_t* const cdf,
+                                   int symbol_count) {
   assert(cdf[symbol_count - 1] == 0);
   --symbol_count;
   uint32_t curr = values_in_range_;
   int symbol = -1;
   uint32_t prev;
-  uint32_t symbol_value = window_diff_ >> (kWindowSize - 16);
+  const auto symbol_value =
+      static_cast<uint32_t>(window_diff_ >> (kWindowSize - 16));
+  uint32_t delta = kMinimumProbabilityPerSymbol * symbol_count;
+  // Search through the |cdf| array to determine where the scaled cdf value and
+  // |symbol_value| cross over.
   do {
     prev = curr;
-    curr = values_in_range_ >> 8;
-    curr *= cdf[++symbol] >> kCdfPrecision;
-    curr >>= 1;
-    curr += kMinimumProbabilityPerSymbol * (symbol_count - symbol);
+    curr = (((values_in_range_ >> 8) * (cdf[++symbol] >> kCdfPrecision)) >> 1) +
+           delta;
+    delta -= kMinimumProbabilityPerSymbol;
   } while (symbol_value < curr);
   values_in_range_ = prev - curr;
-  window_diff_ -= curr << (kWindowSize - 16);
+  window_diff_ -= static_cast<WindowSize>(curr) << (kWindowSize - 16);
   NormalizeRange();
   return symbol;
 }
 
-void DaalaBitReaderNative::PopulateBits() {
+int DaalaBitReader::ReadSymbolImplBinarySearch(const uint16_t* const cdf,
+                                               int symbol_count) {
+  assert(cdf[symbol_count - 1] == 0);
+  assert(symbol_count > 1 && symbol_count <= 16);
+  --symbol_count;
+  const auto symbol_value =
+      static_cast<uint32_t>(window_diff_ >> (kWindowSize - 16));
+  // Search through the |cdf| array to determine where the scaled cdf value and
+  // |symbol_value| cross over. Since the CDFs are sorted, we can use binary
+  // search to do this. Let |symbol| be the index of the first |cdf| array
+  // entry whose scaled cdf value is less than or equal to |symbol_value|. The
+  // binary search maintains the invariant:
+  //   low <= symbol <= high + 1
+  // and terminates when low == high + 1.
+  int low = 0;
+  int high = symbol_count - 1;
+  // The binary search maintains the invariants that |prev| is the scaled cdf
+  // value for low - 1 and |curr| is the scaled cdf value for high + 1. (By
+  // convention, the scaled cdf value for -1 is values_in_range_.) When the
+  // binary search terminates, |prev| is the scaled cdf value for symbol - 1
+  // and |curr| is the scaled cdf value for |symbol|.
+  uint32_t prev = values_in_range_;
+  uint32_t curr = 0;
+  const uint16_t values_in_range_shifted = values_in_range_ >> 8;
+  do {
+    const int mid = DivideBy2(low + high);
+    const uint32_t scaled_cdf =
+        ScaleCdf(values_in_range_shifted, cdf, mid, symbol_count);
+    if (symbol_value < scaled_cdf) {
+      low = mid + 1;
+      prev = scaled_cdf;
+    } else {
+      high = mid - 1;
+      curr = scaled_cdf;
+    }
+  } while (low <= high);
+  assert(low == high + 1);
+  // At this point, |low| is the symbol that has been decoded.
+  values_in_range_ = prev - curr;
+  window_diff_ -= static_cast<WindowSize>(curr) << (kWindowSize - 16);
+  NormalizeRange();
+  return low;
+}
+
+int DaalaBitReader::ReadSymbolImpl(const uint16_t* const cdf) {
+  assert(cdf[1] == 0);
+  const auto symbol_value =
+      static_cast<uint32_t>(window_diff_ >> (kWindowSize - 16));
+  const uint32_t curr = ScaleCdf(values_in_range_ >> 8, cdf, 0, 1);
+  const int symbol = static_cast<int>(symbol_value < curr);
+  if (symbol == 1) {
+    values_in_range_ = curr;
+  } else {
+    values_in_range_ -= curr;
+    window_diff_ -= static_cast<WindowSize>(curr) << (kWindowSize - 16);
+  }
+  NormalizeRange();
+  return symbol;
+}
+
+void DaalaBitReader::PopulateBits() {
   int shift = kWindowSize - 9 - (bits_ + 15);
   for (; shift >= 0 && data_index_ < size_; shift -= 8) {
-    window_diff_ ^= static_cast<uint32_t>(data_[data_index_++]) << shift;
+    window_diff_ ^= static_cast<WindowSize>(data_[data_index_++]) << shift;
     bits_ += 8;
   }
   if (data_index_ >= size_) {
@@ -209,7 +210,7 @@
   }
 }
 
-void DaalaBitReaderNative::NormalizeRange() {
+void DaalaBitReader::NormalizeRange() {
   const int bits_used = 15 - FloorLog2(values_in_range_);
   bits_ -= bits_used;
   window_diff_ = ((window_diff_ + 1) << bits_used) - 1;
@@ -217,4 +218,55 @@
   if (bits_ < 0) PopulateBits();
 }
 
+void DaalaBitReader::UpdateCdf(uint16_t* const cdf, int symbol_count,
+                               int symbol) {
+  const uint16_t count = cdf[symbol_count];
+  // rate is computed in the spec as:
+  //  3 + ( cdf[N] > 15 ) + ( cdf[N] > 31 ) + Min(FloorLog2(N), 2)
+  // In this case cdf[N] is |count|.
+  // Min(FloorLog2(N), 2) is 1 for symbol_count == {2, 3} and 2 for all
+  // symbol_count > 3. So the equation becomes:
+  //  4 + (count > 15) + (count > 31) + (symbol_count > 3).
+  // Note that the largest value for count is 32 (it is not incremented beyond
+  // 32). So using that information:
+  //  count >> 4 is 0 for count from 0 to 15.
+  //  count >> 4 is 1 for count from 16 to 31.
+  //  count >> 4 is 2 for count == 31.
+  // Now, the equation becomes:
+  //  4 + (count >> 4) + (symbol_count > 3).
+  // Since (count >> 4) can only be 0 or 1 or 2, the addition can be replaced
+  // with bitwise or. So the final equation is:
+  // (4 | (count >> 4)) + (symbol_count > 3).
+  const int rate = (4 | (count >> 4)) + static_cast<int>(symbol_count > 3);
+  // Hints for further optimizations:
+  //
+  // 1. clang can vectorize this for loop with width 4, even though the loop
+  // contains an if-else statement. Therefore, it may be advantageous to use
+  // "i < symbol_count" as the loop condition when symbol_count is 8, 12, or 16
+  // (a multiple of 4 that's not too small).
+  //
+  // 2. The for loop can be rewritten in the following form, which would enable
+  // clang to vectorize the loop with width 8:
+  //
+  //   const int mask = (1 << rate) - 1;
+  //   for (int i = 0; i < symbol_count - 1; ++i) {
+  //     const uint16_t a = (i < symbol) ? kCdfMaxProbability : mask;
+  //     cdf[i] += static_cast<int16_t>(a - cdf[i]) >> rate;
+  //   }
+  //
+  // The subtraction (a - cdf[i]) relies on the overflow semantics of unsigned
+  // integer arithmetic. The result of the unsigned subtraction is cast to a
+  // signed integer and right-shifted. This requires the right shift of a
+  // signed integer be an arithmetic shift, which is true for clang, gcc, and
+  // Visual C++.
+  for (int i = 0; i < symbol_count - 1; ++i) {
+    if (i < symbol) {
+      cdf[i] += (libgav1::kCdfMaxProbability - cdf[i]) >> rate;
+    } else {
+      cdf[i] -= cdf[i] >> rate;
+    }
+  }
+  cdf[symbol_count] += static_cast<uint16_t>(count < 32);
+}
+
 }  // namespace libgav1
diff --git a/libgav1/src/utils/entropy_decoder.h b/libgav1/src/utils/entropy_decoder.h
index f6292bc..b891bae 100644
--- a/libgav1/src/utils/entropy_decoder.h
+++ b/libgav1/src/utils/entropy_decoder.h
@@ -4,53 +4,69 @@
 #include <cstddef>
 #include <cstdint>
 
-#if defined(LIBGAV1_USE_LIBAOM_BIT_READER)
-#include "third_party/libaom/git_root/aom_dsp/daalaboolreader.h"
-#endif
-
 #include "src/utils/bit_reader.h"
 
 namespace libgav1 {
 
-#if defined(LIBGAV1_USE_LIBAOM_BIT_READER)
-class DaalaBitReaderAom : public BitReader {
+class DaalaBitReader : public BitReader {
  public:
-  DaalaBitReaderAom(const uint8_t* data, size_t size, bool allow_update_cdf);
-  ~DaalaBitReaderAom() override = default;
-
-  int ReadBit() override;
-  int64_t ReadLiteral(int num_bits) override;
-  int ReadSymbol(uint16_t* cdf, int symbol_count);
-  bool ReadSymbol(uint16_t* cdf);
-  bool ReadSymbolWithoutCdfUpdate(uint16_t* cdf);
-
- private:
-  bool allow_update_cdf_;
-  daala_reader reader_;
-};
-#endif  // defined(LIBGAV1_USE_LIBAOM_BIT_READER)
-
-class DaalaBitReaderNative : public BitReader {
- public:
-  DaalaBitReaderNative(const uint8_t* data, size_t size, bool allow_update_cdf);
-  ~DaalaBitReaderNative() override = default;
+  DaalaBitReader(const uint8_t* data, size_t size, bool allow_update_cdf);
+  ~DaalaBitReader() override = default;
 
   // Move only.
-  DaalaBitReaderNative(DaalaBitReaderNative&& rhs) noexcept;
-  DaalaBitReaderNative& operator=(DaalaBitReaderNative&& rhs) noexcept;
+  DaalaBitReader(DaalaBitReader&& rhs) noexcept;
+  DaalaBitReader& operator=(DaalaBitReader&& rhs) noexcept;
 
   int ReadBit() override;
   int64_t ReadLiteral(int num_bits) override;
+  // ReadSymbol() calls for which the |symbol_count| is only known at runtime
+  // will use this variant.
   int ReadSymbol(uint16_t* cdf, int symbol_count);
+  // ReadSymbol() calls for which the |symbol_count| is equal to 2 (boolean
+  // symbols) will use this variant.
   bool ReadSymbol(uint16_t* cdf);
   bool ReadSymbolWithoutCdfUpdate(uint16_t* cdf);
+  // Use either linear search or binary search for decoding the symbol depending
+  // on |symbol_count|. ReadSymbol calls for which the |symbol_count| is known
+  // at compile time will use this variant.
+  template <int symbol_count>
+  int ReadSymbol(uint16_t* const cdf) {
+    static_assert(symbol_count >= 3 && symbol_count <= 16, "");
+    const int symbol = (symbol_count <= 13)
+                           ? ReadSymbolImpl(cdf, symbol_count)
+                           : ReadSymbolImplBinarySearch(cdf, symbol_count);
+    if (allow_update_cdf_) {
+      UpdateCdf(cdf, symbol_count, symbol);
+    }
+    return symbol;
+  }
 
  private:
+  using WindowSize = uint32_t;
+  static constexpr uint32_t kWindowSize =
+      static_cast<uint32_t>(sizeof(WindowSize)) * 8;
+
+  // Reads a symbol using the |cdf| table which contains the probabilities of
+  // each symbol. On a high level, this function does the following:
+  //   1) Scale the |cdf| values.
+  //   2) Find the index in the |cdf| array where the scaled CDF value crosses
+  //   the modified |window_diff_| threshold.
+  //   3) That index is the symbol that has been decoded.
+  //   4) Update |window_diff_| and |values_in_range_| based on the symbol that
+  //   has been decoded.
   int ReadSymbolImpl(const uint16_t* cdf, int symbol_count);
+  // Similar to ReadSymbolImpl but it uses binary search to perform step 2 in
+  // the comment above. As of now, this function is called when |symbol_count|
+  // is greater than or equal to 8.
+  int ReadSymbolImplBinarySearch(const uint16_t* cdf, int symbol_count);
+  // Specialized implementation of ReadSymbolImpl based on the fact that
+  // symbol_count == 2.
+  int ReadSymbolImpl(const uint16_t* cdf);
   void PopulateBits();
   // Normalizes the range so that 32768 <= |values_in_range_| < 65536. Also
   // calls PopulateBits() if necessary.
   void NormalizeRange();
+  void UpdateCdf(uint16_t* cdf, int symbol_count, int symbol);
 
   const uint8_t* data_;
   const size_t size_;
@@ -63,15 +79,9 @@
   // The difference between the high end of the current range and the coded
   // value minus 1. The 16 least significant bits of this variable is used to
   // decode the next symbol. It is filled in whenever |bits_| is less than 0.
-  uint32_t window_diff_;
+  WindowSize window_diff_;
 };
 
-#if defined(LIBGAV1_USE_LIBAOM_BIT_READER)
-using DaalaBitReader = DaalaBitReaderAom;
-#else
-using DaalaBitReader = DaalaBitReaderNative;
-#endif
-
 }  // namespace libgav1
 
 #endif  // LIBGAV1_SRC_UTILS_ENTROPY_DECODER_H_
diff --git a/libgav1/src/utils/memory.h b/libgav1/src/utils/memory.h
index 1b73dbb..c5eeb26 100644
--- a/libgav1/src/utils/memory.h
+++ b/libgav1/src/utils/memory.h
@@ -15,6 +15,16 @@
 
 namespace libgav1 {
 
+enum {
+// The byte alignment required for buffers used with SIMD code to be read or
+// written with aligned operations.
+#if defined(__i386__) || defined(_M_IX86)
+  kMaxAlignment = 16,  // extended alignment is safe on x86.
+#else
+  kMaxAlignment = alignof(max_align_t),
+#endif
+};
+
 // AlignedAlloc, AlignedFree
 //
 // void* AlignedAlloc(size_t alignment, size_t size);
@@ -113,28 +123,54 @@
   // Class-specific non-throwing allocation functions
   static void* operator new(size_t size, const std::nothrow_t& tag) noexcept {
     if (size > 0x40000000) return nullptr;
+#ifdef __cpp_aligned_new
+    return ::operator new(size, std::align_val_t(kMaxAlignment), tag);
+#else
     return ::operator new(size, tag);
+#endif
   }
   static void* operator new[](size_t size, const std::nothrow_t& tag) noexcept {
     if (size > 0x40000000) return nullptr;
+#ifdef __cpp_aligned_new
+    return ::operator new[](size, std::align_val_t(kMaxAlignment), tag);
+#else
     return ::operator new[](size, tag);
+#endif
   }
 
   // Class-specific deallocation functions.
-  static void operator delete(void* ptr) noexcept { ::operator delete(ptr); }
+  static void operator delete(void* ptr) noexcept {
+#ifdef __cpp_aligned_new
+    ::operator delete(ptr, std::align_val_t(kMaxAlignment));
+#else
+    ::operator delete(ptr);
+#endif
+  }
   static void operator delete[](void* ptr) noexcept {
+#ifdef __cpp_aligned_new
+    ::operator delete[](ptr, std::align_val_t(kMaxAlignment));
+#else
     ::operator delete[](ptr);
+#endif
   }
 
   // Only called if new (std::nothrow) is used and the constructor throws an
   // exception.
   static void operator delete(void* ptr, const std::nothrow_t& tag) noexcept {
+#ifdef __cpp_aligned_new
+    ::operator delete(ptr, std::align_val_t(kMaxAlignment), tag);
+#else
     ::operator delete(ptr, tag);
+#endif
   }
   // Only called if new[] (std::nothrow) is used and the constructor throws an
   // exception.
   static void operator delete[](void* ptr, const std::nothrow_t& tag) noexcept {
+#ifdef __cpp_aligned_new
+    ::operator delete[](ptr, std::align_val_t(kMaxAlignment), tag);
+#else
     ::operator delete[](ptr, tag);
+#endif
   }
 };
 
diff --git a/libgav1/src/utils/mutex.h b/libgav1/src/utils/mutex.h
deleted file mode 100644
index bd84c1e..0000000
--- a/libgav1/src/utils/mutex.h
+++ /dev/null
@@ -1,4 +0,0 @@
-#ifndef LIBGAV1_SRC_UTILS_MUTEX_H_
-#define LIBGAV1_SRC_UTILS_MUTEX_H_
-
-#endif  // LIBGAV1_SRC_UTILS_MUTEX_H_
diff --git a/libgav1/src/utils/parameter_tree.cc b/libgav1/src/utils/parameter_tree.cc
index a6983e6..cd4b24b 100644
--- a/libgav1/src/utils/parameter_tree.cc
+++ b/libgav1/src/utils/parameter_tree.cc
@@ -1,6 +1,7 @@
 #include "src/utils/parameter_tree.h"
 
 #include <cassert>
+#include <memory>
 #include <new>
 
 #include "src/utils/common.h"
@@ -10,7 +11,19 @@
 
 namespace libgav1 {
 
-void ParameterTree::SetPartitionType(Partition partition) {
+// static
+std::unique_ptr<ParameterTree> ParameterTree::Create(int row4x4, int column4x4,
+                                                     BlockSize block_size,
+                                                     bool is_leaf) {
+  std::unique_ptr<ParameterTree> tree(
+      new (std::nothrow) ParameterTree(row4x4, column4x4, block_size));
+  if (tree != nullptr && is_leaf && !tree->SetPartitionType(kPartitionNone)) {
+    tree = nullptr;
+  }
+  return tree;
+}
+
+bool ParameterTree::SetPartitionType(Partition partition) {
   assert(!partition_type_set_);
   partition_ = partition;
   partition_type_set_ = true;
@@ -23,189 +36,84 @@
   switch (partition) {
     case kPartitionNone:
       parameters_.reset(new (std::nothrow) BlockParameters());
-      return;
+      return parameters_ != nullptr;
     case kPartitionHorizontal:
-      children_[0].reset(new (std::nothrow) ParameterTree(row4x4_, column4x4_,
-                                                          sub_size, true));
-      children_[1].reset(new (std::nothrow) ParameterTree(
-          row4x4_ + half_block4x4, column4x4_, sub_size, true));
-      return;
+      children_[0] = ParameterTree::Create(row4x4_, column4x4_, sub_size, true);
+      children_[1] = ParameterTree::Create(row4x4_ + half_block4x4, column4x4_,
+                                           sub_size, true);
+      return children_[0] != nullptr && children_[1] != nullptr;
     case kPartitionVertical:
-      children_[0].reset(new (std::nothrow) ParameterTree(row4x4_, column4x4_,
-                                                          sub_size, true));
-      children_[1].reset(new (std::nothrow) ParameterTree(
-          row4x4_, column4x4_ + half_block4x4, sub_size, true));
-      return;
+      children_[0] = ParameterTree::Create(row4x4_, column4x4_, sub_size, true);
+      children_[1] = ParameterTree::Create(row4x4_, column4x4_ + half_block4x4,
+                                           sub_size, true);
+      return children_[0] != nullptr && children_[1] != nullptr;
     case kPartitionSplit:
-      children_[0].reset(new (std::nothrow) ParameterTree(row4x4_, column4x4_,
-                                                          sub_size, false));
-      children_[1].reset(new (std::nothrow) ParameterTree(
-          row4x4_, column4x4_ + half_block4x4, sub_size, false));
-      children_[2].reset(new (std::nothrow) ParameterTree(
-          row4x4_ + half_block4x4, column4x4_, sub_size, false));
-      children_[3].reset(new (std::nothrow) ParameterTree(
-          row4x4_ + half_block4x4, column4x4_ + half_block4x4, sub_size,
-          false));
-      return;
+      children_[0] =
+          ParameterTree::Create(row4x4_, column4x4_, sub_size, false);
+      children_[1] = ParameterTree::Create(row4x4_, column4x4_ + half_block4x4,
+                                           sub_size, false);
+      children_[2] = ParameterTree::Create(row4x4_ + half_block4x4, column4x4_,
+                                           sub_size, false);
+      children_[3] = ParameterTree::Create(
+          row4x4_ + half_block4x4, column4x4_ + half_block4x4, sub_size, false);
+      return children_[0] != nullptr && children_[1] != nullptr &&
+             children_[2] != nullptr && children_[3] != nullptr;
     case kPartitionHorizontalWithTopSplit:
       assert(split_size != kBlockInvalid);
-      children_[0].reset(new (std::nothrow) ParameterTree(row4x4_, column4x4_,
-                                                          split_size, true));
-      children_[1].reset(new (std::nothrow) ParameterTree(
-          row4x4_, column4x4_ + half_block4x4, split_size, true));
-      children_[2].reset(new (std::nothrow) ParameterTree(
-          row4x4_ + half_block4x4, column4x4_, sub_size, true));
-      return;
+      children_[0] =
+          ParameterTree::Create(row4x4_, column4x4_, split_size, true);
+      children_[1] = ParameterTree::Create(row4x4_, column4x4_ + half_block4x4,
+                                           split_size, true);
+      children_[2] = ParameterTree::Create(row4x4_ + half_block4x4, column4x4_,
+                                           sub_size, true);
+      return children_[0] != nullptr && children_[1] != nullptr &&
+             children_[2] != nullptr;
     case kPartitionHorizontalWithBottomSplit:
       assert(split_size != kBlockInvalid);
-      children_[0].reset(new (std::nothrow) ParameterTree(row4x4_, column4x4_,
-                                                          sub_size, true));
-      children_[1].reset(new (std::nothrow) ParameterTree(
-          row4x4_ + half_block4x4, column4x4_, split_size, true));
-      children_[2].reset(new (std::nothrow) ParameterTree(
-          row4x4_ + half_block4x4, column4x4_ + half_block4x4, split_size,
-          true));
-      return;
+      children_[0] = ParameterTree::Create(row4x4_, column4x4_, sub_size, true);
+      children_[1] = ParameterTree::Create(row4x4_ + half_block4x4, column4x4_,
+                                           split_size, true);
+      children_[2] =
+          ParameterTree::Create(row4x4_ + half_block4x4,
+                                column4x4_ + half_block4x4, split_size, true);
+      return children_[0] != nullptr && children_[1] != nullptr &&
+             children_[2] != nullptr;
     case kPartitionVerticalWithLeftSplit:
       assert(split_size != kBlockInvalid);
-      children_[0].reset(new (std::nothrow) ParameterTree(row4x4_, column4x4_,
-                                                          split_size, true));
-      children_[1].reset(new (std::nothrow) ParameterTree(
-          row4x4_ + half_block4x4, column4x4_, split_size, true));
-      children_[2].reset(new (std::nothrow) ParameterTree(
-          row4x4_, column4x4_ + half_block4x4, sub_size, true));
-      return;
+      children_[0] =
+          ParameterTree::Create(row4x4_, column4x4_, split_size, true);
+      children_[1] = ParameterTree::Create(row4x4_ + half_block4x4, column4x4_,
+                                           split_size, true);
+      children_[2] = ParameterTree::Create(row4x4_, column4x4_ + half_block4x4,
+                                           sub_size, true);
+      return children_[0] != nullptr && children_[1] != nullptr &&
+             children_[2] != nullptr;
     case kPartitionVerticalWithRightSplit:
       assert(split_size != kBlockInvalid);
-      children_[0].reset(new (std::nothrow) ParameterTree(row4x4_, column4x4_,
-                                                          sub_size, true));
-      children_[1].reset(new (std::nothrow) ParameterTree(
-          row4x4_, column4x4_ + half_block4x4, split_size, true));
-      children_[2].reset(new (std::nothrow) ParameterTree(
-          row4x4_ + half_block4x4, column4x4_ + half_block4x4, split_size,
-          true));
-      return;
+      children_[0] = ParameterTree::Create(row4x4_, column4x4_, sub_size, true);
+      children_[1] = ParameterTree::Create(row4x4_, column4x4_ + half_block4x4,
+                                           split_size, true);
+      children_[2] =
+          ParameterTree::Create(row4x4_ + half_block4x4,
+                                column4x4_ + half_block4x4, split_size, true);
+      return children_[0] != nullptr && children_[1] != nullptr &&
+             children_[2] != nullptr;
     case kPartitionHorizontal4:
       for (int i = 0; i < 4; ++i) {
-        children_[i].reset(new (std::nothrow) ParameterTree(
-            row4x4_ + i * quarter_block4x4, column4x4_, sub_size, true));
+        children_[i] = ParameterTree::Create(row4x4_ + i * quarter_block4x4,
+                                             column4x4_, sub_size, true);
+        if (children_[i] == nullptr) return false;
       }
-      return;
-    case kPartitionVertical4:
+      return true;
+    default:
+      assert(partition == kPartitionVertical4);
       for (int i = 0; i < 4; ++i) {
-        children_[i].reset(new (std::nothrow) ParameterTree(
-            row4x4_, column4x4_ + i * quarter_block4x4, sub_size, true));
+        children_[i] = ParameterTree::Create(
+            row4x4_, column4x4_ + i * quarter_block4x4, sub_size, true);
+        if (children_[i] == nullptr) return false;
       }
-      return;
+      return true;
   }
 }
 
-BlockParameters* ParameterTree::Find(int row4x4, int column4x4) const {
-  if (!partition_type_set_ || row4x4 < row4x4_ || column4x4 < column4x4_ ||
-      row4x4 >= row4x4_ + kNum4x4BlocksHigh[block_size_] ||
-      column4x4 >= column4x4_ + kNum4x4BlocksWide[block_size_]) {
-    // Either partition type is not set or the search range is out of bound.
-    return nullptr;
-  }
-  const ParameterTree* node = this;
-  while (node->partition_ != kPartitionNone) {
-    if (!node->partition_type_set_) {
-      LIBGAV1_DLOG(ERROR,
-                   "Partition type was not set for one of the nodes in the "
-                   "path to row4x4: %d column4x4: %d.",
-                   row4x4, column4x4);
-      return nullptr;
-    }
-    const int block_width4x4 = kNum4x4BlocksWide[node->block_size_];
-    const int half_block4x4 = block_width4x4 >> 1;
-    const int quarter_block4x4 = half_block4x4 >> 1;
-    switch (node->partition_) {
-      case kPartitionNone:
-        assert(false);
-        break;
-      case kPartitionHorizontal:
-        if (row4x4 < node->row4x4_ + half_block4x4) {
-          node = node->children_[0].get();
-        } else {
-          node = node->children_[1].get();
-        }
-        break;
-      case kPartitionVertical:
-        if (column4x4 < node->column4x4_ + half_block4x4) {
-          node = node->children_[0].get();
-        } else {
-          node = node->children_[1].get();
-        }
-        break;
-      case kPartitionSplit:
-        if (row4x4 < node->row4x4_ + half_block4x4 &&
-            column4x4 < node->column4x4_ + half_block4x4) {
-          node = node->children_[0].get();
-        } else if (row4x4 < node->row4x4_ + half_block4x4) {
-          node = node->children_[1].get();
-        } else if (column4x4 < node->column4x4_ + half_block4x4) {
-          node = node->children_[2].get();
-        } else {
-          node = node->children_[3].get();
-        }
-        break;
-      case kPartitionHorizontalWithTopSplit:
-        if (row4x4 < node->row4x4_ + half_block4x4 &&
-            column4x4 < node->column4x4_ + half_block4x4) {
-          node = node->children_[0].get();
-        } else if (row4x4 < node->row4x4_ + half_block4x4) {
-          node = node->children_[1].get();
-        } else {
-          node = node->children_[2].get();
-        }
-        break;
-      case kPartitionHorizontalWithBottomSplit:
-        if (row4x4 < node->row4x4_ + half_block4x4) {
-          node = node->children_[0].get();
-        } else if (column4x4 < node->column4x4_ + half_block4x4) {
-          node = node->children_[1].get();
-        } else {
-          node = node->children_[2].get();
-        }
-        break;
-      case kPartitionVerticalWithLeftSplit:
-        if (row4x4 < node->row4x4_ + half_block4x4 &&
-            column4x4 < node->column4x4_ + half_block4x4) {
-          node = node->children_[0].get();
-        } else if (column4x4 < node->column4x4_ + half_block4x4) {
-          node = node->children_[1].get();
-        } else {
-          node = node->children_[2].get();
-        }
-        break;
-      case kPartitionVerticalWithRightSplit:
-        if (column4x4 < node->column4x4_ + half_block4x4) {
-          node = node->children_[0].get();
-        } else if (row4x4 < node->row4x4_ + half_block4x4) {
-          node = node->children_[1].get();
-        } else {
-          node = node->children_[2].get();
-        }
-        break;
-      case kPartitionHorizontal4:
-        for (int i = 0; i < 4; ++i) {
-          if (row4x4 < node->row4x4_ + quarter_block4x4 * (i + 1)) {
-            node = node->children_[i].get();
-            break;
-          }
-        }
-        break;
-      case kPartitionVertical4:
-        for (int i = 0; i < 4; ++i) {
-          if (column4x4 < node->column4x4_ + quarter_block4x4 * (i + 1)) {
-            node = node->children_[i].get();
-            break;
-          }
-        }
-        break;
-    }
-  }
-  return node->parameters_.get();
-}
-
 }  // namespace libgav1
diff --git a/libgav1/src/utils/parameter_tree.h b/libgav1/src/utils/parameter_tree.h
index 7cd0e0b..fabf51a 100644
--- a/libgav1/src/utils/parameter_tree.h
+++ b/libgav1/src/utils/parameter_tree.h
@@ -5,6 +5,7 @@
 #include <memory>
 
 #include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
 #include "src/utils/constants.h"
 #include "src/utils/memory.h"
 #include "src/utils/types.h"
@@ -20,13 +21,9 @@
   // false, |block_size| must be a square block, i.e.,
   // kBlockWidthPixels[block_size] must be equal to
   // kBlockHeightPixels[block_size].
-  ParameterTree(int row4x4, int column4x4, BlockSize block_size,
-                bool is_leaf = false)
-      : row4x4_(row4x4), column4x4_(column4x4), block_size_(block_size) {
-    if (is_leaf) {
-      SetPartitionType(kPartitionNone);
-    }
-  }
+  static std::unique_ptr<ParameterTree> Create(int row4x4, int column4x4,
+                                               BlockSize block_size,
+                                               bool is_leaf = false);
 
   // Move only (not Copyable).
   ParameterTree(ParameterTree&& other) = default;
@@ -46,12 +43,7 @@
   //   will have to set them or their descendants to a terminal type.
   // }
   // This function must be called only once per node.
-  void SetPartitionType(Partition partition);
-
-  // Traverses the tree and searches for the node that contains the
-  // BlockParameters for |row4x4| and |column4x4|. Returns nullptr, if the tree
-  // does not contain the BlockParameters for the given coordinates.
-  BlockParameters* Find(int row4x4, int column4x4) const;
+  LIBGAV1_MUST_USE_RESULT bool SetPartitionType(Partition partition);
 
   // Basic getters.
   int row4x4() const { return row4x4_; }
@@ -68,6 +60,9 @@
   BlockParameters* parameters() const { return parameters_.get(); }
 
  private:
+  ParameterTree(int row4x4, int column4x4, BlockSize block_size)
+      : row4x4_(row4x4), column4x4_(column4x4), block_size_(block_size) {}
+
   Partition partition_ = kPartitionNone;
   std::unique_ptr<BlockParameters> parameters_ = nullptr;
   int row4x4_ = -1;
@@ -93,6 +88,8 @@
   //  * Vertical4: 0 left partition; 1 second left partition; 2 third left
   //    partition; 3 right partition;
   std::unique_ptr<ParameterTree> children_[4] = {};
+
+  friend class ParameterTreeTest;
 };
 
 }  // namespace libgav1
diff --git a/libgav1/src/utils/queue.h b/libgav1/src/utils/queue.h
new file mode 100644
index 0000000..614acef
--- /dev/null
+++ b/libgav1/src/utils/queue.h
@@ -0,0 +1,68 @@
+#ifndef LIBGAV1_SRC_UTILS_QUEUE_H_
+#define LIBGAV1_SRC_UTILS_QUEUE_H_
+
+#include <cassert>
+#include <cstddef>
+#include <memory>
+#include <new>
+
+#include "src/utils/compiler_attributes.h"
+
+namespace libgav1 {
+
+// A FIFO queue of a fixed capacity. The elements are copied, so the element
+// type T should be small.
+//
+// WARNING: No error checking is performed.
+template <typename T>
+class Queue {
+ public:
+  LIBGAV1_MUST_USE_RESULT bool Init(size_t capacity) {
+    elements_.reset(new (std::nothrow) T[capacity]);
+    if (elements_ == nullptr) return false;
+    capacity_ = capacity;
+    return true;
+  }
+
+  // Pushes the element |value| to the end of the queue. It is an error to call
+  // Push() when the queue is full.
+  void Push(T value) {
+    assert(size_ < capacity_);
+    elements_[back_++] = value;
+    if (back_ == capacity_) back_ = 0;
+    ++size_;
+  }
+
+  // Returns the element at the front of the queue and removes it from the
+  // queue. It is an error to call Pop() when the queue is empty.
+  T Pop() {
+    assert(size_ != 0);
+    const T front_element = elements_[front_++];
+    if (front_ == capacity_) front_ = 0;
+    --size_;
+    return front_element;
+  }
+
+  // Returns true if the queue is empty.
+  bool Empty() const { return size_ == 0; }
+
+  // Returns true if the queue is full.
+  bool Full() const { return size_ >= capacity_; }
+
+  // Returns the number of elements in the queue.
+  size_t Size() const { return size_; }
+
+ private:
+  // An array of |capacity| elements. Used as a circular array.
+  std::unique_ptr<T[]> elements_;
+  size_t capacity_ = 0;
+  // The index of the element to be removed by Pop().
+  size_t front_ = 0;
+  // The index where the new element is inserted by Push().
+  size_t back_ = 0;
+  size_t size_ = 0;
+};
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_UTILS_QUEUE_H_
diff --git a/libgav1/src/utils/raw_bit_reader.cc b/libgav1/src/utils/raw_bit_reader.cc
index 3c3964d..084535a 100644
--- a/libgav1/src/utils/raw_bit_reader.cc
+++ b/libgav1/src/utils/raw_bit_reader.cc
@@ -16,18 +16,16 @@
 namespace libgav1 {
 namespace {
 
-const int kMaximumLeb128Size = 8;
-const uint8_t kLeb128ValueByteMask = 0x7f;
-const uint8_t kLeb128TerminationByteMask = 0x80;
+constexpr int kMaximumLeb128Size = 8;
+constexpr uint8_t kLeb128ValueByteMask = 0x7f;
+constexpr uint8_t kLeb128TerminationByteMask = 0x80;
 
-inline uint8_t Mod8(size_t n) {
+uint8_t Mod8(size_t n) {
   // Last 3 bits are the value of mod 8.
   return n & 0x07;
 }
 
-inline size_t DivideBy8(size_t n, bool ceil) {
-  return (n + (ceil ? 7 : 0)) >> 3;
-}
+size_t DivideBy8(size_t n, bool ceil) { return (n + (ceil ? 7 : 0)) >> 3; }
 
 }  // namespace
 
diff --git a/libgav1/src/utils/scan.h b/libgav1/src/utils/scan.h
deleted file mode 100644
index c880496..0000000
--- a/libgav1/src/utils/scan.h
+++ /dev/null
@@ -1,15 +0,0 @@
-#ifndef LIBGAV1_SRC_UTILS_SCAN_H_
-#define LIBGAV1_SRC_UTILS_SCAN_H_
-
-#include <cstdint>
-
-#include "src/utils/constants.h"
-
-namespace libgav1 {
-
-const uint16_t* GetScan(TransformSize tx_size,
-                        TransformType tx_type);  // 5.11.41.
-
-}  // namespace libgav1
-
-#endif  // LIBGAV1_SRC_UTILS_SCAN_H_
diff --git a/libgav1/src/utils/segmentation.h b/libgav1/src/utils/segmentation.h
index 9229783..f74bdb8 100644
--- a/libgav1/src/utils/segmentation.h
+++ b/libgav1/src/utils/segmentation.h
@@ -53,7 +53,9 @@
   bool feature_enabled[kMaxSegments][kSegmentFeatureMax];
   int16_t feature_data[kMaxSegments][kSegmentFeatureMax];
   bool lossless[kMaxSegments];
-  int16_t qindex[kMaxSegments];
+  // Cached values of get_qindex(1, segmentId), to be consumed by
+  // Tile::ReadTransformType(). The values are in the range [0, 255].
+  uint8_t qindex[kMaxSegments];
 };
 
 }  // namespace libgav1
diff --git a/libgav1/src/utils/stack.h b/libgav1/src/utils/stack.h
new file mode 100644
index 0000000..59c6061
--- /dev/null
+++ b/libgav1/src/utils/stack.h
@@ -0,0 +1,42 @@
+#ifndef LIBGAV1_SRC_UTILS_STACK_H_
+#define LIBGAV1_SRC_UTILS_STACK_H_
+
+#include <cassert>
+
+namespace libgav1 {
+
+// A LIFO stack of a fixed capacity. The elements are copied, so the element
+// type T should be small.
+//
+// WARNING: No error checking is performed.
+template <typename T, int capacity>
+class Stack {
+ public:
+  // Pushes the element |value| to the top of the stack. It is an error to call
+  // Push() when the stack is full.
+  void Push(T value) {
+    ++top_;
+    assert(top_ < capacity);
+    elements_[top_] = value;
+  }
+
+  // Returns the element at the top of the stack and removes it from the stack.
+  // It is an error to call Pop() when the stack is empty.
+  T Pop() {
+    assert(top_ >= 0);
+    return elements_[top_--];
+  }
+
+  // Returns true if the stack is empty.
+  bool Empty() const { return top_ < 0; }
+
+ private:
+  static_assert(capacity > 0, "");
+  T elements_[capacity];
+  // The array index of the top of the stack. The stack is empty if top_ is -1.
+  int top_ = -1;
+};
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_UTILS_STACK_H_
diff --git a/libgav1/src/utils/threadpool.cc b/libgav1/src/utils/threadpool.cc
index dea65f9..681ade6 100644
--- a/libgav1/src/utils/threadpool.cc
+++ b/libgav1/src/utils/threadpool.cc
@@ -83,10 +83,15 @@
 
 void ThreadPool::Schedule(std::function<void()> closure) {
   LockMutex();
-  queue_.push_back(std::move(closure));
-  // TODO(jzern): the mutex doesn't need to be locked to signal the condition.
-  SignalOne();
+  if (!queue_.GrowIfNeeded()) {
+    // queue_ is full and we can't grow it. Run |closure| directly.
+    UnlockMutex();
+    closure();
+    return;
+  }
+  queue_.Push(std::move(closure));
   UnlockMutex();
+  SignalOne();
 }
 
 int ThreadPool::num_threads() const { return num_threads_; }
@@ -178,6 +183,7 @@
     assert(rv >= 0);
     rv = pthread_setname_np(name);
     assert(rv == 0);
+    static_cast<void>(rv);
 #elif defined(__ANDROID__) || defined(__GLIBC__)
     // If the |name| buffer is longer than 16 bytes, pthread_setname_np fails
     // with error 34 (ERANGE) on Android.
@@ -188,6 +194,7 @@
     assert(rv >= 0);
     rv = pthread_setname_np(pthread_self(), name);
     assert(rv == 0);
+    static_cast<void>(rv);
 #endif
   }
 }
@@ -195,13 +202,14 @@
 #endif  // defined(_MSC_VER)
 
 void* ThreadPool::WorkerThread::ThreadBody(void* arg) {
-  auto thread = static_cast<WorkerThread*>(arg);
+  auto* thread = static_cast<WorkerThread*>(arg);
   thread->SetupName();
   thread->pool_->WorkerFunction();
   return nullptr;
 }
 
 bool ThreadPool::StartWorkers() {
+  if (!queue_.Init()) return false;
   for (int i = 0; i < num_threads_; ++i) {
     threads_[i] = new (std::nothrow) WorkerThread(this);
     if (threads_[i] == nullptr) return false;
@@ -217,7 +225,7 @@
 void ThreadPool::WorkerFunction() {
   LockMutex();
   while (true) {
-    if (queue_.empty()) {
+    if (queue_.Empty()) {
       if (exit_threads_) {
         break;  // Queue is empty and exit was requested.
       }
@@ -233,7 +241,7 @@
       const auto wait_start = Clock::now();
       while (Clock::now() - wait_start < kBusyWaitDuration) {
         LockMutex();
-        if (!queue_.empty()) {
+        if (!queue_.Empty()) {
           found_job = true;
           break;
         }
@@ -246,7 +254,7 @@
       // point.
       LockMutex();
       // Ensure that the queue is still empty.
-      if (!queue_.empty()) continue;
+      if (!queue_.Empty()) continue;
       if (exit_threads_) {
         break;  // Queue is empty and exit was requested.
       }
@@ -255,8 +263,8 @@
       Wait();
     } else {
       // Take a job from the queue.
-      std::function<void()> job = std::move(queue_.front());
-      queue_.pop_front();
+      std::function<void()> job = std::move(queue_.Front());
+      queue_.Pop();
 
       UnlockMutex();
       // Note that it is good practice to surround this with a try/catch so
@@ -275,9 +283,8 @@
   // Tell worker threads how to exit.
   LockMutex();
   exit_threads_ = true;
-  // TODO(jzern): the mutex doesn't need to be locked to signal the condition.
-  SignalAll();
   UnlockMutex();
+  SignalAll();
 
   // Join all workers. This will block.
   for (int i = 0; i < num_threads_; ++i) {
diff --git a/libgav1/src/utils/threadpool.h b/libgav1/src/utils/threadpool.h
index 2e710d6..238bc44 100644
--- a/libgav1/src/utils/threadpool.h
+++ b/libgav1/src/utils/threadpool.h
@@ -1,7 +1,6 @@
 #ifndef LIBGAV1_SRC_UTILS_THREADPOOL_H_
 #define LIBGAV1_SRC_UTILS_THREADPOOL_H_
 
-#include <deque>
 #include <functional>
 #include <memory>
 
@@ -20,6 +19,7 @@
 #include "src/utils/compiler_attributes.h"
 #include "src/utils/executor.h"
 #include "src/utils/memory.h"
+#include "src/utils/unbounded_queue.h"
 
 namespace libgav1 {
 
@@ -56,6 +56,13 @@
   // Adds the specified "closure" to the queue for processing. If worker threads
   // are available, "closure" will run immediately. Otherwise "closure" is
   // queued for later execution.
+  //
+  // NOTE: If the internal queue is full and cannot be resized because of an
+  // out-of-memory error, the current thread runs "closure" before returning
+  // from Schedule(). For our use cases, this seems better than the
+  // alternatives:
+  //   1. Return a failure status.
+  //   2. Have the current thread wait until the queue is not full.
   void Schedule(std::function<void()> closure) override;
 
   int num_threads() const;
@@ -113,7 +120,7 @@
 
 #endif  // LIBGAV1_THREADPOOL_USE_STD_MUTEX
 
-  std::deque<std::function<void()>> queue_ LIBGAV1_GUARDED_BY(queue_mutex_);
+  UnboundedQueue<std::function<void()>> queue_ LIBGAV1_GUARDED_BY(queue_mutex_);
   // If not all the worker threads are created, the first entry after the
   // created worker threads is a null pointer.
   const std::unique_ptr<WorkerThread*[]> threads_;
diff --git a/libgav1/src/utils/types.h b/libgav1/src/utils/types.h
index d7093ea..53bfe5c 100644
--- a/libgav1/src/utils/types.h
+++ b/libgav1/src/utils/types.h
@@ -4,6 +4,7 @@
 #include <cstdint>
 #include <memory>
 
+#include "src/utils/array_2d.h"
 #include "src/utils/constants.h"
 #include "src/utils/memory.h"
 
@@ -63,7 +64,7 @@
   int8_t cfl_alpha_v;
   int max_luma_width;
   int max_luma_height;
-  uint8_t color_index_map[kNumPlaneTypes][kMaxPaletteSquare][kMaxPaletteSquare];
+  Array2D<uint8_t> color_index_map[kNumPlaneTypes];
   bool use_intra_block_copy;
   InterIntraMode inter_intra_mode;
   bool is_wedge_inter_intra;
@@ -97,13 +98,19 @@
   PredictionMode y_mode;
   PredictionMode uv_mode;
   TransformSize transform_size;
+  TransformSize uv_transform_size;
   PaletteModeInfo palette_mode_info;
   ReferenceFrameType reference_frame[2];
   MotionVector mv[2];
   bool is_explicit_compound_type;  // comp_group_idx in the spec.
   bool is_compound_type_average;   // compound_idx in the spec.
   InterpolationFilter interpolation_filter[2];
-  uint8_t deblock_filter_level[kMaxPlanes][kNumLoopFilterTypes];
+  // The index of this array is as follows:
+  //  0 - Y plane vertical filtering.
+  //  1 - Y plane horizontal filtering.
+  //  2 - U plane (both directions).
+  //  3 - V plane (both directions).
+  uint8_t deblock_filter_level[kFrameLfCount];
   // When |Tile::split_parse_and_decode_| is true, each block gets its own
   // instance of |prediction_parameters|. When it is false, all the blocks point
   // to |Tile::prediction_parameters_|. This field is valid only as long as the
diff --git a/libgav1/src/utils/unbounded_queue.h b/libgav1/src/utils/unbounded_queue.h
new file mode 100644
index 0000000..611ac50
--- /dev/null
+++ b/libgav1/src/utils/unbounded_queue.h
@@ -0,0 +1,227 @@
+#ifndef LIBGAV1_SRC_UTILS_UNBOUNDED_QUEUE_H_
+#define LIBGAV1_SRC_UTILS_UNBOUNDED_QUEUE_H_
+
+#include <cassert>
+#include <cstddef>
+#include <memory>
+#include <new>
+#include <utility>
+
+#include "src/utils/compiler_attributes.h"
+#include "src/utils/memory.h"
+
+namespace libgav1 {
+
+// A FIFO queue of an unbounded capacity.
+//
+// This implementation uses the general approach used in std::deque
+// implementations. See, for example,
+// https://stackoverflow.com/questions/6292332/what-really-is-a-deque-in-stl
+//
+// It is much simpler because it just needs to support the queue interface.
+// The blocks are chained into a circular list, not managed by a "map". It
+// does not shrink the internal buffer.
+//
+// An alternative implementation approach is a resizable circular array. See,
+// for example, ResizingArrayQueue.java in https://algs4.cs.princeton.edu/code/
+// and base::circular_deque in Chromium's base/containers library.
+template <typename T>
+class UnboundedQueue {
+ public:
+  UnboundedQueue() = default;
+
+  // Move only.
+  UnboundedQueue(UnboundedQueue&& other)
+      : first_block_(other.first_block_),
+        front_(other.front_),
+        last_block_(other.last_block_),
+        back_(other.back_) {
+    other.first_block_ = nullptr;
+    other.front_ = 0;
+    other.last_block_ = nullptr;
+    other.back_ = 0;
+  }
+  UnboundedQueue& operator=(UnboundedQueue&& other) {
+    if (this != &other) {
+      Destroy();
+      first_block_ = other.first_block_;
+      front_ = other.front_;
+      last_block_ = other.last_block_;
+      back_ = other.back_;
+      other.first_block_ = nullptr;
+      other.front_ = 0;
+      other.last_block_ = nullptr;
+      other.back_ = 0;
+    }
+    return *this;
+  }
+
+  ~UnboundedQueue() { Destroy(); }
+
+  // Allocates two Blocks upfront because most access patterns require at
+  // least two Blocks. Returns false if the allocation of the Blocks failed.
+  LIBGAV1_MUST_USE_RESULT bool Init() {
+    std::unique_ptr<Block> new_block0(new (std::nothrow) Block);
+    std::unique_ptr<Block> new_block1(new (std::nothrow) Block);
+    if (new_block0 == nullptr || new_block1 == nullptr) return false;
+    first_block_ = last_block_ = new_block0.release();
+    new_block1->next = first_block_;
+    last_block_->next = new_block1.release();
+    return true;
+  }
+
+  // Checks if the queue has room for a new element. If the queue is full,
+  // tries to grow it. Returns false if the queue is full and the attempt to
+  // grow it failed.
+  //
+  // NOTE: GrowIfNeeded() must be called before each call to Push(). This
+  // inconvenient design is necessary to guarantee a successful Push() call.
+  //
+  // Push(T&& value) is often called with the argument std::move(value). The
+  // moved-from object |value| won't be usable afterwards, so it would be
+  // problematic if Push(T&& value) failed and we lost access to the original
+  // |value| object.
+  LIBGAV1_MUST_USE_RESULT bool GrowIfNeeded() {
+    assert(last_block_ != nullptr);
+    if (back_ == kBlockCapacity) {
+      if (last_block_->next == first_block_) {
+        // All Blocks are in use.
+        std::unique_ptr<Block> new_block(new (std::nothrow) Block);
+        if (new_block == nullptr) return false;
+        new_block->next = first_block_;
+        last_block_->next = new_block.release();
+      }
+      last_block_ = last_block_->next;
+      back_ = 0;
+    }
+    return true;
+  }
+
+  // Pushes the element |value| to the end of the queue. It is an error to call
+  // Push() when the queue is full.
+  void Push(const T& value) {
+    assert(last_block_ != nullptr);
+    assert(back_ < kBlockCapacity);
+    T* elements = reinterpret_cast<T*>(last_block_->buffer);
+    new (&elements[back_++]) T(value);
+  }
+
+  void Push(T&& value) {
+    assert(last_block_ != nullptr);
+    assert(back_ < kBlockCapacity);
+    T* elements = reinterpret_cast<T*>(last_block_->buffer);
+    new (&elements[back_++]) T(std::move(value));
+  }
+
+  // Returns the element at the front of the queue. It is an error to call
+  // Front() when the queue is empty.
+  T& Front() {
+    assert(!Empty());
+    T* elements = reinterpret_cast<T*>(first_block_->buffer);
+    return elements[front_];
+  }
+
+  const T& Front() const {
+    assert(!Empty());
+    T* elements = reinterpret_cast<T*>(first_block_->buffer);
+    return elements[front_];
+  }
+
+  // Removes the element at the front of the queue from the queue. It is an
+  // error to call Pop() when the queue is empty.
+  void Pop() {
+    assert(!Empty());
+    T* elements = reinterpret_cast<T*>(first_block_->buffer);
+    elements[front_++].~T();
+    if (front_ == kBlockCapacity) {
+      // The first block has become empty.
+      front_ = 0;
+      if (first_block_ == last_block_) {
+        // Only one Block is in use. Simply reset back_.
+        back_ = 0;
+      } else {
+        first_block_ = first_block_->next;
+      }
+    }
+  }
+
+  // Returns true if the queue is empty.
+  bool Empty() const { return first_block_ == last_block_ && front_ == back_; }
+
+ private:
+  // kBlockCapacity is the maximum number of elements each Block can hold.
+  // sizeof(void*) is subtracted from 2048 to account for the |next| pointer in
+  // the Block struct.
+  //
+  // In Linux x86_64, sizeof(std::function<void()>) is 32, so each Block can
+  // hold 63 std::function<void()> objects.
+  //
+  // NOTE: The corresponding value in <deque> in libc++ revision
+  // 245b5ba3448b9d3f6de5962066557e253a6bc9a4 is:
+  //   template <class _ValueType, class _DiffType>
+  //   struct __deque_block_size {
+  //     static const _DiffType value =
+  //         sizeof(_ValueType) < 256 ? 4096 / sizeof(_ValueType) : 16;
+  //   };
+  //
+  // Note that 4096 / 256 = 16, so apparently this expression is intended to
+  // ensure the block size is at least 4096 bytes and each block can hold at
+  // least 16 elements.
+  static constexpr size_t kBlockCapacity =
+      (sizeof(T) < 128) ? (2048 - sizeof(void*)) / sizeof(T) : 16;
+
+  struct Block : public Allocable {
+    alignas(T) char buffer[kBlockCapacity * sizeof(T)];
+    Block* next;
+  };
+
+  void Destroy() {
+    if (first_block_ == nullptr) return;  // An uninitialized queue.
+
+    // First free the unused blocks, which are located after last_block and
+    // before first_block_.
+    Block* block = last_block_->next;
+    // Cut the circular list open after last_block_.
+    last_block_->next = nullptr;
+    while (block != first_block_) {
+      Block* next = block->next;
+      delete block;
+      block = next;
+    }
+
+    // Then free the used blocks. Destruct the elements in the used blocks.
+    while (block != nullptr) {
+      const size_t begin = (block == first_block_) ? front_ : 0;
+      const size_t end = (block == last_block_) ? back_ : kBlockCapacity;
+      T* elements = reinterpret_cast<T*>(block->buffer);
+      for (size_t i = begin; i < end; ++i) {
+        elements[i].~T();
+      }
+      Block* next = block->next;
+      delete block;
+      block = next;
+    }
+  }
+
+  // Blocks are chained in a circular singly-linked list. If the list of Blocks
+  // is empty, both first_block_ and last_block_ are null pointers. If the list
+  // is nonempty, first_block_ points to the first used Block and last_block_
+  // points to the last used Block.
+  //
+  // Invariant: If Init() is called and succeeds, the queue is always nonempty.
+  // This allows all methods (except the destructor) to avoid null pointer
+  // checks for first_block_ and last_block_.
+  Block* first_block_ = nullptr;
+  // The index of the element in first_block_ to be removed by Pop().
+  size_t front_ = 0;
+  Block* last_block_ = nullptr;
+  // The index in last_block_ where the new element is inserted by Push().
+  size_t back_ = 0;
+};
+
+template <typename T>
+constexpr size_t UnboundedQueue<T>::kBlockCapacity;
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_UTILS_UNBOUNDED_QUEUE_H_
diff --git a/libgav1/src/utils/vector.h b/libgav1/src/utils/vector.h
index ad10a6f..f60a0a0 100644
--- a/libgav1/src/utils/vector.h
+++ b/libgav1/src/utils/vector.h
@@ -1,90 +1,325 @@
+// libgav1::Vector implementation
+
 #ifndef LIBGAV1_SRC_UTILS_VECTOR_H_
 #define LIBGAV1_SRC_UTILS_VECTOR_H_
 
-#include <algorithm>
-#include <initializer_list>
+#include <cassert>
+#include <cstddef>
+#include <cstdlib>
+#include <cstring>
+#include <iterator>
+#include <type_traits>
 #include <utility>
-#include <vector>
 
-#include "src/utils/allocator.h"
 #include "src/utils/compiler_attributes.h"
 
 namespace libgav1 {
+namespace internal {
+
+static constexpr size_t kMinVectorAllocation = 16;
+
+// Returns the smallest power of two greater or equal to 'value'.
+inline size_t NextPow2(size_t value) {
+  if (value == 0) return 0;
+  --value;
+  for (size_t i = 1; i < sizeof(size_t) * 8; i *= 2) value |= value >> i;
+  return value + 1;
+}
+
+// Returns the smallest capacity greater or equal to 'value'.
+inline size_t NextCapacity(size_t value) {
+  if (value == 0) return 0;
+  if (value <= kMinVectorAllocation) return kMinVectorAllocation;
+  return NextPow2(value);
+}
 
 //------------------------------------------------------------------------------
-// Vector class that does *NOT* initialize the content by default, unless an
-// explicit value is passed to the constructor.
-// Should be reserved to POD preferably.
+// Data structure equivalent to std::vector but returning false and to its last
+// valid state on memory allocation failure.
+// std::vector with a custom allocator does not fill this need without
+// exceptions.
 
-// resize(), reserve(), and push_back() are overridden to return bool.
-//
-// New methods: ok() and CopyFrom().
-//
-// DO NOT USE emplace_back(), insert(), and emplace().
-template <typename T, typename super = std::vector<T, AllocatorNoCtor<T>>>
-class VectorNoCtor : public super {
+template <typename T>
+class VectorBase {
  public:
-  using super::super;
-  bool ok() const { return this->get_allocator().ok(); }
-  T* operator*() = delete;
-  LIBGAV1_MUST_USE_RESULT inline bool resize(size_t n) {
-    return ok() && (super::resize(n), ok());
+  using iterator = T*;
+  using const_iterator = const T*;
+
+  VectorBase() noexcept = default;
+  // Move only.
+  VectorBase(const VectorBase&) = delete;
+  VectorBase& operator=(const VectorBase&) = delete;
+  VectorBase(VectorBase&& other) noexcept
+      : items_(other.items_),
+        capacity_(other.capacity_),
+        num_items_(other.num_items_) {
+    other.items_ = nullptr;
+    other.capacity_ = 0;
+    other.num_items_ = 0;
   }
-  LIBGAV1_MUST_USE_RESULT inline bool reserve(size_t n) {
-    return ok() && (super::reserve(n), ok());
+  VectorBase& operator=(VectorBase&& other) noexcept {
+    if (this != &other) {
+      clear();
+      free(items_);
+      items_ = other.items_;
+      capacity_ = other.capacity_;
+      num_items_ = other.num_items_;
+      other.items_ = nullptr;
+      other.capacity_ = 0;
+      other.num_items_ = 0;
+    }
+    return *this;
   }
-  // Release the memory.
-  inline void reset() {
-    VectorNoCtor<T, super> tmp;
-    super::swap(tmp);
+  ~VectorBase() {
+    clear();
+    free(items_);
   }
 
-  // disable resizing ctors
-  VectorNoCtor(size_t size) noexcept = delete;
-  VectorNoCtor(size_t size, const T&) noexcept = delete;
-  VectorNoCtor& operator=(const VectorNoCtor& A) noexcept = delete;
-  VectorNoCtor(const VectorNoCtor& other) noexcept = delete;
-  template <typename InputIt>
-  VectorNoCtor(InputIt first, InputIt last) = delete;
-  VectorNoCtor(std::initializer_list<T> init) = delete;
-
-  // benign ctors
-  VectorNoCtor() noexcept : super() {}
-  VectorNoCtor& operator=(VectorNoCtor&& A) = default;
-  VectorNoCtor(VectorNoCtor&& A) noexcept : super(std::move(A)) {}
-
-  void assign(size_t count, const T& value) = delete;
-  template <typename InputIt>
-  void assign(InputIt first, InputIt last) = delete;
-  void assign(std::initializer_list<T> ilist) = delete;
-
-  // To be used instead of copy-ctor:
-  bool CopyFrom(const VectorNoCtor& A) {
-    if (!resize(A.size())) return false;
-    std::copy(A.begin(), A.end(), super::begin());
+  // Reallocates just enough memory if needed so that 'new_cap' items can fit.
+  LIBGAV1_MUST_USE_RESULT bool reserve(size_t new_cap) {
+    if (capacity_ < new_cap) {
+      T* const new_items = static_cast<T*>(malloc(new_cap * sizeof(T)));
+      if (new_items == nullptr) return false;
+      if (num_items_ > 0) {
+        if (std::is_trivial<T>::value) {
+          memcpy(new_items, items_, num_items_ * sizeof(T));
+        } else {
+          for (size_t i = 0; i < num_items_; ++i) {
+            new (&new_items[i]) T(std::move(items_[i]));
+            items_[i].~T();
+          }
+        }
+      }
+      free(items_);
+      items_ = new_items;
+      capacity_ = new_cap;
+    }
     return true;
   }
-  // Performs a push back *if* the vector was properly allocated.
-  // *NO* re-allocation happens.
-  LIBGAV1_MUST_USE_RESULT inline bool push_back(const T& v) {
-    if (super::size() < super::capacity()) {
-      super::push_back(v);
+
+  // Reallocates less memory so that only the existing items can fit.
+  bool shrink_to_fit() {
+    if (capacity_ == num_items_) return true;
+    if (num_items_ == 0) {
+      free(items_);
+      items_ = nullptr;
+      capacity_ = 0;
       return true;
     }
+    const size_t previous_capacity = capacity_;
+    capacity_ = 0;  // Force reserve() to allocate and copy.
+    if (reserve(num_items_)) return true;
+    capacity_ = previous_capacity;
     return false;
   }
-  LIBGAV1_MUST_USE_RESULT inline bool push_back(T&& v) {
-    if (super::size() < super::capacity()) {
-      super::push_back(v);
-      return true;
+
+  // Constructs a new item by copy constructor. May reallocate if
+  // 'resize_if_needed'.
+  LIBGAV1_MUST_USE_RESULT bool push_back(const T& value,
+                                         bool resize_if_needed = true) {
+    if (num_items_ >= capacity_ &&
+        (!resize_if_needed ||
+         !reserve(internal::NextCapacity(num_items_ + 1)))) {
+      return false;
     }
-    return false;
+    new (&items_[num_items_]) T(value);
+    ++num_items_;
+    return true;
+  }
+
+  // Constructs a new item by copy constructor. reserve() must have been called
+  // with a sufficient capacity.
+  //
+  // WARNING: No error checking is performed.
+  void push_back_unchecked(const T& value) {
+    assert(num_items_ < capacity_);
+    new (&items_[num_items_]) T(value);
+    ++num_items_;
+  }
+
+  // Constructs a new item by move constructor. May reallocate if
+  // 'resize_if_needed'.
+  LIBGAV1_MUST_USE_RESULT bool push_back(T&& value,
+                                         bool resize_if_needed = true) {
+    if (num_items_ >= capacity_ &&
+        (!resize_if_needed ||
+         !reserve(internal::NextCapacity(num_items_ + 1)))) {
+      return false;
+    }
+    new (&items_[num_items_]) T(std::move(value));
+    ++num_items_;
+    return true;
+  }
+
+  // Constructs a new item by move constructor. reserve() must have been called
+  // with a sufficient capacity.
+  //
+  // WARNING: No error checking is performed.
+  void push_back_unchecked(T&& value) {
+    assert(num_items_ < capacity_);
+    new (&items_[num_items_]) T(std::move(value));
+    ++num_items_;
+  }
+
+  // Constructs a new item in place by forwarding the arguments args... to the
+  // constructor. May reallocate.
+  template <typename... Args>
+  LIBGAV1_MUST_USE_RESULT bool emplace_back(Args&&... args) {
+    if (num_items_ >= capacity_ &&
+        !reserve(internal::NextCapacity(num_items_ + 1))) {
+      return false;
+    }
+    new (&items_[num_items_]) T(std::forward<Args>(args)...);
+    ++num_items_;
+    return true;
+  }
+
+  // Destructs the last item.
+  void pop_back() {
+    --num_items_;
+    items_[num_items_].~T();
+  }
+
+  // Destructs the item at 'pos'.
+  void erase(iterator pos) { erase(pos, pos + 1); }
+
+  // Destructs the items in [first,last).
+  void erase(iterator first, iterator last) {
+    for (iterator it = first; it != last; ++it) it->~T();
+    if (last != end()) {
+      if (std::is_trivial<T>::value) {
+        memmove(first, last, (end() - last) * sizeof(T));
+      } else {
+        for (iterator it_src = last, it_dst = first; it_src != end();
+             ++it_src, ++it_dst) {
+          new (it_dst) T(std::move(*it_src));
+          it_src->~T();
+        }
+      }
+    }
+    num_items_ -= std::distance(first, last);
+  }
+
+  // Destructs all the items.
+  void clear() { erase(begin(), end()); }
+
+  // Destroys (including deallocating) all the items.
+  void reset() {
+    clear();
+    if (!shrink_to_fit()) assert(false);
+  }
+
+  // Accessors
+  bool empty() const { return (num_items_ == 0); }
+  size_t size() const { return num_items_; }
+  size_t capacity() const { return capacity_; }
+
+  T* data() { return items_; }
+  T& front() { return items_[0]; }
+  T& back() { return items_[num_items_ - 1]; }
+  T& operator[](size_t i) { return items_[i]; }
+  T& at(size_t i) { return items_[i]; }
+  const T* data() const { return items_; }
+  const T& front() const { return items_[0]; }
+  const T& back() const { return items_[num_items_ - 1]; }
+  const T& operator[](size_t i) const { return items_[i]; }
+  const T& at(size_t i) const { return items_[i]; }
+
+  iterator begin() { return &items_[0]; }
+  const_iterator begin() const { return &items_[0]; }
+  iterator end() { return &items_[num_items_]; }
+  const_iterator end() const { return &items_[num_items_]; }
+
+  void swap(VectorBase& b) {
+    // Although not necessary here, adding "using std::swap;" and then calling
+    // swap() without namespace qualification is recommended. See Effective
+    // C++, Item 25.
+    using std::swap;
+    swap(items_, b.items_);
+    swap(capacity_, b.capacity_);
+    swap(num_items_, b.num_items_);
+  }
+
+ protected:
+  T* items_ = nullptr;
+  size_t capacity_ = 0;
+  size_t num_items_ = 0;
+};
+
+}  // namespace internal
+
+//------------------------------------------------------------------------------
+
+// Vector class that does *NOT* construct the content on resize().
+// Should be reserved to plain old data.
+template <typename T>
+class VectorNoCtor : public internal::VectorBase<T> {
+ public:
+  // Creates or destructs items so that 'new_num_items' exist.
+  // Allocated memory grows every power-of-two items.
+  LIBGAV1_MUST_USE_RESULT bool resize(size_t new_num_items) {
+    using super = internal::VectorBase<T>;
+    if (super::num_items_ < new_num_items) {
+      if (super::capacity_ < new_num_items) {
+        if (!super::reserve(internal::NextCapacity(new_num_items))) {
+          return false;
+        }
+      }
+      super::num_items_ = new_num_items;
+    } else {
+      while (super::num_items_ > new_num_items) {
+        --super::num_items_;
+        super::items_[super::num_items_].~T();
+      }
+    }
+    return true;
   }
 };
 
-// This generic vector class will call the constructors
+// This generic vector class will call the constructors.
 template <typename T>
-using Vector = VectorNoCtor<T, std::vector<T, Allocator<T>>>;
+class Vector : public internal::VectorBase<T> {
+ public:
+  // Constructs or destructs items so that 'new_num_items' exist.
+  // Allocated memory grows every power-of-two items.
+  LIBGAV1_MUST_USE_RESULT bool resize(size_t new_num_items) {
+    using super = internal::VectorBase<T>;
+    if (super::num_items_ < new_num_items) {
+      if (super::capacity_ < new_num_items) {
+        if (!super::reserve(internal::NextCapacity(new_num_items))) {
+          return false;
+        }
+      }
+      while (super::num_items_ < new_num_items) {
+        new (&super::items_[super::num_items_]) T();
+        ++super::num_items_;
+      }
+    } else {
+      while (super::num_items_ > new_num_items) {
+        --super::num_items_;
+        super::items_[super::num_items_].~T();
+      }
+    }
+    return true;
+  }
+};
+
+//------------------------------------------------------------------------------
+
+// Define non-member swap() functions in the namespace in which VectorNoCtor
+// and Vector are implemented. See Effective C++, Item 25.
+
+template <typename T>
+void swap(VectorNoCtor<T>& a, VectorNoCtor<T>& b) {
+  a.swap(b);
+}
+
+template <typename T>
+void swap(Vector<T>& a, Vector<T>& b) {
+  a.swap(b);
+}
+
+//------------------------------------------------------------------------------
 
 }  // namespace libgav1
 
diff --git a/libgav1/src/warp_prediction.cc b/libgav1/src/warp_prediction.cc
index d0d5788..830f951 100644
--- a/libgav1/src/warp_prediction.cc
+++ b/libgav1/src/warp_prediction.cc
@@ -134,16 +134,12 @@
   int bx[2] = {};
   int by[2] = {};
 
-  // TODO(chengchen): for simplicity, the spec always uses absolute coordinates
+  // Note: for simplicity, the spec always uses absolute coordinates
   // in the warp estimation process. subpixel_mid_x, subpixel_mid_y,
   // and candidates are relative to the top left of the frame.
   // In contrast, libaom uses a mixture of coordinate systems.
   // In av1/common/warped_motion.c:find_affine_int(). The coordinate is relative
   // to the top left of the block.
-  // On one hand, we need to make sure libgav1 always keep consistency in
-  // coordinate system.
-  // On the other hand, we might investigate which representation is better for
-  // the sake of efficiency.
   // mid_y/mid_x: the row/column coordinate of the center of the block.
   const int mid_y = MultiplyBy4(row4x4) + MultiplyBy2(block_height4x4) - 1;
   const int mid_x = MultiplyBy4(column4x4) + MultiplyBy2(block_width4x4) - 1;
@@ -167,8 +163,6 @@
     // block, with center of the reference block as origin.
     const int dy = candidates[i][2] - reference_subpixel_mid_y;
     const int dx = candidates[i][3] - reference_subpixel_mid_x;
-    // TODO(chengchen): If this check is done somewhere (in find samples), we
-    // can remove it. If remove, remember to change unit tests input range.
     if (std::abs(sx - dx) < kLargestMotionVectorDiff &&
         std::abs(sy - dy) < kLargestMotionVectorDiff) {
       a[0][0] += LeastSquareProduct(sx, sx) + 8;