external/libgav1: update to v0.16.0

clears local cherry-picks.

single-threaded is ~1.5-3.4x faster and multi-threaded is ~1.4x-2.7x
faster than the previous snapshot on a sargo device.

Bug: 162267932
Test: playback, cts on sargo-userdebug / sargo_hwasan-userdebug
Test: aosp_arm-eng aosp_arm64-eng aosp_x86-eng aosp_x86_64-eng aosp_crosshatch-userdebug build
Change-Id: I63361c388d8cf6f652774b07c9475c45951599c5
Merged-In: I63361c388d8cf6f652774b07c9475c45951599c5
(cherry picked from commit 9a38f2a0f6e268ffad1856425ca8f429c48c2cc1)
diff --git a/Android.bp b/Android.bp
index 3d5b91a..e6a47be 100644
--- a/Android.bp
+++ b/Android.bp
@@ -18,6 +18,7 @@
 
     export_include_dirs: [
         ".",
+        "libgav1/src",
     ],
 
     cflags: [
@@ -40,10 +41,12 @@
         "libgav1/src/buffer_pool.cc",
         "libgav1/src/decoder.cc",
         "libgav1/src/decoder_impl.cc",
-        "libgav1/src/decoder_scratch_buffer.cc",
+        "libgav1/src/decoder_settings.cc",
         "libgav1/src/dsp/arm/average_blend_neon.cc",
+        "libgav1/src/dsp/arm/cdef_neon.cc",
         "libgav1/src/dsp/arm/convolve_neon.cc",
         "libgav1/src/dsp/arm/distance_weighted_blend_neon.cc",
+        "libgav1/src/dsp/arm/film_grain_neon.cc",
         "libgav1/src/dsp/arm/intra_edge_neon.cc",
         "libgav1/src/dsp/arm/intrapred_cfl_neon.cc",
         "libgav1/src/dsp/arm/intrapred_directional_neon.cc",
@@ -54,13 +57,16 @@
         "libgav1/src/dsp/arm/loop_filter_neon.cc",
         "libgav1/src/dsp/arm/loop_restoration_neon.cc",
         "libgav1/src/dsp/arm/mask_blend_neon.cc",
+        "libgav1/src/dsp/arm/motion_field_projection_neon.cc",
+        "libgav1/src/dsp/arm/motion_vector_search_neon.cc",
         "libgav1/src/dsp/arm/obmc_neon.cc",
+        "libgav1/src/dsp/arm/super_res_neon.cc",
         "libgav1/src/dsp/arm/warp_neon.cc",
+        "libgav1/src/dsp/arm/weight_mask_neon.cc",
         "libgav1/src/dsp/average_blend.cc",
         "libgav1/src/dsp/cdef.cc",
         "libgav1/src/dsp/constants.cc",
         "libgav1/src/dsp/convolve.cc",
-        "libgav1/src/dsp/cpu.cc",
         "libgav1/src/dsp/distance_weighted_blend.cc",
         "libgav1/src/dsp/dsp.cc",
         "libgav1/src/dsp/film_grain.cc",
@@ -70,9 +76,14 @@
         "libgav1/src/dsp/loop_filter.cc",
         "libgav1/src/dsp/loop_restoration.cc",
         "libgav1/src/dsp/mask_blend.cc",
+        "libgav1/src/dsp/motion_field_projection.cc",
+        "libgav1/src/dsp/motion_vector_search.cc",
         "libgav1/src/dsp/obmc.cc",
+        "libgav1/src/dsp/super_res.cc",
         "libgav1/src/dsp/warp.cc",
+        "libgav1/src/dsp/weight_mask.cc",
         "libgav1/src/dsp/x86/average_blend_sse4.cc",
+        "libgav1/src/dsp/x86/cdef_sse4.cc",
         "libgav1/src/dsp/x86/convolve_sse4.cc",
         "libgav1/src/dsp/x86/distance_weighted_blend_sse4.cc",
         "libgav1/src/dsp/x86/intra_edge_sse4.cc",
@@ -82,17 +93,29 @@
         "libgav1/src/dsp/x86/inverse_transform_sse4.cc",
         "libgav1/src/dsp/x86/loop_filter_sse4.cc",
         "libgav1/src/dsp/x86/loop_restoration_sse4.cc",
+        "libgav1/src/dsp/x86/mask_blend_sse4.cc",
+        "libgav1/src/dsp/x86/motion_field_projection_sse4.cc",
+        "libgav1/src/dsp/x86/motion_vector_search_sse4.cc",
         "libgav1/src/dsp/x86/obmc_sse4.cc",
+        "libgav1/src/dsp/x86/super_res_sse4.cc",
+        "libgav1/src/dsp/x86/warp_sse4.cc",
+        "libgav1/src/dsp/x86/weight_mask_sse4.cc",
+        "libgav1/src/film_grain.cc",
+        "libgav1/src/frame_buffer.cc",
         "libgav1/src/internal_frame_buffer_list.cc",
-        "libgav1/src/loop_filter_mask.cc",
         "libgav1/src/loop_restoration_info.cc",
         "libgav1/src/motion_vector.cc",
         "libgav1/src/obu_parser.cc",
-        "libgav1/src/post_filter.cc",
+        "libgav1/src/post_filter/cdef.cc",
+        "libgav1/src/post_filter/deblock.cc",
+        "libgav1/src/post_filter/loop_restoration.cc",
+        "libgav1/src/post_filter/post_filter.cc",
+        "libgav1/src/post_filter/super_res.cc",
         "libgav1/src/prediction_mask.cc",
         "libgav1/src/quantizer.cc",
         "libgav1/src/reconstruction.cc",
         "libgav1/src/residual_buffer_pool.cc",
+        "libgav1/src/status_code.cc",
         "libgav1/src/symbol_decoder_context.cc",
         "libgav1/src/threading_strategy.cc",
         "libgav1/src/tile/bitstream/mode_info.cc",
@@ -100,10 +123,12 @@
         "libgav1/src/tile/bitstream/partition.cc",
         "libgav1/src/tile/bitstream/transform_size.cc",
         "libgav1/src/tile/prediction.cc",
+        "libgav1/src/tile_scratch_buffer.cc",
         "libgav1/src/tile/tile.cc",
         "libgav1/src/utils/bit_reader.cc",
         "libgav1/src/utils/block_parameters_holder.cc",
         "libgav1/src/utils/constants.cc",
+        "libgav1/src/utils/cpu.cc",
         "libgav1/src/utils/entropy_decoder.cc",
         "libgav1/src/utils/executor.cc",
         "libgav1/src/utils/logging.cc",
@@ -112,6 +137,7 @@
         "libgav1/src/utils/segmentation.cc",
         "libgav1/src/utils/segmentation_map.cc",
         "libgav1/src/utils/threadpool.cc",
+        "libgav1/src/version.cc",
         "libgav1/src/warp_prediction.cc",
         "libgav1/src/yuv_buffer.cc",
     ],
diff --git a/README.version b/README.version
index 5d15f7f..b65b65a 100644
--- a/README.version
+++ b/README.version
@@ -1,11 +1,5 @@
 URL: https://chromium.googlesource.com/codecs/libgav1
-Version: cl/267700628
+Version: v0.16.0
 BugComponent: 324837
 Local Modifications:
-- ab3390a external/libgav1,cosmetics: add license headers
-- backport cl/281117442: Fully use the frame border for reference block.
-- backport cl/289984918: convolve: Use the correct subsampling for ref frames
-- backport cl/289966078: Move initial_display_delay out of OperatingParamet
-- backport cl/290784565: Handle a change of sequence header parameters.
-- backport cl/291222461: Disallow change of sequence header during a frame.
-- backport cl/289910031: obu: Check for size validity in SetTileDataOffset
+None
diff --git a/libgav1/.gitignore b/libgav1/.gitignore
new file mode 100644
index 0000000..87ccf24
--- /dev/null
+++ b/libgav1/.gitignore
@@ -0,0 +1,2 @@
+/build
+/third_party
diff --git a/libgav1/AUTHORS b/libgav1/AUTHORS
new file mode 100644
index 0000000..d92ea0a
--- /dev/null
+++ b/libgav1/AUTHORS
@@ -0,0 +1,6 @@
+# This is the list of libgav1 authors for copyright purposes.
+#
+# This does not necessarily list everyone who has contributed code, since in
+# some cases, their employer may be the copyright holder.  To see the full list
+# of contributors, see the revision history in source control.
+Google LLC
diff --git a/libgav1/CMakeLists.txt b/libgav1/CMakeLists.txt
new file mode 100644
index 0000000..f033bae
--- /dev/null
+++ b/libgav1/CMakeLists.txt
@@ -0,0 +1,124 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# libgav1 requires modern CMake.
+cmake_minimum_required(VERSION 3.7.1 FATAL_ERROR)
+
+# libgav1 requires C++11.
+set(CMAKE_CXX_STANDARD 11)
+set(ABSL_CXX_STANDARD 11)
+
+project(libgav1 CXX)
+
+set(libgav1_root "${CMAKE_CURRENT_SOURCE_DIR}")
+set(libgav1_build "${CMAKE_BINARY_DIR}")
+
+if("${libgav1_root}" STREQUAL "${libgav1_build}")
+  message(
+    FATAL_ERROR
+      "Building from within the libgav1 source tree is not supported.\n"
+      "Hint: Run these commands\n" "$ rm -rf CMakeCache.txt CMakeFiles\n"
+      "$ mkdir -p ../libgav1_build\n" "$ cd ../libgav1_build\n"
+      "And re-run CMake from the libgav1_build directory.")
+endif()
+
+set(libgav1_examples "${libgav1_root}/examples")
+set(libgav1_source "${libgav1_root}/src")
+
+include(FindThreads)
+
+include("${libgav1_examples}/libgav1_examples.cmake")
+include("${libgav1_root}/cmake/libgav1_build_definitions.cmake")
+include("${libgav1_root}/cmake/libgav1_cpu_detection.cmake")
+include("${libgav1_root}/cmake/libgav1_flags.cmake")
+include("${libgav1_root}/cmake/libgav1_helpers.cmake")
+include("${libgav1_root}/cmake/libgav1_install.cmake")
+include("${libgav1_root}/cmake/libgav1_intrinsics.cmake")
+include("${libgav1_root}/cmake/libgav1_options.cmake")
+include("${libgav1_root}/cmake/libgav1_sanitizer.cmake")
+include("${libgav1_root}/cmake/libgav1_targets.cmake")
+include("${libgav1_root}/cmake/libgav1_variables.cmake")
+include("${libgav1_source}/dsp/libgav1_dsp.cmake")
+include("${libgav1_source}/libgav1_decoder.cmake")
+include("${libgav1_source}/utils/libgav1_utils.cmake")
+
+libgav1_option(NAME LIBGAV1_ENABLE_OPTIMIZATIONS HELPSTRING
+               "Enables optimized code." VALUE ON)
+libgav1_option(NAME LIBGAV1_ENABLE_NEON HELPSTRING "Enables neon optimizations."
+               VALUE ON)
+libgav1_option(NAME LIBGAV1_ENABLE_SSE4_1 HELPSTRING
+               "Enables sse4.1 optimizations." VALUE ON)
+libgav1_option(
+  NAME LIBGAV1_VERBOSE HELPSTRING
+  "Enables verbose build system output. Higher numbers are more verbose." VALUE
+  OFF)
+
+if(NOT CMAKE_BUILD_TYPE)
+  set(CMAKE_BUILD_TYPE Release)
+endif()
+
+libgav1_optimization_detect()
+libgav1_set_build_definitions()
+libgav1_set_cxx_flags()
+libgav1_configure_sanitizer()
+
+# Supported bit depth.
+libgav1_track_configuration_variable(LIBGAV1_MAX_BITDEPTH)
+
+# C++ and linker flags.
+libgav1_track_configuration_variable(LIBGAV1_CXX_FLAGS)
+libgav1_track_configuration_variable(LIBGAV1_EXE_LINKER_FLAGS)
+
+# Sanitizer integration.
+libgav1_track_configuration_variable(LIBGAV1_SANITIZE)
+
+# Generated source file directory.
+libgav1_track_configuration_variable(LIBGAV1_GENERATED_SOURCES_DIRECTORY)
+
+# Controls use of std::mutex and absl::Mutex in ThreadPool.
+libgav1_track_configuration_variable(LIBGAV1_THREADPOOL_USE_STD_MUTEX)
+
+if(LIBGAV1_VERBOSE)
+  libgav1_dump_cmake_flag_variables()
+  libgav1_dump_tracked_configuration_variables()
+  libgav1_dump_options()
+endif()
+
+set(libgav1_abseil_build "${libgav1_build}/abseil")
+set(libgav1_gtest_build "${libgav1_build}/gtest")
+
+# Compiler/linker flags must be lists, but come in from the environment as
+# strings. Break them up:
+if(NOT "${LIBGAV1_CXX_FLAGS}" STREQUAL "")
+  separate_arguments(LIBGAV1_CXX_FLAGS)
+endif()
+if(NOT "${LIBGAV1_EXE_LINKER_FLAGS}" STREQUAL "")
+  separate_arguments(LIBGAV1_EXE_LINKER_FLAGS)
+endif()
+
+add_subdirectory("${libgav1_root}/third_party/abseil-cpp"
+                 "${libgav1_abseil_build}" EXCLUDE_FROM_ALL)
+
+libgav1_reset_target_lists()
+libgav1_add_dsp_targets()
+libgav1_add_decoder_targets()
+libgav1_add_examples_targets()
+libgav1_add_utils_targets()
+libgav1_setup_install_target()
+
+if(LIBGAV1_VERBOSE)
+  libgav1_dump_cmake_flag_variables()
+  libgav1_dump_tracked_configuration_variables()
+  libgav1_dump_options()
+endif()
diff --git a/libgav1/CONTRIBUTING.md b/libgav1/CONTRIBUTING.md
new file mode 100644
index 0000000..69140ff
--- /dev/null
+++ b/libgav1/CONTRIBUTING.md
@@ -0,0 +1,27 @@
+# How to Contribute
+
+We'd love to accept your patches and contributions to this project. There are
+just a few small guidelines you need to follow.
+
+## Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement. You (or your employer) retain the copyright to your contribution;
+this simply gives us permission to use and redistribute your contributions as
+part of the project. Head over to <https://cla.developers.google.com/> to see
+your current agreements on file or to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+## Code reviews
+
+All submissions, including submissions by project members, require review. We
+use a [Gerrit](https://www.gerritcodereview.com) instance hosted at
+https://chromium-review.googlesource.com for this purpose.
+
+## Community Guidelines
+
+This project follows
+[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
diff --git a/libgav1/README.md b/libgav1/README.md
new file mode 100644
index 0000000..b935679
--- /dev/null
+++ b/libgav1/README.md
@@ -0,0 +1,165 @@
+# libgav1 -- an AV1 decoder
+
+libgav1 is a Main profile (0) & High profile (1) compliant AV1 decoder. More
+information on the AV1 video format can be found at
+[aomedia.org](https://aomedia.org).
+
+[TOC]
+
+## Building
+
+### Prerequisites
+
+1.  A C++11 compiler. gcc 6+, clang 7+ or Microsoft Visual Studio 2017+ are
+    recommended.
+
+2.  [CMake >= 3.7.1](https://cmake.org/download/)
+
+3.  [Abseil](https://abseil.io)
+
+    From within the libgav1 directory:
+
+    ```shell
+      $ git clone https://github.com/abseil/abseil-cpp.git third_party/abseil-cpp
+    ```
+
+### Compile
+
+```shell
+  $ mkdir build && cd build
+  $ cmake -G "Unix Makefiles" ..
+  $ make
+```
+
+Configuration options:
+
+*   `LIBGAV1_MAX_BITDEPTH`: defines the maximum supported bitdepth (8, 10;
+    default: 10).
+*   `LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS`: define to a non-zero value to disable
+    [symbol reduction](#symbol-reduction) in an optimized build to keep all
+    versions of dsp functions available. Automatically defined in
+    `src/dsp/dsp.h` if unset.
+*   `LIBGAV1_ENABLE_NEON`: define to a non-zero value to enable NEON
+    optimizations. Automatically defined in `src/dsp/dsp.h` if unset.
+*   `LIBGAV1_ENABLE_SSE4_1`: define to a non-zero value to enable sse4.1
+    optimizations. Automatically defined in `src/dsp/dsp.h` if unset.
+*   `LIBGAV1_ENABLE_LOGGING`: define to 0/1 to control debug logging.
+    Automatically defined in `src/utils/logging.h` if unset.
+*   `LIBGAV1_EXAMPLES_ENABLE_LOGGING`: define to 0/1 to control error logging in
+    the examples. Automatically defined in `examples/logging.h` if unset.
+*   `LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK`: define to 1 to enable transform
+    coefficient range checks.
+*   `LIBGAV1_LOG_LEVEL`: controls the maximum allowed log level, see `enum
+    LogSeverity` in `src/utils/logging.h`. Automatically defined in
+    `src/utils/logging.cc` if unset.
+*   `LIBGAV1_THREADPOOL_USE_STD_MUTEX`: controls use of std::mutex and
+    absl::Mutex in ThreadPool. Defining this to 1 will remove any Abseil
+    dependency from the core library. Automatically defined in
+    `src/utils/threadpool.h` if unset.
+*   `LIBGAV1_MAX_THREADS`: sets the number of threads that the library is
+    allowed to create. Has to be an integer > 0. Otherwise this is ignored.
+    The default value is 128.
+*   `LIBGAV1_FRAME_PARALLEL_THRESHOLD_MULTIPLIER`: the threshold multiplier that
+    is used to determine when to use frame parallel decoding. Frame parallel
+    decoding will be used if |threads| > |tile_count| * this multiplier. Has to
+    be an integer > 0. The default value is 4. This is an advanced setting
+    intended for testing purposes.
+
+For additional options see:
+
+```shell
+  $ cmake .. -LH
+```
+
+## Testing
+
+*   `gav1_decode` can be used to decode IVF files, see `gav1_decode --help` for
+    options. Note: tools like [FFmpeg](https://ffmpeg.org) can be used to
+    convert other container formats to IVF.
+
+## Development
+
+### Contributing
+
+See [CONTRIBUTING.md](CONTRIBUTING.md) for details on how to submit patches.
+
+### Style
+
+libgav1 follows the
+[Google C++ style guide](https://google.github.io/styleguide/cppguide.html) with
+formatting enforced by `clang-format`.
+
+### Comments
+
+Comments of the form '`// X.Y(.Z).`', '`Section X.Y(.Z).`' or '`... in the
+spec`' reference the relevant section(s) in the
+[AV1 specification](http://aomediacodec.github.io/av1-spec/av1-spec.pdf).
+
+### DSP structure
+
+*   `src/dsp/dsp.cc` defines the main entry point: `libgav1::dsp::DspInit()`.
+    This handles cpu-detection and initializing each logical unit which populate
+    `libgav1::dsp::Dsp` function tables.
+*   `src/dsp/dsp.h` contains function and type definitions for all logical units
+    (e.g., intra-predictors)
+*   `src/utils/cpu.h` contains definitions for cpu-detection
+*   base implementations are located in `src/dsp/*.{h,cc}` with platform
+    specific optimizations in sub-folders
+*   unit tests define `DISABLED_Speed` test(s) to allow timing of individual
+    functions
+
+#### Symbol reduction
+
+Based on the build configuration unneeded lesser optimizations are removed using
+a hierarchical include and define system. Each logical unit in `src/dsp` should
+include all platform specific headers in descending order to allow higher level
+optimizations to disable lower level ones. See `src/dsp/loop_filter.h` for an
+example.
+
+Each function receives a new define which can be checked in platform specific
+headers. The format is: `LIBGAV1_<Dsp-table>_FunctionName` or
+`LIBGAV1_<Dsp-table>_[sub-table-index1][...-indexN]`, e.g.,
+`LIBGAV1_Dsp8bpp_AverageBlend`,
+`LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDc`. The Dsp-table name is of
+the form `Dsp<bitdepth>bpp` e.g. `Dsp10bpp` for bitdepth == 10 (bpp stands for
+bits per pixel). The indices correspond to enum values used as lookups with
+leading 'k' removed. Platform specific headers then should first check if the
+symbol is defined and if not set the value to the corresponding
+`LIBGAV1_CPU_<arch>` value from `src/utils/cpu.h`.
+
+```
+  #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDc
+  #define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_CPU_SSE4_1
+  #endif
+```
+
+Within each module the code should check if the symbol is defined to its
+specific architecture or forced via `LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS` before
+defining the function. The `DSP_ENABLED_(8|10)BPP_*` macros are available to
+simplify this check for optimized code.
+
+```
+  #if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_IntraPredictorDc)
+  ...
+
+  // In unoptimized code use the following structure; there's no equivalent
+  // define for LIBGAV1_CPU_C as it would require duplicating the function
+  // defines used in optimized code for only a small benefit to this
+  // boilerplate.
+  #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  ...
+  #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcFill
+  ...
+```
+
+## Bugs
+
+Please report all bugs to the issue tracker:
+https://issuetracker.google.com/issues/new?component=750480&template=1355007
+
+## Discussion
+
+Email: gav1-devel@googlegroups.com
+
+Web: https://groups.google.com/forum/#!forum/gav1-devel
diff --git a/libgav1/cmake/libgav1-config.cmake.template b/libgav1/cmake/libgav1-config.cmake.template
new file mode 100644
index 0000000..dc253d3
--- /dev/null
+++ b/libgav1/cmake/libgav1-config.cmake.template
@@ -0,0 +1,2 @@
+set(LIBGAV1_INCLUDE_DIRS "@LIBGAV1_INCLUDE_DIRS@")
+set(LIBGAV1_LIBRARIES "gav1")
diff --git a/libgav1/cmake/libgav1.pc.template b/libgav1/cmake/libgav1.pc.template
new file mode 100644
index 0000000..c571a43
--- /dev/null
+++ b/libgav1/cmake/libgav1.pc.template
@@ -0,0 +1,11 @@
+prefix=@prefix@
+exec_prefix=@exec_prefix@
+libdir=@libdir@
+includedir=@includedir@
+
+Name: @PROJECT_NAME@
+Description: AV1 decoder library (@LIBGAV1_MAX_BITDEPTH@-bit).
+Version: @LIBGAV1_VERSION@
+Cflags: -I${includedir}
+Libs: -L${libdir} -lgav1
+Libs.private: @CMAKE_THREAD_LIBS_INIT@
diff --git a/libgav1/cmake/libgav1_build_definitions.cmake b/libgav1/cmake/libgav1_build_definitions.cmake
new file mode 100644
index 0000000..930d8f5
--- /dev/null
+++ b/libgav1/cmake/libgav1_build_definitions.cmake
@@ -0,0 +1,149 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_CMAKE_LIBGAV1_BUILD_DEFINITIONS_CMAKE_)
+  return()
+endif() # LIBGAV1_CMAKE_LIBGAV1_BUILD_DEFINITIONS_CMAKE_
+set(LIBGAV1_CMAKE_LIBGAV1_BUILD_DEFINITIONS_CMAKE_ 1)
+
+macro(libgav1_set_build_definitions)
+  string(TOLOWER "${CMAKE_BUILD_TYPE}" build_type_lowercase)
+
+  libgav1_load_version_info()
+  set(LIBGAV1_SOVERSION 0)
+
+  list(APPEND libgav1_include_paths "${libgav1_root}" "${libgav1_root}/src"
+              "${libgav1_build}" "${libgav1_root}/third_party/abseil-cpp")
+  list(APPEND libgav1_gtest_include_paths
+              "third_party/googletest/googlemock/include"
+              "third_party/googletest/googletest/include"
+              "third_party/googletest/googletest")
+  list(APPEND libgav1_test_include_paths ${libgav1_include_paths}
+              ${libgav1_gtest_include_paths})
+  list(APPEND libgav1_defines "LIBGAV1_CMAKE=1"
+              "LIBGAV1_FLAGS_SRCDIR=\"${libgav1_root}\""
+              "LIBGAV1_FLAGS_TMPDIR=\"/tmp\"")
+
+  if(MSVC OR WIN32)
+    list(APPEND libgav1_defines "_CRT_SECURE_NO_DEPRECATE=1" "NOMINMAX=1")
+  endif()
+
+  if(ANDROID)
+    if(CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a")
+      set(CMAKE_ANDROID_ARM_MODE ON)
+    endif()
+
+    if(build_type_lowercase MATCHES "rel")
+      list(APPEND libgav1_base_cxx_flags "-fno-stack-protector")
+    endif()
+  endif()
+
+  list(APPEND libgav1_base_cxx_flags "-Wall" "-Wextra" "-Wmissing-declarations"
+              "-Wno-sign-compare" "-fvisibility=hidden"
+              "-fvisibility-inlines-hidden")
+
+  if(BUILD_SHARED_LIBS)
+    set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+    set(libgav1_dependency libgav1_shared)
+  else()
+    set(libgav1_dependency libgav1_static)
+  endif()
+
+  list(APPEND libgav1_clang_cxx_flags "-Wextra-semi" "-Wmissing-prototypes"
+              "-Wshorten-64-to-32")
+
+  if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
+    if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "6")
+      # Quiet warnings in copy-list-initialization where {} elision has always
+      # been allowed.
+      list(APPEND libgav1_clang_cxx_flags "-Wno-missing-braces")
+    endif()
+    if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8)
+      list(APPEND libgav1_clang_cxx_flags "-Wextra-semi-stmt")
+    endif()
+  endif()
+
+  if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+    if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "7")
+      # Quiet warnings due to potential snprintf() truncation in threadpool.cc.
+      list(APPEND libgav1_base_cxx_flags "-Wno-format-truncation")
+
+      if(CMAKE_SYSTEM_PROCESSOR STREQUAL "armv7")
+        # Quiet gcc 6 vs 7 abi warnings:
+        # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=77728
+        list(APPEND libgav1_base_cxx_flags "-Wno-psabi")
+        list(APPEND ABSL_GCC_FLAGS "-Wno-psabi")
+      endif()
+    endif()
+  endif()
+
+  if(build_type_lowercase MATCHES "rel")
+    # TODO(tomfinegan): this value is only a concern for the core library and
+    # can be made smaller if the test targets are avoided.
+    list(APPEND libgav1_base_cxx_flags "-Wstack-usage=196608")
+  endif()
+
+  list(APPEND libgav1_msvc_cxx_flags
+              # Warning level 3.
+              "/W3"
+              # Disable warning C4018:
+              # '<comparison operator>' signed/unsigned mismatch
+              "/wd4018"
+              # Disable warning C4244:
+              # 'argument': conversion from '<double/int>' to
+              # '<float/smaller int type>', possible loss of data
+              "/wd4244"
+              # Disable warning C4267:
+              # '=': conversion from '<double/int>' to
+              # '<float/smaller int type>', possible loss of data
+              "/wd4267"
+              # Disable warning C4309:
+              # 'argument': truncation of constant value
+              "/wd4309"
+              # Disable warning C4551:
+              # function call missing argument list
+              "/wd4551")
+
+  if(BUILD_SHARED_LIBS)
+    list(APPEND libgav1_msvc_cxx_flags
+                # Disable warning C4251:
+                # 'libgav1::DecoderImpl class member' needs to have
+                # dll-interface to be used by clients of class
+                # 'libgav1::Decoder'.
+                "/wd4251")
+  endif()
+
+  if(NOT LIBGAV1_MAX_BITDEPTH)
+    set(LIBGAV1_MAX_BITDEPTH 10)
+  elseif(NOT LIBGAV1_MAX_BITDEPTH EQUAL 8 AND NOT LIBGAV1_MAX_BITDEPTH EQUAL 10)
+    libgav1_die("LIBGAV1_MAX_BITDEPTH must be 8 or 10.")
+  endif()
+
+  list(APPEND libgav1_defines "LIBGAV1_MAX_BITDEPTH=${LIBGAV1_MAX_BITDEPTH}")
+
+  if(DEFINED LIBGAV1_THREADPOOL_USE_STD_MUTEX)
+    if(NOT LIBGAV1_THREADPOOL_USE_STD_MUTEX EQUAL 0
+       AND NOT LIBGAV1_THREADPOOL_USE_STD_MUTEX EQUAL 1)
+      libgav1_die("LIBGAV1_THREADPOOL_USE_STD_MUTEX must be 0 or 1.")
+    endif()
+
+    list(APPEND libgav1_defines
+         "LIBGAV1_THREADPOOL_USE_STD_MUTEX=${LIBGAV1_THREADPOOL_USE_STD_MUTEX}")
+  endif()
+
+  # Source file names ending in these suffixes will have the appropriate
+  # compiler flags added to their compile commands to enable intrinsics.
+  set(libgav1_neon_source_file_suffix "neon.cc")
+  set(libgav1_sse4_source_file_suffix "sse4.cc")
+endmacro()
diff --git a/libgav1/cmake/libgav1_cpu_detection.cmake b/libgav1/cmake/libgav1_cpu_detection.cmake
new file mode 100644
index 0000000..6972d34
--- /dev/null
+++ b/libgav1/cmake/libgav1_cpu_detection.cmake
@@ -0,0 +1,42 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_CMAKE_LIBGAV1_CPU_DETECTION_CMAKE_)
+  return()
+endif() # LIBGAV1_CMAKE_LIBGAV1_CPU_DETECTION_CMAKE_
+set(LIBGAV1_CMAKE_LIBGAV1_CPU_DETECTION_CMAKE_ 1)
+
+# Detect optimizations available for the current target CPU.
+macro(libgav1_optimization_detect)
+  if(LIBGAV1_ENABLE_OPTIMIZATIONS)
+    string(TOLOWER "${CMAKE_SYSTEM_PROCESSOR}" cpu_lowercase)
+    if(cpu_lowercase MATCHES "^arm|^aarch64")
+      set(libgav1_have_neon ON)
+    elseif(cpu_lowercase MATCHES "^x86|amd64")
+      set(libgav1_have_sse4 ON)
+    endif()
+  endif()
+
+  if(libgav1_have_neon AND LIBGAV1_ENABLE_NEON)
+    list(APPEND libgav1_defines "LIBGAV1_ENABLE_NEON=1")
+  else()
+    list(APPEND libgav1_defines "LIBGAV1_ENABLE_NEON=0")
+  endif()
+
+  if(libgav1_have_sse4 AND LIBGAV1_ENABLE_SSE4_1)
+    list(APPEND libgav1_defines "LIBGAV1_ENABLE_SSE4_1=1")
+  else()
+    list(APPEND libgav1_defines "LIBGAV1_ENABLE_SSE4_1=0")
+  endif()
+endmacro()
diff --git a/libgav1/cmake/libgav1_flags.cmake b/libgav1/cmake/libgav1_flags.cmake
new file mode 100644
index 0000000..0b8df60
--- /dev/null
+++ b/libgav1/cmake/libgav1_flags.cmake
@@ -0,0 +1,245 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_CMAKE_LIBGAV1_FLAGS_CMAKE_)
+  return()
+endif() # LIBGAV1_CMAKE_LIBGAV1_FLAGS_CMAKE_
+set(LIBGAV1_CMAKE_LIBGAV1_FLAGS_CMAKE_ 1)
+
+include(CheckCXXCompilerFlag)
+include(CheckCXXSourceCompiles)
+
+# Adds compiler flags specified by FLAGS to the sources specified by SOURCES:
+#
+# libgav1_set_compiler_flags_for_sources(SOURCES <sources> FLAGS <flags>)
+macro(libgav1_set_compiler_flags_for_sources)
+  unset(compiler_SOURCES)
+  unset(compiler_FLAGS)
+  unset(optional_args)
+  unset(single_value_args)
+  set(multi_value_args SOURCES FLAGS)
+  cmake_parse_arguments(compiler "${optional_args}" "${single_value_args}"
+                        "${multi_value_args}" ${ARGN})
+
+  if(NOT (compiler_SOURCES AND compiler_FLAGS))
+    libgav1_die("libgav1_set_compiler_flags_for_sources: SOURCES and "
+                "FLAGS required.")
+  endif()
+
+  set_source_files_properties(${compiler_SOURCES} PROPERTIES COMPILE_FLAGS
+                              ${compiler_FLAGS})
+
+  if(LIBGAV1_VERBOSE GREATER 1)
+    foreach(source ${compiler_SOURCES})
+      foreach(flag ${compiler_FLAGS})
+        message("libgav1_set_compiler_flags_for_sources: source:${source} "
+                "flag:${flag}")
+      endforeach()
+    endforeach()
+  endif()
+endmacro()
+
+# Tests compiler flags stored in list(s) specified by FLAG_LIST_VAR_NAMES, adds
+# flags to $LIBGAV1_CXX_FLAGS when tests pass. Terminates configuration if
+# FLAG_REQUIRED is specified and any flag check fails.
+#
+# ~~~
+# libgav1_test_cxx_flag(<FLAG_LIST_VAR_NAMES <flag list variable(s)>>
+#                       [FLAG_REQUIRED])
+# ~~~
+macro(libgav1_test_cxx_flag)
+  unset(cxx_test_FLAG_LIST_VAR_NAMES)
+  unset(cxx_test_FLAG_REQUIRED)
+  unset(single_value_args)
+  set(optional_args FLAG_REQUIRED)
+  set(multi_value_args FLAG_LIST_VAR_NAMES)
+  cmake_parse_arguments(cxx_test "${optional_args}" "${single_value_args}"
+                        "${multi_value_args}" ${ARGN})
+
+  if(NOT cxx_test_FLAG_LIST_VAR_NAMES)
+    libgav1_die("libgav1_test_cxx_flag: FLAG_LIST_VAR_NAMES required")
+  endif()
+
+  unset(cxx_flags)
+  foreach(list_var ${cxx_test_FLAG_LIST_VAR_NAMES})
+    if(LIBGAV1_VERBOSE)
+      message("libgav1_test_cxx_flag: adding ${list_var} to cxx_flags")
+    endif()
+    list(APPEND cxx_flags ${${list_var}})
+  endforeach()
+
+  if(LIBGAV1_VERBOSE)
+    message("CXX test: all flags: ${cxx_flags}")
+  endif()
+
+  unset(all_cxx_flags)
+  list(APPEND all_cxx_flags ${LIBGAV1_CXX_FLAGS} ${cxx_flags})
+
+  # Turn off output from check_cxx_source_compiles. Print status directly
+  # instead since the logging messages from check_cxx_source_compiles can be
+  # quite confusing.
+  set(CMAKE_REQUIRED_QUIET TRUE)
+
+  # Run the actual compile test.
+  unset(libgav1_all_cxx_flags_pass CACHE)
+  message("--- Running combined CXX flags test, flags: ${all_cxx_flags}")
+  check_cxx_compiler_flag("${all_cxx_flags}" libgav1_all_cxx_flags_pass)
+
+  if(cxx_test_FLAG_REQUIRED AND NOT libgav1_all_cxx_flags_pass)
+    libgav1_die("Flag test failed for required flag(s): "
+                "${all_cxx_flags} and FLAG_REQUIRED specified.")
+  endif()
+
+  if(libgav1_all_cxx_flags_pass)
+    # Test passed: update the global flag list used by the libgav1 target
+    # creation wrappers.
+    set(LIBGAV1_CXX_FLAGS ${cxx_flags})
+    list(REMOVE_DUPLICATES LIBGAV1_CXX_FLAGS)
+
+    if(LIBGAV1_VERBOSE)
+      message("LIBGAV1_CXX_FLAGS=${LIBGAV1_CXX_FLAGS}")
+    endif()
+
+    message("--- Passed combined CXX flags test")
+  else()
+    message("--- Failed combined CXX flags test, testing flags individually.")
+
+    if(cxx_flags)
+      message("--- Testing flags from $cxx_flags: " "${cxx_flags}")
+      foreach(cxx_flag ${cxx_flags})
+        unset(cxx_flag_test_passed CACHE)
+        message("--- Testing flag: ${cxx_flag}")
+        check_cxx_compiler_flag("${cxx_flag}" cxx_flag_test_passed)
+
+        if(cxx_flag_test_passed)
+          message("--- Passed test for ${cxx_flag}")
+        else()
+          list(REMOVE_ITEM cxx_flags ${cxx_flag})
+          message("--- Failed test for ${cxx_flag}, flag removed.")
+        endif()
+      endforeach()
+
+      set(LIBGAV1_CXX_FLAGS ${cxx_flags})
+    endif()
+  endif()
+
+  if(LIBGAV1_CXX_FLAGS)
+    list(REMOVE_DUPLICATES LIBGAV1_CXX_FLAGS)
+  endif()
+endmacro()
+
+# Tests executable linker flags stored in list specified by FLAG_LIST_VAR_NAME,
+# adds flags to $LIBGAV1_EXE_LINKER_FLAGS when test passes. Terminates
+# configuration when flag check fails. libgav1_set_cxx_flags() must be called
+# before calling this macro because it assumes $LIBGAV1_CXX_FLAGS contains only
+# valid CXX flags.
+#
+# libgav1_test_exe_linker_flag(<FLAG_LIST_VAR_NAME <flag list variable)>)
+macro(libgav1_test_exe_linker_flag)
+  unset(link_FLAG_LIST_VAR_NAME)
+  unset(optional_args)
+  unset(multi_value_args)
+  set(single_value_args FLAG_LIST_VAR_NAME)
+  cmake_parse_arguments(link "${optional_args}" "${single_value_args}"
+                        "${multi_value_args}" ${ARGN})
+
+  if(NOT link_FLAG_LIST_VAR_NAME)
+    libgav1_die("libgav1_test_link_flag: FLAG_LIST_VAR_NAME required")
+  endif()
+
+  libgav1_set_and_stringify(DEST linker_flags SOURCE_VARS
+                            ${link_FLAG_LIST_VAR_NAME})
+
+  if(LIBGAV1_VERBOSE)
+    message("EXE LINKER test: all flags: ${linker_flags}")
+  endif()
+
+  # Tests of $LIBGAV1_CXX_FLAGS have already passed. Include them with the
+  # linker test.
+  libgav1_set_and_stringify(DEST CMAKE_REQUIRED_FLAGS SOURCE_VARS
+                            LIBGAV1_CXX_FLAGS)
+
+  # Cache the global exe linker flags.
+  if(CMAKE_EXE_LINKER_FLAGS)
+    set(cached_CMAKE_EXE_LINKER_FLAGS ${CMAKE_EXE_LINKER_FLAGS})
+    libgav1_set_and_stringify(DEST CMAKE_EXE_LINKER_FLAGS SOURCE
+                              ${linker_flags})
+  endif()
+
+  libgav1_set_and_stringify(DEST CMAKE_EXE_LINKER_FLAGS SOURCE ${linker_flags}
+                            ${CMAKE_EXE_LINKER_FLAGS})
+
+  # Turn off output from check_cxx_source_compiles. Print status directly
+  # instead since the logging messages from check_cxx_source_compiles can be
+  # quite confusing.
+  set(CMAKE_REQUIRED_QUIET TRUE)
+
+  message("--- Running EXE LINKER test for flags: ${linker_flags}")
+
+  unset(linker_flag_test_passed CACHE)
+  set(libgav1_cxx_main "\nint main() { return 0; }")
+  check_cxx_source_compiles("${libgav1_cxx_main}" linker_flag_test_passed)
+
+  if(NOT linker_flag_test_passed)
+    libgav1_die("EXE LINKER test failed.")
+  endif()
+
+  message("--- Passed EXE LINKER flag test.")
+
+  # Restore cached global exe linker flags.
+  if(cached_CMAKE_EXE_LINKER_FLAGS)
+    set(CMAKE_EXE_LINKER_FLAGS cached_CMAKE_EXE_LINKER_FLAGS)
+  else()
+    unset(CMAKE_EXE_LINKER_FLAGS)
+  endif()
+endmacro()
+
+# Runs the libgav1 compiler tests. This macro builds up the list of list var(s)
+# that is passed to libgav1_test_cxx_flag().
+#
+# Note: libgav1_set_build_definitions() must be called before this macro.
+macro(libgav1_set_cxx_flags)
+  unset(cxx_flag_lists)
+
+  if(CMAKE_CXX_COMPILER_ID MATCHES "Clang|GNU")
+    list(APPEND cxx_flag_lists libgav1_base_cxx_flags)
+  endif()
+
+  # Append clang flags after the base set to allow -Wno* overrides to take
+  # effect. Some of the base flags may enable a large set of warnings, e.g.,
+  # -Wall.
+  if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
+    list(APPEND cxx_flag_lists libgav1_clang_cxx_flags)
+  endif()
+
+  if(MSVC)
+    list(APPEND cxx_flag_lists libgav1_msvc_cxx_flags)
+  endif()
+
+  if(LIBGAV1_VERBOSE)
+    if(cxx_flag_lists)
+      libgav1_set_and_stringify(DEST cxx_flags SOURCE_VARS ${cxx_flag_lists})
+      message("libgav1_set_cxx_flags: internal CXX flags: ${cxx_flags}")
+    endif()
+  endif()
+
+  if(LIBGAV1_CXX_FLAGS)
+    list(APPEND cxx_flag_lists LIBGAV1_CXX_FLAGS)
+    if(LIBGAV1_VERBOSE)
+      message("libgav1_set_cxx_flags: user CXX flags: ${LIBGAV1_CXX_FLAGS}")
+    endif()
+  endif()
+
+  libgav1_test_cxx_flag(FLAG_LIST_VAR_NAMES ${cxx_flag_lists})
+endmacro()
diff --git a/libgav1/cmake/libgav1_helpers.cmake b/libgav1/cmake/libgav1_helpers.cmake
new file mode 100644
index 0000000..76d8d67
--- /dev/null
+++ b/libgav1/cmake/libgav1_helpers.cmake
@@ -0,0 +1,134 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_CMAKE_LIBGAV1_HELPERS_CMAKE_)
+  return()
+endif() # LIBGAV1_CMAKE_LIBGAV1_HELPERS_CMAKE_
+set(LIBGAV1_CMAKE_LIBGAV1_HELPERS_CMAKE_ 1)
+
+# Kills build generation using message(FATAL_ERROR) and outputs all data passed
+# to the console via use of $ARGN.
+macro(libgav1_die)
+  message(FATAL_ERROR ${ARGN})
+endmacro()
+
+# Converts semi-colon delimited list variable(s) to string. Output is written to
+# variable supplied via the DEST parameter. Input is from an expanded variable
+# referenced by SOURCE and/or variable(s) referenced by SOURCE_VARS.
+macro(libgav1_set_and_stringify)
+  set(optional_args)
+  set(single_value_args DEST SOURCE_VAR)
+  set(multi_value_args SOURCE SOURCE_VARS)
+  cmake_parse_arguments(sas "${optional_args}" "${single_value_args}"
+                        "${multi_value_args}" ${ARGN})
+
+  if(NOT sas_DEST OR NOT (sas_SOURCE OR sas_SOURCE_VARS))
+    libgav1_die("libgav1_set_and_stringify: DEST and at least one of SOURCE "
+                "SOURCE_VARS required.")
+  endif()
+
+  unset(${sas_DEST})
+
+  if(sas_SOURCE)
+    # $sas_SOURCE is one or more expanded variables, just copy the values to
+    # $sas_DEST.
+    set(${sas_DEST} "${sas_SOURCE}")
+  endif()
+
+  if(sas_SOURCE_VARS)
+    # $sas_SOURCE_VARS is one or more variable names. Each iteration expands a
+    # variable and appends it to $sas_DEST.
+    foreach(source_var ${sas_SOURCE_VARS})
+      set(${sas_DEST} "${${sas_DEST}} ${${source_var}}")
+    endforeach()
+
+    # Because $sas_DEST can be empty when entering this scope leading whitespace
+    # can be introduced to $sas_DEST on the first iteration of the above loop.
+    # Remove it:
+    string(STRIP "${${sas_DEST}}" ${sas_DEST})
+  endif()
+
+  # Lists in CMake are simply semicolon delimited strings, so stringification is
+  # just a find and replace of the semicolon.
+  string(REPLACE ";" " " ${sas_DEST} "${${sas_DEST}}")
+
+  if(LIBGAV1_VERBOSE GREATER 1)
+    message("libgav1_set_and_stringify: ${sas_DEST}=${${sas_DEST}}")
+  endif()
+endmacro()
+
+# Creates a dummy source file in $LIBGAV1_GENERATED_SOURCES_DIRECTORY and adds
+# it to the specified target. Optionally adds its path to a list variable.
+#
+# libgav1_create_dummy_source_file(<TARGET <target> BASENAME <basename of file>>
+# [LISTVAR <list variable>])
+macro(libgav1_create_dummy_source_file)
+  set(optional_args)
+  set(single_value_args TARGET BASENAME LISTVAR)
+  set(multi_value_args)
+  cmake_parse_arguments(cdsf "${optional_args}" "${single_value_args}"
+                        "${multi_value_args}" ${ARGN})
+
+  if(NOT cdsf_TARGET OR NOT cdsf_BASENAME)
+    libgav1_die(
+      "libgav1_create_dummy_source_file: TARGET and BASENAME required.")
+  endif()
+
+  if(NOT LIBGAV1_GENERATED_SOURCES_DIRECTORY)
+    set(LIBGAV1_GENERATED_SOURCES_DIRECTORY "${libgav1_build}/gen_src")
+  endif()
+
+  set(dummy_source_dir "${LIBGAV1_GENERATED_SOURCES_DIRECTORY}")
+  set(dummy_source_file
+      "${dummy_source_dir}/libgav1_${cdsf_TARGET}_${cdsf_BASENAME}.cc")
+  set(dummy_source_code
+      "// Generated file. DO NOT EDIT!\n"
+      "// C++ source file created for target ${cdsf_TARGET}. \n"
+      "void libgav1_${cdsf_TARGET}_${cdsf_BASENAME}_dummy_function(void);\n"
+      "void libgav1_${cdsf_TARGET}_${cdsf_BASENAME}_dummy_function(void) {}\n")
+  file(WRITE "${dummy_source_file}" "${dummy_source_code}")
+
+  target_sources(${cdsf_TARGET} PRIVATE ${dummy_source_file})
+
+  if(cdsf_LISTVAR)
+    list(APPEND ${cdsf_LISTVAR} "${dummy_source_file}")
+  endif()
+endmacro()
+
+# Loads the version components from $libgav1_source/gav1/version.h and sets the
+# corresponding CMake variables:
+# - LIBGAV1_MAJOR_VERSION
+# - LIBGAV1_MINOR_VERSION
+# - LIBGAV1_PATCH_VERSION
+# - LIBGAV1_VERSION, which is:
+#   - $LIBGAV1_MAJOR_VERSION.$LIBGAV1_MINOR_VERSION.$LIBGAV1_PATCH_VERSION
+macro(libgav1_load_version_info)
+  file(STRINGS "${libgav1_source}/gav1/version.h" version_file_strings)
+  foreach(str ${version_file_strings})
+    if(str MATCHES "#define LIBGAV1_")
+      if(str MATCHES "#define LIBGAV1_MAJOR_VERSION ")
+        string(REPLACE "#define LIBGAV1_MAJOR_VERSION " "" LIBGAV1_MAJOR_VERSION
+                       "${str}")
+      elseif(str MATCHES "#define LIBGAV1_MINOR_VERSION ")
+        string(REPLACE "#define LIBGAV1_MINOR_VERSION " "" LIBGAV1_MINOR_VERSION
+                       "${str}")
+      elseif(str MATCHES "#define LIBGAV1_PATCH_VERSION ")
+        string(REPLACE "#define LIBGAV1_PATCH_VERSION " "" LIBGAV1_PATCH_VERSION
+                       "${str}")
+      endif()
+    endif()
+  endforeach()
+  set(LIBGAV1_VERSION "${LIBGAV1_MAJOR_VERSION}.${LIBGAV1_MINOR_VERSION}")
+  set(LIBGAV1_VERSION "${LIBGAV1_VERSION}.${LIBGAV1_PATCH_VERSION}")
+endmacro()
diff --git a/libgav1/cmake/libgav1_install.cmake b/libgav1/cmake/libgav1_install.cmake
new file mode 100644
index 0000000..b7f6006
--- /dev/null
+++ b/libgav1/cmake/libgav1_install.cmake
@@ -0,0 +1,60 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_CMAKE_LIBGAV1_INSTALL_CMAKE_)
+  return()
+endif() # LIBGAV1_CMAKE_LIBGAV1_INSTALL_CMAKE_
+set(LIBGAV1_CMAKE_LIBGAV1_INSTALL_CMAKE_ 1)
+
+# Sets up the Libgav1 install targets. Must be called after the static library
+# target is created.
+macro(libgav1_setup_install_target)
+  if(NOT (MSVC OR XCODE))
+    include(GNUInstallDirs)
+
+    # pkg-config: libgav1.pc
+    set(prefix "${CMAKE_INSTALL_PREFIX}")
+    set(exec_prefix "\${prefix}")
+    set(libdir "\${prefix}/${CMAKE_INSTALL_LIBDIR}")
+    set(includedir "\${prefix}/${CMAKE_INSTALL_INCLUDEDIR}")
+    set(libgav1_lib_name "libgav1")
+
+    configure_file("${libgav1_root}/cmake/libgav1.pc.template"
+                   "${libgav1_build}/libgav1.pc" @ONLY NEWLINE_STYLE UNIX)
+    install(FILES "${libgav1_build}/libgav1.pc"
+            DESTINATION "${prefix}/${CMAKE_INSTALL_LIBDIR}/pkgconfig")
+
+    # CMake config: libgav1-config.cmake
+    set(LIBGAV1_INCLUDE_DIRS "${prefix}/${CMAKE_INSTALL_INCLUDEDIR}")
+    configure_file("${libgav1_root}/cmake/libgav1-config.cmake.template"
+                   "${libgav1_build}/libgav1-config.cmake" @ONLY
+                   NEWLINE_STYLE UNIX)
+    install(
+      FILES "${libgav1_build}/libgav1-config.cmake"
+      DESTINATION "${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_DATAROOTDIR}/cmake")
+
+    install(
+      FILES ${libgav1_api_includes}
+      DESTINATION "${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/gav1")
+
+    install(TARGETS gav1_decode DESTINATION
+                    "${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_BINDIR}")
+    install(TARGETS libgav1_static DESTINATION
+                    "${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}")
+    if(BUILD_SHARED_LIBS)
+      install(TARGETS libgav1_shared DESTINATION
+                      "${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}")
+    endif()
+  endif()
+endmacro()
diff --git a/libgav1/cmake/libgav1_intrinsics.cmake b/libgav1/cmake/libgav1_intrinsics.cmake
new file mode 100644
index 0000000..039ef35
--- /dev/null
+++ b/libgav1/cmake/libgav1_intrinsics.cmake
@@ -0,0 +1,110 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_CMAKE_LIBGAV1_INTRINSICS_CMAKE_)
+  return()
+endif() # LIBGAV1_CMAKE_LIBGAV1_INTRINSICS_CMAKE_
+set(LIBGAV1_CMAKE_LIBGAV1_INTRINSICS_CMAKE_ 1)
+
+# Returns the compiler flag for the SIMD intrinsics suffix specified by the
+# SUFFIX argument via the variable specified by the VARIABLE argument:
+# libgav1_get_intrinsics_flag_for_suffix(SUFFIX <suffix> VARIABLE <var name>)
+macro(libgav1_get_intrinsics_flag_for_suffix)
+  unset(intrinsics_SUFFIX)
+  unset(intrinsics_VARIABLE)
+  unset(optional_args)
+  unset(multi_value_args)
+  set(single_value_args SUFFIX VARIABLE)
+  cmake_parse_arguments(intrinsics "${optional_args}" "${single_value_args}"
+                        "${multi_value_args}" ${ARGN})
+
+  if(NOT (intrinsics_SUFFIX AND intrinsics_VARIABLE))
+    message(FATAL_ERROR "libgav1_get_intrinsics_flag_for_suffix: SUFFIX and "
+                        "VARIABLE required.")
+  endif()
+
+  if(intrinsics_SUFFIX MATCHES "neon")
+    if(NOT MSVC)
+      set(${intrinsics_VARIABLE} "${LIBGAV1_NEON_INTRINSICS_FLAG}")
+    endif()
+  elseif(intrinsics_SUFFIX MATCHES "sse4")
+    if(NOT MSVC)
+      set(${intrinsics_VARIABLE} "-msse4.1")
+    endif()
+  else()
+    message(FATAL_ERROR "libgav1_get_intrinsics_flag_for_suffix: Unknown "
+                        "instrinics suffix: ${intrinsics_SUFFIX}")
+  endif()
+
+  if(LIBGAV1_VERBOSE GREATER 1)
+    message("libgav1_get_intrinsics_flag_for_suffix: "
+            "suffix:${intrinsics_SUFFIX} flag:${${intrinsics_VARIABLE}}")
+  endif()
+endmacro()
+
+# Processes source files specified by SOURCES and adds intrinsics flags as
+# necessary: libgav1_process_intrinsics_sources(SOURCES <sources>)
+#
+# Detects requirement for intrinsics flags using source file name suffix.
+# Currently supports only SSE4.1.
+macro(libgav1_process_intrinsics_sources)
+  unset(arg_TARGET)
+  unset(arg_SOURCES)
+  unset(optional_args)
+  set(single_value_args TARGET)
+  set(multi_value_args SOURCES)
+  cmake_parse_arguments(arg "${optional_args}" "${single_value_args}"
+                        "${multi_value_args}" ${ARGN})
+  if(NOT (arg_TARGET AND arg_SOURCES))
+    message(FATAL_ERROR "libgav1_process_intrinsics_sources: TARGET and "
+                        "SOURCES required.")
+  endif()
+
+  if(LIBGAV1_ENABLE_SSE4_1 AND libgav1_have_sse4)
+    unset(sse4_sources)
+    list(APPEND sse4_sources ${arg_SOURCES})
+
+    list(FILTER sse4_sources INCLUDE REGEX
+         "${libgav1_sse4_source_file_suffix}$")
+
+    if(sse4_sources)
+      unset(sse4_flags)
+      libgav1_get_intrinsics_flag_for_suffix(SUFFIX
+                                             ${libgav1_sse4_source_file_suffix}
+                                             VARIABLE sse4_flags)
+      if(sse4_flags)
+        libgav1_set_compiler_flags_for_sources(SOURCES ${sse4_sources} FLAGS
+                                               ${sse4_flags})
+      endif()
+    endif()
+  endif()
+
+  if(LIBGAV1_ENABLE_NEON AND libgav1_have_neon)
+    unset(neon_sources)
+    list(APPEND neon_sources ${arg_SOURCES})
+    list(FILTER neon_sources INCLUDE REGEX
+         "${libgav1_neon_source_file_suffix}$")
+
+    if(neon_sources AND LIBGAV1_NEON_INTRINSICS_FLAG)
+      unset(neon_flags)
+      libgav1_get_intrinsics_flag_for_suffix(SUFFIX
+                                             ${libgav1_neon_source_file_suffix}
+                                             VARIABLE neon_flags)
+      if(neon_flags)
+        libgav1_set_compiler_flags_for_sources(SOURCES ${neon_sources} FLAGS
+                                               ${neon_flags})
+      endif()
+    endif()
+  endif()
+endmacro()
diff --git a/libgav1/cmake/libgav1_options.cmake b/libgav1/cmake/libgav1_options.cmake
new file mode 100644
index 0000000..6327bee
--- /dev/null
+++ b/libgav1/cmake/libgav1_options.cmake
@@ -0,0 +1,55 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_CMAKE_LIBGAV1_OPTIONS_CMAKE_)
+  return()
+endif() # LIBGAV1_CMAKE_LIBGAV1_OPTIONS_CMAKE_
+set(LIBGAV1_CMAKE_LIBGAV1_OPTIONS_CMAKE_)
+
+# Simple wrapper for CMake's builtin option command that tracks libgav1's build
+# options in the list variable $libgav1_options.
+macro(libgav1_option)
+  unset(option_NAME)
+  unset(option_HELPSTRING)
+  unset(option_VALUE)
+  unset(optional_args)
+  unset(multi_value_args)
+  set(single_value_args NAME HELPSTRING VALUE)
+  cmake_parse_arguments(option "${optional_args}" "${single_value_args}"
+                        "${multi_value_args}" ${ARGN})
+
+  if(NOT (option_NAME AND option_HELPSTRING AND DEFINED option_VALUE))
+    message(FATAL_ERROR "libgav1_option: NAME HELPSTRING and VALUE required.")
+  endif()
+
+  option(${option_NAME} ${option_HELPSTRING} ${option_VALUE})
+
+  if(LIBGAV1_VERBOSE GREATER 2)
+    message("--------- libgav1_option ---------\n"
+            "option_NAME=${option_NAME}\n"
+            "option_HELPSTRING=${option_HELPSTRING}\n"
+            "option_VALUE=${option_VALUE}\n"
+            "------------------------------------------\n")
+  endif()
+
+  list(APPEND libgav1_options ${option_NAME})
+  list(REMOVE_DUPLICATES libgav1_options)
+endmacro()
+
+# Dumps the $libgav1_options list via CMake message command.
+macro(libgav1_dump_options)
+  foreach(option_name ${libgav1_options})
+    message("${option_name}: ${${option_name}}")
+  endforeach()
+endmacro()
diff --git a/libgav1/cmake/libgav1_sanitizer.cmake b/libgav1/cmake/libgav1_sanitizer.cmake
new file mode 100644
index 0000000..4bb2263
--- /dev/null
+++ b/libgav1/cmake/libgav1_sanitizer.cmake
@@ -0,0 +1,45 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_CMAKE_LIBGAV1_SANITIZER_CMAKE_)
+  return()
+endif() # LIBGAV1_CMAKE_LIBGAV1_SANITIZER_CMAKE_
+set(LIBGAV1_CMAKE_LIBGAV1_SANITIZER_CMAKE_ 1)
+
+macro(libgav1_configure_sanitizer)
+  if(LIBGAV1_SANITIZE AND NOT MSVC)
+    if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
+      if(LIBGAV1_SANITIZE MATCHES "cfi")
+        list(APPEND LIBGAV1_CXX_FLAGS "-flto" "-fno-sanitize-trap=cfi")
+        list(APPEND LIBGAV1_EXE_LINKER_FLAGS "-flto" "-fno-sanitize-trap=cfi"
+                    "-fuse-ld=gold")
+      endif()
+
+      if(${CMAKE_SIZEOF_VOID_P} EQUAL 4
+         AND LIBGAV1_SANITIZE MATCHES "integer|undefined")
+        list(APPEND LIBGAV1_EXE_LINKER_FLAGS "--rtlib=compiler-rt" "-lgcc_s")
+      endif()
+    endif()
+
+    list(APPEND LIBGAV1_CXX_FLAGS "-fsanitize=${LIBGAV1_SANITIZE}")
+    list(APPEND LIBGAV1_EXE_LINKER_FLAGS "-fsanitize=${LIBGAV1_SANITIZE}")
+
+    # Make sanitizer callstacks accurate.
+    list(APPEND LIBGAV1_CXX_FLAGS "-fno-omit-frame-pointer"
+                "-fno-optimize-sibling-calls")
+
+    libgav1_test_cxx_flag(FLAG_LIST_VAR_NAMES LIBGAV1_CXX_FLAGS FLAG_REQUIRED)
+    libgav1_test_exe_linker_flag(FLAG_LIST_VAR_NAME LIBGAV1_EXE_LINKER_FLAGS)
+  endif()
+endmacro()
diff --git a/libgav1/cmake/libgav1_targets.cmake b/libgav1/cmake/libgav1_targets.cmake
new file mode 100644
index 0000000..78b4865
--- /dev/null
+++ b/libgav1/cmake/libgav1_targets.cmake
@@ -0,0 +1,347 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_CMAKE_GAV1_TARGETS_CMAKE_)
+  return()
+endif() # LIBGAV1_CMAKE_GAV1_TARGETS_CMAKE_
+set(LIBGAV1_CMAKE_GAV1_TARGETS_CMAKE_ 1)
+
+# Resets list variables used to track libgav1 targets.
+macro(libgav1_reset_target_lists)
+  unset(libgav1_targets)
+  unset(libgav1_exe_targets)
+  unset(libgav1_lib_targets)
+  unset(libgav1_objlib_targets)
+  unset(libgav1_sources)
+  unset(libgav1_test_targets)
+endmacro()
+
+# Creates an executable target. The target name is passed as a parameter to the
+# NAME argument, and the sources passed as a parameter to the SOURCES argument:
+# libgav1_add_test(NAME <name> SOURCES <sources> [optional args])
+#
+# Optional args:
+# cmake-format: off
+#   - OUTPUT_NAME: Override output file basename. Target basename defaults to
+#     NAME.
+#   - TEST: Flag. Presence means treat executable as a test.
+#   - DEFINES: List of preprocessor macro definitions.
+#   - INCLUDES: list of include directories for the target.
+#   - COMPILE_FLAGS: list of compiler flags for the target.
+#   - LINK_FLAGS: List of linker flags for the target.
+#   - OBJLIB_DEPS: List of CMake object library target dependencies.
+#   - LIB_DEPS: List of CMake library dependencies.
+# cmake-format: on
+#
+# Sources passed to this macro are added to $libgav1_test_sources when TEST is
+# specified. Otherwise sources are added to $libgav1_sources.
+#
+# Targets passed to this macro are always added $libgav1_targets. When TEST is
+# specified targets are also added to list $libgav1_test_targets. Otherwise
+# targets are added to $libgav1_exe_targets.
+macro(libgav1_add_executable)
+  unset(exe_TEST)
+  unset(exe_TEST_DEFINES_MAIN)
+  unset(exe_NAME)
+  unset(exe_OUTPUT_NAME)
+  unset(exe_SOURCES)
+  unset(exe_DEFINES)
+  unset(exe_INCLUDES)
+  unset(exe_COMPILE_FLAGS)
+  unset(exe_LINK_FLAGS)
+  unset(exe_OBJLIB_DEPS)
+  unset(exe_LIB_DEPS)
+  set(optional_args TEST)
+  set(single_value_args NAME OUTPUT_NAME)
+  set(multi_value_args SOURCES DEFINES INCLUDES COMPILE_FLAGS LINK_FLAGS
+                       OBJLIB_DEPS LIB_DEPS)
+
+  cmake_parse_arguments(exe "${optional_args}" "${single_value_args}"
+                        "${multi_value_args}" ${ARGN})
+
+  if(LIBGAV1_VERBOSE GREATER 1)
+    message("--------- libgav1_add_executable ---------\n"
+            "exe_TEST=${exe_TEST}\n"
+            "exe_TEST_DEFINES_MAIN=${exe_TEST_DEFINES_MAIN}\n"
+            "exe_NAME=${exe_NAME}\n"
+            "exe_OUTPUT_NAME=${exe_OUTPUT_NAME}\n"
+            "exe_SOURCES=${exe_SOURCES}\n"
+            "exe_DEFINES=${exe_DEFINES}\n"
+            "exe_INCLUDES=${exe_INCLUDES}\n"
+            "exe_COMPILE_FLAGS=${exe_COMPILE_FLAGS}\n"
+            "exe_LINK_FLAGS=${exe_LINK_FLAGS}\n"
+            "exe_OBJLIB_DEPS=${exe_OBJLIB_DEPS}\n"
+            "exe_LIB_DEPS=${exe_LIB_DEPS}\n"
+            "------------------------------------------\n")
+  endif()
+
+  if(NOT (exe_NAME AND exe_SOURCES))
+    message(FATAL_ERROR "libgav1_add_executable: NAME and SOURCES required.")
+  endif()
+
+  list(APPEND libgav1_targets ${exe_NAME})
+  if(exe_TEST)
+    list(APPEND libgav1_test_targets ${exe_NAME})
+    list(APPEND libgav1_test_sources ${exe_SOURCES})
+  else()
+    list(APPEND libgav1_exe_targets ${exe_NAME})
+    list(APPEND libgav1_sources ${exe_SOURCES})
+  endif()
+
+  add_executable(${exe_NAME} ${exe_SOURCES})
+
+  if(exe_OUTPUT_NAME)
+    set_target_properties(${exe_NAME} PROPERTIES OUTPUT_NAME ${exe_OUTPUT_NAME})
+  endif()
+
+  libgav1_process_intrinsics_sources(TARGET ${exe_NAME} SOURCES ${exe_SOURCES})
+
+  if(exe_DEFINES)
+    target_compile_definitions(${exe_NAME} PRIVATE ${exe_DEFINES})
+  endif()
+
+  if(exe_INCLUDES)
+    target_include_directories(${exe_NAME} PRIVATE ${exe_INCLUDES})
+  endif()
+
+  if(exe_COMPILE_FLAGS OR LIBGAV1_CXX_FLAGS)
+    target_compile_options(${exe_NAME}
+                           PRIVATE ${exe_COMPILE_FLAGS} ${LIBGAV1_CXX_FLAGS})
+  endif()
+
+  if(exe_LINK_FLAGS OR LIBGAV1_EXE_LINKER_FLAGS)
+    set_target_properties(${exe_NAME}
+                          PROPERTIES LINK_FLAGS ${exe_LINK_FLAGS}
+                                     ${LIBGAV1_EXE_LINKER_FLAGS})
+  endif()
+
+  if(exe_OBJLIB_DEPS)
+    foreach(objlib_dep ${exe_OBJLIB_DEPS})
+      target_sources(${exe_NAME} PRIVATE $<TARGET_OBJECTS:${objlib_dep}>)
+    endforeach()
+  endif()
+
+  if(CMAKE_THREAD_LIBS_INIT)
+    list(APPEND exe_LIB_DEPS ${CMAKE_THREAD_LIBS_INIT})
+  endif()
+
+  if(BUILD_SHARED_LIBS AND (MSVC OR WIN32))
+    target_compile_definitions(${lib_NAME} PRIVATE "LIBGAV1_BUILDING_DLL=0")
+  endif()
+
+  if(exe_LIB_DEPS)
+    unset(exe_static)
+    if("${CMAKE_EXE_LINKER_FLAGS} ${LIBGAV1_EXE_LINKER_FLAGS}" MATCHES "static")
+      set(exe_static ON)
+    endif()
+
+    if(exe_static AND CMAKE_CXX_COMPILER_ID MATCHES "Clang|GNU")
+      # Third party dependencies can introduce dependencies on system and test
+      # libraries. Since the target created here is an executable, and CMake
+      # does not provide a method of controlling order of link dependencies,
+      # wrap all of the dependencies of this target in start/end group flags to
+      # ensure that dependencies of third party targets can be resolved when
+      # those dependencies happen to be resolved by dependencies of the current
+      # target.
+      list(INSERT exe_LIB_DEPS 0 -Wl,--start-group)
+      list(APPEND exe_LIB_DEPS -Wl,--end-group)
+    endif()
+    target_link_libraries(${exe_NAME} PRIVATE ${exe_LIB_DEPS})
+  endif()
+endmacro()
+
+# Creates a library target of the specified type. The target name is passed as a
+# parameter to the NAME argument, the type as a parameter to the TYPE argument,
+# and the sources passed as a parameter to the SOURCES argument:
+# libgav1_add_library(NAME <name> TYPE <type> SOURCES <sources> [optional args])
+#
+# Optional args:
+# cmake-format: off
+#   - OUTPUT_NAME: Override output file basename. Target basename defaults to
+#     NAME. OUTPUT_NAME is ignored when BUILD_SHARED_LIBS is enabled and CMake
+#     is generating a build for which MSVC or WIN32 are true. This is to avoid
+#     output basename collisions with DLL import libraries.
+#   - TEST: Flag. Presence means treat library as a test.
+#   - DEFINES: List of preprocessor macro definitions.
+#   - INCLUDES: list of include directories for the target.
+#   - COMPILE_FLAGS: list of compiler flags for the target.
+#   - LINK_FLAGS: List of linker flags for the target.
+#   - OBJLIB_DEPS: List of CMake object library target dependencies.
+#   - LIB_DEPS: List of CMake library dependencies.
+#   - PUBLIC_INCLUDES: List of include paths to export to dependents.
+# cmake-format: on
+#
+# Sources passed to the macro are added to the lists tracking libgav1 sources:
+# cmake-format: off
+#   - When TEST is specified sources are added to $libgav1_test_sources.
+#   - Otherwise sources are added to $libgav1_sources.
+# cmake-format: on
+#
+# Targets passed to this macro are added to the lists tracking libgav1 targets:
+# cmake-format: off
+#   - Targets are always added to $libgav1_targets.
+#   - When the TEST flag is specified, targets are added to
+#     $libgav1_test_targets.
+#   - When TEST is not specified:
+#     - Libraries of type SHARED are added to $libgav1_dylib_targets.
+#     - Libraries of type OBJECT are added to $libgav1_objlib_targets.
+#     - Libraries of type STATIC are added to $libgav1_lib_targets.
+# cmake-format: on
+macro(libgav1_add_library)
+  unset(lib_TEST)
+  unset(lib_NAME)
+  unset(lib_OUTPUT_NAME)
+  unset(lib_TYPE)
+  unset(lib_SOURCES)
+  unset(lib_DEFINES)
+  unset(lib_INCLUDES)
+  unset(lib_COMPILE_FLAGS)
+  unset(lib_LINK_FLAGS)
+  unset(lib_OBJLIB_DEPS)
+  unset(lib_LIB_DEPS)
+  unset(lib_PUBLIC_INCLUDES)
+  set(optional_args TEST)
+  set(single_value_args NAME OUTPUT_NAME TYPE)
+  set(multi_value_args SOURCES DEFINES INCLUDES COMPILE_FLAGS LINK_FLAGS
+                       OBJLIB_DEPS LIB_DEPS PUBLIC_INCLUDES)
+
+  cmake_parse_arguments(lib "${optional_args}" "${single_value_args}"
+                        "${multi_value_args}" ${ARGN})
+
+  if(LIBGAV1_VERBOSE GREATER 1)
+    message("--------- libgav1_add_library ---------\n"
+            "lib_TEST=${lib_TEST}\n"
+            "lib_NAME=${lib_NAME}\n"
+            "lib_OUTPUT_NAME=${lib_OUTPUT_NAME}\n"
+            "lib_TYPE=${lib_TYPE}\n"
+            "lib_SOURCES=${lib_SOURCES}\n"
+            "lib_DEFINES=${lib_DEFINES}\n"
+            "lib_INCLUDES=${lib_INCLUDES}\n"
+            "lib_COMPILE_FLAGS=${lib_COMPILE_FLAGS}\n"
+            "lib_LINK_FLAGS=${lib_LINK_FLAGS}\n"
+            "lib_OBJLIB_DEPS=${lib_OBJLIB_DEPS}\n"
+            "lib_LIB_DEPS=${lib_LIB_DEPS}\n"
+            "lib_PUBLIC_INCLUDES=${lib_PUBLIC_INCLUDES}\n"
+            "---------------------------------------\n")
+  endif()
+
+  if(NOT (lib_NAME AND lib_TYPE AND lib_SOURCES))
+    message(FATAL_ERROR "libgav1_add_library: NAME, TYPE and SOURCES required.")
+  endif()
+
+  list(APPEND libgav1_targets ${lib_NAME})
+  if(lib_TEST)
+    list(APPEND libgav1_test_targets ${lib_NAME})
+    list(APPEND libgav1_test_sources ${lib_SOURCES})
+  else()
+    list(APPEND libgav1_sources ${lib_SOURCES})
+    if(lib_TYPE STREQUAL OBJECT)
+      list(APPEND libgav1_objlib_targets ${lib_NAME})
+    elseif(lib_TYPE STREQUAL SHARED)
+      list(APPEND libgav1_dylib_targets ${lib_NAME})
+    elseif(lib_TYPE STREQUAL STATIC)
+      list(APPEND libgav1_lib_targets ${lib_NAME})
+    else()
+      message(WARNING "libgav1_add_library: Unhandled type: ${lib_TYPE}")
+    endif()
+  endif()
+
+  add_library(${lib_NAME} ${lib_TYPE} ${lib_SOURCES})
+  libgav1_process_intrinsics_sources(TARGET ${lib_NAME} SOURCES ${lib_SOURCES})
+
+  if(lib_OUTPUT_NAME)
+    if(NOT (BUILD_SHARED_LIBS AND (MSVC OR WIN32)))
+      set_target_properties(${lib_NAME}
+                            PROPERTIES OUTPUT_NAME ${lib_OUTPUT_NAME})
+    endif()
+  endif()
+
+  if(lib_DEFINES)
+    target_compile_definitions(${lib_NAME} PRIVATE ${lib_DEFINES})
+  endif()
+
+  if(lib_INCLUDES)
+    target_include_directories(${lib_NAME} PRIVATE ${lib_INCLUDES})
+  endif()
+
+  if(lib_PUBLIC_INCLUDES)
+    target_include_directories(${lib_NAME} PUBLIC ${lib_PUBLIC_INCLUDES})
+  endif()
+
+  if(lib_COMPILE_FLAGS OR LIBGAV1_CXX_FLAGS)
+    target_compile_options(${lib_NAME}
+                           PRIVATE ${lib_COMPILE_FLAGS} ${LIBGAV1_CXX_FLAGS})
+  endif()
+
+  if(lib_LINK_FLAGS)
+    set_target_properties(${lib_NAME} PROPERTIES LINK_FLAGS ${lib_LINK_FLAGS})
+  endif()
+
+  if(lib_OBJLIB_DEPS)
+    foreach(objlib_dep ${lib_OBJLIB_DEPS})
+      target_sources(${lib_NAME} PRIVATE $<TARGET_OBJECTS:${objlib_dep}>)
+    endforeach()
+  endif()
+
+  if(lib_LIB_DEPS)
+    if(lib_TYPE STREQUAL STATIC)
+      set(link_type PUBLIC)
+    else()
+      set(link_type PRIVATE)
+      if(lib_TYPE STREQUAL SHARED AND CMAKE_CXX_COMPILER_ID MATCHES "Clang|GNU")
+        # The libgav1 shared object uses the static libgav1 as input to turn it
+        # into a shared object. Include everything from the static library in
+        # the shared object.
+        if(APPLE)
+          list(INSERT lib_LIB_DEPS 0 -Wl,-force_load)
+        else()
+          list(INSERT lib_LIB_DEPS 0 -Wl,--whole-archive)
+          list(APPEND lib_LIB_DEPS -Wl,--no-whole-archive)
+        endif()
+      endif()
+    endif()
+    target_link_libraries(${lib_NAME} ${link_type} ${lib_LIB_DEPS})
+  endif()
+
+  if(NOT MSVC AND lib_NAME MATCHES "^lib")
+    # Non-MSVC generators prepend lib to static lib target file names. Libgav1
+    # already includes lib in its name. Avoid naming output files liblib*.
+    set_target_properties(${lib_NAME} PROPERTIES PREFIX "")
+  endif()
+
+  if(lib_TYPE STREQUAL SHARED AND NOT MSVC)
+    set_target_properties(${lib_NAME} PROPERTIES SOVERSION ${LIBGAV1_SOVERSION})
+  endif()
+
+  if(BUILD_SHARED_LIBS AND (MSVC OR WIN32))
+    if(lib_TYPE STREQUAL SHARED)
+      target_compile_definitions(${lib_NAME} PRIVATE "LIBGAV1_BUILDING_DLL=1")
+    else()
+      target_compile_definitions(${lib_NAME} PRIVATE "LIBGAV1_BUILDING_DLL=0")
+    endif()
+  endif()
+
+  # Determine if $lib_NAME is a header only target.
+  set(sources_list ${lib_SOURCES})
+  list(FILTER sources_list INCLUDE REGEX cc$)
+  if(NOT sources_list)
+    if(NOT XCODE)
+      # This is a header only target. Tell CMake the link language.
+      set_target_properties(${lib_NAME} PROPERTIES LINKER_LANGUAGE CXX)
+    else()
+      # The Xcode generator ignores LINKER_LANGUAGE. Add a dummy cc file.
+      libgav1_create_dummy_source_file(TARGET ${lib_NAME} BASENAME ${lib_NAME})
+    endif()
+  endif()
+endmacro()
diff --git a/libgav1/cmake/libgav1_variables.cmake b/libgav1/cmake/libgav1_variables.cmake
new file mode 100644
index 0000000..0dd0f37
--- /dev/null
+++ b/libgav1/cmake/libgav1_variables.cmake
@@ -0,0 +1,78 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_CMAKE_LIBGAV1_VARIABLES_CMAKE_)
+  return()
+endif() # LIBGAV1_CMAKE_LIBGAV1_VARIABLES_CMAKE_
+set(LIBGAV1_CMAKE_LIBGAV1_VARIABLES_CMAKE_ 1)
+
+# Halts generation when $variable_name does not refer to a directory that
+# exists.
+macro(libgav1_variable_must_be_directory variable_name)
+  if("${variable_name}" STREQUAL "")
+    message(
+      FATAL_ERROR
+        "Empty variable_name passed to libgav1_variable_must_be_directory.")
+  endif()
+
+  if("${${variable_name}}" STREQUAL "")
+    message(
+      FATAL_ERROR
+        "Empty variable ${variable_name} is required to build libgav1.")
+  endif()
+
+  if(NOT IS_DIRECTORY "${${variable_name}}")
+    message(
+      FATAL_ERROR
+        "${variable_name}, which is ${${variable_name}}, does not refer to a\n"
+        "directory.")
+  endif()
+endmacro()
+
+# Adds $var_name to the tracked variables list.
+macro(libgav1_track_configuration_variable var_name)
+  if(LIBGAV1_VERBOSE GREATER 2)
+    message("---- libgav1_track_configuration_variable ----\n"
+            "var_name=${var_name}\n"
+            "----------------------------------------------\n")
+  endif()
+
+  list(APPEND libgav1_configuration_variables ${var_name})
+  list(REMOVE_DUPLICATES libgav1_configuration_variables)
+endmacro()
+
+# Logs current C++ and executable linker flags via CMake's message command.
+macro(libgav1_dump_cmake_flag_variables)
+  unset(flag_variables)
+  list(APPEND flag_variables "CMAKE_CXX_FLAGS_INIT" "CMAKE_CXX_FLAGS"
+              "CMAKE_EXE_LINKER_FLAGS_INIT" "CMAKE_EXE_LINKER_FLAGS")
+  if(CMAKE_BUILD_TYPE)
+    list(APPEND flag_variables "CMAKE_BUILD_TYPE"
+                "CMAKE_CXX_FLAGS_${CMAKE_BUILD_TYPE}_INIT"
+                "CMAKE_CXX_FLAGS_${CMAKE_BUILD_TYPE}"
+                "CMAKE_EXE_LINKER_FLAGS_${CMAKE_BUILD_TYPE}_INIT"
+                "CMAKE_EXE_LINKER_FLAGS_${CMAKE_BUILD_TYPE}")
+  endif()
+  foreach(flag_variable ${flag_variables})
+    message("${flag_variable}:${${flag_variable}}")
+  endforeach()
+endmacro()
+
+# Dumps the variables tracked in $libgav1_configuration_variables via CMake's
+# message command.
+macro(libgav1_dump_tracked_configuration_variables)
+  foreach(config_variable ${libgav1_configuration_variables})
+    message("${config_variable}:${${config_variable}}")
+  endforeach()
+endmacro()
diff --git a/libgav1/cmake/toolchains/aarch64-linux-gnu.cmake b/libgav1/cmake/toolchains/aarch64-linux-gnu.cmake
new file mode 100644
index 0000000..7ffe397
--- /dev/null
+++ b/libgav1/cmake/toolchains/aarch64-linux-gnu.cmake
@@ -0,0 +1,28 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_CMAKE_TOOLCHAINS_AARCH64_LINUX_GNU_CMAKE_)
+  return()
+endif() # LIBGAV1_CMAKE_TOOLCHAINS_AARCH64_LINUX_GNU_CMAKE_
+set(LIBGAV1_CMAKE_TOOLCHAINS_AARCH64_LINUX_GNU_CMAKE_ 1)
+
+set(CMAKE_SYSTEM_NAME "Linux")
+
+if("${CROSS}" STREQUAL "")
+  set(CROSS aarch64-linux-gnu-)
+endif()
+
+set(CMAKE_CXX_COMPILER ${CROSS}g++)
+set(CMAKE_CXX_FLAGS_INIT "-march=armv8-a")
+set(CMAKE_SYSTEM_PROCESSOR "aarch64")
diff --git a/libgav1/cmake/toolchains/android.cmake b/libgav1/cmake/toolchains/android.cmake
new file mode 100644
index 0000000..492957b
--- /dev/null
+++ b/libgav1/cmake/toolchains/android.cmake
@@ -0,0 +1,53 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_CMAKE_TOOLCHAINS_ANDROID_CMAKE_)
+  return()
+endif() # LIBGAV1_CMAKE_TOOLCHAINS_ANDROID_CMAKE_
+
+# Additional ANDROID_* settings are available, see:
+# https://developer.android.com/ndk/guides/cmake#variables
+
+if(NOT ANDROID_PLATFORM)
+  set(ANDROID_PLATFORM android-21)
+endif()
+
+# Choose target architecture with:
+#
+# -DANDROID_ABI={armeabi-v7a,armeabi-v7a with NEON,arm64-v8a,x86,x86_64}
+if(NOT ANDROID_ABI)
+  set(ANDROID_ABI arm64-v8a)
+endif()
+
+# Force arm mode for 32-bit targets (instead of the default thumb) to improve
+# performance.
+if(NOT ANDROID_ARM_MODE)
+  set(ANDROID_ARM_MODE arm)
+endif()
+
+# Toolchain files don't have access to cached variables:
+# https://gitlab.kitware.com/cmake/cmake/issues/16170. Set an intermediate
+# environment variable when loaded the first time.
+if(LIBGAV1_ANDROID_NDK_PATH)
+  set(ENV{LIBGAV1_ANDROID_NDK_PATH} "${LIBGAV1_ANDROID_NDK_PATH}")
+else()
+  set(LIBGAV1_ANDROID_NDK_PATH "$ENV{LIBGAV1_ANDROID_NDK_PATH}")
+endif()
+
+if(NOT LIBGAV1_ANDROID_NDK_PATH)
+  message(FATAL_ERROR "LIBGAV1_ANDROID_NDK_PATH not set.")
+  return()
+endif()
+
+include("${LIBGAV1_ANDROID_NDK_PATH}/build/cmake/android.toolchain.cmake")
diff --git a/libgav1/cmake/toolchains/arm-linux-gnueabihf.cmake b/libgav1/cmake/toolchains/arm-linux-gnueabihf.cmake
new file mode 100644
index 0000000..8051f0d
--- /dev/null
+++ b/libgav1/cmake/toolchains/arm-linux-gnueabihf.cmake
@@ -0,0 +1,29 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_CMAKE_TOOLCHAINS_ARM_LINUX_GNUEABIHF_CMAKE_)
+  return()
+endif() # LIBGAV1_CMAKE_TOOLCHAINS_ARM_LINUX_GNUEABIHF_CMAKE_
+set(LIBGAV1_CMAKE_TOOLCHAINS_ARM_LINUX_GNUEABIHF_CMAKE_ 1)
+
+set(CMAKE_SYSTEM_NAME "Linux")
+
+if("${CROSS}" STREQUAL "")
+  set(CROSS arm-linux-gnueabihf-)
+endif()
+
+set(CMAKE_CXX_COMPILER ${CROSS}g++)
+set(CMAKE_CXX_FLAGS_INIT "-march=armv7-a -marm")
+set(CMAKE_SYSTEM_PROCESSOR "armv7")
+set(LIBGAV1_NEON_INTRINSICS_FLAG "-mfpu=neon")
diff --git a/libgav1/codereview.settings b/libgav1/codereview.settings
new file mode 100644
index 0000000..ccba2ee
--- /dev/null
+++ b/libgav1/codereview.settings
@@ -0,0 +1,4 @@
+# This file is used by git cl to get repository specific information.
+GERRIT_HOST: True
+CODE_REVIEW_SERVER: chromium-review.googlesource.com
+GERRIT_SQUASH_UPLOADS: False
diff --git a/libgav1/examples/file_reader.cc b/libgav1/examples/file_reader.cc
new file mode 100644
index 0000000..b096722
--- /dev/null
+++ b/libgav1/examples/file_reader.cc
@@ -0,0 +1,186 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "examples/file_reader.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdio>
+#include <new>
+#include <string>
+#include <vector>
+
+#if defined(_WIN32)
+#include <fcntl.h>
+#include <io.h>
+#endif
+
+#include "examples/file_reader_constants.h"
+#include "examples/file_reader_factory.h"
+#include "examples/file_reader_interface.h"
+#include "examples/ivf_parser.h"
+#include "examples/logging.h"
+
+namespace libgav1 {
+namespace {
+
+FILE* SetBinaryMode(FILE* stream) {
+#if defined(_WIN32)
+  _setmode(_fileno(stream), _O_BINARY);
+#endif
+  return stream;
+}
+
+}  // namespace
+
+bool FileReader::registered_in_factory_ =
+    FileReaderFactory::RegisterReader(FileReader::Open);
+
+FileReader::~FileReader() {
+  if (owns_file_) fclose(file_);
+}
+
+std::unique_ptr<FileReaderInterface> FileReader::Open(
+    const std::string& file_name, const bool error_tolerant) {
+  if (file_name.empty()) return nullptr;
+
+  FILE* raw_file_ptr;
+
+  bool owns_file = true;
+  if (file_name == "-") {
+    raw_file_ptr = SetBinaryMode(stdin);
+    owns_file = false;  // stdin is owned by the Standard C Library.
+  } else {
+    raw_file_ptr = fopen(file_name.c_str(), "rb");
+  }
+
+  if (raw_file_ptr == nullptr) {
+    return nullptr;
+  }
+
+  std::unique_ptr<FileReader> file(
+      new (std::nothrow) FileReader(raw_file_ptr, owns_file, error_tolerant));
+  if (file == nullptr) {
+    LIBGAV1_EXAMPLES_LOG_ERROR("Out of memory");
+    if (owns_file) fclose(raw_file_ptr);
+    return nullptr;
+  }
+
+  if (!file->ReadIvfFileHeader()) {
+    LIBGAV1_EXAMPLES_LOG_ERROR("Unsupported file type");
+    return nullptr;
+  }
+
+  return file;
+}
+
+// IVF Frame Header format, from https://wiki.multimedia.cx/index.php/IVF
+// bytes 0-3    size of frame in bytes (not including the 12-byte header)
+// bytes 4-11   64-bit presentation timestamp
+// bytes 12..   frame data
+bool FileReader::ReadTemporalUnit(std::vector<uint8_t>* const tu_data,
+                                  int64_t* const timestamp) {
+  if (tu_data == nullptr) return false;
+  tu_data->clear();
+
+  uint8_t header_buffer[kIvfFrameHeaderSize];
+  const size_t num_read = fread(header_buffer, 1, kIvfFrameHeaderSize, file_);
+
+  if (IsEndOfFile()) {
+    if (num_read != 0) {
+      LIBGAV1_EXAMPLES_LOG_ERROR(
+          "Cannot read IVF frame header: Not enough data available");
+      return false;
+    }
+
+    return true;
+  }
+
+  IvfFrameHeader ivf_frame_header;
+  if (!ParseIvfFrameHeader(header_buffer, &ivf_frame_header)) {
+    LIBGAV1_EXAMPLES_LOG_ERROR("Could not parse IVF frame header");
+    if (error_tolerant_) {
+      ivf_frame_header.frame_size =
+          std::min(ivf_frame_header.frame_size, size_t{kMaxTemporalUnitSize});
+    } else {
+      return false;
+    }
+  }
+
+  if (timestamp != nullptr) *timestamp = ivf_frame_header.timestamp;
+
+  tu_data->resize(ivf_frame_header.frame_size);
+  const size_t size_read =
+      fread(tu_data->data(), 1, ivf_frame_header.frame_size, file_);
+  if (size_read != ivf_frame_header.frame_size) {
+    LIBGAV1_EXAMPLES_LOG_ERROR(
+        "Unexpected EOF or I/O error reading frame data");
+    if (error_tolerant_) {
+      tu_data->resize(size_read);
+    } else {
+      return false;
+    }
+  }
+  return true;
+}
+
+// Attempt to read an IVF file header. Returns true for success, and false for
+// failure.
+//
+// IVF File Header format, from https://wiki.multimedia.cx/index.php/IVF
+// bytes 0-3    signature: 'DKIF'
+// bytes 4-5    version (should be 0)
+// bytes 6-7    length of header in bytes
+// bytes 8-11   codec FourCC (e.g., 'VP80')
+// bytes 12-13  width in pixels
+// bytes 14-15  height in pixels
+// bytes 16-19  frame rate
+// bytes 20-23  time scale
+// bytes 24-27  number of frames in file
+// bytes 28-31  unused
+//
+// Note: The rate and scale fields correspond to the numerator and denominator
+// of frame rate (fps) or time base (the reciprocal of frame rate) as follows:
+//
+// bytes 16-19  frame rate  timebase.den  framerate.numerator
+// bytes 20-23  time scale  timebase.num  framerate.denominator
+bool FileReader::ReadIvfFileHeader() {
+  uint8_t header_buffer[kIvfFileHeaderSize];
+  const size_t num_read = fread(header_buffer, 1, kIvfFileHeaderSize, file_);
+  if (num_read != kIvfFileHeaderSize) {
+    LIBGAV1_EXAMPLES_LOG_ERROR(
+        "Cannot read IVF header: Not enough data available");
+    return false;
+  }
+
+  IvfFileHeader ivf_file_header;
+  if (!ParseIvfFileHeader(header_buffer, &ivf_file_header)) {
+    LIBGAV1_EXAMPLES_LOG_ERROR("Could not parse IVF file header");
+    if (error_tolerant_) {
+      ivf_file_header = {};
+    } else {
+      return false;
+    }
+  }
+
+  width_ = ivf_file_header.width;
+  height_ = ivf_file_header.height;
+  frame_rate_ = ivf_file_header.frame_rate_numerator;
+  time_scale_ = ivf_file_header.frame_rate_denominator;
+  type_ = kFileTypeIvf;
+
+  return true;
+}
+
+}  // namespace libgav1
diff --git a/libgav1/examples/file_reader.h b/libgav1/examples/file_reader.h
new file mode 100644
index 0000000..c342a20
--- /dev/null
+++ b/libgav1/examples/file_reader.h
@@ -0,0 +1,100 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_EXAMPLES_FILE_READER_H_
+#define LIBGAV1_EXAMPLES_FILE_READER_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "examples/file_reader_interface.h"
+
+namespace libgav1 {
+
+// Temporal Unit based file reader class. Currently supports only IVF files.
+class FileReader : public FileReaderInterface {
+ public:
+  enum FileType {
+    kFileTypeUnknown,
+    kFileTypeIvf,
+  };
+
+  // Creates and returns a FileReader that reads from |file_name|.
+  // If |error_tolerant| is true format and read errors are ignored,
+  // ReadTemporalUnit() may return truncated data.
+  // Returns nullptr when the file does not exist, cannot be read, or is not an
+  // IVF file.
+  static std::unique_ptr<FileReaderInterface> Open(const std::string& file_name,
+                                                   bool error_tolerant = false);
+
+  FileReader() = delete;
+  FileReader(const FileReader&) = delete;
+  FileReader& operator=(const FileReader&) = delete;
+
+  // Closes |file_|.
+  ~FileReader() override;
+
+  // Reads a temporal unit from |file_| and writes the data to |tu_data|.
+  // Returns true when:
+  // - A temporal unit is read successfully, or
+  // - At end of file.
+  // When ReadTemporalUnit() is called at the end of the file, it will return
+  // true without writing any data to |tu_data|.
+  //
+  // The |timestamp| pointer is optional: callers not interested in timestamps
+  // can pass nullptr. When |timestamp| is not a nullptr, this function returns
+  // the presentation timestamp from the IVF frame header.
+  /*LIBGAV1_MUST_USE_RESULT*/ bool ReadTemporalUnit(
+      std::vector<uint8_t>* tu_data, int64_t* timestamp) override;
+
+  /*LIBGAV1_MUST_USE_RESULT*/ bool IsEndOfFile() const override {
+    return feof(file_) != 0;
+  }
+
+  // The values returned by these accessors are strictly informative. No
+  // validation is performed when they are read from the IVF file header.
+  size_t width() const override { return width_; }
+  size_t height() const override { return height_; }
+  size_t frame_rate() const override { return frame_rate_; }
+  size_t time_scale() const override { return time_scale_; }
+
+ private:
+  FileReader(FILE* file, bool owns_file, bool error_tolerant)
+      : file_(file), owns_file_(owns_file), error_tolerant_(error_tolerant) {}
+
+  bool ReadIvfFileHeader();
+
+  FILE* file_ = nullptr;
+  size_t width_ = 0;
+  size_t height_ = 0;
+  size_t frame_rate_ = 0;
+  size_t time_scale_ = 0;
+  FileType type_ = kFileTypeUnknown;
+  // True if this object owns file_ and is responsible for closing it when
+  // done.
+  const bool owns_file_;
+  const bool error_tolerant_;
+
+  static bool registered_in_factory_;
+};
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_EXAMPLES_FILE_READER_H_
diff --git a/libgav1/src/decoder_scratch_buffer.cc b/libgav1/examples/file_reader_constants.cc
similarity index 75%
copy from libgav1/src/decoder_scratch_buffer.cc
copy to libgav1/examples/file_reader_constants.cc
index bb9b5f2..8439071 100644
--- a/libgav1/src/decoder_scratch_buffer.cc
+++ b/libgav1/examples/file_reader_constants.cc
@@ -12,12 +12,12 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/decoder_scratch_buffer.h"
+#include "examples/file_reader_constants.h"
 
 namespace libgav1 {
 
-// static
-constexpr int DecoderScratchBuffer::kBlockDecodedStride;
-constexpr int DecoderScratchBuffer::kPixelSize;
+const char kIvfSignature[4] = {'D', 'K', 'I', 'F'};
+const char kAv1FourCcUpper[4] = {'A', 'V', '0', '1'};
+const char kAv1FourCcLower[4] = {'a', 'v', '0', '1'};
 
 }  // namespace libgav1
diff --git a/libgav1/examples/file_reader_constants.h b/libgav1/examples/file_reader_constants.h
new file mode 100644
index 0000000..00922b4
--- /dev/null
+++ b/libgav1/examples/file_reader_constants.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_EXAMPLES_FILE_READER_CONSTANTS_H_
+#define LIBGAV1_EXAMPLES_FILE_READER_CONSTANTS_H_
+
+namespace libgav1 {
+
+enum {
+  kIvfHeaderVersion = 0,
+  kIvfFrameHeaderSize = 12,
+  kIvfFileHeaderSize = 32,
+#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
+  kMaxTemporalUnitSize = 512 * 1024,
+#else
+  kMaxTemporalUnitSize = 256 * 1024 * 1024,
+#endif
+};
+
+extern const char kIvfSignature[4];
+extern const char kAv1FourCcUpper[4];
+extern const char kAv1FourCcLower[4];
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_EXAMPLES_FILE_READER_CONSTANTS_H_
diff --git a/libgav1/examples/file_reader_factory.cc b/libgav1/examples/file_reader_factory.cc
new file mode 100644
index 0000000..d5260eb
--- /dev/null
+++ b/libgav1/examples/file_reader_factory.cc
@@ -0,0 +1,51 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "examples/file_reader_factory.h"
+
+#include <new>
+
+#include "examples/logging.h"
+
+namespace libgav1 {
+namespace {
+
+std::vector<FileReaderFactory::OpenFunction>* GetFileReaderOpenFunctions() {
+  static auto* open_functions =
+      new (std::nothrow) std::vector<FileReaderFactory::OpenFunction>();
+  return open_functions;
+}
+
+}  // namespace
+
+bool FileReaderFactory::RegisterReader(OpenFunction open_function) {
+  if (open_function == nullptr) return false;
+  auto* open_functions = GetFileReaderOpenFunctions();
+  const size_t num_readers = open_functions->size();
+  open_functions->push_back(open_function);
+  return open_functions->size() == num_readers + 1;
+}
+
+std::unique_ptr<FileReaderInterface> FileReaderFactory::OpenReader(
+    const std::string& file_name, const bool error_tolerant /*= false*/) {
+  for (auto* open_function : *GetFileReaderOpenFunctions()) {
+    auto reader = open_function(file_name, error_tolerant);
+    if (reader == nullptr) continue;
+    return reader;
+  }
+  LIBGAV1_EXAMPLES_LOG_ERROR("No file reader able to open input");
+  return nullptr;
+}
+
+}  // namespace libgav1
diff --git a/libgav1/examples/file_reader_factory.h b/libgav1/examples/file_reader_factory.h
new file mode 100644
index 0000000..0f53484
--- /dev/null
+++ b/libgav1/examples/file_reader_factory.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_EXAMPLES_FILE_READER_FACTORY_H_
+#define LIBGAV1_EXAMPLES_FILE_READER_FACTORY_H_
+
+#include <memory>
+#include <string>
+
+#include "examples/file_reader_interface.h"
+
+namespace libgav1 {
+
+class FileReaderFactory {
+ public:
+  using OpenFunction = std::unique_ptr<FileReaderInterface> (*)(
+      const std::string& file_name, bool error_tolerant);
+
+  FileReaderFactory() = delete;
+  FileReaderFactory(const FileReaderFactory&) = delete;
+  FileReaderFactory& operator=(const FileReaderFactory&) = delete;
+  ~FileReaderFactory() = default;
+
+  // Registers the OpenFunction for a FileReaderInterface and returns true when
+  // registration succeeds.
+  static bool RegisterReader(OpenFunction open_function);
+
+  // Passes |file_name| to each OpenFunction until one succeeds. Returns nullptr
+  // when no reader is found for |file_name|. Otherwise a FileReaderInterface is
+  // returned. If |error_tolerant| is true and the reader supports it, some
+  // format and read errors may be ignored and partial data returned.
+  static std::unique_ptr<FileReaderInterface> OpenReader(
+      const std::string& file_name, bool error_tolerant = false);
+};
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_EXAMPLES_FILE_READER_FACTORY_H_
diff --git a/libgav1/examples/file_reader_interface.h b/libgav1/examples/file_reader_interface.h
new file mode 100644
index 0000000..d8f7030
--- /dev/null
+++ b/libgav1/examples/file_reader_interface.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_EXAMPLES_FILE_READER_INTERFACE_H_
+#define LIBGAV1_EXAMPLES_FILE_READER_INTERFACE_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <vector>
+
+namespace libgav1 {
+
+class FileReaderInterface {
+ public:
+  FileReaderInterface() = default;
+  FileReaderInterface(const FileReaderInterface&) = delete;
+  FileReaderInterface& operator=(const FileReaderInterface&) = delete;
+
+  FileReaderInterface(FileReaderInterface&&) = default;
+  FileReaderInterface& operator=(FileReaderInterface&&) = default;
+
+  // Closes the file.
+  virtual ~FileReaderInterface() = default;
+
+  // Reads a temporal unit from the file and writes the data to |tu_data|.
+  // Returns true when:
+  // - A temporal unit is read successfully, or
+  // - At end of file.
+  // When ReadTemporalUnit() is called at the end of the file, it will return
+  // true without writing any data to |tu_data|.
+  //
+  // The |timestamp| pointer is optional: callers not interested in timestamps
+  // can pass nullptr. When |timestamp| is not a nullptr, this function returns
+  // the presentation timestamp of the temporal unit.
+  /*LIBGAV1_MUST_USE_RESULT*/ virtual bool ReadTemporalUnit(
+      std::vector<uint8_t>* tu_data, int64_t* timestamp) = 0;
+
+  /*LIBGAV1_MUST_USE_RESULT*/ virtual bool IsEndOfFile() const = 0;
+
+  // The values returned by these accessors are strictly informative. No
+  // validation is performed when they are read from file.
+  virtual size_t width() const = 0;
+  virtual size_t height() const = 0;
+  virtual size_t frame_rate() const = 0;
+  virtual size_t time_scale() const = 0;
+};
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_EXAMPLES_FILE_READER_INTERFACE_H_
diff --git a/libgav1/examples/file_writer.cc b/libgav1/examples/file_writer.cc
new file mode 100644
index 0000000..54afe14
--- /dev/null
+++ b/libgav1/examples/file_writer.cc
@@ -0,0 +1,183 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "examples/file_writer.h"
+
+#include <cerrno>
+#include <cstdio>
+#include <cstring>
+#include <new>
+#include <string>
+
+#if defined(_WIN32)
+#include <fcntl.h>
+#include <io.h>
+#endif
+
+#include "examples/logging.h"
+
+namespace libgav1 {
+namespace {
+
+FILE* SetBinaryMode(FILE* stream) {
+#if defined(_WIN32)
+  _setmode(_fileno(stream), _O_BINARY);
+#endif
+  return stream;
+}
+
+std::string GetY4mColorSpaceString(
+    const FileWriter::Y4mParameters& y4m_parameters) {
+  std::string color_space_string;
+  switch (y4m_parameters.image_format) {
+    case kImageFormatMonochrome400:
+      color_space_string = "mono";
+      break;
+    case kImageFormatYuv420:
+      if (y4m_parameters.bitdepth == 8) {
+        if (y4m_parameters.chroma_sample_position ==
+            kChromaSamplePositionVertical) {
+          color_space_string = "420mpeg2";
+        } else if (y4m_parameters.chroma_sample_position ==
+                   kChromaSamplePositionColocated) {
+          color_space_string = "420";
+        } else {
+          color_space_string = "420jpeg";
+        }
+      } else {
+        color_space_string = "420";
+      }
+      break;
+    case kImageFormatYuv422:
+      color_space_string = "422";
+      break;
+    case kImageFormatYuv444:
+      color_space_string = "444";
+      break;
+  }
+
+  if (y4m_parameters.bitdepth > 8) {
+    const bool monochrome =
+        y4m_parameters.image_format == kImageFormatMonochrome400;
+    if (!monochrome) color_space_string += "p";
+    color_space_string += std::to_string(y4m_parameters.bitdepth);
+  }
+
+  return color_space_string;
+}
+
+}  // namespace
+
+FileWriter::~FileWriter() { fclose(file_); }
+
+std::unique_ptr<FileWriter> FileWriter::Open(
+    const std::string& file_name, FileType file_type,
+    const Y4mParameters* const y4m_parameters) {
+  if (file_name.empty() ||
+      (file_type == kFileTypeY4m && y4m_parameters == nullptr) ||
+      (file_type != kFileTypeRaw && file_type != kFileTypeY4m)) {
+    LIBGAV1_EXAMPLES_LOG_ERROR("Invalid parameters");
+    return nullptr;
+  }
+
+  FILE* raw_file_ptr;
+
+  if (file_name == "-") {
+    raw_file_ptr = SetBinaryMode(stdout);
+  } else {
+    raw_file_ptr = fopen(file_name.c_str(), "wb");
+  }
+
+  if (raw_file_ptr == nullptr) {
+    LIBGAV1_EXAMPLES_LOG_ERROR("Unable to open output file");
+    return nullptr;
+  }
+
+  std::unique_ptr<FileWriter> file(new (std::nothrow) FileWriter(raw_file_ptr));
+  if (file == nullptr) {
+    LIBGAV1_EXAMPLES_LOG_ERROR("Out of memory");
+    fclose(raw_file_ptr);
+    return nullptr;
+  }
+
+  if (file_type == kFileTypeY4m && !file->WriteY4mFileHeader(*y4m_parameters)) {
+    LIBGAV1_EXAMPLES_LOG_ERROR("Error writing Y4M file header");
+    return nullptr;
+  }
+
+  file->file_type_ = file_type;
+  return file;
+}
+
+bool FileWriter::WriteFrame(const DecoderBuffer& frame_buffer) {
+  if (file_type_ == kFileTypeY4m) {
+    const char kY4mFrameHeader[] = "FRAME\n";
+    if (fwrite(kY4mFrameHeader, 1, strlen(kY4mFrameHeader), file_) !=
+        strlen(kY4mFrameHeader)) {
+      LIBGAV1_EXAMPLES_LOG_ERROR("Error writing Y4M frame header");
+      return false;
+    }
+  }
+
+  const size_t pixel_size =
+      (frame_buffer.bitdepth == 8) ? sizeof(uint8_t) : sizeof(uint16_t);
+  for (int plane_index = 0; plane_index < frame_buffer.NumPlanes();
+       ++plane_index) {
+    const int height = frame_buffer.displayed_height[plane_index];
+    const int width = frame_buffer.displayed_width[plane_index];
+    const int stride = frame_buffer.stride[plane_index];
+    const uint8_t* const plane_pointer = frame_buffer.plane[plane_index];
+    for (int row = 0; row < height; ++row) {
+      const uint8_t* const row_pointer = &plane_pointer[row * stride];
+      if (fwrite(row_pointer, pixel_size, width, file_) !=
+          static_cast<size_t>(width)) {
+        char error_string[256];
+        snprintf(error_string, sizeof(error_string),
+                 "File write failed: %s (errno=%d)", strerror(errno), errno);
+        LIBGAV1_EXAMPLES_LOG_ERROR(error_string);
+        return false;
+      }
+    }
+  }
+
+  return true;
+}
+
+// Writes Y4M file header to |file_| and returns true when successful.
+//
+// A Y4M file begins with a plaintext file signature of 'YUV4MPEG2 '.
+//
+// Following the signature is any number of optional parameters preceded by a
+// space. We always write:
+//
+// Width: 'W' followed by image width in pixels.
+// Height: 'H' followed by image height in pixels.
+// Frame Rate: 'F' followed frames/second in the form numerator:denominator.
+// Interlacing: 'I' followed by 'p' for progressive.
+// Color space: 'C' followed by a string representation of the color space.
+//
+// More info here: https://wiki.multimedia.cx/index.php/YUV4MPEG2
+bool FileWriter::WriteY4mFileHeader(const Y4mParameters& y4m_parameters) {
+  std::string y4m_header = "YUV4MPEG2";
+  y4m_header += " W" + std::to_string(y4m_parameters.width);
+  y4m_header += " H" + std::to_string(y4m_parameters.height);
+  y4m_header += " F" + std::to_string(y4m_parameters.frame_rate_numerator) +
+                ":" + std::to_string(y4m_parameters.frame_rate_denominator);
+  y4m_header += " Ip C" + GetY4mColorSpaceString(y4m_parameters);
+  y4m_header += "\n";
+  return fwrite(y4m_header.c_str(), 1, y4m_header.length(), file_) ==
+         y4m_header.length();
+}
+
+}  // namespace libgav1
diff --git a/libgav1/examples/file_writer.h b/libgav1/examples/file_writer.h
new file mode 100644
index 0000000..00f6cc3
--- /dev/null
+++ b/libgav1/examples/file_writer.h
@@ -0,0 +1,102 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_EXAMPLES_FILE_WRITER_H_
+#define LIBGAV1_EXAMPLES_FILE_WRITER_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+#include <memory>
+#include <string>
+
+#include "gav1/decoder_buffer.h"
+
+namespace libgav1 {
+
+// Frame based file writer class. Supports only Y4M (YUV4MPEG2) and RAW output.
+class FileWriter {
+ public:
+  enum FileType : uint8_t {
+    kFileTypeRaw,
+    kFileTypeY4m,
+  };
+
+  struct Y4mParameters {
+    Y4mParameters() = default;
+    Y4mParameters(size_t width, size_t height, size_t frame_rate_numerator,
+                  size_t frame_rate_denominator,
+                  ChromaSamplePosition chroma_sample_position,
+                  ImageFormat image_format, size_t bitdepth)
+        : width(width),
+          height(height),
+          frame_rate_numerator(frame_rate_numerator),
+          frame_rate_denominator(frame_rate_denominator),
+          chroma_sample_position(chroma_sample_position),
+          image_format(image_format),
+          bitdepth(bitdepth) {}
+
+    Y4mParameters(const Y4mParameters& rhs) = default;
+    Y4mParameters& operator=(const Y4mParameters& rhs) = default;
+    Y4mParameters(Y4mParameters&& rhs) = default;
+    Y4mParameters& operator=(Y4mParameters&& rhs) = default;
+
+    size_t width = 0;
+    size_t height = 0;
+    size_t frame_rate_numerator = 30;
+    size_t frame_rate_denominator = 1;
+    ChromaSamplePosition chroma_sample_position = kChromaSamplePositionUnknown;
+    ImageFormat image_format = kImageFormatYuv420;
+    size_t bitdepth = 8;
+  };
+
+  // Opens |file_name|. When |file_type| is kFileTypeY4m the Y4M file header is
+  // written out to |file_| before this method returns.
+  //
+  // Returns a FileWriter instance after the file is opened successfully for
+  // kFileTypeRaw files, and after the Y4M file header bytes are written for
+  // kFileTypeY4m files. Returns nullptr upon failure.
+  static std::unique_ptr<FileWriter> Open(const std::string& file_name,
+                                          FileType type,
+                                          const Y4mParameters* y4m_parameters);
+
+  FileWriter() = delete;
+  FileWriter(const FileWriter&) = delete;
+  FileWriter& operator=(const FileWriter&) = delete;
+
+  FileWriter(FileWriter&&) = default;
+  FileWriter& operator=(FileWriter&&) = default;
+
+  // Closes |file_|.
+  ~FileWriter();
+
+  // Writes the frame data in |frame_buffer| to |file_|. Returns true after
+  // successful write of |frame_buffer| data.
+  /*LIBGAV1_MUST_USE_RESULT*/ bool WriteFrame(
+      const DecoderBuffer& frame_buffer);
+
+ private:
+  explicit FileWriter(FILE* file) : file_(file) {}
+
+  bool WriteY4mFileHeader(const Y4mParameters& y4m_parameters);
+
+  FILE* file_ = nullptr;
+  FileType file_type_ = kFileTypeRaw;
+};
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_EXAMPLES_FILE_WRITER_H_
diff --git a/libgav1/examples/gav1_decode.cc b/libgav1/examples/gav1_decode.cc
new file mode 100644
index 0000000..e7d3246
--- /dev/null
+++ b/libgav1/examples/gav1_decode.cc
@@ -0,0 +1,453 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <cerrno>
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <deque>
+#include <memory>
+#include <new>
+#include <vector>
+
+#include "absl/strings/numbers.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "examples/file_reader_factory.h"
+#include "examples/file_reader_interface.h"
+#include "examples/file_writer.h"
+#include "gav1/decoder.h"
+
+#ifdef GAV1_DECODE_USE_CV_PIXEL_BUFFER_POOL
+#include "examples/gav1_decode_cv_pixel_buffer_pool.h"
+#endif
+
+namespace {
+
+struct Options {
+  const char* input_file_name = nullptr;
+  const char* output_file_name = nullptr;
+  const char* frame_timing_file_name = nullptr;
+  libgav1::FileWriter::FileType output_file_type =
+      libgav1::FileWriter::kFileTypeRaw;
+  uint8_t post_filter_mask = 0x1f;
+  int threads = 1;
+  bool frame_parallel = false;
+  bool output_all_layers = false;
+  int operating_point = 0;
+  int limit = 0;
+  int skip = 0;
+  int verbose = 0;
+};
+
+struct Timing {
+  absl::Duration input;
+  absl::Duration dequeue;
+};
+
+struct FrameTiming {
+  absl::Time enqueue;
+  absl::Time dequeue;
+};
+
+void PrintHelp(FILE* const fout) {
+  fprintf(fout,
+          "Usage: gav1_decode [options] <input file>"
+          " [-o <output file>]\n");
+  fprintf(fout, "\n");
+  fprintf(fout, "Options:\n");
+  fprintf(fout, "  -h, --help This help message.\n");
+  fprintf(fout, "  --threads <positive integer> (Default 1).\n");
+  fprintf(fout, "  --frame_parallel.\n");
+  fprintf(fout,
+          "  --limit <integer> Stop decoding after N frames (0 = all).\n");
+  fprintf(fout, "  --skip <integer> Skip initial N frames (Default 0).\n");
+  fprintf(fout, "  --version.\n");
+  fprintf(fout, "  --y4m (Default false).\n");
+  fprintf(fout, "  --raw (Default true).\n");
+  fprintf(fout, "  -v logging verbosity, can be used multiple times.\n");
+  fprintf(fout, "  --all_layers.\n");
+  fprintf(fout,
+          "  --operating_point <integer between 0 and 31> (Default 0).\n");
+  fprintf(fout,
+          "  --frame_timing <file> Output per-frame timing to <file> in tsv"
+          " format.\n   Yields meaningful results only when frame parallel is"
+          " off.\n");
+  fprintf(fout, "\nAdvanced settings:\n");
+  fprintf(fout, "  --post_filter_mask <integer> (Default 0x1f).\n");
+  fprintf(fout,
+          "   Mask indicating which post filters should be applied to the"
+          " reconstructed\n   frame. This may be given as octal, decimal or"
+          " hexadecimal. From LSB:\n");
+  fprintf(fout, "     Bit 0: Loop filter (deblocking filter)\n");
+  fprintf(fout, "     Bit 1: Cdef\n");
+  fprintf(fout, "     Bit 2: SuperRes\n");
+  fprintf(fout, "     Bit 3: Loop Restoration\n");
+  fprintf(fout, "     Bit 4: Film Grain Synthesis\n");
+}
+
+void ParseOptions(int argc, char* argv[], Options* const options) {
+  for (int i = 1; i < argc; ++i) {
+    int32_t value;
+    if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) {
+      PrintHelp(stdout);
+      exit(EXIT_SUCCESS);
+    } else if (strcmp(argv[i], "-o") == 0) {
+      if (++i >= argc) {
+        fprintf(stderr, "Missing argument for '-o'\n");
+        PrintHelp(stderr);
+        exit(EXIT_FAILURE);
+      }
+      options->output_file_name = argv[i];
+    } else if (strcmp(argv[i], "--frame_timing") == 0) {
+      if (++i >= argc) {
+        fprintf(stderr, "Missing argument for '--frame_timing'\n");
+        PrintHelp(stderr);
+        exit(EXIT_FAILURE);
+      }
+      options->frame_timing_file_name = argv[i];
+    } else if (strcmp(argv[i], "--version") == 0) {
+      printf("gav1_decode, a libgav1 based AV1 decoder\n");
+      printf("libgav1 %s\n", libgav1::GetVersionString());
+      printf("max bitdepth: %d\n", libgav1::Decoder::GetMaxBitdepth());
+      printf("build configuration: %s\n", libgav1::GetBuildConfiguration());
+      exit(EXIT_SUCCESS);
+    } else if (strcmp(argv[i], "-v") == 0) {
+      ++options->verbose;
+    } else if (strcmp(argv[i], "--raw") == 0) {
+      options->output_file_type = libgav1::FileWriter::kFileTypeRaw;
+    } else if (strcmp(argv[i], "--y4m") == 0) {
+      options->output_file_type = libgav1::FileWriter::kFileTypeY4m;
+    } else if (strcmp(argv[i], "--threads") == 0) {
+      if (++i >= argc || !absl::SimpleAtoi(argv[i], &value)) {
+        fprintf(stderr, "Missing/Invalid value for --threads.\n");
+        PrintHelp(stderr);
+        exit(EXIT_FAILURE);
+      }
+      options->threads = value;
+    } else if (strcmp(argv[i], "--frame_parallel") == 0) {
+      options->frame_parallel = true;
+    } else if (strcmp(argv[i], "--all_layers") == 0) {
+      options->output_all_layers = true;
+    } else if (strcmp(argv[i], "--operating_point") == 0) {
+      if (++i >= argc || !absl::SimpleAtoi(argv[i], &value) || value < 0 ||
+          value >= 32) {
+        fprintf(stderr, "Missing/Invalid value for --operating_point.\n");
+        PrintHelp(stderr);
+        exit(EXIT_FAILURE);
+      }
+      options->operating_point = value;
+    } else if (strcmp(argv[i], "--limit") == 0) {
+      if (++i >= argc || !absl::SimpleAtoi(argv[i], &value) || value < 0) {
+        fprintf(stderr, "Missing/Invalid value for --limit.\n");
+        PrintHelp(stderr);
+        exit(EXIT_FAILURE);
+      }
+      options->limit = value;
+    } else if (strcmp(argv[i], "--skip") == 0) {
+      if (++i >= argc || !absl::SimpleAtoi(argv[i], &value) || value < 0) {
+        fprintf(stderr, "Missing/Invalid value for --skip.\n");
+        PrintHelp(stderr);
+        exit(EXIT_FAILURE);
+      }
+      options->skip = value;
+    } else if (strcmp(argv[i], "--post_filter_mask") == 0) {
+      errno = 0;
+      char* endptr = nullptr;
+      value = (++i >= argc) ? -1
+                            // NOLINTNEXTLINE(runtime/deprecated_fn)
+                            : static_cast<int32_t>(strtol(argv[i], &endptr, 0));
+      // Only the last 5 bits of the mask can be set.
+      if ((value & ~31) != 0 || errno != 0 || endptr == argv[i]) {
+        fprintf(stderr, "Invalid value for --post_filter_mask.\n");
+        PrintHelp(stderr);
+        exit(EXIT_FAILURE);
+      }
+      options->post_filter_mask = value;
+    } else if (strlen(argv[i]) > 1 && argv[i][0] == '-') {
+      fprintf(stderr, "Unknown option '%s'!\n", argv[i]);
+      exit(EXIT_FAILURE);
+    } else {
+      if (options->input_file_name == nullptr) {
+        options->input_file_name = argv[i];
+      } else {
+        fprintf(stderr, "Found invalid parameter: \"%s\".\n", argv[i]);
+        PrintHelp(stderr);
+        exit(EXIT_FAILURE);
+      }
+    }
+  }
+
+  if (argc < 2 || options->input_file_name == nullptr) {
+    fprintf(stderr, "Input file is required!\n");
+    PrintHelp(stderr);
+    exit(EXIT_FAILURE);
+  }
+}
+
+using InputBuffer = std::vector<uint8_t>;
+
+class InputBuffers {
+ public:
+  ~InputBuffers() {
+    for (auto buffer : free_buffers_) {
+      delete buffer;
+    }
+  }
+  InputBuffer* GetFreeBuffer() {
+    if (free_buffers_.empty()) {
+      auto* const buffer = new (std::nothrow) InputBuffer();
+      if (buffer == nullptr) {
+        fprintf(stderr, "Failed to create input buffer.\n");
+        return nullptr;
+      }
+      free_buffers_.push_back(buffer);
+    }
+    InputBuffer* const buffer = free_buffers_.front();
+    free_buffers_.pop_front();
+    return buffer;
+  }
+
+  void ReleaseInputBuffer(InputBuffer* buffer) {
+    free_buffers_.push_back(buffer);
+  }
+
+ private:
+  std::deque<InputBuffer*> free_buffers_;
+};
+
+void ReleaseInputBuffer(void* callback_private_data,
+                        void* buffer_private_data) {
+  auto* const input_buffers = static_cast<InputBuffers*>(callback_private_data);
+  input_buffers->ReleaseInputBuffer(
+      static_cast<InputBuffer*>(buffer_private_data));
+}
+
+int CloseFile(FILE* stream) { return (stream == nullptr) ? 0 : fclose(stream); }
+
+}  // namespace
+
+int main(int argc, char* argv[]) {
+  Options options;
+  ParseOptions(argc, argv, &options);
+
+  auto file_reader =
+      libgav1::FileReaderFactory::OpenReader(options.input_file_name);
+  if (file_reader == nullptr) {
+    fprintf(stderr, "Cannot open input file!\n");
+    return EXIT_FAILURE;
+  }
+
+  std::unique_ptr<FILE, decltype(&CloseFile)> frame_timing_file(nullptr,
+                                                                &CloseFile);
+  if (options.frame_timing_file_name != nullptr) {
+    frame_timing_file.reset(fopen(options.frame_timing_file_name, "wb"));
+    if (frame_timing_file == nullptr) {
+      fprintf(stderr, "Cannot open frame timing file '%s'!\n",
+              options.frame_timing_file_name);
+      return EXIT_FAILURE;
+    }
+  }
+
+#ifdef GAV1_DECODE_USE_CV_PIXEL_BUFFER_POOL
+  // Reference frames + 1 scratch frame (for either the current frame or the
+  // film grain frame).
+  constexpr int kNumBuffers = 8 + 1;
+  std::unique_ptr<Gav1DecodeCVPixelBufferPool> cv_pixel_buffers =
+      Gav1DecodeCVPixelBufferPool::Create(kNumBuffers);
+  if (cv_pixel_buffers == nullptr) {
+    fprintf(stderr, "Cannot create Gav1DecodeCVPixelBufferPool!\n");
+    return EXIT_FAILURE;
+  }
+#endif
+
+  InputBuffers input_buffers;
+  libgav1::Decoder decoder;
+  libgav1::DecoderSettings settings;
+  settings.post_filter_mask = options.post_filter_mask;
+  settings.threads = options.threads;
+  settings.frame_parallel = options.frame_parallel;
+  settings.output_all_layers = options.output_all_layers;
+  settings.operating_point = options.operating_point;
+  settings.blocking_dequeue = true;
+  settings.callback_private_data = &input_buffers;
+  settings.release_input_buffer = ReleaseInputBuffer;
+#ifdef GAV1_DECODE_USE_CV_PIXEL_BUFFER_POOL
+  settings.on_frame_buffer_size_changed = Gav1DecodeOnCVPixelBufferSizeChanged;
+  settings.get_frame_buffer = Gav1DecodeGetCVPixelBuffer;
+  settings.release_frame_buffer = Gav1DecodeReleaseCVPixelBuffer;
+  settings.callback_private_data = cv_pixel_buffers.get();
+  settings.release_input_buffer = nullptr;
+  // TODO(vigneshv): Support frame parallel mode to be used with
+  // CVPixelBufferPool.
+  settings.frame_parallel = false;
+#endif
+  libgav1::StatusCode status = decoder.Init(&settings);
+  if (status != libgav1::kStatusOk) {
+    fprintf(stderr, "Error initializing decoder: %s\n",
+            libgav1::GetErrorString(status));
+    return EXIT_FAILURE;
+  }
+
+  fprintf(stderr, "decoding '%s'\n", options.input_file_name);
+  if (options.verbose > 0 && options.skip > 0) {
+    fprintf(stderr, "skipping %d frame(s).\n", options.skip);
+  }
+
+  int input_frames = 0;
+  int decoded_frames = 0;
+  Timing timing = {};
+  std::vector<FrameTiming> frame_timing;
+  const bool record_frame_timing = frame_timing_file != nullptr;
+  std::unique_ptr<libgav1::FileWriter> file_writer;
+  InputBuffer* input_buffer = nullptr;
+  bool limit_reached = false;
+  bool dequeue_finished = false;
+  const absl::Time decode_loop_start = absl::Now();
+  do {
+    if (input_buffer == nullptr && !file_reader->IsEndOfFile() &&
+        !limit_reached) {
+      input_buffer = input_buffers.GetFreeBuffer();
+      if (input_buffer == nullptr) return EXIT_FAILURE;
+      const absl::Time read_start = absl::Now();
+      if (!file_reader->ReadTemporalUnit(input_buffer,
+                                         /*timestamp=*/nullptr)) {
+        fprintf(stderr, "Error reading input file.\n");
+        return EXIT_FAILURE;
+      }
+      timing.input += absl::Now() - read_start;
+    }
+
+    if (++input_frames <= options.skip) {
+      input_buffers.ReleaseInputBuffer(input_buffer);
+      input_buffer = nullptr;
+      continue;
+    }
+
+    if (input_buffer != nullptr) {
+      if (input_buffer->empty()) {
+        input_buffers.ReleaseInputBuffer(input_buffer);
+        input_buffer = nullptr;
+        continue;
+      }
+
+      const absl::Time enqueue_start = absl::Now();
+      status = decoder.EnqueueFrame(input_buffer->data(), input_buffer->size(),
+                                    static_cast<int64_t>(frame_timing.size()),
+                                    /*buffer_private_data=*/input_buffer);
+      if (status == libgav1::kStatusOk) {
+        if (options.verbose > 1) {
+          fprintf(stderr, "enqueue frame (length %zu)\n", input_buffer->size());
+        }
+        if (record_frame_timing) {
+          FrameTiming enqueue_time = {enqueue_start, absl::UnixEpoch()};
+          frame_timing.emplace_back(enqueue_time);
+        }
+
+        input_buffer = nullptr;
+        // Continue to enqueue frames until we get a kStatusTryAgain status.
+        continue;
+      }
+      if (status != libgav1::kStatusTryAgain) {
+        fprintf(stderr, "Unable to enqueue frame: %s\n",
+                libgav1::GetErrorString(status));
+        return EXIT_FAILURE;
+      }
+    }
+
+    const libgav1::DecoderBuffer* buffer;
+    status = decoder.DequeueFrame(&buffer);
+    if (status != libgav1::kStatusOk &&
+        status != libgav1::kStatusNothingToDequeue) {
+      fprintf(stderr, "Unable to dequeue frame: %s\n",
+              libgav1::GetErrorString(status));
+      return EXIT_FAILURE;
+    }
+    if (status == libgav1::kStatusNothingToDequeue) {
+      dequeue_finished = true;
+      continue;
+    }
+    dequeue_finished = false;
+    if (buffer == nullptr) continue;
+    ++decoded_frames;
+    if (options.verbose > 1) {
+      fprintf(stderr, "buffer dequeued\n");
+    }
+
+    if (record_frame_timing) {
+      frame_timing[static_cast<int>(buffer->user_private_data)].dequeue =
+          absl::Now();
+    }
+
+    if (options.output_file_name != nullptr && file_writer == nullptr) {
+      libgav1::FileWriter::Y4mParameters y4m_parameters;
+      y4m_parameters.width = buffer->displayed_width[0];
+      y4m_parameters.height = buffer->displayed_height[0];
+      y4m_parameters.frame_rate_numerator = file_reader->frame_rate();
+      y4m_parameters.frame_rate_denominator = file_reader->time_scale();
+      y4m_parameters.chroma_sample_position = buffer->chroma_sample_position;
+      y4m_parameters.image_format = buffer->image_format;
+      y4m_parameters.bitdepth = static_cast<size_t>(buffer->bitdepth);
+      file_writer = libgav1::FileWriter::Open(
+          options.output_file_name, options.output_file_type, &y4m_parameters);
+      if (file_writer == nullptr) {
+        fprintf(stderr, "Cannot open output file!\n");
+        return EXIT_FAILURE;
+      }
+    }
+
+    if (!limit_reached && file_writer != nullptr &&
+        !file_writer->WriteFrame(*buffer)) {
+      fprintf(stderr, "Error writing output file.\n");
+      return EXIT_FAILURE;
+    }
+    if (options.limit > 0 && options.limit == decoded_frames) {
+      limit_reached = true;
+      if (input_buffer != nullptr) {
+        input_buffers.ReleaseInputBuffer(input_buffer);
+      }
+      input_buffer = nullptr;
+    }
+  } while (input_buffer != nullptr ||
+           (!file_reader->IsEndOfFile() && !limit_reached) ||
+           !dequeue_finished);
+  timing.dequeue = absl::Now() - decode_loop_start - timing.input;
+
+  if (record_frame_timing) {
+    // Note timing for frame parallel will be skewed by the time spent queueing
+    // additional frames and in the output queue waiting for previous frames,
+    // the values reported won't be that meaningful.
+    fprintf(frame_timing_file.get(), "frame number\tdecode time us\n");
+    for (size_t i = 0; i < frame_timing.size(); ++i) {
+      const int decode_time_us = static_cast<int>(absl::ToInt64Microseconds(
+          frame_timing[i].dequeue - frame_timing[i].enqueue));
+      fprintf(frame_timing_file.get(), "%zu\t%d\n", i, decode_time_us);
+    }
+  }
+
+  if (options.verbose > 0) {
+    fprintf(stderr, "time to read input: %d us\n",
+            static_cast<int>(absl::ToInt64Microseconds(timing.input)));
+    const int decode_time_us =
+        static_cast<int>(absl::ToInt64Microseconds(timing.dequeue));
+    const double decode_fps =
+        (decode_time_us == 0) ? 0.0 : 1.0e6 * decoded_frames / decode_time_us;
+    fprintf(stderr, "time to decode input: %d us (%d frames, %.2f fps)\n",
+            decode_time_us, decoded_frames, decode_fps);
+  }
+
+  return EXIT_SUCCESS;
+}
diff --git a/libgav1/examples/gav1_decode_cv_pixel_buffer_pool.cc b/libgav1/examples/gav1_decode_cv_pixel_buffer_pool.cc
new file mode 100644
index 0000000..6aa4e61
--- /dev/null
+++ b/libgav1/examples/gav1_decode_cv_pixel_buffer_pool.cc
@@ -0,0 +1,278 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "examples/gav1_decode_cv_pixel_buffer_pool.h"
+
+#include <cassert>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <memory>
+#include <new>
+#include <type_traits>
+
+namespace {
+
+struct CFTypeDeleter {
+  void operator()(CFTypeRef cf) const { CFRelease(cf); }
+};
+
+using UniqueCFNumberRef =
+    std::unique_ptr<std::remove_pointer<CFNumberRef>::type, CFTypeDeleter>;
+
+using UniqueCFDictionaryRef =
+    std::unique_ptr<std::remove_pointer<CFDictionaryRef>::type, CFTypeDeleter>;
+
+}  // namespace
+
+extern "C" {
+
+libgav1::StatusCode Gav1DecodeOnCVPixelBufferSizeChanged(
+    void* callback_private_data, int bitdepth,
+    libgav1::ImageFormat image_format, int width, int height, int left_border,
+    int right_border, int top_border, int bottom_border, int stride_alignment) {
+  auto* buffer_pool =
+      static_cast<Gav1DecodeCVPixelBufferPool*>(callback_private_data);
+  return buffer_pool->OnCVPixelBufferSizeChanged(
+      bitdepth, image_format, width, height, left_border, right_border,
+      top_border, bottom_border, stride_alignment);
+}
+
+libgav1::StatusCode Gav1DecodeGetCVPixelBuffer(
+    void* callback_private_data, int bitdepth,
+    libgav1::ImageFormat image_format, int width, int height, int left_border,
+    int right_border, int top_border, int bottom_border, int stride_alignment,
+    libgav1::FrameBuffer* frame_buffer) {
+  auto* buffer_pool =
+      static_cast<Gav1DecodeCVPixelBufferPool*>(callback_private_data);
+  return buffer_pool->GetCVPixelBuffer(
+      bitdepth, image_format, width, height, left_border, right_border,
+      top_border, bottom_border, stride_alignment, frame_buffer);
+}
+
+void Gav1DecodeReleaseCVPixelBuffer(void* callback_private_data,
+                                    void* buffer_private_data) {
+  auto* buffer_pool =
+      static_cast<Gav1DecodeCVPixelBufferPool*>(callback_private_data);
+  buffer_pool->ReleaseCVPixelBuffer(buffer_private_data);
+}
+
+}  // extern "C"
+
+// static
+std::unique_ptr<Gav1DecodeCVPixelBufferPool>
+Gav1DecodeCVPixelBufferPool::Create(size_t num_buffers) {
+  std::unique_ptr<Gav1DecodeCVPixelBufferPool> buffer_pool(
+      new (std::nothrow) Gav1DecodeCVPixelBufferPool(num_buffers));
+  return buffer_pool;
+}
+
+Gav1DecodeCVPixelBufferPool::Gav1DecodeCVPixelBufferPool(size_t num_buffers)
+    : num_buffers_(static_cast<int>(num_buffers)) {}
+
+Gav1DecodeCVPixelBufferPool::~Gav1DecodeCVPixelBufferPool() {
+  CVPixelBufferPoolRelease(pool_);
+}
+
+libgav1::StatusCode Gav1DecodeCVPixelBufferPool::OnCVPixelBufferSizeChanged(
+    int bitdepth, libgav1::ImageFormat image_format, int width, int height,
+    int left_border, int right_border, int top_border, int bottom_border,
+    int stride_alignment) {
+  if (bitdepth != 8 || (image_format != libgav1::kImageFormatYuv420 &&
+                        image_format != libgav1::kImageFormatMonochrome400)) {
+    fprintf(stderr,
+            "Only bitdepth 8, 4:2:0 videos are supported: bitdepth %d, "
+            "image_format: %d.\n",
+            bitdepth, image_format);
+    return libgav1::kStatusUnimplemented;
+  }
+
+  // stride_alignment must be a power of 2.
+  assert((stride_alignment & (stride_alignment - 1)) == 0);
+
+  // The possible keys for CVPixelBufferPool are:
+  //   kCVPixelBufferPoolMinimumBufferCountKey
+  //   kCVPixelBufferPoolMaximumBufferAgeKey
+  //   kCVPixelBufferPoolAllocationThresholdKey
+  const void* pool_keys[] = {kCVPixelBufferPoolMinimumBufferCountKey};
+  const int min_buffer_count = 10;
+  UniqueCFNumberRef cf_min_buffer_count(
+      CFNumberCreate(kCFAllocatorDefault, kCFNumberIntType, &min_buffer_count));
+  if (cf_min_buffer_count == nullptr) {
+    fprintf(stderr, "CFNumberCreate failed.\n");
+    return libgav1::kStatusUnknownError;
+  }
+  const void* pool_values[] = {cf_min_buffer_count.get()};
+  UniqueCFDictionaryRef pool_attributes(CFDictionaryCreate(
+      nullptr, pool_keys, pool_values, 1, &kCFTypeDictionaryKeyCallBacks,
+      &kCFTypeDictionaryValueCallBacks));
+  if (pool_attributes == nullptr) {
+    fprintf(stderr, "CFDictionaryCreate failed.\n");
+    return libgav1::kStatusUnknownError;
+  }
+
+  // The pixelBufferAttributes argument to CVPixelBufferPoolCreate() cannot be
+  // null and must contain the pixel format, width, and height, otherwise
+  // CVPixelBufferPoolCreate() fails with kCVReturnInvalidPixelBufferAttributes
+  // (-6682).
+
+  // I420: kCVPixelFormatType_420YpCbCr8Planar (video range).
+  const int pixel_format = (image_format == libgav1::kImageFormatYuv420)
+                               ? kCVPixelFormatType_420YpCbCr8PlanarFullRange
+                               : kCVPixelFormatType_OneComponent8;
+  UniqueCFNumberRef cf_pixel_format(
+      CFNumberCreate(kCFAllocatorDefault, kCFNumberIntType, &pixel_format));
+  UniqueCFNumberRef cf_width(
+      CFNumberCreate(kCFAllocatorDefault, kCFNumberIntType, &width));
+  UniqueCFNumberRef cf_height(
+      CFNumberCreate(kCFAllocatorDefault, kCFNumberIntType, &height));
+  UniqueCFNumberRef cf_left_border(
+      CFNumberCreate(kCFAllocatorDefault, kCFNumberIntType, &left_border));
+  UniqueCFNumberRef cf_right_border(
+      CFNumberCreate(kCFAllocatorDefault, kCFNumberIntType, &right_border));
+  UniqueCFNumberRef cf_top_border(
+      CFNumberCreate(kCFAllocatorDefault, kCFNumberIntType, &top_border));
+  UniqueCFNumberRef cf_bottom_border(
+      CFNumberCreate(kCFAllocatorDefault, kCFNumberIntType, &bottom_border));
+  UniqueCFNumberRef cf_stride_alignment(
+      CFNumberCreate(kCFAllocatorDefault, kCFNumberIntType, &stride_alignment));
+
+  const void* buffer_keys[] = {
+      kCVPixelBufferPixelFormatTypeKey,
+      kCVPixelBufferWidthKey,
+      kCVPixelBufferHeightKey,
+      kCVPixelBufferExtendedPixelsLeftKey,
+      kCVPixelBufferExtendedPixelsRightKey,
+      kCVPixelBufferExtendedPixelsTopKey,
+      kCVPixelBufferExtendedPixelsBottomKey,
+      kCVPixelBufferBytesPerRowAlignmentKey,
+  };
+  const void* buffer_values[] = {
+      cf_pixel_format.get(),  cf_width.get(),
+      cf_height.get(),        cf_left_border.get(),
+      cf_right_border.get(),  cf_top_border.get(),
+      cf_bottom_border.get(), cf_stride_alignment.get(),
+  };
+  UniqueCFDictionaryRef buffer_attributes(CFDictionaryCreate(
+      kCFAllocatorDefault, buffer_keys, buffer_values, 8,
+      &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks));
+  if (buffer_attributes == nullptr) {
+    fprintf(stderr, "CFDictionaryCreate of buffer_attributes failed.\n");
+    return libgav1::kStatusUnknownError;
+  }
+  CVPixelBufferPoolRef cv_pool;
+  CVReturn ret = CVPixelBufferPoolCreate(
+      /*allocator=*/nullptr, pool_attributes.get(), buffer_attributes.get(),
+      &cv_pool);
+  if (ret != kCVReturnSuccess) {
+    fprintf(stderr, "CVPixelBufferPoolCreate failed: %d.\n",
+            static_cast<int>(ret));
+    return libgav1::kStatusOutOfMemory;
+  }
+  CVPixelBufferPoolRelease(pool_);
+  pool_ = cv_pool;
+  return libgav1::kStatusOk;
+}
+
+libgav1::StatusCode Gav1DecodeCVPixelBufferPool::GetCVPixelBuffer(
+    int bitdepth, libgav1::ImageFormat image_format, int /*width*/,
+    int /*height*/, int /*left_border*/, int /*right_border*/,
+    int /*top_border*/, int /*bottom_border*/, int /*stride_alignment*/,
+    libgav1::FrameBuffer* frame_buffer) {
+  static_cast<void>(bitdepth);
+  assert(bitdepth == 8 && (image_format == libgav1::kImageFormatYuv420 ||
+                           image_format == libgav1::kImageFormatMonochrome400));
+  const bool is_monochrome =
+      (image_format == libgav1::kImageFormatMonochrome400);
+
+  // The dictionary must have kCVPixelBufferPoolAllocationThresholdKey,
+  // otherwise CVPixelBufferPoolCreatePixelBufferWithAuxAttributes() fails with
+  // kCVReturnWouldExceedAllocationThreshold (-6689).
+  UniqueCFNumberRef cf_num_buffers(
+      CFNumberCreate(kCFAllocatorDefault, kCFNumberIntType, &num_buffers_));
+
+  const void* buffer_keys[] = {
+      kCVPixelBufferPoolAllocationThresholdKey,
+  };
+  const void* buffer_values[] = {
+      cf_num_buffers.get(),
+  };
+  UniqueCFDictionaryRef aux_attributes(CFDictionaryCreate(
+      kCFAllocatorDefault, buffer_keys, buffer_values, 1,
+      &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks));
+  if (aux_attributes == nullptr) {
+    fprintf(stderr, "CFDictionaryCreate of aux_attributes failed.\n");
+    return libgav1::kStatusUnknownError;
+  }
+
+  CVPixelBufferRef pixel_buffer;
+  CVReturn ret = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes(
+      /*allocator=*/nullptr, pool_, aux_attributes.get(), &pixel_buffer);
+  if (ret != kCVReturnSuccess) {
+    fprintf(stderr,
+            "CVPixelBufferPoolCreatePixelBufferWithAuxAttributes failed: %d.\n",
+            static_cast<int>(ret));
+    return libgav1::kStatusOutOfMemory;
+  }
+
+  ret = CVPixelBufferLockBaseAddress(pixel_buffer, /*lockFlags=*/0);
+  if (ret != kCVReturnSuccess) {
+    fprintf(stderr, "CVPixelBufferLockBaseAddress failed: %d.\n",
+            static_cast<int>(ret));
+    CFRelease(pixel_buffer);
+    return libgav1::kStatusUnknownError;
+  }
+
+  // If the pixel format type is kCVPixelFormatType_OneComponent8, the pixel
+  // buffer is nonplanar (CVPixelBufferIsPlanar returns false and
+  // CVPixelBufferGetPlaneCount returns 0), but
+  // CVPixelBufferGetBytesPerRowOfPlane and CVPixelBufferGetBaseAddressOfPlane
+  // still work for plane index 0, even though the documentation says they
+  // return NULL for nonplanar pixel buffers.
+  frame_buffer->stride[0] =
+      static_cast<int>(CVPixelBufferGetBytesPerRowOfPlane(pixel_buffer, 0));
+  frame_buffer->plane[0] = static_cast<uint8_t*>(
+      CVPixelBufferGetBaseAddressOfPlane(pixel_buffer, 0));
+  if (is_monochrome) {
+    frame_buffer->stride[1] = 0;
+    frame_buffer->stride[2] = 0;
+    frame_buffer->plane[1] = nullptr;
+    frame_buffer->plane[2] = nullptr;
+  } else {
+    frame_buffer->stride[1] =
+        static_cast<int>(CVPixelBufferGetBytesPerRowOfPlane(pixel_buffer, 1));
+    frame_buffer->stride[2] =
+        static_cast<int>(CVPixelBufferGetBytesPerRowOfPlane(pixel_buffer, 2));
+    frame_buffer->plane[1] = static_cast<uint8_t*>(
+        CVPixelBufferGetBaseAddressOfPlane(pixel_buffer, 1));
+    frame_buffer->plane[2] = static_cast<uint8_t*>(
+        CVPixelBufferGetBaseAddressOfPlane(pixel_buffer, 2));
+  }
+  frame_buffer->private_data = pixel_buffer;
+
+  return libgav1::kStatusOk;
+}
+
+void Gav1DecodeCVPixelBufferPool::ReleaseCVPixelBuffer(
+    void* buffer_private_data) {
+  auto const pixel_buffer = static_cast<CVPixelBufferRef>(buffer_private_data);
+  CVReturn ret =
+      CVPixelBufferUnlockBaseAddress(pixel_buffer, /*unlockFlags=*/0);
+  if (ret != kCVReturnSuccess) {
+    fprintf(stderr, "%s:%d: CVPixelBufferUnlockBaseAddress failed: %d.\n",
+            __FILE__, __LINE__, static_cast<int>(ret));
+    abort();
+  }
+  CFRelease(pixel_buffer);
+}
diff --git a/libgav1/examples/gav1_decode_cv_pixel_buffer_pool.h b/libgav1/examples/gav1_decode_cv_pixel_buffer_pool.h
new file mode 100644
index 0000000..7aee324
--- /dev/null
+++ b/libgav1/examples/gav1_decode_cv_pixel_buffer_pool.h
@@ -0,0 +1,73 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_EXAMPLES_GAV1_DECODE_CV_PIXEL_BUFFER_POOL_H_
+#define LIBGAV1_EXAMPLES_GAV1_DECODE_CV_PIXEL_BUFFER_POOL_H_
+
+#include <CoreVideo/CoreVideo.h>
+
+#include <cstddef>
+#include <memory>
+
+#include "gav1/frame_buffer.h"
+
+extern "C" libgav1::StatusCode Gav1DecodeOnCVPixelBufferSizeChanged(
+    void* callback_private_data, int bitdepth,
+    libgav1::ImageFormat image_format, int width, int height, int left_border,
+    int right_border, int top_border, int bottom_border, int stride_alignment);
+
+extern "C" libgav1::StatusCode Gav1DecodeGetCVPixelBuffer(
+    void* callback_private_data, int bitdepth,
+    libgav1::ImageFormat image_format, int width, int height, int left_border,
+    int right_border, int top_border, int bottom_border, int stride_alignment,
+    libgav1::FrameBuffer* frame_buffer);
+
+extern "C" void Gav1DecodeReleaseCVPixelBuffer(void* callback_private_data,
+                                               void* buffer_private_data);
+
+class Gav1DecodeCVPixelBufferPool {
+ public:
+  static std::unique_ptr<Gav1DecodeCVPixelBufferPool> Create(
+      size_t num_buffers);
+
+  // Not copyable or movable.
+  Gav1DecodeCVPixelBufferPool(const Gav1DecodeCVPixelBufferPool&) = delete;
+  Gav1DecodeCVPixelBufferPool& operator=(const Gav1DecodeCVPixelBufferPool&) =
+      delete;
+
+  ~Gav1DecodeCVPixelBufferPool();
+
+  libgav1::StatusCode OnCVPixelBufferSizeChanged(
+      int bitdepth, libgav1::ImageFormat image_format, int width, int height,
+      int left_border, int right_border, int top_border, int bottom_border,
+      int stride_alignment);
+
+  libgav1::StatusCode GetCVPixelBuffer(int bitdepth,
+                                       libgav1::ImageFormat image_format,
+                                       int width, int height, int left_border,
+                                       int right_border, int top_border,
+                                       int bottom_border, int stride_alignment,
+                                       libgav1::FrameBuffer* frame_buffer);
+  void ReleaseCVPixelBuffer(void* buffer_private_data);
+
+ private:
+  Gav1DecodeCVPixelBufferPool(size_t num_buffers);
+
+  CVPixelBufferPoolRef pool_ = nullptr;
+  const int num_buffers_;
+};
+
+#endif  // LIBGAV1_EXAMPLES_GAV1_DECODE_CV_PIXEL_BUFFER_POOL_H_
diff --git a/libgav1/examples/ivf_parser.cc b/libgav1/examples/ivf_parser.cc
new file mode 100644
index 0000000..f8adb14
--- /dev/null
+++ b/libgav1/examples/ivf_parser.cc
@@ -0,0 +1,96 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "examples/ivf_parser.h"
+
+#include <cstdio>
+#include <cstring>
+
+#include "examples/file_reader_constants.h"
+#include "examples/logging.h"
+
+namespace libgav1 {
+namespace {
+
+size_t ReadLittleEndian16(const uint8_t* const buffer) {
+  size_t value = buffer[1] << 8;
+  value |= buffer[0];
+  return value;
+}
+
+size_t ReadLittleEndian32(const uint8_t* const buffer) {
+  size_t value = buffer[3] << 24;
+  value |= buffer[2] << 16;
+  value |= buffer[1] << 8;
+  value |= buffer[0];
+  return value;
+}
+
+}  // namespace
+
+bool ParseIvfFileHeader(const uint8_t* const header_buffer,
+                        IvfFileHeader* const ivf_file_header) {
+  if (header_buffer == nullptr || ivf_file_header == nullptr) return false;
+
+  if (memcmp(kIvfSignature, header_buffer, 4) != 0) {
+    return false;
+  }
+
+  // Verify header version and length.
+  const size_t ivf_header_version = ReadLittleEndian16(&header_buffer[4]);
+  if (ivf_header_version != kIvfHeaderVersion) {
+    LIBGAV1_EXAMPLES_LOG_ERROR("Unexpected IVF version");
+  }
+
+  const size_t ivf_header_size = ReadLittleEndian16(&header_buffer[6]);
+  if (ivf_header_size != kIvfFileHeaderSize) {
+    LIBGAV1_EXAMPLES_LOG_ERROR("Invalid IVF file header size");
+    return false;
+  }
+
+  if (memcmp(kAv1FourCcLower, &header_buffer[8], 4) != 0 &&
+      memcmp(kAv1FourCcUpper, &header_buffer[8], 4) != 0) {
+    LIBGAV1_EXAMPLES_LOG_ERROR("Unsupported codec 4CC");
+    return false;
+  }
+
+  ivf_file_header->width = ReadLittleEndian16(&header_buffer[12]);
+  ivf_file_header->height = ReadLittleEndian16(&header_buffer[14]);
+  ivf_file_header->frame_rate_numerator =
+      ReadLittleEndian32(&header_buffer[16]);
+  ivf_file_header->frame_rate_denominator =
+      ReadLittleEndian32(&header_buffer[20]);
+
+  return true;
+}
+
+bool ParseIvfFrameHeader(const uint8_t* const header_buffer,
+                         IvfFrameHeader* const ivf_frame_header) {
+  if (header_buffer == nullptr || ivf_frame_header == nullptr) return false;
+
+  ivf_frame_header->frame_size = ReadLittleEndian32(header_buffer);
+  if (ivf_frame_header->frame_size > kMaxTemporalUnitSize) {
+    LIBGAV1_EXAMPLES_LOG_ERROR("Temporal Unit size exceeds maximum");
+    return false;
+  }
+
+  ivf_frame_header->timestamp = ReadLittleEndian32(&header_buffer[4]);
+  const uint64_t timestamp_hi =
+      static_cast<uint64_t>(ReadLittleEndian32(&header_buffer[8])) << 32;
+  ivf_frame_header->timestamp |= timestamp_hi;
+
+  return true;
+}
+
+}  // namespace libgav1
diff --git a/libgav1/examples/ivf_parser.h b/libgav1/examples/ivf_parser.h
new file mode 100644
index 0000000..b6bbc59
--- /dev/null
+++ b/libgav1/examples/ivf_parser.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_EXAMPLES_IVF_PARSER_H_
+#define LIBGAV1_EXAMPLES_IVF_PARSER_H_
+
+#include <cstddef>
+#include <cstdint>
+
+namespace libgav1 {
+
+struct IvfFileHeader {
+  IvfFileHeader() = default;
+  IvfFileHeader(const IvfFileHeader& rhs) = default;
+  IvfFileHeader& operator=(const IvfFileHeader& rhs) = default;
+  IvfFileHeader(IvfFileHeader&& rhs) = default;
+  IvfFileHeader& operator=(IvfFileHeader&& rhs) = default;
+
+  size_t width = 0;
+  size_t height = 0;
+  size_t frame_rate_numerator = 0;
+  size_t frame_rate_denominator = 0;
+};
+
+struct IvfFrameHeader {
+  IvfFrameHeader() = default;
+  IvfFrameHeader(const IvfFrameHeader& rhs) = default;
+  IvfFrameHeader& operator=(const IvfFrameHeader& rhs) = default;
+  IvfFrameHeader(IvfFrameHeader&& rhs) = default;
+  IvfFrameHeader& operator=(IvfFrameHeader&& rhs) = default;
+
+  size_t frame_size = 0;
+  int64_t timestamp = 0;
+};
+
+bool ParseIvfFileHeader(const uint8_t* header_buffer,
+                        IvfFileHeader* ivf_file_header);
+
+bool ParseIvfFrameHeader(const uint8_t* header_buffer,
+                         IvfFrameHeader* ivf_frame_header);
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_EXAMPLES_IVF_PARSER_H_
diff --git a/libgav1/examples/libgav1_examples.cmake b/libgav1/examples/libgav1_examples.cmake
new file mode 100644
index 0000000..1f949f3
--- /dev/null
+++ b/libgav1/examples/libgav1_examples.cmake
@@ -0,0 +1,63 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_EXAMPLES_LIBGAV1_EXAMPLES_CMAKE_)
+  return()
+endif() # LIBGAV1_EXAMPLES_LIBGAV1_EXAMPLES_CMAKE_
+set(LIBGAV1_EXAMPLES_LIBGAV1_EXAMPLES_CMAKE_ 1)
+
+set(libgav1_file_reader_sources "${libgav1_examples}/file_reader.cc"
+                                "${libgav1_examples}/file_reader.h"
+                                "${libgav1_examples}/file_reader_constants.cc"
+                                "${libgav1_examples}/file_reader_constants.h"
+                                "${libgav1_examples}/file_reader_factory.cc"
+                                "${libgav1_examples}/file_reader_factory.h"
+                                "${libgav1_examples}/file_reader_interface.h"
+                                "${libgav1_examples}/ivf_parser.cc"
+                                "${libgav1_examples}/ivf_parser.h"
+                                "${libgav1_examples}/logging.h")
+
+set(libgav1_file_writer_sources "${libgav1_examples}/file_writer.cc"
+                                "${libgav1_examples}/file_writer.h"
+                                "${libgav1_examples}/logging.h")
+
+set(libgav1_decode_sources "${libgav1_examples}/gav1_decode.cc")
+
+macro(libgav1_add_examples_targets)
+  libgav1_add_library(NAME libgav1_file_reader TYPE OBJECT SOURCES
+                      ${libgav1_file_reader_sources} DEFINES ${libgav1_defines}
+                      INCLUDES ${libgav1_include_paths})
+
+  libgav1_add_library(NAME libgav1_file_writer TYPE OBJECT SOURCES
+                      ${libgav1_file_writer_sources} DEFINES ${libgav1_defines}
+                      INCLUDES ${libgav1_include_paths})
+
+  libgav1_add_executable(NAME
+                         gav1_decode
+                         SOURCES
+                         ${libgav1_decode_sources}
+                         DEFINES
+                         ${libgav1_defines}
+                         INCLUDES
+                         ${libgav1_include_paths}
+                         ${libgav1_gtest_include_paths}
+                         OBJLIB_DEPS
+                         libgav1_file_reader
+                         libgav1_file_writer
+                         LIB_DEPS
+                         absl::strings
+                         absl::str_format_internal
+                         absl::time
+                         ${libgav1_dependency})
+endmacro()
diff --git a/libgav1/examples/logging.h b/libgav1/examples/logging.h
new file mode 100644
index 0000000..c0bcad7
--- /dev/null
+++ b/libgav1/examples/logging.h
@@ -0,0 +1,65 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_EXAMPLES_LOGGING_H_
+#define LIBGAV1_EXAMPLES_LOGGING_H_
+
+#include <cstddef>
+#include <cstdio>
+
+namespace libgav1 {
+namespace examples {
+
+#if !defined(LIBGAV1_EXAMPLES_ENABLE_LOGGING)
+#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION)
+#define LIBGAV1_EXAMPLES_ENABLE_LOGGING 0
+#else
+#define LIBGAV1_EXAMPLES_ENABLE_LOGGING 1
+#endif
+#endif
+
+#if LIBGAV1_EXAMPLES_ENABLE_LOGGING
+
+// Compile-time function to get the 'base' file_name, that is, the part of
+// a file_name after the last '/' or '\' path separator. The search starts at
+// the end of the string; the second parameter is the length of the string.
+constexpr const char* Basename(const char* file_name, size_t offset) {
+  return (offset == 0 || file_name[offset - 1] == '/' ||
+          file_name[offset - 1] == '\\')
+             ? file_name + offset
+             : Basename(file_name, offset - 1);
+}
+
+#define LIBGAV1_EXAMPLES_LOG_ERROR(error_string)                              \
+  do {                                                                        \
+    constexpr const char* libgav1_examples_basename =                         \
+        ::libgav1::examples::Basename(__FILE__, sizeof(__FILE__) - 1);        \
+    fprintf(stderr, "%s:%d (%s): %s.\n", libgav1_examples_basename, __LINE__, \
+            __func__, error_string);                                          \
+  } while (false)
+
+#else  // !LIBGAV1_EXAMPLES_ENABLE_LOGGING
+
+#define LIBGAV1_EXAMPLES_LOG_ERROR(error_string) \
+  do {                                           \
+  } while (false)
+
+#endif  // LIBGAV1_EXAMPLES_ENABLE_LOGGING
+
+}  // namespace examples
+}  // namespace libgav1
+
+#endif  // LIBGAV1_EXAMPLES_LOGGING_H_
diff --git a/libgav1/src/buffer_pool.cc b/libgav1/src/buffer_pool.cc
index 63312ef..c1a5606 100644
--- a/libgav1/src/buffer_pool.cc
+++ b/libgav1/src/buffer_pool.cc
@@ -18,6 +18,7 @@
 #include <cstring>
 
 #include "src/utils/common.h"
+#include "src/utils/constants.h"
 #include "src/utils/logging.h"
 
 namespace libgav1 {
@@ -36,19 +37,28 @@
 
 }  // namespace
 
-RefCountedBuffer::RefCountedBuffer() {
-  memset(&raw_frame_buffer_, 0, sizeof(raw_frame_buffer_));
-}
+RefCountedBuffer::RefCountedBuffer() = default;
 
 RefCountedBuffer::~RefCountedBuffer() = default;
 
 bool RefCountedBuffer::Realloc(int bitdepth, bool is_monochrome, int width,
                                int height, int subsampling_x, int subsampling_y,
-                               int border, int byte_alignment) {
-  return yuv_buffer_.Realloc(bitdepth, is_monochrome, width, height,
-                             subsampling_x, subsampling_y, border,
-                             byte_alignment, pool_->get_frame_buffer_,
-                             pool_->callback_private_data_, &raw_frame_buffer_);
+                               int left_border, int right_border,
+                               int top_border, int bottom_border) {
+  // The YuvBuffer::Realloc() could call the get frame buffer callback which
+  // will need to be thread safe. So we ensure that we only call Realloc() once
+  // at any given time.
+  std::lock_guard<std::mutex> lock(pool_->mutex_);
+  assert(!buffer_private_data_valid_);
+  if (!yuv_buffer_.Realloc(
+          bitdepth, is_monochrome, width, height, subsampling_x, subsampling_y,
+          left_border, right_border, top_border, bottom_border,
+          pool_->get_frame_buffer_, pool_->callback_private_data_,
+          &buffer_private_data_)) {
+    return false;
+  }
+  buffer_private_data_valid_ = true;
+  return true;
 }
 
 bool RefCountedBuffer::SetFrameDimensions(const ObuFrameHeader& frame_header) {
@@ -59,13 +69,13 @@
   render_height_ = frame_header.render_height;
   rows4x4_ = frame_header.rows4x4;
   columns4x4_ = frame_header.columns4x4;
-  const int rows4x4_half = DivideBy2(rows4x4_);
-  const int columns4x4_half = DivideBy2(columns4x4_);
-  if (!motion_field_reference_frame_.Reset(rows4x4_half, columns4x4_half,
-                                           /*zero_initialize=*/false) ||
-      !motion_field_mv_.Reset(rows4x4_half, columns4x4_half,
-                              /*zero_initialize=*/false)) {
-    return false;
+  if (frame_header.refresh_frame_flags != 0 &&
+      !IsIntraFrame(frame_header.frame_type)) {
+    const int rows4x4_half = DivideBy2(rows4x4_);
+    const int columns4x4_half = DivideBy2(columns4x4_);
+    if (!reference_info_.Reset(rows4x4_half, columns4x4_half)) {
+      return false;
+    }
   }
   return segmentation_map_.Allocate(rows4x4_, columns4x4_);
 }
@@ -103,55 +113,105 @@
   ptr->pool_->ReturnUnusedBuffer(ptr);
 }
 
-// static
-constexpr int BufferPool::kNumBuffers;
-
-BufferPool::BufferPool(const DecoderSettings& settings) {
-  if (settings.get != nullptr && settings.release != nullptr) {
-    get_frame_buffer_ = settings.get;
-    release_frame_buffer_ = settings.release;
-    callback_private_data_ = settings.callback_private_data;
+BufferPool::BufferPool(
+    FrameBufferSizeChangedCallback on_frame_buffer_size_changed,
+    GetFrameBufferCallback get_frame_buffer,
+    ReleaseFrameBufferCallback release_frame_buffer,
+    void* callback_private_data) {
+  if (get_frame_buffer != nullptr) {
+    // on_frame_buffer_size_changed may be null.
+    assert(release_frame_buffer != nullptr);
+    on_frame_buffer_size_changed_ = on_frame_buffer_size_changed;
+    get_frame_buffer_ = get_frame_buffer;
+    release_frame_buffer_ = release_frame_buffer;
+    callback_private_data_ = callback_private_data;
   } else {
-    internal_frame_buffers_ = InternalFrameBufferList::Create(kNumBuffers);
-    // GetInternalFrameBuffer checks whether its private_data argument is null,
-    // so we don't need to check whether internal_frame_buffers_ is null here.
+    on_frame_buffer_size_changed_ = OnInternalFrameBufferSizeChanged;
     get_frame_buffer_ = GetInternalFrameBuffer;
     release_frame_buffer_ = ReleaseInternalFrameBuffer;
-    callback_private_data_ = internal_frame_buffers_.get();
-  }
-  for (RefCountedBuffer& buffer : buffers_) {
-    buffer.SetBufferPool(this);
+    callback_private_data_ = &internal_frame_buffers_;
   }
 }
 
 BufferPool::~BufferPool() {
-  for (const RefCountedBuffer& buffer : buffers_) {
-    if (buffer.in_use_) {
-      assert(0 && "RefCountedBuffer still in use at destruction time.");
+  for (const auto* buffer : buffers_) {
+    if (buffer->in_use_) {
+      assert(false && "RefCountedBuffer still in use at destruction time.");
       LIBGAV1_DLOG(ERROR, "RefCountedBuffer still in use at destruction time.");
     }
+    delete buffer;
   }
 }
 
+bool BufferPool::OnFrameBufferSizeChanged(int bitdepth,
+                                          Libgav1ImageFormat image_format,
+                                          int width, int height,
+                                          int left_border, int right_border,
+                                          int top_border, int bottom_border) {
+  if (on_frame_buffer_size_changed_ == nullptr) return true;
+  return on_frame_buffer_size_changed_(callback_private_data_, bitdepth,
+                                       image_format, width, height, left_border,
+                                       right_border, top_border, bottom_border,
+                                       /*stride_alignment=*/16) == kStatusOk;
+}
+
 RefCountedBufferPtr BufferPool::GetFreeBuffer() {
-  for (RefCountedBuffer& buffer : buffers_) {
-    if (!buffer.in_use_) {
-      buffer.in_use_ = true;
-      return RefCountedBufferPtr(&buffer, RefCountedBuffer::ReturnToBufferPool);
+  // In frame parallel mode, the GetFreeBuffer() calls from ObuParser all happen
+  // from the same thread serially, but the GetFreeBuffer() call in
+  // DecoderImpl::ApplyFilmGrain can happen from multiple threads at the same
+  // time. So this function has to be thread safe.
+  // TODO(b/142583029): Investigate if the GetFreeBuffer() call in
+  // DecoderImpl::ApplyFilmGrain() call can be serialized so that this function
+  // need not be thread safe.
+  std::unique_lock<std::mutex> lock(mutex_);
+  for (auto buffer : buffers_) {
+    if (!buffer->in_use_) {
+      buffer->in_use_ = true;
+      buffer->progress_row_ = -1;
+      buffer->frame_state_ = kFrameStateUnknown;
+      lock.unlock();
+      return RefCountedBufferPtr(buffer, RefCountedBuffer::ReturnToBufferPool);
     }
   }
+  lock.unlock();
+  auto* const buffer = new (std::nothrow) RefCountedBuffer();
+  if (buffer == nullptr) {
+    LIBGAV1_DLOG(ERROR, "Failed to allocate a new reference counted buffer.");
+    return RefCountedBufferPtr();
+  }
+  buffer->SetBufferPool(this);
+  buffer->in_use_ = true;
+  buffer->progress_row_ = -1;
+  buffer->frame_state_ = kFrameStateUnknown;
+  lock.lock();
+  const bool ok = buffers_.push_back(buffer);
+  lock.unlock();
+  if (!ok) {
+    LIBGAV1_DLOG(
+        ERROR,
+        "Failed to push the new reference counted buffer into the vector.");
+    delete buffer;
+    return RefCountedBufferPtr();
+  }
+  return RefCountedBufferPtr(buffer, RefCountedBuffer::ReturnToBufferPool);
+}
 
-  // We should never run out of free buffers. If we reach here, there is a
-  // reference leak.
-  return RefCountedBufferPtr();
+void BufferPool::Abort() {
+  std::unique_lock<std::mutex> lock(mutex_);
+  for (auto buffer : buffers_) {
+    if (buffer->in_use_) {
+      buffer->Abort();
+    }
+  }
 }
 
 void BufferPool::ReturnUnusedBuffer(RefCountedBuffer* buffer) {
+  std::lock_guard<std::mutex> lock(mutex_);
   assert(buffer->in_use_);
   buffer->in_use_ = false;
-  if (buffer->raw_frame_buffer_.data[0] != nullptr) {
-    release_frame_buffer_(callback_private_data_, &buffer->raw_frame_buffer_);
-    memset(&buffer->raw_frame_buffer_, 0, sizeof(buffer->raw_frame_buffer_));
+  if (buffer->buffer_private_data_valid_) {
+    release_frame_buffer_(callback_private_data_, buffer->buffer_private_data_);
+    buffer->buffer_private_data_valid_ = false;
   }
 }
 
diff --git a/libgav1/src/buffer_pool.h b/libgav1/src/buffer_pool.h
index 4a34e23..f35a633 100644
--- a/libgav1/src/buffer_pool.h
+++ b/libgav1/src/buffer_pool.h
@@ -18,27 +18,38 @@
 #define LIBGAV1_SRC_BUFFER_POOL_H_
 
 #include <array>
+#include <cassert>
+#include <climits>
+#include <condition_variable>  // NOLINT (unapproved c++11 header)
 #include <cstdint>
-#include <memory>
+#include <cstring>
+#include <mutex>  // NOLINT (unapproved c++11 header)
 
-#include "src/decoder_buffer.h"
-#include "src/decoder_settings.h"
 #include "src/dsp/common.h"
-#include "src/frame_buffer.h"
+#include "src/gav1/decoder_buffer.h"
+#include "src/gav1/frame_buffer.h"
 #include "src/internal_frame_buffer_list.h"
-#include "src/obu_parser.h"
 #include "src/symbol_decoder_context.h"
-#include "src/utils/array_2d.h"
+#include "src/utils/compiler_attributes.h"
 #include "src/utils/constants.h"
+#include "src/utils/reference_info.h"
 #include "src/utils/segmentation.h"
 #include "src/utils/segmentation_map.h"
 #include "src/utils/types.h"
+#include "src/utils/vector.h"
 #include "src/yuv_buffer.h"
 
 namespace libgav1 {
 
 class BufferPool;
 
+enum FrameState : uint8_t {
+  kFrameStateUnknown,
+  kFrameStateStarted,
+  kFrameStateParsed,
+  kFrameStateDecoded
+};
+
 // A reference-counted frame buffer. Clients should access it via
 // RefCountedBufferPtr, which manages reference counting transparently.
 class RefCountedBuffer {
@@ -48,34 +59,39 @@
   RefCountedBuffer& operator=(const RefCountedBuffer&) = delete;
 
   // Allocates the YUV buffer. Returns true on success. Returns false on
-  // failure.
+  // failure. This function ensures the thread safety of the |get_frame_buffer_|
+  // call (i.e.) only one |get_frame_buffer_| call will happen at a given time.
+  // TODO(b/142583029): In frame parallel mode, we can require the callbacks to
+  // be thread safe so that we can remove the thread safety of this function and
+  // applications can have fine grained locks.
   //
   // * |width| and |height| are the image dimensions in pixels.
   // * |subsampling_x| and |subsampling_y| (either 0 or 1) specify the
   //   subsampling of the width and height of the chroma planes, respectively.
-  // * |border| is the size of the borders (on all four sides) in pixels.
-  // * |byte_alignment| specifies the additional alignment requirement of the
-  //   data buffers of the Y, U, and V planes. If |byte_alignment| is 0, there
-  //   is no additional alignment requirement. Otherwise, |byte_alignment|
-  //   must be a power of 2 and greater than or equal to 16.
-  //   NOTE: The strides are a multiple of 16. Therefore only the first row in
-  //   each plane is aligned to |byte_alignment|. Subsequent rows are only
-  //   16-byte aligned.
+  // * |left_border|, |right_border|, |top_border|, and |bottom_border| are
+  //   the sizes (in pixels) of the borders on the left, right, top, and
+  //   bottom sides, respectively.
+  //
+  // NOTE: The strides are a multiple of 16. Since the first row in each plane
+  // is 16-byte aligned, subsequent rows are also 16-byte aligned.
   bool Realloc(int bitdepth, bool is_monochrome, int width, int height,
-               int subsampling_x, int subsampling_y, int border,
-               int byte_alignment);
+               int subsampling_x, int subsampling_y, int left_border,
+               int right_border, int top_border, int bottom_border);
 
   YuvBuffer* buffer() { return &yuv_buffer_; }
 
   // Returns the buffer private data set by the get frame buffer callback when
   // it allocated the YUV buffer.
-  void* buffer_private_data() const { return raw_frame_buffer_.private_data; }
+  void* buffer_private_data() const {
+    assert(buffer_private_data_valid_);
+    return buffer_private_data_;
+  }
 
   // NOTE: In the current frame, this is the frame_type syntax element in the
   // frame header. In a reference frame, this implements the RefFrameType array
   // in the spec.
   FrameType frame_type() const { return frame_type_; }
-  void set_frame_type(enum FrameType frame_type) { frame_type_ = frame_type; }
+  void set_frame_type(FrameType frame_type) { frame_type_ = frame_type; }
 
   // The sample position for subsampled streams. This is the
   // chroma_sample_position syntax element in the sequence header.
@@ -85,8 +101,7 @@
   ChromaSamplePosition chroma_sample_position() const {
     return chroma_sample_position_;
   }
-  void set_chroma_sample_position(
-      enum ChromaSamplePosition chroma_sample_position) {
+  void set_chroma_sample_position(ChromaSamplePosition chroma_sample_position) {
     chroma_sample_position_ = chroma_sample_position;
   }
 
@@ -94,19 +109,11 @@
   bool showable_frame() const { return showable_frame_; }
   void set_showable_frame(bool value) { showable_frame_ = value; }
 
-  uint8_t order_hint(ReferenceFrameType reference_frame) const {
-    return order_hint_[reference_frame];
-  }
-  void set_order_hint(ReferenceFrameType reference_frame, uint8_t order_hint) {
-    order_hint_[reference_frame] = order_hint;
-  }
-  void ClearOrderHints() { order_hint_.fill(0); }
-
   // Sets upscaled_width_, frame_width_, frame_height_, render_width_,
   // render_height_, rows4x4_ and columns4x4_ from the corresponding fields
-  // in frame_header. Allocates motion_field_reference_frame_,
-  // motion_field_mv_, and segmentation_map_. Returns true on success, false
-  // on failure.
+  // in frame_header. Allocates reference_info_.motion_field_reference_frame,
+  // reference_info_.motion_field_mv_, and segmentation_map_. Returns true on
+  // success, false on failure.
   bool SetFrameDimensions(const ObuFrameHeader& frame_header);
 
   int32_t upscaled_width() const { return upscaled_width_; }
@@ -119,17 +126,10 @@
   int32_t rows4x4() const { return rows4x4_; }
   int32_t columns4x4() const { return columns4x4_; }
 
-  // Entry at |row|, |column| corresponds to
-  // MfRefFrames[row * 2 + 1][column * 2 + 1] in the spec.
-  ReferenceFrameType* motion_field_reference_frame(int row, int column) {
-    return &motion_field_reference_frame_[row][column];
-  }
-
-  // Entry at |row|, |column| corresponds to
-  // MfMvs[row * 2 + 1][column * 2 + 1] in the spec.
-  MotionVector* motion_field_mv(int row, int column) {
-    return &motion_field_mv_[row][column];
-  }
+  int spatial_id() const { return spatial_id_; }
+  void set_spatial_id(int value) { spatial_id_ = value; }
+  int temporal_id() const { return temporal_id_; }
+  void set_temporal_id(int value) { temporal_id_ = value; }
 
   SegmentationMap* segmentation_map() { return &segmentation_map_; }
   const SegmentationMap* segmentation_map() const { return &segmentation_map_; }
@@ -180,6 +180,99 @@
     film_grain_params_ = params;
   }
 
+  const ReferenceInfo* reference_info() const { return &reference_info_; }
+  ReferenceInfo* reference_info() { return &reference_info_; }
+
+  // This will wake up the WaitUntil*() functions and make them return false.
+  void Abort() {
+    {
+      std::lock_guard<std::mutex> lock(mutex_);
+      abort_ = true;
+    }
+    parsed_condvar_.notify_all();
+    decoded_condvar_.notify_all();
+    progress_row_condvar_.notify_all();
+  }
+
+  void SetFrameState(FrameState frame_state) {
+    {
+      std::lock_guard<std::mutex> lock(mutex_);
+      frame_state_ = frame_state;
+    }
+    if (frame_state == kFrameStateParsed) {
+      parsed_condvar_.notify_all();
+    } else if (frame_state == kFrameStateDecoded) {
+      decoded_condvar_.notify_all();
+      progress_row_condvar_.notify_all();
+    }
+  }
+
+  // Sets the progress of this frame to |progress_row| and notifies any threads
+  // that may be waiting on rows <= |progress_row|.
+  void SetProgress(int progress_row) {
+    {
+      std::lock_guard<std::mutex> lock(mutex_);
+      if (progress_row_ >= progress_row) return;
+      progress_row_ = progress_row;
+    }
+    progress_row_condvar_.notify_all();
+  }
+
+  void MarkFrameAsStarted() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    if (frame_state_ != kFrameStateUnknown) return;
+    frame_state_ = kFrameStateStarted;
+  }
+
+  // All the WaitUntil* functions will return true if the desired wait state was
+  // reached successfully. If the return value is false, then the caller must
+  // assume that the wait was not successful and try to stop whatever they are
+  // doing as early as possible.
+
+  // Waits until the frame has been parsed.
+  bool WaitUntilParsed() {
+    std::unique_lock<std::mutex> lock(mutex_);
+    while (frame_state_ < kFrameStateParsed && !abort_) {
+      parsed_condvar_.wait(lock);
+    }
+    return !abort_;
+  }
+
+  // Waits until the |progress_row| has been decoded (as indicated either by
+  // |progress_row_| or |frame_state_|). |progress_row_cache| must not be
+  // nullptr and will be populated with the value of |progress_row_| after the
+  // wait.
+  //
+  // Typical usage of |progress_row_cache| is as follows:
+  //  * Initialize |*progress_row_cache| to INT_MIN.
+  //  * Call WaitUntil only if |*progress_row_cache| < |progress_row|.
+  bool WaitUntil(int progress_row, int* progress_row_cache) {
+    // If |progress_row| is negative, it means that the wait is on the top
+    // border to be available. The top border will be available when row 0 has
+    // been decoded. So we can simply wait on row 0 instead.
+    progress_row = std::max(progress_row, 0);
+    std::unique_lock<std::mutex> lock(mutex_);
+    while (progress_row_ < progress_row && frame_state_ != kFrameStateDecoded &&
+           !abort_) {
+      progress_row_condvar_.wait(lock);
+    }
+    // Once |frame_state_| reaches kFrameStateDecoded, |progress_row_| may no
+    // longer be updated. So we set |*progress_row_cache| to INT_MAX in that
+    // case.
+    *progress_row_cache =
+        (frame_state_ != kFrameStateDecoded) ? progress_row_ : INT_MAX;
+    return !abort_;
+  }
+
+  // Waits until the entire frame has been decoded.
+  bool WaitUntilDecoded() {
+    std::unique_lock<std::mutex> lock(mutex_);
+    while (frame_state_ != kFrameStateDecoded && !abort_) {
+      decoded_condvar_.wait(lock);
+    }
+    return !abort_;
+  }
+
  private:
   friend class BufferPool;
 
@@ -190,17 +283,26 @@
   static void ReturnToBufferPool(RefCountedBuffer* ptr);
 
   BufferPool* pool_ = nullptr;
-  FrameBuffer raw_frame_buffer_;
+  bool buffer_private_data_valid_ = false;
+  void* buffer_private_data_ = nullptr;
   YuvBuffer yuv_buffer_;
   bool in_use_ = false;  // Only used by BufferPool.
 
-  enum FrameType frame_type_ = kFrameKey;
-  enum ChromaSamplePosition chroma_sample_position_ =
-      kChromaSamplePositionUnknown;
-  bool showable_frame_ = false;
+  std::mutex mutex_;
+  FrameState frame_state_ = kFrameStateUnknown LIBGAV1_GUARDED_BY(mutex_);
+  int progress_row_ = -1 LIBGAV1_GUARDED_BY(mutex_);
+  // Signaled when progress_row_ is updated or when frame_state_ is set to
+  // kFrameStateDecoded.
+  std::condition_variable progress_row_condvar_;
+  // Signaled when the frame state is set to kFrameStateParsed.
+  std::condition_variable parsed_condvar_;
+  // Signaled when the frame state is set to kFrameStateDecoded.
+  std::condition_variable decoded_condvar_;
+  bool abort_ = false LIBGAV1_GUARDED_BY(mutex_);
 
-  // Note: order_hint_[0] (for kReferenceFrameIntra) is not used.
-  std::array<uint8_t, kNumReferenceFrameTypes> order_hint_ = {};
+  FrameType frame_type_ = kFrameKey;
+  ChromaSamplePosition chroma_sample_position_ = kChromaSamplePositionUnknown;
+  bool showable_frame_ = false;
 
   int32_t upscaled_width_ = 0;
   int32_t frame_width_ = 0;
@@ -209,13 +311,9 @@
   int32_t render_height_ = 0;
   int32_t columns4x4_ = 0;
   int32_t rows4x4_ = 0;
+  int spatial_id_ = 0;
+  int temporal_id_ = 0;
 
-  // Array of size (rows4x4 / 2) x (columns4x4 / 2). Entry at i, j corresponds
-  // to MfRefFrames[i * 2 + 1][j * 2 + 1] in the spec.
-  Array2D<ReferenceFrameType> motion_field_reference_frame_;
-  // Array of size (rows4x4 / 2) x (columns4x4 / 2). Entry at i, j corresponds
-  // to MfMvs[i * 2 + 1][j * 2 + 1] in the spec.
-  Array2D<MotionVector> motion_field_mv_;
   // segmentation_map_ contains a rows4x4_ by columns4x4_ 2D array.
   SegmentationMap segmentation_map_;
 
@@ -233,6 +331,7 @@
   // on feature_enabled only, we also save their values as an optimization.
   Segmentation segmentation_ = {};
   FilmGrainParams film_grain_params_ = {};
+  ReferenceInfo reference_info_;
 };
 
 // RefCountedBufferPtr contains a reference to a RefCountedBuffer.
@@ -247,7 +346,10 @@
 // BufferPool maintains a pool of RefCountedBuffers.
 class BufferPool {
  public:
-  explicit BufferPool(const DecoderSettings& settings);
+  BufferPool(FrameBufferSizeChangedCallback on_frame_buffer_size_changed,
+             GetFrameBufferCallback get_frame_buffer,
+             ReleaseFrameBufferCallback release_frame_buffer,
+             void* callback_private_data);
 
   // Not copyable or movable.
   BufferPool(const BufferPool&) = delete;
@@ -255,26 +357,37 @@
 
   ~BufferPool();
 
-  // Finds a free buffer in the buffer pool and returns a reference to the
-  // free buffer. If there is no free buffer, returns a null pointer.
+  LIBGAV1_MUST_USE_RESULT bool OnFrameBufferSizeChanged(
+      int bitdepth, Libgav1ImageFormat image_format, int width, int height,
+      int left_border, int right_border, int top_border, int bottom_border);
+
+  // Finds a free buffer in the buffer pool and returns a reference to the free
+  // buffer. If there is no free buffer, returns a null pointer. This function
+  // is thread safe.
   RefCountedBufferPtr GetFreeBuffer();
 
+  // Aborts all the buffers that are in use.
+  void Abort();
+
  private:
   friend class RefCountedBuffer;
 
-  // Reference frames + 1 scratch frame (for either the current frame or the
-  // film grain frame).
-  static constexpr int kNumBuffers = kNumReferenceFrameTypes + 1;
-
   // Returns an unused buffer to the buffer pool. Called by RefCountedBuffer
-  // only.
+  // only. This function is thread safe.
   void ReturnUnusedBuffer(RefCountedBuffer* buffer);
 
-  RefCountedBuffer buffers_[kNumBuffers];
+  // Used to make the following functions thread safe: GetFreeBuffer(),
+  // ReturnUnusedBuffer(), RefCountedBuffer::Realloc().
+  std::mutex mutex_;
 
-  std::unique_ptr<InternalFrameBufferList> internal_frame_buffers_;
+  // Storing a RefCountedBuffer object in a Vector is complicated because of the
+  // copy/move semantics. So the simplest way around that is to store a list of
+  // pointers in the vector.
+  Vector<RefCountedBuffer*> buffers_ LIBGAV1_GUARDED_BY(mutex_);
+  InternalFrameBufferList internal_frame_buffers_;
 
   // Frame buffer callbacks.
+  FrameBufferSizeChangedCallback on_frame_buffer_size_changed_;
   GetFrameBufferCallback get_frame_buffer_;
   ReleaseFrameBufferCallback release_frame_buffer_;
   // Private data associated with the frame buffer callbacks.
diff --git a/libgav1/src/decoder.cc b/libgav1/src/decoder.cc
index 9a38dd1..b9e43e0 100644
--- a/libgav1/src/decoder.cc
+++ b/libgav1/src/decoder.cc
@@ -12,10 +12,73 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/decoder.h"
+#include "src/gav1/decoder.h"
+
+#include <memory>
+#include <new>
 
 #include "src/decoder_impl.h"
 
+extern "C" {
+
+Libgav1StatusCode Libgav1DecoderCreate(const Libgav1DecoderSettings* settings,
+                                       Libgav1Decoder** decoder_out) {
+  std::unique_ptr<libgav1::Decoder> cxx_decoder(new (std::nothrow)
+                                                    libgav1::Decoder());
+  if (cxx_decoder == nullptr) return kLibgav1StatusOutOfMemory;
+
+  libgav1::DecoderSettings cxx_settings;
+  cxx_settings.threads = settings->threads;
+  cxx_settings.frame_parallel = settings->frame_parallel != 0;
+  cxx_settings.blocking_dequeue = settings->blocking_dequeue != 0;
+  cxx_settings.on_frame_buffer_size_changed =
+      settings->on_frame_buffer_size_changed;
+  cxx_settings.get_frame_buffer = settings->get_frame_buffer;
+  cxx_settings.release_frame_buffer = settings->release_frame_buffer;
+  cxx_settings.release_input_buffer = settings->release_input_buffer;
+  cxx_settings.callback_private_data = settings->callback_private_data;
+  cxx_settings.output_all_layers = settings->output_all_layers != 0;
+  cxx_settings.operating_point = settings->operating_point;
+  cxx_settings.post_filter_mask = settings->post_filter_mask;
+
+  const Libgav1StatusCode status = cxx_decoder->Init(&cxx_settings);
+  if (status == kLibgav1StatusOk) {
+    *decoder_out = reinterpret_cast<Libgav1Decoder*>(cxx_decoder.release());
+  }
+  return status;
+}
+
+void Libgav1DecoderDestroy(Libgav1Decoder* decoder) {
+  auto* cxx_decoder = reinterpret_cast<libgav1::Decoder*>(decoder);
+  delete cxx_decoder;
+}
+
+Libgav1StatusCode Libgav1DecoderEnqueueFrame(Libgav1Decoder* decoder,
+                                             const uint8_t* data, size_t size,
+                                             int64_t user_private_data,
+                                             void* buffer_private_data) {
+  auto* cxx_decoder = reinterpret_cast<libgav1::Decoder*>(decoder);
+  return cxx_decoder->EnqueueFrame(data, size, user_private_data,
+                                   buffer_private_data);
+}
+
+Libgav1StatusCode Libgav1DecoderDequeueFrame(
+    Libgav1Decoder* decoder, const Libgav1DecoderBuffer** out_ptr) {
+  auto* cxx_decoder = reinterpret_cast<libgav1::Decoder*>(decoder);
+  return cxx_decoder->DequeueFrame(out_ptr);
+}
+
+Libgav1StatusCode Libgav1DecoderSignalEOS(Libgav1Decoder* decoder) {
+  auto* cxx_decoder = reinterpret_cast<libgav1::Decoder*>(decoder);
+  return cxx_decoder->SignalEOS();
+}
+
+int Libgav1DecoderGetMaxBitdepth() {
+  return libgav1::Decoder::GetMaxBitdepth();
+}
+
+}  // extern "C"
+
 namespace libgav1 {
 
 Decoder::Decoder() = default;
@@ -23,27 +86,31 @@
 Decoder::~Decoder() = default;
 
 StatusCode Decoder::Init(const DecoderSettings* const settings) {
-  if (initialized_) return kLibgav1StatusAlready;
+  if (impl_ != nullptr) return kStatusAlready;
   if (settings != nullptr) settings_ = *settings;
-  const StatusCode status = DecoderImpl::Create(&settings_, &impl_);
-  if (status != kLibgav1StatusOk) return status;
-  initialized_ = true;
-  return kLibgav1StatusOk;
+  return DecoderImpl::Create(&settings_, &impl_);
 }
 
 StatusCode Decoder::EnqueueFrame(const uint8_t* data, const size_t size,
-                                 int64_t user_private_data) {
-  if (!initialized_) return kLibgav1StatusNotInitialized;
-  return impl_->EnqueueFrame(data, size, user_private_data);
+                                 int64_t user_private_data,
+                                 void* buffer_private_data) {
+  if (impl_ == nullptr) return kStatusNotInitialized;
+  return impl_->EnqueueFrame(data, size, user_private_data,
+                             buffer_private_data);
 }
 
 StatusCode Decoder::DequeueFrame(const DecoderBuffer** out_ptr) {
-  if (!initialized_) return kLibgav1StatusNotInitialized;
+  if (impl_ == nullptr) return kStatusNotInitialized;
   return impl_->DequeueFrame(out_ptr);
 }
 
-int Decoder::GetMaxAllowedFrames() const {
-  return settings_.frame_parallel ? settings_.threads : 1;
+StatusCode Decoder::SignalEOS() {
+  if (impl_ == nullptr) return kStatusNotInitialized;
+  // In non-frame-parallel mode, we have to release all the references. This
+  // simply means replacing the |impl_| with a new instance so that all the
+  // existing references are released and the state is cleared.
+  impl_ = nullptr;
+  return DecoderImpl::Create(&settings_, &impl_);
 }
 
 // static.
diff --git a/libgav1/src/decoder.h b/libgav1/src/decoder.h
deleted file mode 100644
index 1e3ac1a..0000000
--- a/libgav1/src/decoder.h
+++ /dev/null
@@ -1,86 +0,0 @@
-/*
- * Copyright 2019 The libgav1 Authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBGAV1_SRC_DECODER_H_
-#define LIBGAV1_SRC_DECODER_H_
-
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-
-#include "src/decoder_buffer.h"
-#include "src/decoder_settings.h"
-#include "src/status_code.h"
-#include "src/symbol_visibility.h"
-
-namespace libgav1 {
-
-// Forward declaration.
-class DecoderImpl;
-
-class LIBGAV1_PUBLIC Decoder {
- public:
-  Decoder();
-  ~Decoder();
-
-  // Init must be called exactly once per instance. Subsequent calls will do
-  // nothing. If |settings| is nullptr, the decoder will be initialized with
-  // default settings. Returns kLibgav1StatusOk on success, an error status
-  // otherwise.
-  StatusCode Init(const DecoderSettings* settings);
-
-  // Enqueues a compressed frame to be decoded. Applications can continue
-  // enqueue'ing up to |GetMaxAllowedFrames()|. The decoder can be thought of as
-  // a queue of size |GetMaxAllowedFrames()|. Returns kLibgav1StatusOk on
-  // success and an error status otherwise. Returning an error status here isn't
-  // a fatal error and the decoder can continue decoding further frames. To
-  // signal EOF, call this function with |data| as nullptr and |size| as 0. That
-  // will release all the frames held by the decoder.
-  //
-  // |user_private_data| may be used to asssociate application specific private
-  // data with the compressed frame. It will be copied to the user_private_data
-  // field of the DecoderBuffer returned by the corresponding |DequeueFrame()|
-  // call.
-  //
-  // NOTE: |EnqueueFrame()| does not copy the data. Therefore, after a
-  // successful |EnqueueFrame()| call, the caller must keep the |data| buffer
-  // alive until the corresponding |DequeueFrame()| call returns.
-  StatusCode EnqueueFrame(const uint8_t* data, size_t size,
-                          int64_t user_private_data);
-
-  // Dequeues a decompressed frame. If there are enqueued compressed frames,
-  // decodes one and sets |*out_ptr| to the last displayable frame in the
-  // compressed frame. If there are no displayable frames available, sets
-  // |*out_ptr| to nullptr. Returns an error status if there is an error.
-  StatusCode DequeueFrame(const DecoderBuffer** out_ptr);
-
-  // Returns the maximum number of frames allowed to be enqueued at a time. The
-  // decoder will reject frames beyond this count. If |settings_.frame_parallel|
-  // is false, then this function will always return 1.
-  int GetMaxAllowedFrames() const;
-
-  // Returns the maximum bitdepth that is supported by this decoder.
-  static int GetMaxBitdepth();
-
- private:
-  bool initialized_ = false;
-  DecoderSettings settings_;
-  std::unique_ptr<DecoderImpl> impl_;
-};
-
-}  // namespace libgav1
-
-#endif  // LIBGAV1_SRC_DECODER_H_
diff --git a/libgav1/src/decoder_buffer.h b/libgav1/src/decoder_buffer.h
deleted file mode 100644
index ecd133d..0000000
--- a/libgav1/src/decoder_buffer.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Copyright 2019 The libgav1 Authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBGAV1_SRC_DECODER_BUFFER_H_
-#define LIBGAV1_SRC_DECODER_BUFFER_H_
-
-#include <cstdint>
-
-#include "src/frame_buffer.h"
-#include "src/symbol_visibility.h"
-
-// All the declarations in this file are part of the public ABI.
-
-namespace libgav1 {
-
-enum ChromaSamplePosition : uint8_t {
-  kChromaSamplePositionUnknown,
-  kChromaSamplePositionVertical,
-  kChromaSamplePositionColocated,
-  kChromaSamplePositionReserved
-};
-
-enum ImageFormat : uint8_t {
-  kImageFormatYuv420,
-  kImageFormatYuv422,
-  kImageFormatYuv444,
-  kImageFormatMonochrome400
-};
-
-struct LIBGAV1_PUBLIC DecoderBuffer {
-  int NumPlanes() const {
-    return (image_format == kImageFormatMonochrome400) ? 1 : 3;
-  }
-
-  ChromaSamplePosition chroma_sample_position;
-  ImageFormat image_format;
-
-  // TODO(wtc): Add the following members:
-  // - color range
-  //   * studio range: Y [16..235], UV [16..240]
-  //   * full range: (YUV/RGB [0..255]
-  // - CICP Color Primaries (cp)
-  // - CICP Transfer Characteristics (tc)
-  // - CICP Matrix Coefficients (mc)
-
-  // Image storage dimensions.
-  // NOTE: These fields are named w and h in vpx_image_t and aom_image_t.
-  // uint32_t width;  // Stored image width.
-  // uint32_t height;  // Stored image height.
-  int bitdepth;  // Stored image bitdepth.
-
-  // Image display dimensions.
-  // NOTES:
-  // 1. These fields are named d_w and d_h in vpx_image_t and aom_image_t.
-  // 2. libvpx and libaom clients use d_w and d_h much more often than w and h.
-  // 3. These fields can just be stored for the Y plane and the clients can
-  //    calculate the values for the U and V planes if the image format or
-  //    subsampling is exposed.
-  int displayed_width[3];   // Displayed image width.
-  int displayed_height[3];  // Displayed image height.
-
-  int stride[3];
-  uint8_t* plane[3];
-
-  // The |user_private_data| argument passed to Decoder::EnqueueFrame().
-  int64_t user_private_data;
-  // The |private_data| field of FrameBuffer. Set by the get frame buffer
-  // callback when it allocates a frame buffer.
-  void* buffer_private_data;
-};
-
-}  // namespace libgav1
-
-#endif  // LIBGAV1_SRC_DECODER_BUFFER_H_
diff --git a/libgav1/src/decoder_impl.cc b/libgav1/src/decoder_impl.cc
index 5c61993..e40c692 100644
--- a/libgav1/src/decoder_impl.cc
+++ b/libgav1/src/decoder_impl.cc
@@ -24,13 +24,18 @@
 #include "src/dsp/common.h"
 #include "src/dsp/constants.h"
 #include "src/dsp/dsp.h"
-#include "src/loop_filter_mask.h"
+#include "src/film_grain.h"
+#include "src/frame_buffer_utils.h"
+#include "src/frame_scratch_buffer.h"
 #include "src/loop_restoration_info.h"
+#include "src/obu_parser.h"
 #include "src/post_filter.h"
 #include "src/prediction_mask.h"
 #include "src/quantizer.h"
+#include "src/threading_strategy.h"
 #include "src/utils/blocking_counter.h"
 #include "src/utils/common.h"
+#include "src/utils/constants.h"
 #include "src/utils/logging.h"
 #include "src/utils/parameter_tree.h"
 #include "src/utils/raw_bit_reader.h"
@@ -44,275 +49,1066 @@
 constexpr int kMaxBlockWidth4x4 = 32;
 constexpr int kMaxBlockHeight4x4 = 32;
 
-// A cleanup helper class that releases the frame buffer reference held in
-// |frame| in the destructor.
-class RefCountedBufferPtrCleanup {
+// Computes the bottom border size in pixels. If CDEF, loop restoration or
+// SuperRes is enabled, adds extra border pixels to facilitate those steps to
+// happen nearly in-place (a few extra rows instead of an entire frame buffer).
+// The logic in this function should match the corresponding logic for
+// |vertical_shift| in the PostFilter constructor.
+int GetBottomBorderPixels(const bool do_cdef, const bool do_restoration,
+                          const bool do_superres, const int subsampling_y) {
+  int extra_border = 0;
+  if (do_cdef) {
+    extra_border += kCdefBorder;
+  } else if (do_restoration) {
+    // If CDEF is enabled, loop restoration is safe without extra border.
+    extra_border += kRestorationVerticalBorder;
+  }
+  if (do_superres) extra_border += kSuperResVerticalBorder;
+  // Double the number of extra bottom border pixels if the bottom border will
+  // be subsampled.
+  extra_border <<= subsampling_y;
+  return Align(kBorderPixels + extra_border, 2);  // Must be a multiple of 2.
+}
+
+// Sets |frame_scratch_buffer->tile_decoding_failed| to true (while holding on
+// to |frame_scratch_buffer->superblock_row_mutex|) and notifies the first
+// |count| condition variables in
+// |frame_scratch_buffer->superblock_row_progress_condvar|.
+void SetFailureAndNotifyAll(FrameScratchBuffer* const frame_scratch_buffer,
+                            int count) {
+  {
+    std::lock_guard<std::mutex> lock(
+        frame_scratch_buffer->superblock_row_mutex);
+    frame_scratch_buffer->tile_decoding_failed = true;
+  }
+  std::condition_variable* const condvars =
+      frame_scratch_buffer->superblock_row_progress_condvar.get();
+  for (int i = 0; i < count; ++i) {
+    condvars[i].notify_one();
+  }
+}
+
+// Helper class that releases the frame scratch buffer in the destructor.
+class FrameScratchBufferReleaser {
  public:
-  explicit RefCountedBufferPtrCleanup(RefCountedBufferPtr* frame)
-      : frame_(*frame) {}
-
-  // Not copyable or movable.
-  RefCountedBufferPtrCleanup(const RefCountedBufferPtrCleanup&) = delete;
-  RefCountedBufferPtrCleanup& operator=(const RefCountedBufferPtrCleanup&) =
-      delete;
-
-  ~RefCountedBufferPtrCleanup() { frame_ = nullptr; }
+  FrameScratchBufferReleaser(
+      FrameScratchBufferPool* frame_scratch_buffer_pool,
+      std::unique_ptr<FrameScratchBuffer>* frame_scratch_buffer)
+      : frame_scratch_buffer_pool_(frame_scratch_buffer_pool),
+        frame_scratch_buffer_(frame_scratch_buffer) {}
+  ~FrameScratchBufferReleaser() {
+    frame_scratch_buffer_pool_->Release(std::move(*frame_scratch_buffer_));
+  }
 
  private:
-  RefCountedBufferPtr& frame_;
+  FrameScratchBufferPool* const frame_scratch_buffer_pool_;
+  std::unique_ptr<FrameScratchBuffer>* const frame_scratch_buffer_;
 };
 
-}  // namespace
-
-void DecoderState::UpdateReferenceFrames(int refresh_frame_flags) {
-  for (int ref_index = 0, mask = refresh_frame_flags; mask != 0;
-       ++ref_index, mask >>= 1) {
-    if ((mask & 1) != 0) {
-      reference_valid[ref_index] = true;
-      reference_frame_id[ref_index] = current_frame_id;
-      reference_frame[ref_index] = current_frame;
-      reference_order_hint[ref_index] = order_hint;
+// Sets the |frame|'s segmentation map for two cases. The third case is handled
+// in Tile::DecodeBlock().
+void SetSegmentationMap(const ObuFrameHeader& frame_header,
+                        const SegmentationMap* prev_segment_ids,
+                        RefCountedBuffer* const frame) {
+  if (!frame_header.segmentation.enabled) {
+    // All segment_id's are 0.
+    frame->segmentation_map()->Clear();
+  } else if (!frame_header.segmentation.update_map) {
+    // Copy from prev_segment_ids.
+    if (prev_segment_ids == nullptr) {
+      // Treat a null prev_segment_ids pointer as if it pointed to a
+      // segmentation map containing all 0s.
+      frame->segmentation_map()->Clear();
+    } else {
+      frame->segmentation_map()->CopyFrom(*prev_segment_ids);
     }
   }
 }
 
-void DecoderState::ClearReferenceFrames() {
-  reference_valid = {};
-  reference_frame_id = {};
-  reference_order_hint = {};
-  for (int ref_index = 0; ref_index < kNumReferenceFrameTypes; ++ref_index) {
-    reference_frame[ref_index] = nullptr;
+StatusCode DecodeTilesNonFrameParallel(
+    const ObuSequenceHeader& sequence_header,
+    const ObuFrameHeader& frame_header,
+    const Vector<std::unique_ptr<Tile>>& tiles,
+    FrameScratchBuffer* const frame_scratch_buffer,
+    PostFilter* const post_filter) {
+  // Decode in superblock row order.
+  const int block_width4x4 = sequence_header.use_128x128_superblock ? 32 : 16;
+  std::unique_ptr<TileScratchBuffer> tile_scratch_buffer =
+      frame_scratch_buffer->tile_scratch_buffer_pool.Get();
+  if (tile_scratch_buffer == nullptr) return kLibgav1StatusOutOfMemory;
+  for (int row4x4 = 0; row4x4 < frame_header.rows4x4;
+       row4x4 += block_width4x4) {
+    for (const auto& tile_ptr : tiles) {
+      if (!tile_ptr->ProcessSuperBlockRow<kProcessingModeParseAndDecode, true>(
+              row4x4, tile_scratch_buffer.get())) {
+        return kLibgav1StatusUnknownError;
+      }
+    }
+    post_filter->ApplyFilteringForOneSuperBlockRow(
+        row4x4, block_width4x4, row4x4 + block_width4x4 >= frame_header.rows4x4,
+        /*do_deblock=*/true);
+  }
+  frame_scratch_buffer->tile_scratch_buffer_pool.Release(
+      std::move(tile_scratch_buffer));
+  return kStatusOk;
+}
+
+StatusCode DecodeTilesThreadedNonFrameParallel(
+    const Vector<std::unique_ptr<Tile>>& tiles,
+    FrameScratchBuffer* const frame_scratch_buffer,
+    PostFilter* const post_filter,
+    BlockingCounterWithStatus* const pending_tiles) {
+  ThreadingStrategy& threading_strategy =
+      frame_scratch_buffer->threading_strategy;
+  const int num_workers = threading_strategy.tile_thread_count();
+  BlockingCounterWithStatus pending_workers(num_workers);
+  std::atomic<int> tile_counter(0);
+  const int tile_count = static_cast<int>(tiles.size());
+  bool tile_decoding_failed = false;
+  // Submit tile decoding jobs to the thread pool.
+  for (int i = 0; i < num_workers; ++i) {
+    threading_strategy.tile_thread_pool()->Schedule([&tiles, tile_count,
+                                                     &tile_counter,
+                                                     &pending_workers,
+                                                     &pending_tiles]() {
+      bool failed = false;
+      int index;
+      while ((index = tile_counter.fetch_add(1, std::memory_order_relaxed)) <
+             tile_count) {
+        if (!failed) {
+          const auto& tile_ptr = tiles[index];
+          if (!tile_ptr->ParseAndDecode()) {
+            LIBGAV1_DLOG(ERROR, "Error decoding tile #%d", tile_ptr->number());
+            failed = true;
+          }
+        } else {
+          pending_tiles->Decrement(false);
+        }
+      }
+      pending_workers.Decrement(!failed);
+    });
+  }
+  // Have the current thread partake in tile decoding.
+  int index;
+  while ((index = tile_counter.fetch_add(1, std::memory_order_relaxed)) <
+         tile_count) {
+    if (!tile_decoding_failed) {
+      const auto& tile_ptr = tiles[index];
+      if (!tile_ptr->ParseAndDecode()) {
+        LIBGAV1_DLOG(ERROR, "Error decoding tile #%d", tile_ptr->number());
+        tile_decoding_failed = true;
+      }
+    } else {
+      pending_tiles->Decrement(false);
+    }
+  }
+  // Wait until all the workers are done. This ensures that all the tiles have
+  // been parsed.
+  tile_decoding_failed |= !pending_workers.Wait();
+  // Wait until all the tiles have been decoded.
+  tile_decoding_failed |= !pending_tiles->Wait();
+  if (tile_decoding_failed) return kStatusUnknownError;
+  assert(threading_strategy.post_filter_thread_pool() != nullptr);
+  post_filter->ApplyFilteringThreaded();
+  return kStatusOk;
+}
+
+StatusCode DecodeTilesFrameParallel(
+    const ObuSequenceHeader& sequence_header,
+    const ObuFrameHeader& frame_header,
+    const Vector<std::unique_ptr<Tile>>& tiles,
+    const SymbolDecoderContext& saved_symbol_decoder_context,
+    const SegmentationMap* const prev_segment_ids,
+    FrameScratchBuffer* const frame_scratch_buffer,
+    PostFilter* const post_filter, RefCountedBuffer* const current_frame) {
+  // Parse the frame.
+  for (const auto& tile : tiles) {
+    if (!tile->Parse()) {
+      LIBGAV1_DLOG(ERROR, "Failed to parse tile number: %d\n", tile->number());
+      return kStatusUnknownError;
+    }
+  }
+  if (frame_header.enable_frame_end_update_cdf) {
+    frame_scratch_buffer->symbol_decoder_context = saved_symbol_decoder_context;
+  }
+  current_frame->SetFrameContext(frame_scratch_buffer->symbol_decoder_context);
+  SetSegmentationMap(frame_header, prev_segment_ids, current_frame);
+  // Mark frame as parsed.
+  current_frame->SetFrameState(kFrameStateParsed);
+  std::unique_ptr<TileScratchBuffer> tile_scratch_buffer =
+      frame_scratch_buffer->tile_scratch_buffer_pool.Get();
+  if (tile_scratch_buffer == nullptr) {
+    return kStatusOutOfMemory;
+  }
+  const int block_width4x4 = sequence_header.use_128x128_superblock ? 32 : 16;
+  // Decode in superblock row order (inter prediction in the Tile class will
+  // block until the required superblocks in the reference frame are decoded).
+  for (int row4x4 = 0; row4x4 < frame_header.rows4x4;
+       row4x4 += block_width4x4) {
+    for (const auto& tile_ptr : tiles) {
+      if (!tile_ptr->ProcessSuperBlockRow<kProcessingModeDecodeOnly, false>(
+              row4x4, tile_scratch_buffer.get())) {
+        LIBGAV1_DLOG(ERROR, "Failed to decode tile number: %d\n",
+                     tile_ptr->number());
+        return kStatusUnknownError;
+      }
+    }
+    const int progress_row = post_filter->ApplyFilteringForOneSuperBlockRow(
+        row4x4, block_width4x4, row4x4 + block_width4x4 >= frame_header.rows4x4,
+        /*do_deblock=*/true);
+    if (progress_row >= 0) {
+      current_frame->SetProgress(progress_row);
+    }
+  }
+  // Mark frame as decoded (we no longer care about row-level progress since the
+  // entire frame has been decoded).
+  current_frame->SetFrameState(kFrameStateDecoded);
+  frame_scratch_buffer->tile_scratch_buffer_pool.Release(
+      std::move(tile_scratch_buffer));
+  return kStatusOk;
+}
+
+// Helper function used by DecodeTilesThreadedFrameParallel. Applies the
+// deblocking filter for tile boundaries for the superblock row at |row4x4|.
+void ApplyDeblockingFilterForTileBoundaries(
+    PostFilter* const post_filter, const std::unique_ptr<Tile>* tile_row_base,
+    const ObuFrameHeader& frame_header, int row4x4, int block_width4x4,
+    int tile_columns, bool decode_entire_tiles_in_worker_threads) {
+  // Apply vertical deblock filtering for the first 64 columns of each tile.
+  for (int tile_column = 0; tile_column < tile_columns; ++tile_column) {
+    const Tile& tile = *tile_row_base[tile_column];
+    post_filter->ApplyDeblockFilter(
+        kLoopFilterTypeVertical, row4x4, tile.column4x4_start(),
+        tile.column4x4_start() + kNum4x4InLoopFilterUnit, block_width4x4);
+  }
+  if (decode_entire_tiles_in_worker_threads &&
+      row4x4 == tile_row_base[0]->row4x4_start()) {
+    // This is the first superblock row of a tile row. In this case, apply
+    // horizontal deblock filtering for the entire superblock row.
+    post_filter->ApplyDeblockFilter(kLoopFilterTypeHorizontal, row4x4, 0,
+                                    frame_header.columns4x4, block_width4x4);
+  } else {
+    // Apply horizontal deblock filtering for the first 64 columns of the
+    // first tile.
+    const Tile& first_tile = *tile_row_base[0];
+    post_filter->ApplyDeblockFilter(
+        kLoopFilterTypeHorizontal, row4x4, first_tile.column4x4_start(),
+        first_tile.column4x4_start() + kNum4x4InLoopFilterUnit, block_width4x4);
+    // Apply horizontal deblock filtering for the last 64 columns of the
+    // previous tile and the first 64 columns of the current tile.
+    for (int tile_column = 1; tile_column < tile_columns; ++tile_column) {
+      const Tile& tile = *tile_row_base[tile_column];
+      // If the previous tile has more than 64 columns, then include those
+      // for the horizontal deblock.
+      const Tile& previous_tile = *tile_row_base[tile_column - 1];
+      const int column4x4_start =
+          tile.column4x4_start() -
+          ((tile.column4x4_start() - kNum4x4InLoopFilterUnit !=
+            previous_tile.column4x4_start())
+               ? kNum4x4InLoopFilterUnit
+               : 0);
+      post_filter->ApplyDeblockFilter(
+          kLoopFilterTypeHorizontal, row4x4, column4x4_start,
+          tile.column4x4_start() + kNum4x4InLoopFilterUnit, block_width4x4);
+    }
+    // Apply horizontal deblock filtering for the last 64 columns of the
+    // last tile.
+    const Tile& last_tile = *tile_row_base[tile_columns - 1];
+    // Identify the last column4x4 value and do horizontal filtering for
+    // that column4x4. The value of last column4x4 is the nearest multiple
+    // of 16 that is before tile.column4x4_end().
+    const int column4x4_start = (last_tile.column4x4_end() - 1) & ~15;
+    // If column4x4_start is the same as tile.column4x4_start() then it
+    // means that the last tile has <= 64 columns. So there is nothing left
+    // to deblock (since it was already deblocked in the loop above).
+    if (column4x4_start != last_tile.column4x4_start()) {
+      post_filter->ApplyDeblockFilter(
+          kLoopFilterTypeHorizontal, row4x4, column4x4_start,
+          last_tile.column4x4_end(), block_width4x4);
+    }
   }
 }
 
+// Helper function used by DecodeTilesThreadedFrameParallel. Decodes the
+// superblock row starting at |row4x4| for tile at index |tile_index| in the
+// list of tiles |tiles|. If the decoding is successful, then it does the
+// following:
+//   * Schedule the next superblock row in the current tile column for decoding
+//     (the next superblock row may be in a different tile than the current
+//     one).
+//   * If an entire superblock row of the frame has been decoded, it notifies
+//     the waiters (if there are any).
+void DecodeSuperBlockRowInTile(
+    const Vector<std::unique_ptr<Tile>>& tiles, size_t tile_index, int row4x4,
+    const int superblock_size4x4, const int tile_columns,
+    const int superblock_rows, FrameScratchBuffer* const frame_scratch_buffer,
+    PostFilter* const post_filter, BlockingCounter* const pending_jobs) {
+  std::unique_ptr<TileScratchBuffer> scratch_buffer =
+      frame_scratch_buffer->tile_scratch_buffer_pool.Get();
+  if (scratch_buffer == nullptr) {
+    SetFailureAndNotifyAll(frame_scratch_buffer, superblock_rows);
+    return;
+  }
+  Tile& tile = *tiles[tile_index];
+  const bool ok = tile.ProcessSuperBlockRow<kProcessingModeDecodeOnly, false>(
+      row4x4, scratch_buffer.get());
+  frame_scratch_buffer->tile_scratch_buffer_pool.Release(
+      std::move(scratch_buffer));
+  if (!ok) {
+    SetFailureAndNotifyAll(frame_scratch_buffer, superblock_rows);
+    return;
+  }
+  if (post_filter->DoDeblock()) {
+    // Apply vertical deblock filtering for all the columns in this tile except
+    // for the first 64 columns.
+    post_filter->ApplyDeblockFilter(
+        kLoopFilterTypeVertical, row4x4,
+        tile.column4x4_start() + kNum4x4InLoopFilterUnit, tile.column4x4_end(),
+        superblock_size4x4);
+    // Apply horizontal deblock filtering for all the columns in this tile
+    // except for the first and the last 64 columns.
+    // Note about the last tile of each row: For the last tile, column4x4_end
+    // may not be a multiple of 16. In that case it is still okay to simply
+    // subtract 16 since ApplyDeblockFilter() will only do the filters in
+    // increments of 64 columns (or 32 columns for chroma with subsampling).
+    post_filter->ApplyDeblockFilter(
+        kLoopFilterTypeHorizontal, row4x4,
+        tile.column4x4_start() + kNum4x4InLoopFilterUnit,
+        tile.column4x4_end() - kNum4x4InLoopFilterUnit, superblock_size4x4);
+  }
+  const int superblock_size4x4_log2 = FloorLog2(superblock_size4x4);
+  const int index = row4x4 >> superblock_size4x4_log2;
+  int* const superblock_row_progress =
+      frame_scratch_buffer->superblock_row_progress.get();
+  std::condition_variable* const superblock_row_progress_condvar =
+      frame_scratch_buffer->superblock_row_progress_condvar.get();
+  bool notify;
+  {
+    std::lock_guard<std::mutex> lock(
+        frame_scratch_buffer->superblock_row_mutex);
+    notify = ++superblock_row_progress[index] == tile_columns;
+  }
+  if (notify) {
+    // We are done decoding this superblock row. Notify the post filtering
+    // thread.
+    superblock_row_progress_condvar[index].notify_one();
+  }
+  // Schedule the next superblock row (if one exists).
+  ThreadPool& thread_pool =
+      *frame_scratch_buffer->threading_strategy.thread_pool();
+  const int next_row4x4 = row4x4 + superblock_size4x4;
+  if (!tile.IsRow4x4Inside(next_row4x4)) {
+    tile_index += tile_columns;
+  }
+  if (tile_index >= tiles.size()) return;
+  pending_jobs->IncrementBy(1);
+  thread_pool.Schedule([&tiles, tile_index, next_row4x4, superblock_size4x4,
+                        tile_columns, superblock_rows, frame_scratch_buffer,
+                        post_filter, pending_jobs]() {
+    DecodeSuperBlockRowInTile(tiles, tile_index, next_row4x4,
+                              superblock_size4x4, tile_columns, superblock_rows,
+                              frame_scratch_buffer, post_filter, pending_jobs);
+    pending_jobs->Decrement();
+  });
+}
+
+StatusCode DecodeTilesThreadedFrameParallel(
+    const ObuSequenceHeader& sequence_header,
+    const ObuFrameHeader& frame_header,
+    const Vector<std::unique_ptr<Tile>>& tiles,
+    const SymbolDecoderContext& saved_symbol_decoder_context,
+    const SegmentationMap* const prev_segment_ids,
+    FrameScratchBuffer* const frame_scratch_buffer,
+    PostFilter* const post_filter, RefCountedBuffer* const current_frame) {
+  // Parse the frame.
+  ThreadPool& thread_pool =
+      *frame_scratch_buffer->threading_strategy.thread_pool();
+  std::atomic<int> tile_counter(0);
+  const int tile_count = static_cast<int>(tiles.size());
+  const int num_workers = thread_pool.num_threads();
+  BlockingCounterWithStatus parse_workers(num_workers);
+  // Submit tile parsing jobs to the thread pool.
+  for (int i = 0; i < num_workers; ++i) {
+    thread_pool.Schedule([&tiles, tile_count, &tile_counter, &parse_workers]() {
+      bool failed = false;
+      int index;
+      while ((index = tile_counter.fetch_add(1, std::memory_order_relaxed)) <
+             tile_count) {
+        if (!failed) {
+          const auto& tile_ptr = tiles[index];
+          if (!tile_ptr->Parse()) {
+            LIBGAV1_DLOG(ERROR, "Error parsing tile #%d", tile_ptr->number());
+            failed = true;
+          }
+        }
+      }
+      parse_workers.Decrement(!failed);
+    });
+  }
+
+  // Have the current thread participate in parsing.
+  bool failed = false;
+  int index;
+  while ((index = tile_counter.fetch_add(1, std::memory_order_relaxed)) <
+         tile_count) {
+    if (!failed) {
+      const auto& tile_ptr = tiles[index];
+      if (!tile_ptr->Parse()) {
+        LIBGAV1_DLOG(ERROR, "Error parsing tile #%d", tile_ptr->number());
+        failed = true;
+      }
+    }
+  }
+
+  // Wait until all the parse workers are done. This ensures that all the tiles
+  // have been parsed.
+  if (!parse_workers.Wait() || failed) {
+    return kLibgav1StatusUnknownError;
+  }
+  if (frame_header.enable_frame_end_update_cdf) {
+    frame_scratch_buffer->symbol_decoder_context = saved_symbol_decoder_context;
+  }
+  current_frame->SetFrameContext(frame_scratch_buffer->symbol_decoder_context);
+  SetSegmentationMap(frame_header, prev_segment_ids, current_frame);
+  current_frame->SetFrameState(kFrameStateParsed);
+
+  // Decode the frame.
+  const int block_width4x4 = sequence_header.use_128x128_superblock ? 32 : 16;
+  const int block_width4x4_log2 =
+      sequence_header.use_128x128_superblock ? 5 : 4;
+  const int superblock_rows =
+      (frame_header.rows4x4 + block_width4x4 - 1) >> block_width4x4_log2;
+  if (!frame_scratch_buffer->superblock_row_progress.Resize(superblock_rows) ||
+      !frame_scratch_buffer->superblock_row_progress_condvar.Resize(
+          superblock_rows)) {
+    return kLibgav1StatusOutOfMemory;
+  }
+  int* const superblock_row_progress =
+      frame_scratch_buffer->superblock_row_progress.get();
+  memset(superblock_row_progress, 0,
+         superblock_rows * sizeof(superblock_row_progress[0]));
+  frame_scratch_buffer->tile_decoding_failed = false;
+  const int tile_columns = frame_header.tile_info.tile_columns;
+  const bool decode_entire_tiles_in_worker_threads =
+      num_workers >= tile_columns;
+  BlockingCounter pending_jobs(
+      decode_entire_tiles_in_worker_threads ? num_workers : tile_columns);
+  if (decode_entire_tiles_in_worker_threads) {
+    // Submit tile decoding jobs to the thread pool.
+    tile_counter = 0;
+    for (int i = 0; i < num_workers; ++i) {
+      thread_pool.Schedule([&tiles, tile_count, &tile_counter, &pending_jobs,
+                            frame_scratch_buffer, superblock_rows]() {
+        bool failed = false;
+        int index;
+        while ((index = tile_counter.fetch_add(1, std::memory_order_relaxed)) <
+               tile_count) {
+          if (failed) continue;
+          const auto& tile_ptr = tiles[index];
+          if (!tile_ptr->Decode(
+                  &frame_scratch_buffer->superblock_row_mutex,
+                  frame_scratch_buffer->superblock_row_progress.get(),
+                  frame_scratch_buffer->superblock_row_progress_condvar
+                      .get())) {
+            LIBGAV1_DLOG(ERROR, "Error decoding tile #%d", tile_ptr->number());
+            failed = true;
+            SetFailureAndNotifyAll(frame_scratch_buffer, superblock_rows);
+          }
+        }
+        pending_jobs.Decrement();
+      });
+    }
+  } else {
+    // Schedule the jobs for first tile row.
+    for (int tile_index = 0; tile_index < tile_columns; ++tile_index) {
+      thread_pool.Schedule([&tiles, tile_index, block_width4x4, tile_columns,
+                            superblock_rows, frame_scratch_buffer, post_filter,
+                            &pending_jobs]() {
+        DecodeSuperBlockRowInTile(
+            tiles, tile_index, 0, block_width4x4, tile_columns, superblock_rows,
+            frame_scratch_buffer, post_filter, &pending_jobs);
+        pending_jobs.Decrement();
+      });
+    }
+  }
+
+  // Current thread will do the post filters.
+  std::condition_variable* const superblock_row_progress_condvar =
+      frame_scratch_buffer->superblock_row_progress_condvar.get();
+  const std::unique_ptr<Tile>* tile_row_base = &tiles[0];
+  for (int row4x4 = 0, index = 0; row4x4 < frame_header.rows4x4;
+       row4x4 += block_width4x4, ++index) {
+    if (!tile_row_base[0]->IsRow4x4Inside(row4x4)) {
+      tile_row_base += tile_columns;
+    }
+    {
+      std::unique_lock<std::mutex> lock(
+          frame_scratch_buffer->superblock_row_mutex);
+      while (superblock_row_progress[index] != tile_columns &&
+             !frame_scratch_buffer->tile_decoding_failed) {
+        superblock_row_progress_condvar[index].wait(lock);
+      }
+      if (frame_scratch_buffer->tile_decoding_failed) break;
+    }
+    if (post_filter->DoDeblock()) {
+      // Apply deblocking filter for the tile boundaries of this superblock row.
+      // The deblocking filter for the internal blocks will be applied in the
+      // tile worker threads. In this thread, we will only have to apply
+      // deblocking filter for the tile boundaries.
+      ApplyDeblockingFilterForTileBoundaries(
+          post_filter, tile_row_base, frame_header, row4x4, block_width4x4,
+          tile_columns, decode_entire_tiles_in_worker_threads);
+    }
+    // Apply all the post filters other than deblocking.
+    const int progress_row = post_filter->ApplyFilteringForOneSuperBlockRow(
+        row4x4, block_width4x4, row4x4 + block_width4x4 >= frame_header.rows4x4,
+        /*do_deblock=*/false);
+    if (progress_row >= 0) {
+      current_frame->SetProgress(progress_row);
+    }
+  }
+  // Wait until all the pending jobs are done. This ensures that all the tiles
+  // have been decoded and wrapped up.
+  pending_jobs.Wait();
+  {
+    std::lock_guard<std::mutex> lock(
+        frame_scratch_buffer->superblock_row_mutex);
+    if (frame_scratch_buffer->tile_decoding_failed) {
+      return kLibgav1StatusUnknownError;
+    }
+  }
+
+  current_frame->SetFrameState(kFrameStateDecoded);
+  return kStatusOk;
+}
+
+}  // namespace
+
 // static
 StatusCode DecoderImpl::Create(const DecoderSettings* settings,
                                std::unique_ptr<DecoderImpl>* output) {
   if (settings->threads <= 0) {
     LIBGAV1_DLOG(ERROR, "Invalid settings->threads: %d.", settings->threads);
-    return kLibgav1StatusInvalidArgument;
+    return kStatusInvalidArgument;
+  }
+  if (settings->frame_parallel) {
+    if (settings->release_input_buffer == nullptr) {
+      LIBGAV1_DLOG(ERROR,
+                   "release_input_buffer callback must not be null when "
+                   "frame_parallel is true.");
+      return kStatusInvalidArgument;
+    }
   }
   std::unique_ptr<DecoderImpl> impl(new (std::nothrow) DecoderImpl(settings));
   if (impl == nullptr) {
     LIBGAV1_DLOG(ERROR, "Failed to allocate DecoderImpl.");
-    return kLibgav1StatusOutOfMemory;
+    return kStatusOutOfMemory;
   }
   const StatusCode status = impl->Init();
-  if (status != kLibgav1StatusOk) return status;
+  if (status != kStatusOk) return status;
   *output = std::move(impl);
-  return kLibgav1StatusOk;
+  return kStatusOk;
 }
 
 DecoderImpl::DecoderImpl(const DecoderSettings* settings)
-    : buffer_pool_(*settings), settings_(*settings) {
+    : buffer_pool_(settings->on_frame_buffer_size_changed,
+                   settings->get_frame_buffer, settings->release_frame_buffer,
+                   settings->callback_private_data),
+      settings_(*settings) {
   dsp::DspInit();
-  GenerateWedgeMask(state_.wedge_master_mask.data(), state_.wedge_masks.data());
 }
 
 DecoderImpl::~DecoderImpl() {
-  // The frame buffer references need to be released before |buffer_pool_| is
-  // destroyed.
+  // Clean up and wait until all the threads have stopped. We just have to pass
+  // in a dummy status that is not kStatusOk or kStatusTryAgain to trigger the
+  // path that clears all the threads and structs.
+  SignalFailure(kStatusUnknownError);
+  // Release any other frame buffer references that we may be holding on to.
   ReleaseOutputFrame();
-  assert(state_.current_frame == nullptr);
+  output_frame_queue_.Clear();
   for (auto& reference_frame : state_.reference_frame) {
     reference_frame = nullptr;
   }
 }
 
 StatusCode DecoderImpl::Init() {
-  const int max_allowed_frames =
-      settings_.frame_parallel ? settings_.threads : 1;
-  assert(max_allowed_frames > 0);
-  if (!encoded_frames_.Init(max_allowed_frames)) {
-    LIBGAV1_DLOG(ERROR, "encoded_frames_.Init() failed.");
-    return kLibgav1StatusOutOfMemory;
+  if (!GenerateWedgeMask(&wedge_masks_)) {
+    LIBGAV1_DLOG(ERROR, "GenerateWedgeMask() failed.");
+    return kStatusOutOfMemory;
   }
-  return kLibgav1StatusOk;
+  if (!output_frame_queue_.Init(kMaxLayers)) {
+    LIBGAV1_DLOG(ERROR, "output_frame_queue_.Init() failed.");
+    return kStatusOutOfMemory;
+  }
+  return kStatusOk;
+}
+
+StatusCode DecoderImpl::InitializeFrameThreadPoolAndTemporalUnitQueue(
+    const uint8_t* data, size_t size) {
+  is_frame_parallel_ = false;
+  if (settings_.frame_parallel) {
+    DecoderState state;
+    std::unique_ptr<ObuParser> obu(new (std::nothrow) ObuParser(
+        data, size, settings_.operating_point, &buffer_pool_, &state));
+    if (obu == nullptr) {
+      LIBGAV1_DLOG(ERROR, "Failed to allocate OBU parser.");
+      return kStatusOutOfMemory;
+    }
+    RefCountedBufferPtr current_frame;
+    const StatusCode status = obu->ParseOneFrame(&current_frame);
+    if (status != kStatusOk) {
+      LIBGAV1_DLOG(ERROR, "Failed to parse OBU.");
+      return status;
+    }
+    current_frame = nullptr;
+    // We assume that the first frame that was parsed will contain the frame
+    // header. This assumption is usually true in practice. So we will simply
+    // not use frame parallel mode if this is not the case.
+    if (settings_.threads > 1 &&
+        !InitializeThreadPoolsForFrameParallel(
+            settings_.threads, obu->frame_header().tile_info.tile_count,
+            obu->frame_header().tile_info.tile_columns, &frame_thread_pool_,
+            &frame_scratch_buffer_pool_)) {
+      return kStatusOutOfMemory;
+    }
+  }
+  const int max_allowed_frames =
+      (frame_thread_pool_ != nullptr) ? frame_thread_pool_->num_threads() : 1;
+  assert(max_allowed_frames > 0);
+  if (!temporal_units_.Init(max_allowed_frames)) {
+    LIBGAV1_DLOG(ERROR, "temporal_units_.Init() failed.");
+    return kStatusOutOfMemory;
+  }
+  is_frame_parallel_ = frame_thread_pool_ != nullptr;
+  return kStatusOk;
 }
 
 StatusCode DecoderImpl::EnqueueFrame(const uint8_t* data, size_t size,
-                                     int64_t user_private_data) {
-  if (data == nullptr) {
-    // This has to actually flush the decoder.
-    return kLibgav1StatusOk;
+                                     int64_t user_private_data,
+                                     void* buffer_private_data) {
+  if (data == nullptr || size == 0) return kStatusInvalidArgument;
+  if (HasFailure()) return kStatusUnknownError;
+  if (!seen_first_frame_) {
+    seen_first_frame_ = true;
+    const StatusCode status =
+        InitializeFrameThreadPoolAndTemporalUnitQueue(data, size);
+    if (status != kStatusOk) {
+      return SignalFailure(status);
+    }
   }
-  if (encoded_frames_.Full()) {
-    return kLibgav1StatusResourceExhausted;
+  if (temporal_units_.Full()) {
+    return kStatusTryAgain;
   }
-  encoded_frames_.Push(EncodedFrame(data, size, user_private_data));
-  return kLibgav1StatusOk;
+  if (is_frame_parallel_) {
+    return ParseAndSchedule(data, size, user_private_data, buffer_private_data);
+  }
+  TemporalUnit temporal_unit(data, size, user_private_data,
+                             buffer_private_data);
+  temporal_units_.Push(std::move(temporal_unit));
+  return kStatusOk;
+}
+
+StatusCode DecoderImpl::SignalFailure(StatusCode status) {
+  if (status == kStatusOk || status == kStatusTryAgain) return status;
+  // Set the |failure_status_| first so that any pending jobs in
+  // |frame_thread_pool_| will exit right away when the thread pool is being
+  // released below.
+  {
+    std::lock_guard<std::mutex> lock(mutex_);
+    failure_status_ = status;
+  }
+  // Make sure all waiting threads exit.
+  buffer_pool_.Abort();
+  frame_thread_pool_ = nullptr;
+  while (!temporal_units_.Empty()) {
+    if (settings_.release_input_buffer != nullptr) {
+      settings_.release_input_buffer(
+          settings_.callback_private_data,
+          temporal_units_.Front().buffer_private_data);
+    }
+    temporal_units_.Pop();
+  }
+  return status;
 }
 
 // DequeueFrame() follows the following policy to avoid holding unnecessary
-// frame buffer references in state_.current_frame and output_frame_.
-//
-// 1. state_.current_frame must be null when DequeueFrame() returns (success
-// or failure).
-//
-// 2. output_frame_ must be null when DequeueFrame() returns false.
+// frame buffer references in output_frame_: output_frame_ must be null when
+// DequeueFrame() returns false.
 StatusCode DecoderImpl::DequeueFrame(const DecoderBuffer** out_ptr) {
   if (out_ptr == nullptr) {
     LIBGAV1_DLOG(ERROR, "Invalid argument: out_ptr == nullptr.");
-    return kLibgav1StatusInvalidArgument;
+    return kStatusInvalidArgument;
   }
-  assert(state_.current_frame == nullptr);
   // We assume a call to DequeueFrame() indicates that the caller is no longer
   // using the previous output frame, so we can release it.
   ReleaseOutputFrame();
-  if (encoded_frames_.Empty()) {
-    // No encoded frame to decode. Not an error.
+  if (temporal_units_.Empty()) {
+    // No input frames to decode.
     *out_ptr = nullptr;
-    return kLibgav1StatusOk;
+    return kStatusNothingToDequeue;
   }
-  const EncodedFrame encoded_frame = encoded_frames_.Pop();
-  std::unique_ptr<ObuParser> obu(new (std::nothrow) ObuParser(
-      encoded_frame.data, encoded_frame.size, &state_));
-  if (obu == nullptr) {
-    LIBGAV1_DLOG(ERROR, "Failed to initialize OBU parser.");
-    return kLibgav1StatusOutOfMemory;
-  }
-  if (state_.has_sequence_header) {
-    obu->set_sequence_header(state_.sequence_header);
-  }
-  RefCountedBufferPtrCleanup current_frame_cleanup(&state_.current_frame);
-  RefCountedBufferPtr displayable_frame;
-  StatusCode status;
-  while (obu->HasData()) {
-    state_.current_frame = buffer_pool_.GetFreeBuffer();
-    if (state_.current_frame == nullptr) {
-      LIBGAV1_DLOG(ERROR, "Could not get current_frame from the buffer pool.");
-      return kLibgav1StatusResourceExhausted;
+  TemporalUnit& temporal_unit = temporal_units_.Front();
+  if (!is_frame_parallel_) {
+    // If |output_frame_queue_| is not empty, then return the first frame from
+    // that queue.
+    if (!output_frame_queue_.Empty()) {
+      RefCountedBufferPtr frame = std::move(output_frame_queue_.Front());
+      output_frame_queue_.Pop();
+      buffer_.user_private_data = temporal_unit.user_private_data;
+      if (output_frame_queue_.Empty()) {
+        temporal_units_.Pop();
+      }
+      const StatusCode status = CopyFrameToOutputBuffer(frame);
+      if (status != kStatusOk) {
+        return status;
+      }
+      *out_ptr = &buffer_;
+      return kStatusOk;
     }
+    // Decode the next available temporal unit and return.
+    const StatusCode status = DecodeTemporalUnit(temporal_unit, out_ptr);
+    if (status != kStatusOk) {
+      // In case of failure, discard all the output frames that we may be
+      // holding on references to.
+      output_frame_queue_.Clear();
+    }
+    if (settings_.release_input_buffer != nullptr) {
+      settings_.release_input_buffer(settings_.callback_private_data,
+                                     temporal_unit.buffer_private_data);
+    }
+    if (output_frame_queue_.Empty()) {
+      temporal_units_.Pop();
+    }
+    return status;
+  }
+  {
+    std::unique_lock<std::mutex> lock(mutex_);
+    if (settings_.blocking_dequeue) {
+      while (!temporal_unit.decoded && failure_status_ == kStatusOk) {
+        decoded_condvar_.wait(lock);
+      }
+    } else {
+      if (!temporal_unit.decoded && failure_status_ == kStatusOk) {
+        return kStatusTryAgain;
+      }
+    }
+    if (failure_status_ != kStatusOk) {
+      const StatusCode failure_status = failure_status_;
+      lock.unlock();
+      return SignalFailure(failure_status);
+    }
+  }
+  if (settings_.release_input_buffer != nullptr &&
+      !temporal_unit.released_input_buffer) {
+    temporal_unit.released_input_buffer = true;
+    settings_.release_input_buffer(settings_.callback_private_data,
+                                   temporal_unit.buffer_private_data);
+  }
+  if (temporal_unit.status != kStatusOk) {
+    temporal_units_.Pop();
+    return SignalFailure(temporal_unit.status);
+  }
+  if (!temporal_unit.has_displayable_frame) {
+    *out_ptr = nullptr;
+    temporal_units_.Pop();
+    return kStatusOk;
+  }
+  assert(temporal_unit.output_layer_count > 0);
+  StatusCode status = CopyFrameToOutputBuffer(
+      temporal_unit.output_layers[temporal_unit.output_layer_count - 1].frame);
+  temporal_unit.output_layers[temporal_unit.output_layer_count - 1].frame =
+      nullptr;
+  if (status != kStatusOk) {
+    temporal_units_.Pop();
+    return SignalFailure(status);
+  }
+  buffer_.user_private_data = temporal_unit.user_private_data;
+  *out_ptr = &buffer_;
+  if (--temporal_unit.output_layer_count == 0) {
+    temporal_units_.Pop();
+  }
+  return kStatusOk;
+}
 
-    if (!obu->ParseOneFrame()) {
+StatusCode DecoderImpl::ParseAndSchedule(const uint8_t* data, size_t size,
+                                         int64_t user_private_data,
+                                         void* buffer_private_data) {
+  TemporalUnit temporal_unit(data, size, user_private_data,
+                             buffer_private_data);
+  std::unique_ptr<ObuParser> obu(new (std::nothrow) ObuParser(
+      temporal_unit.data, temporal_unit.size, settings_.operating_point,
+      &buffer_pool_, &state_));
+  if (obu == nullptr) {
+    LIBGAV1_DLOG(ERROR, "Failed to allocate OBU parser.");
+    return kStatusOutOfMemory;
+  }
+  if (has_sequence_header_) {
+    obu->set_sequence_header(sequence_header_);
+  }
+  StatusCode status;
+  int position_in_temporal_unit = 0;
+  while (obu->HasData()) {
+    RefCountedBufferPtr current_frame;
+    status = obu->ParseOneFrame(&current_frame);
+    if (status != kStatusOk) {
       LIBGAV1_DLOG(ERROR, "Failed to parse OBU.");
-      return kLibgav1StatusUnknownError;
+      return status;
     }
-    if (std::find_if(obu->obu_headers().begin(), obu->obu_headers().end(),
-                     [](const ObuHeader& obu_header) {
-                       return obu_header.type == kObuSequenceHeader;
-                     }) != obu->obu_headers().end()) {
-      state_.sequence_header = obu->sequence_header();
-      state_.has_sequence_header = true;
+    if (IsNewSequenceHeader(*obu)) {
+      const ObuSequenceHeader& sequence_header = obu->sequence_header();
+      const Libgav1ImageFormat image_format =
+          ComposeImageFormat(sequence_header.color_config.is_monochrome,
+                             sequence_header.color_config.subsampling_x,
+                             sequence_header.color_config.subsampling_y);
+      const int max_bottom_border = GetBottomBorderPixels(
+          /*do_cdef=*/true, /*do_restoration=*/true,
+          /*do_superres=*/true, sequence_header.color_config.subsampling_y);
+      // TODO(vigneshv): This may not be the right place to call this callback
+      // for the frame parallel case. Investigate and fix it.
+      if (!buffer_pool_.OnFrameBufferSizeChanged(
+              sequence_header.color_config.bitdepth, image_format,
+              sequence_header.max_frame_width, sequence_header.max_frame_height,
+              kBorderPixels, kBorderPixels, kBorderPixels, max_bottom_border)) {
+        LIBGAV1_DLOG(ERROR, "buffer_pool_.OnFrameBufferSizeChanged failed.");
+        return kStatusUnknownError;
+      }
+    }
+    // This can happen when there are multiple spatial/temporal layers and if
+    // all the layers are outside the current operating point.
+    if (current_frame == nullptr) {
+      continue;
+    }
+    // Note that we cannot set EncodedFrame.temporal_unit here. It will be set
+    // in the code below after |temporal_unit| is std::move'd into the
+    // |temporal_units_| queue.
+    if (!temporal_unit.frames.emplace_back(obu.get(), state_, current_frame,
+                                           position_in_temporal_unit++)) {
+      LIBGAV1_DLOG(ERROR, "temporal_unit.frames.emplace_back failed.");
+      return kStatusOutOfMemory;
+    }
+    state_.UpdateReferenceFrames(current_frame,
+                                 obu->frame_header().refresh_frame_flags);
+  }
+  // This function cannot fail after this point. So it is okay to move the
+  // |temporal_unit| into |temporal_units_| queue.
+  temporal_units_.Push(std::move(temporal_unit));
+  if (temporal_units_.Back().frames.empty()) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    temporal_units_.Back().has_displayable_frame = false;
+    temporal_units_.Back().decoded = true;
+    return kStatusOk;
+  }
+  for (auto& frame : temporal_units_.Back().frames) {
+    EncodedFrame* const encoded_frame = &frame;
+    encoded_frame->temporal_unit = &temporal_units_.Back();
+    frame_thread_pool_->Schedule([this, encoded_frame]() {
+      if (HasFailure()) return;
+      const StatusCode status = DecodeFrame(encoded_frame);
+      encoded_frame->state = {};
+      encoded_frame->frame = nullptr;
+      TemporalUnit& temporal_unit = *encoded_frame->temporal_unit;
+      std::lock_guard<std::mutex> lock(mutex_);
+      if (failure_status_ != kStatusOk) return;
+      // temporal_unit's status defaults to kStatusOk. So we need to set it only
+      // on error. If |failure_status_| is not kStatusOk at this point, it means
+      // that there has already been a failure. So we don't care about this
+      // subsequent failure.  We will simply return the error code of the first
+      // failure.
+      if (status != kStatusOk) {
+        temporal_unit.status = status;
+        if (failure_status_ == kStatusOk) {
+          failure_status_ = status;
+        }
+      }
+      temporal_unit.decoded =
+          ++temporal_unit.decoded_count == temporal_unit.frames.size();
+      if (temporal_unit.decoded && settings_.output_all_layers &&
+          temporal_unit.output_layer_count > 1) {
+        std::sort(
+            temporal_unit.output_layers,
+            temporal_unit.output_layers + temporal_unit.output_layer_count);
+      }
+      if (temporal_unit.decoded || failure_status_ != kStatusOk) {
+        decoded_condvar_.notify_one();
+      }
+    });
+  }
+  return kStatusOk;
+}
+
+StatusCode DecoderImpl::DecodeFrame(EncodedFrame* const encoded_frame) {
+  const ObuSequenceHeader& sequence_header = encoded_frame->sequence_header;
+  const ObuFrameHeader& frame_header = encoded_frame->frame_header;
+  RefCountedBufferPtr current_frame = std::move(encoded_frame->frame);
+
+  std::unique_ptr<FrameScratchBuffer> frame_scratch_buffer =
+      frame_scratch_buffer_pool_.Get();
+  if (frame_scratch_buffer == nullptr) {
+    LIBGAV1_DLOG(ERROR, "Error when getting FrameScratchBuffer.");
+    return kStatusOutOfMemory;
+  }
+  // |frame_scratch_buffer| will be released when this local variable goes out
+  // of scope (i.e.) on any return path in this function.
+  FrameScratchBufferReleaser frame_scratch_buffer_releaser(
+      &frame_scratch_buffer_pool_, &frame_scratch_buffer);
+
+  StatusCode status;
+  if (!frame_header.show_existing_frame) {
+    if (encoded_frame->tile_buffers.empty()) {
+      // This means that the last call to ParseOneFrame() did not actually
+      // have any tile groups. This could happen in rare cases (for example,
+      // if there is a Metadata OBU after the TileGroup OBU). We currently do
+      // not have a reason to handle those cases, so we simply continue.
+      return kStatusOk;
+    }
+    status = DecodeTiles(sequence_header, frame_header,
+                         encoded_frame->tile_buffers, encoded_frame->state,
+                         frame_scratch_buffer.get(), current_frame.get());
+    if (status != kStatusOk) {
+      return status;
+    }
+  } else {
+    if (!current_frame->WaitUntilDecoded()) {
+      return kStatusUnknownError;
+    }
+  }
+  if (!frame_header.show_frame && !frame_header.show_existing_frame) {
+    // This frame is not displayable. Not an error.
+    return kStatusOk;
+  }
+  RefCountedBufferPtr film_grain_frame;
+  status = ApplyFilmGrain(
+      sequence_header, frame_header, current_frame, &film_grain_frame,
+      frame_scratch_buffer->threading_strategy.thread_pool());
+  if (status != kStatusOk) {
+    return status;
+  }
+
+  TemporalUnit& temporal_unit = *encoded_frame->temporal_unit;
+  std::lock_guard<std::mutex> lock(mutex_);
+  if (temporal_unit.has_displayable_frame && !settings_.output_all_layers) {
+    assert(temporal_unit.output_frame_position >= 0);
+    // A displayable frame was already found in this temporal unit. This can
+    // happen if there are multiple spatial/temporal layers. Since
+    // |settings_.output_all_layers| is false, we will output only the last
+    // displayable frame.
+    if (temporal_unit.output_frame_position >
+        encoded_frame->position_in_temporal_unit) {
+      return kStatusOk;
+    }
+    // Replace any output frame that we may have seen before with the current
+    // frame.
+    assert(temporal_unit.output_layer_count == 1);
+    --temporal_unit.output_layer_count;
+  }
+  temporal_unit.has_displayable_frame = true;
+  temporal_unit.output_layers[temporal_unit.output_layer_count].frame =
+      std::move(film_grain_frame);
+  temporal_unit.output_layers[temporal_unit.output_layer_count]
+      .position_in_temporal_unit = encoded_frame->position_in_temporal_unit;
+  ++temporal_unit.output_layer_count;
+  temporal_unit.output_frame_position =
+      encoded_frame->position_in_temporal_unit;
+  return kStatusOk;
+}
+
+StatusCode DecoderImpl::DecodeTemporalUnit(const TemporalUnit& temporal_unit,
+                                           const DecoderBuffer** out_ptr) {
+  std::unique_ptr<ObuParser> obu(new (std::nothrow) ObuParser(
+      temporal_unit.data, temporal_unit.size, settings_.operating_point,
+      &buffer_pool_, &state_));
+  if (obu == nullptr) {
+    LIBGAV1_DLOG(ERROR, "Failed to allocate OBU parser.");
+    return kStatusOutOfMemory;
+  }
+  if (has_sequence_header_) {
+    obu->set_sequence_header(sequence_header_);
+  }
+  StatusCode status;
+  std::unique_ptr<FrameScratchBuffer> frame_scratch_buffer =
+      frame_scratch_buffer_pool_.Get();
+  if (frame_scratch_buffer == nullptr) {
+    LIBGAV1_DLOG(ERROR, "Error when getting FrameScratchBuffer.");
+    return kStatusOutOfMemory;
+  }
+  // |frame_scratch_buffer| will be released when this local variable goes out
+  // of scope (i.e.) on any return path in this function.
+  FrameScratchBufferReleaser frame_scratch_buffer_releaser(
+      &frame_scratch_buffer_pool_, &frame_scratch_buffer);
+
+  while (obu->HasData()) {
+    RefCountedBufferPtr current_frame;
+    status = obu->ParseOneFrame(&current_frame);
+    if (status != kStatusOk) {
+      LIBGAV1_DLOG(ERROR, "Failed to parse OBU.");
+      return status;
+    }
+    if (IsNewSequenceHeader(*obu)) {
+      const ObuSequenceHeader& sequence_header = obu->sequence_header();
+      const Libgav1ImageFormat image_format =
+          ComposeImageFormat(sequence_header.color_config.is_monochrome,
+                             sequence_header.color_config.subsampling_x,
+                             sequence_header.color_config.subsampling_y);
+      const int max_bottom_border = GetBottomBorderPixels(
+          /*do_cdef=*/true, /*do_restoration=*/true,
+          /*do_superres=*/true, sequence_header.color_config.subsampling_y);
+      if (!buffer_pool_.OnFrameBufferSizeChanged(
+              sequence_header.color_config.bitdepth, image_format,
+              sequence_header.max_frame_width, sequence_header.max_frame_height,
+              kBorderPixels, kBorderPixels, kBorderPixels, max_bottom_border)) {
+        LIBGAV1_DLOG(ERROR, "buffer_pool_.OnFrameBufferSizeChanged failed.");
+        return kStatusUnknownError;
+      }
     }
     if (!obu->frame_header().show_existing_frame) {
-      if (obu->tile_groups().empty()) {
+      if (obu->tile_buffers().empty()) {
         // This means that the last call to ParseOneFrame() did not actually
         // have any tile groups. This could happen in rare cases (for example,
         // if there is a Metadata OBU after the TileGroup OBU). We currently do
         // not have a reason to handle those cases, so we simply continue.
         continue;
       }
-      status = DecodeTiles(obu.get());
-      if (status != kLibgav1StatusOk) {
+      status = DecodeTiles(obu->sequence_header(), obu->frame_header(),
+                           obu->tile_buffers(), state_,
+                           frame_scratch_buffer.get(), current_frame.get());
+      if (status != kStatusOk) {
         return status;
       }
     }
-    state_.UpdateReferenceFrames(obu->frame_header().refresh_frame_flags);
+    state_.UpdateReferenceFrames(current_frame,
+                                 obu->frame_header().refresh_frame_flags);
     if (obu->frame_header().show_frame ||
         obu->frame_header().show_existing_frame) {
-      if (displayable_frame != nullptr) {
-        // This can happen if there are multiple spatial/temporal layers. We
-        // don't care about it for now, so simply return the last displayable
-        // frame.
-        // TODO(b/129153372): Add support for outputting multiple
-        // spatial/temporal layers.
-        LIBGAV1_DLOG(
-            WARNING,
-            "More than one displayable frame found. Using the last one.");
+      if (!output_frame_queue_.Empty() && !settings_.output_all_layers) {
+        // There is more than one displayable frame in the current operating
+        // point and |settings_.output_all_layers| is false. In this case, we
+        // simply return the last displayable frame as the output frame and
+        // ignore the rest.
+        assert(output_frame_queue_.Size() == 1);
+        output_frame_queue_.Pop();
       }
-      displayable_frame = std::move(state_.current_frame);
-      if (obu->sequence_header().film_grain_params_present &&
-          displayable_frame->film_grain_params().apply_grain &&
-          (settings_.post_filter_mask & 0x10) != 0) {
-        RefCountedBufferPtr film_grain_frame;
-        if (!obu->frame_header().show_existing_frame &&
-            obu->frame_header().refresh_frame_flags == 0) {
-          // If show_existing_frame is true, then the current frame is a
-          // previously saved reference frame. If refresh_frame_flags is
-          // nonzero, then the state_.UpdateReferenceFrames() call above has
-          // saved the current frame as a reference frame. Therefore, if both
-          // of these conditions are false, then the current frame is not
-          // saved as a reference frame. displayable_frame should hold the
-          // only reference to the current frame.
-          assert(displayable_frame.use_count() == 1);
-          // Add film grain noise in place.
-          film_grain_frame = displayable_frame;
-        } else {
-          film_grain_frame = buffer_pool_.GetFreeBuffer();
-          if (film_grain_frame == nullptr) {
-            LIBGAV1_DLOG(
-                ERROR, "Could not get film_grain_frame from the buffer pool.");
-            return kLibgav1StatusResourceExhausted;
-          }
-          if (!film_grain_frame->Realloc(
-                  displayable_frame->buffer()->bitdepth(),
-                  displayable_frame->buffer()->is_monochrome(),
-                  displayable_frame->upscaled_width(),
-                  displayable_frame->frame_height(),
-                  displayable_frame->buffer()->subsampling_x(),
-                  displayable_frame->buffer()->subsampling_y(),
-                  /*border=*/0,
-                  /*byte_alignment=*/0)) {
-            LIBGAV1_DLOG(ERROR, "film_grain_frame->Realloc() failed.");
-            return kLibgav1StatusOutOfMemory;
-          }
-          film_grain_frame->set_chroma_sample_position(
-              displayable_frame->chroma_sample_position());
-        }
-        const dsp::Dsp* const dsp =
-            dsp::GetDspTable(displayable_frame->buffer()->bitdepth());
-        if (!dsp->film_grain_synthesis(
-                displayable_frame->buffer()->data(kPlaneY),
-                displayable_frame->buffer()->stride(kPlaneY),
-                displayable_frame->buffer()->data(kPlaneU),
-                displayable_frame->buffer()->stride(kPlaneU),
-                displayable_frame->buffer()->data(kPlaneV),
-                displayable_frame->buffer()->stride(kPlaneV),
-                displayable_frame->film_grain_params(),
-                displayable_frame->buffer()->is_monochrome(),
-                obu->sequence_header().color_config.matrix_coefficients ==
-                    kMatrixCoefficientIdentity,
-                displayable_frame->upscaled_width(),
-                displayable_frame->frame_height(),
-                displayable_frame->buffer()->subsampling_x(),
-                displayable_frame->buffer()->subsampling_y(),
-                film_grain_frame->buffer()->data(kPlaneY),
-                film_grain_frame->buffer()->stride(kPlaneY),
-                film_grain_frame->buffer()->data(kPlaneU),
-                film_grain_frame->buffer()->stride(kPlaneU),
-                film_grain_frame->buffer()->data(kPlaneV),
-                film_grain_frame->buffer()->stride(kPlaneV))) {
-          LIBGAV1_DLOG(ERROR, "dsp->film_grain_synthesis() failed.");
-          return kLibgav1StatusOutOfMemory;
-        }
-        displayable_frame = std::move(film_grain_frame);
-      }
+      RefCountedBufferPtr film_grain_frame;
+      status = ApplyFilmGrain(
+          obu->sequence_header(), obu->frame_header(), current_frame,
+          &film_grain_frame,
+          frame_scratch_buffer->threading_strategy.film_grain_thread_pool());
+      if (status != kStatusOk) return status;
+      output_frame_queue_.Push(std::move(film_grain_frame));
     }
   }
-  if (displayable_frame == nullptr) {
-    // No displayable frame in the encoded frame. Not an error.
+  if (output_frame_queue_.Empty()) {
+    // No displayable frame in the temporal unit. Not an error.
     *out_ptr = nullptr;
-    return kLibgav1StatusOk;
+    return kStatusOk;
   }
-  status = CopyFrameToOutputBuffer(displayable_frame);
-  if (status != kLibgav1StatusOk) {
+  status = CopyFrameToOutputBuffer(output_frame_queue_.Front());
+  output_frame_queue_.Pop();
+  if (status != kStatusOk) {
     return status;
   }
-  buffer_.user_private_data = encoded_frame.user_private_data;
+  buffer_.user_private_data = temporal_unit.user_private_data;
   *out_ptr = &buffer_;
-  return kLibgav1StatusOk;
-}
-
-bool DecoderImpl::AllocateCurrentFrame(const ObuFrameHeader& frame_header) {
-  const ColorConfig& color_config = state_.sequence_header.color_config;
-  state_.current_frame->set_chroma_sample_position(
-      color_config.chroma_sample_position);
-  return state_.current_frame->Realloc(
-      color_config.bitdepth, color_config.is_monochrome,
-      frame_header.upscaled_width, frame_header.height,
-      color_config.subsampling_x, color_config.subsampling_y, kBorderPixels,
-      /*byte_alignment=*/0);
+  return kStatusOk;
 }
 
 StatusCode DecoderImpl::CopyFrameToOutputBuffer(
@@ -336,9 +1132,15 @@
       LIBGAV1_DLOG(ERROR,
                    "Invalid chroma subsampling values: cannot determine buffer "
                    "image format.");
-      return kLibgav1StatusInvalidArgument;
+      return kStatusInvalidArgument;
     }
   }
+  buffer_.color_range = sequence_header_.color_config.color_range;
+  buffer_.color_primary = sequence_header_.color_config.color_primary;
+  buffer_.transfer_characteristics =
+      sequence_header_.color_config.transfer_characteristics;
+  buffer_.matrix_coefficients =
+      sequence_header_.color_config.matrix_coefficients;
 
   buffer_.bitdepth = yuv_buffer->bitdepth();
   const int num_planes =
@@ -347,8 +1149,8 @@
   for (; plane < num_planes; ++plane) {
     buffer_.stride[plane] = yuv_buffer->stride(plane);
     buffer_.plane[plane] = yuv_buffer->data(plane);
-    buffer_.displayed_width[plane] = yuv_buffer->displayed_width(plane);
-    buffer_.displayed_height[plane] = yuv_buffer->displayed_height(plane);
+    buffer_.displayed_width[plane] = yuv_buffer->width(plane);
+    buffer_.displayed_height[plane] = yuv_buffer->height(plane);
   }
   for (; plane < kMaxPlanes; ++plane) {
     buffer_.stride[plane] = 0;
@@ -356,9 +1158,11 @@
     buffer_.displayed_width[plane] = 0;
     buffer_.displayed_height[plane] = 0;
   }
+  buffer_.spatial_id = frame->spatial_id();
+  buffer_.temporal_id = frame->temporal_id();
   buffer_.buffer_private_data = frame->buffer_private_data();
   output_frame_ = frame;
-  return kLibgav1StatusOk;
+  return kStatusOk;
 }
 
 void DecoderImpl::ReleaseOutputFrame() {
@@ -368,336 +1172,458 @@
   output_frame_ = nullptr;
 }
 
-StatusCode DecoderImpl::DecodeTiles(const ObuParser* obu) {
-  if (PostFilter::DoDeblock(obu->frame_header(), settings_.post_filter_mask) &&
-      !loop_filter_mask_.Reset(obu->frame_header().width,
-                               obu->frame_header().height)) {
-    LIBGAV1_DLOG(ERROR, "Failed to allocate memory for loop filter masks.");
-    return kLibgav1StatusOutOfMemory;
-  }
-  LoopRestorationInfo loop_restoration_info(
-      obu->frame_header().loop_restoration, obu->frame_header().upscaled_width,
-      obu->frame_header().height,
-      obu->sequence_header().color_config.subsampling_x,
-      obu->sequence_header().color_config.subsampling_y,
-      obu->sequence_header().color_config.is_monochrome);
-  if (!loop_restoration_info.Allocate()) {
+StatusCode DecoderImpl::DecodeTiles(
+    const ObuSequenceHeader& sequence_header,
+    const ObuFrameHeader& frame_header, const Vector<TileBuffer>& tile_buffers,
+    const DecoderState& state, FrameScratchBuffer* const frame_scratch_buffer,
+    RefCountedBuffer* const current_frame) {
+  frame_scratch_buffer->tile_scratch_buffer_pool.Reset(
+      sequence_header.color_config.bitdepth);
+  if (!frame_scratch_buffer->loop_restoration_info.Reset(
+          &frame_header.loop_restoration, frame_header.upscaled_width,
+          frame_header.height, sequence_header.color_config.subsampling_x,
+          sequence_header.color_config.subsampling_y,
+          sequence_header.color_config.is_monochrome)) {
     LIBGAV1_DLOG(ERROR,
                  "Failed to allocate memory for loop restoration info units.");
-    return kLibgav1StatusOutOfMemory;
+    return kStatusOutOfMemory;
   }
-  if (!AllocateCurrentFrame(obu->frame_header())) {
+  const bool do_cdef =
+      PostFilter::DoCdef(frame_header, settings_.post_filter_mask);
+  const int num_planes = sequence_header.color_config.is_monochrome
+                             ? kMaxPlanesMonochrome
+                             : kMaxPlanes;
+  const bool do_restoration = PostFilter::DoRestoration(
+      frame_header.loop_restoration, settings_.post_filter_mask, num_planes);
+  const bool do_superres =
+      PostFilter::DoSuperRes(frame_header, settings_.post_filter_mask);
+  // Use kBorderPixels for the left, right, and top borders. Only the bottom
+  // border may need to be bigger. SuperRes border is needed only if we are
+  // applying SuperRes in-place which is being done only in single threaded
+  // mode.
+  const int bottom_border = GetBottomBorderPixels(
+      do_cdef, do_restoration,
+      do_superres &&
+          frame_scratch_buffer->threading_strategy.post_filter_thread_pool() ==
+              nullptr,
+      sequence_header.color_config.subsampling_y);
+  current_frame->set_chroma_sample_position(
+      sequence_header.color_config.chroma_sample_position);
+  if (!current_frame->Realloc(sequence_header.color_config.bitdepth,
+                              sequence_header.color_config.is_monochrome,
+                              frame_header.upscaled_width, frame_header.height,
+                              sequence_header.color_config.subsampling_x,
+                              sequence_header.color_config.subsampling_y,
+                              /*left_border=*/kBorderPixels,
+                              /*right_border=*/kBorderPixels,
+                              /*top_border=*/kBorderPixels, bottom_border)) {
     LIBGAV1_DLOG(ERROR, "Failed to allocate memory for the decoder buffer.");
-    return kLibgav1StatusOutOfMemory;
+    return kStatusOutOfMemory;
   }
-  Array2D<int16_t> cdef_index;
-  if (obu->sequence_header().enable_cdef) {
-    if (!cdef_index.Reset(
-            DivideBy16(obu->frame_header().rows4x4 + kMaxBlockHeight4x4),
-            DivideBy16(obu->frame_header().columns4x4 + kMaxBlockWidth4x4))) {
+  if (sequence_header.enable_cdef) {
+    if (!frame_scratch_buffer->cdef_index.Reset(
+            DivideBy16(frame_header.rows4x4 + kMaxBlockHeight4x4),
+            DivideBy16(frame_header.columns4x4 + kMaxBlockWidth4x4),
+            /*zero_initialize=*/false)) {
       LIBGAV1_DLOG(ERROR, "Failed to allocate memory for cdef index.");
-      return kLibgav1StatusOutOfMemory;
+      return kStatusOutOfMemory;
     }
   }
-  if (!inter_transform_sizes_.Reset(
-          obu->frame_header().rows4x4 + kMaxBlockHeight4x4,
-          obu->frame_header().columns4x4 + kMaxBlockWidth4x4,
+  if (!frame_scratch_buffer->inter_transform_sizes.Reset(
+          frame_header.rows4x4 + kMaxBlockHeight4x4,
+          frame_header.columns4x4 + kMaxBlockWidth4x4,
           /*zero_initialize=*/false)) {
     LIBGAV1_DLOG(ERROR, "Failed to allocate memory for inter_transform_sizes.");
-    return kLibgav1StatusOutOfMemory;
+    return kStatusOutOfMemory;
   }
-  if (obu->frame_header().use_ref_frame_mvs &&
-      !state_.motion_field_mv.Reset(DivideBy2(obu->frame_header().rows4x4),
-                                    DivideBy2(obu->frame_header().columns4x4),
-                                    /*zero_initialize=*/false)) {
-    LIBGAV1_DLOG(ERROR,
-                 "Failed to allocate memory for temporal motion vectors.");
-    return kLibgav1StatusOutOfMemory;
+  if (frame_header.use_ref_frame_mvs) {
+    if (!frame_scratch_buffer->motion_field.mv.Reset(
+            DivideBy2(frame_header.rows4x4), DivideBy2(frame_header.columns4x4),
+            /*zero_initialize=*/false) ||
+        !frame_scratch_buffer->motion_field.reference_offset.Reset(
+            DivideBy2(frame_header.rows4x4), DivideBy2(frame_header.columns4x4),
+            /*zero_initialize=*/false)) {
+      LIBGAV1_DLOG(ERROR,
+                   "Failed to allocate memory for temporal motion vectors.");
+      return kStatusOutOfMemory;
+    }
+
+    // For each motion vector, only mv[0] needs to be initialized to
+    // kInvalidMvValue, mv[1] is not necessary to be initialized and can be
+    // set to an arbitrary value. For simplicity, mv[1] is set to 0.
+    // The following memory initialization of contiguous memory is very fast. It
+    // is not recommended to make the initialization multi-threaded, unless the
+    // memory which needs to be initialized in each thread is still contiguous.
+    MotionVector invalid_mv;
+    invalid_mv.mv[0] = kInvalidMvValue;
+    invalid_mv.mv[1] = 0;
+    MotionVector* const motion_field_mv =
+        &frame_scratch_buffer->motion_field.mv[0][0];
+    std::fill(motion_field_mv,
+              motion_field_mv + frame_scratch_buffer->motion_field.mv.size(),
+              invalid_mv);
   }
 
   // The addition of kMaxBlockHeight4x4 and kMaxBlockWidth4x4 is necessary so
   // that the block parameters cache can be filled in for the last row/column
   // without having to check for boundary conditions.
-  BlockParametersHolder block_parameters_holder(
-      obu->frame_header().rows4x4 + kMaxBlockHeight4x4,
-      obu->frame_header().columns4x4 + kMaxBlockWidth4x4,
-      obu->sequence_header().use_128x128_superblock);
-  if (!block_parameters_holder.Init()) {
-    return kLibgav1StatusOutOfMemory;
+  if (!frame_scratch_buffer->block_parameters_holder.Reset(
+          frame_header.rows4x4 + kMaxBlockHeight4x4,
+          frame_header.columns4x4 + kMaxBlockWidth4x4,
+          sequence_header.use_128x128_superblock)) {
+    return kStatusOutOfMemory;
   }
   const dsp::Dsp* const dsp =
-      dsp::GetDspTable(obu->sequence_header().color_config.bitdepth);
+      dsp::GetDspTable(sequence_header.color_config.bitdepth);
   if (dsp == nullptr) {
     LIBGAV1_DLOG(ERROR, "Failed to get the dsp table for bitdepth %d.",
-                 obu->sequence_header().color_config.bitdepth);
-    return kLibgav1StatusInternalError;
-  }
-  // If prev_segment_ids is a null pointer, it is treated as if it pointed to
-  // a segmentation map containing all 0s.
-  const SegmentationMap* prev_segment_ids = nullptr;
-  if (obu->frame_header().primary_reference_frame == kPrimaryReferenceNone) {
-    symbol_decoder_context_.Initialize(
-        obu->frame_header().quantizer.base_index);
-  } else {
-    const int index =
-        obu->frame_header()
-            .reference_frame_index[obu->frame_header().primary_reference_frame];
-    const RefCountedBuffer* prev_frame = state_.reference_frame[index].get();
-    symbol_decoder_context_ = prev_frame->FrameContext();
-    if (obu->frame_header().segmentation.enabled &&
-        prev_frame->columns4x4() == obu->frame_header().columns4x4 &&
-        prev_frame->rows4x4() == obu->frame_header().rows4x4) {
-      prev_segment_ids = prev_frame->segmentation_map();
-    }
+                 sequence_header.color_config.bitdepth);
+    return kStatusInternalError;
   }
 
-  const uint8_t tile_size_bytes = obu->frame_header().tile_info.tile_size_bytes;
-  const int tile_count = obu->tile_groups().back().end + 1;
+  const int tile_count = frame_header.tile_info.tile_count;
   assert(tile_count >= 1);
   Vector<std::unique_ptr<Tile>> tiles;
   if (!tiles.reserve(tile_count)) {
     LIBGAV1_DLOG(ERROR, "tiles.reserve(%d) failed.\n", tile_count);
-    return kLibgav1StatusOutOfMemory;
+    return kStatusOutOfMemory;
   }
-  if (!threading_strategy_.Reset(obu->frame_header(), settings_.threads)) {
-    return kLibgav1StatusOutOfMemory;
+  ThreadingStrategy& threading_strategy =
+      frame_scratch_buffer->threading_strategy;
+  if (!is_frame_parallel_ &&
+      !threading_strategy.Reset(frame_header, settings_.threads)) {
+    return kStatusOutOfMemory;
   }
 
-  if (threading_strategy_.row_thread_pool(0) != nullptr) {
-    if (residual_buffer_pool_ == nullptr) {
-      residual_buffer_pool_.reset(new (std::nothrow) ResidualBufferPool(
-          obu->sequence_header().use_128x128_superblock,
-          obu->sequence_header().color_config.subsampling_x,
-          obu->sequence_header().color_config.subsampling_y,
-          obu->sequence_header().color_config.bitdepth == 8 ? sizeof(int16_t)
-                                                            : sizeof(int32_t)));
-      if (residual_buffer_pool_ == nullptr) {
+  if (threading_strategy.row_thread_pool(0) != nullptr || is_frame_parallel_) {
+    if (frame_scratch_buffer->residual_buffer_pool == nullptr) {
+      frame_scratch_buffer->residual_buffer_pool.reset(
+          new (std::nothrow) ResidualBufferPool(
+              sequence_header.use_128x128_superblock,
+              sequence_header.color_config.subsampling_x,
+              sequence_header.color_config.subsampling_y,
+              sequence_header.color_config.bitdepth == 8 ? sizeof(int16_t)
+                                                         : sizeof(int32_t)));
+      if (frame_scratch_buffer->residual_buffer_pool == nullptr) {
         LIBGAV1_DLOG(ERROR, "Failed to allocate residual buffer.\n");
-        return kLibgav1StatusOutOfMemory;
+        return kStatusOutOfMemory;
       }
     } else {
-      residual_buffer_pool_->Reset(
-          obu->sequence_header().use_128x128_superblock,
-          obu->sequence_header().color_config.subsampling_x,
-          obu->sequence_header().color_config.subsampling_y,
-          obu->sequence_header().color_config.bitdepth == 8 ? sizeof(int16_t)
-                                                            : sizeof(int32_t));
+      frame_scratch_buffer->residual_buffer_pool->Reset(
+          sequence_header.use_128x128_superblock,
+          sequence_header.color_config.subsampling_x,
+          sequence_header.color_config.subsampling_y,
+          sequence_header.color_config.bitdepth == 8 ? sizeof(int16_t)
+                                                     : sizeof(int32_t));
     }
   }
 
-  const bool do_cdef =
-      PostFilter::DoCdef(obu->frame_header(), settings_.post_filter_mask);
-  const int num_planes = obu->sequence_header().color_config.is_monochrome
-                             ? kMaxPlanesMonochrome
-                             : kMaxPlanes;
-  const bool do_restoration =
-      PostFilter::DoRestoration(obu->frame_header().loop_restoration,
-                                settings_.post_filter_mask, num_planes);
-  if (threading_strategy_.post_filter_thread_pool() != nullptr &&
+  if (threading_strategy.post_filter_thread_pool() != nullptr &&
       (do_cdef || do_restoration)) {
     const int window_buffer_width = PostFilter::GetWindowBufferWidth(
-        threading_strategy_.post_filter_thread_pool(), obu->frame_header());
+        threading_strategy.post_filter_thread_pool(), frame_header);
     size_t threaded_window_buffer_size =
         window_buffer_width *
         PostFilter::GetWindowBufferHeight(
-            threading_strategy_.post_filter_thread_pool(),
-            obu->frame_header()) *
-        (obu->sequence_header().color_config.bitdepth == 8 ? sizeof(uint8_t)
-                                                           : sizeof(uint16_t));
-    if (do_cdef && !do_restoration) {
+            threading_strategy.post_filter_thread_pool(), frame_header) *
+        (sequence_header.color_config.bitdepth == 8 ? sizeof(uint8_t)
+                                                    : sizeof(uint16_t));
+    if (do_cdef) {
       // TODO(chengchen): for cdef U, V planes, if there's subsampling, we can
       // use smaller buffer.
       threaded_window_buffer_size *= num_planes;
     }
-    if (threaded_window_buffer_size_ < threaded_window_buffer_size) {
-      // threaded_window_buffer_ will be subdivided by PostFilter into windows
-      // of width 512 pixels. Each row in the window is filtered by a worker
-      // thread. To avoid false sharing, each 512-pixel row processed by one
-      // thread should not share a cache line with a row processed by another
-      // thread. So we align threaded_window_buffer_ to the cache line size.
-      // In addition, it is faster to memcpy from an aligned buffer.
-      //
-      // On Linux, the cache line size can be looked up with the command:
-      //   getconf LEVEL1_DCACHE_LINESIZE
-      //
-      // The cache line size should ideally be queried at run time. 64 is a
-      // common cache line size of x86 CPUs. Web searches showed the cache line
-      // size of ARM CPUs is 32 or 64 bytes. So aligning to 64-byte boundary
-      // will work for all CPUs that we care about, even though it is excessive
-      // for some ARM CPUs.
-      constexpr size_t kCacheLineSize = 64;
-      // To avoid false sharing, PostFilter's window width in bytes should also
-      // be a multiple of the cache line size. For simplicity, we check the
-      // window width in pixels.
-      assert(window_buffer_width % kCacheLineSize == 0);
-      threaded_window_buffer_ = MakeAlignedUniquePtr<uint8_t>(
-          kCacheLineSize, threaded_window_buffer_size);
-      if (threaded_window_buffer_ == nullptr) {
-        LIBGAV1_DLOG(ERROR,
-                     "Failed to allocate threaded loop restoration buffer.\n");
-        threaded_window_buffer_size_ = 0;
-        return kLibgav1StatusOutOfMemory;
-      }
-      threaded_window_buffer_size_ = threaded_window_buffer_size;
+    // To avoid false sharing, PostFilter's window width in bytes should be a
+    // multiple of the cache line size. For simplicity, we check the window
+    // width in pixels.
+    assert(window_buffer_width % kCacheLineSize == 0);
+    if (!frame_scratch_buffer->threaded_window_buffer.Resize(
+            threaded_window_buffer_size)) {
+      LIBGAV1_DLOG(ERROR,
+                   "Failed to resize threaded loop restoration buffer.\n");
+      return kStatusOutOfMemory;
     }
   }
 
-  PostFilter post_filter(
-      obu->frame_header(), obu->sequence_header(), &loop_filter_mask_,
-      cdef_index, &loop_restoration_info, &block_parameters_holder,
-      state_.current_frame->buffer(), dsp,
-      threading_strategy_.post_filter_thread_pool(),
-      threaded_window_buffer_.get(), settings_.post_filter_mask);
-  SymbolDecoderContext saved_symbol_decoder_context;
-  int tile_index = 0;
-  BlockingCounterWithStatus pending_tiles(tile_count);
-  for (const auto& tile_group : obu->tile_groups()) {
-    size_t bytes_left = tile_group.data_size;
-    size_t byte_offset = 0;
-    // The for loop in 5.11.1.
-    for (int tile_number = tile_group.start; tile_number <= tile_group.end;
-         ++tile_number) {
-      size_t tile_size = 0;
-      if (tile_number != tile_group.end) {
-        RawBitReader bit_reader(tile_group.data + byte_offset, bytes_left);
-        if (!bit_reader.ReadLittleEndian(tile_size_bytes, &tile_size)) {
-          LIBGAV1_DLOG(ERROR, "Could not read tile size for tile #%d",
-                       tile_number);
-          return kLibgav1StatusBitstreamError;
-        }
-        ++tile_size;
-        byte_offset += tile_size_bytes;
-        bytes_left -= tile_size_bytes;
-        if (tile_size > bytes_left) {
-          LIBGAV1_DLOG(ERROR, "Invalid tile size %zu for tile #%d", tile_size,
-                       tile_number);
-          return kLibgav1StatusBitstreamError;
-        }
-      } else {
-        tile_size = bytes_left;
-      }
-
-      std::unique_ptr<Tile> tile(new (std::nothrow) Tile(
-          tile_number, tile_group.data + byte_offset, tile_size,
-          obu->sequence_header(), obu->frame_header(),
-          state_.current_frame.get(), state_.reference_frame_sign_bias,
-          state_.reference_frame, &state_.motion_field_mv,
-          state_.reference_order_hint, state_.wedge_masks,
-          symbol_decoder_context_, &saved_symbol_decoder_context,
-          prev_segment_ids, &post_filter, &block_parameters_holder, &cdef_index,
-          &inter_transform_sizes_, dsp,
-          threading_strategy_.row_thread_pool(tile_index++),
-          residual_buffer_pool_.get(), &decoder_scratch_buffer_pool_,
-          &pending_tiles));
-      if (tile == nullptr) {
-        LIBGAV1_DLOG(ERROR, "Failed to allocate tile.");
-        return kLibgav1StatusOutOfMemory;
-      }
-      tiles.push_back_unchecked(std::move(tile));
-
-      byte_offset += tile_size;
-      bytes_left -= tile_size;
+  if (do_cdef && do_restoration) {
+    // We need to store 4 rows per 64x64 unit.
+    const int num_deblock_units = MultiplyBy4(Ceil(frame_header.rows4x4, 16));
+    // subsampling_y is set to zero irrespective of the actual frame's
+    // subsampling since we need to store exactly |num_deblock_units| rows of
+    // the deblocked pixels.
+    if (!frame_scratch_buffer->deblock_buffer.Realloc(
+            sequence_header.color_config.bitdepth,
+            sequence_header.color_config.is_monochrome,
+            frame_header.upscaled_width, num_deblock_units,
+            sequence_header.color_config.subsampling_x,
+            /*subsampling_y=*/0, kBorderPixels, kBorderPixels, kBorderPixels,
+            kBorderPixels, nullptr, nullptr, nullptr)) {
+      return kStatusOutOfMemory;
     }
   }
+
+  if (do_superres) {
+    const int num_threads =
+        1 + ((threading_strategy.post_filter_thread_pool() == nullptr)
+                 ? 0
+                 : threading_strategy.post_filter_thread_pool()->num_threads());
+    const size_t superres_line_buffer_size =
+        num_threads *
+        (MultiplyBy4(frame_header.columns4x4) +
+         MultiplyBy2(kSuperResHorizontalBorder) + kSuperResHorizontalPadding) *
+        (sequence_header.color_config.bitdepth == 8 ? sizeof(uint8_t)
+                                                    : sizeof(uint16_t));
+    if (!frame_scratch_buffer->superres_line_buffer.Resize(
+            superres_line_buffer_size)) {
+      LIBGAV1_DLOG(ERROR, "Failed to resize superres line buffer.\n");
+      return kStatusOutOfMemory;
+    }
+  }
+
+  PostFilter post_filter(frame_header, sequence_header, frame_scratch_buffer,
+                         current_frame->buffer(), dsp,
+                         settings_.post_filter_mask);
+
+  if (is_frame_parallel_) {
+    // We can parse the current frame if all the reference frames have been
+    // parsed.
+    for (int i = 0; i < kNumReferenceFrameTypes; ++i) {
+      if (!state.reference_valid[i] || state.reference_frame[i] == nullptr) {
+        continue;
+      }
+      if (!state.reference_frame[i]->WaitUntilParsed()) {
+        return kStatusUnknownError;
+      }
+    }
+  }
+
+  // If prev_segment_ids is a null pointer, it is treated as if it pointed to
+  // a segmentation map containing all 0s.
+  const SegmentationMap* prev_segment_ids = nullptr;
+  if (frame_header.primary_reference_frame == kPrimaryReferenceNone) {
+    frame_scratch_buffer->symbol_decoder_context.Initialize(
+        frame_header.quantizer.base_index);
+  } else {
+    const int index =
+        frame_header
+            .reference_frame_index[frame_header.primary_reference_frame];
+    assert(index != -1);
+    const RefCountedBuffer* prev_frame = state.reference_frame[index].get();
+    frame_scratch_buffer->symbol_decoder_context = prev_frame->FrameContext();
+    if (frame_header.segmentation.enabled &&
+        prev_frame->columns4x4() == frame_header.columns4x4 &&
+        prev_frame->rows4x4() == frame_header.rows4x4) {
+      prev_segment_ids = prev_frame->segmentation_map();
+    }
+  }
+
+  // The Tile class must make use of a separate buffer to store the unfiltered
+  // pixels for the intra prediction of the next superblock row. This is done
+  // only when one of the following conditions are true:
+  //   * is_frame_parallel_ is true.
+  //   * settings_.threads == 1.
+  // In the non-frame-parallel multi-threaded case, we do not run the post
+  // filters in the decode loop. So this buffer need not be used.
+  const bool use_intra_prediction_buffer =
+      is_frame_parallel_ || settings_.threads == 1;
+  if (use_intra_prediction_buffer) {
+    if (!frame_scratch_buffer->intra_prediction_buffers.Resize(
+            frame_header.tile_info.tile_rows)) {
+      LIBGAV1_DLOG(ERROR, "Failed to Resize intra_prediction_buffers.");
+      return kStatusOutOfMemory;
+    }
+    IntraPredictionBuffer* const intra_prediction_buffers =
+        frame_scratch_buffer->intra_prediction_buffers.get();
+    for (int plane = 0; plane < num_planes; ++plane) {
+      const int subsampling =
+          (plane == kPlaneY) ? 0 : sequence_header.color_config.subsampling_x;
+      const size_t intra_prediction_buffer_size =
+          ((MultiplyBy4(frame_header.columns4x4) >> subsampling) *
+           (sequence_header.color_config.bitdepth == 8 ? sizeof(uint8_t)
+                                                       : sizeof(uint16_t)));
+      for (int tile_row = 0; tile_row < frame_header.tile_info.tile_rows;
+           ++tile_row) {
+        if (!intra_prediction_buffers[tile_row][plane].Resize(
+                intra_prediction_buffer_size)) {
+          LIBGAV1_DLOG(ERROR,
+                       "Failed to allocate intra prediction buffer for tile "
+                       "row %d plane %d.\n",
+                       tile_row, plane);
+          return kStatusOutOfMemory;
+        }
+      }
+    }
+  }
+
+  SymbolDecoderContext saved_symbol_decoder_context;
+  BlockingCounterWithStatus pending_tiles(tile_count);
+  for (int tile_number = 0; tile_number < tile_count; ++tile_number) {
+    std::unique_ptr<Tile> tile = Tile::Create(
+        tile_number, tile_buffers[tile_number].data,
+        tile_buffers[tile_number].size, sequence_header, frame_header,
+        current_frame, state, frame_scratch_buffer, wedge_masks_,
+        &saved_symbol_decoder_context, prev_segment_ids, &post_filter, dsp,
+        threading_strategy.row_thread_pool(tile_number), &pending_tiles,
+        is_frame_parallel_, use_intra_prediction_buffer);
+    if (tile == nullptr) {
+      LIBGAV1_DLOG(ERROR, "Failed to create tile.");
+      return kStatusOutOfMemory;
+    }
+    tiles.push_back_unchecked(std::move(tile));
+  }
   assert(tiles.size() == static_cast<size_t>(tile_count));
-  bool tile_decoding_failed = false;
-  if (threading_strategy_.tile_thread_pool() == nullptr) {
-    for (const auto& tile_ptr : tiles) {
-      if (!tile_decoding_failed) {
-        if (!tile_ptr->Decode(/*is_main_thread=*/true)) {
-          LIBGAV1_DLOG(ERROR, "Error decoding tile #%d", tile_ptr->number());
-          tile_decoding_failed = true;
-        }
-      } else {
-        pending_tiles.Decrement(false);
-      }
+  if (is_frame_parallel_) {
+    if (frame_scratch_buffer->threading_strategy.thread_pool() == nullptr) {
+      return DecodeTilesFrameParallel(
+          sequence_header, frame_header, tiles, saved_symbol_decoder_context,
+          prev_segment_ids, frame_scratch_buffer, &post_filter, current_frame);
     }
+    return DecodeTilesThreadedFrameParallel(
+        sequence_header, frame_header, tiles, saved_symbol_decoder_context,
+        prev_segment_ids, frame_scratch_buffer, &post_filter, current_frame);
+  }
+  StatusCode status;
+  if (settings_.threads == 1) {
+    status = DecodeTilesNonFrameParallel(sequence_header, frame_header, tiles,
+                                         frame_scratch_buffer, &post_filter);
   } else {
-    const int num_workers = threading_strategy_.tile_thread_count();
-    BlockingCounterWithStatus pending_workers(num_workers);
-    std::atomic<int> tile_counter(0);
-    // Submit tile decoding jobs to the thread pool.
-    for (int i = 0; i < num_workers; ++i) {
-      threading_strategy_.tile_thread_pool()->Schedule(
-          [&tiles, tile_count, &tile_counter, &pending_workers,
-           &pending_tiles]() {
-            bool failed = false;
-            int index;
-            while ((index = tile_counter.fetch_add(
-                        1, std::memory_order_relaxed)) < tile_count) {
-              if (!failed) {
-                const auto& tile_ptr = tiles[index];
-                if (!tile_ptr->Decode(/*is_main_thread=*/false)) {
-                  LIBGAV1_DLOG(ERROR, "Error decoding tile #%d",
-                               tile_ptr->number());
-                  failed = true;
-                }
-              } else {
-                pending_tiles.Decrement(false);
-              }
-            }
-            pending_workers.Decrement(!failed);
-          });
-    }
-    // Have the current thread partake in tile decoding.
-    int index;
-    while ((index = tile_counter.fetch_add(1, std::memory_order_relaxed)) <
-           tile_count) {
-      if (!tile_decoding_failed) {
-        const auto& tile_ptr = tiles[index];
-        if (!tile_ptr->Decode(/*is_main_thread=*/true)) {
-          LIBGAV1_DLOG(ERROR, "Error decoding tile #%d", tile_ptr->number());
-          tile_decoding_failed = true;
-        }
-      } else {
-        pending_tiles.Decrement(false);
-      }
-    }
-    // Wait until all the workers are done. This ensures that all the tiles have
-    // been parsed.
-    tile_decoding_failed |= !pending_workers.Wait();
+    status = DecodeTilesThreadedNonFrameParallel(tiles, frame_scratch_buffer,
+                                                 &post_filter, &pending_tiles);
   }
-  // Wait until all the tiles have been decoded.
-  tile_decoding_failed |= !pending_tiles.Wait();
-
-  // At this point, all the tiles have been parsed and decoded and the
-  // threadpool will be empty.
-  if (tile_decoding_failed) return kLibgav1StatusUnknownError;
-
-  if (obu->frame_header().enable_frame_end_update_cdf) {
-    symbol_decoder_context_ = saved_symbol_decoder_context;
+  if (status != kStatusOk) return status;
+  if (frame_header.enable_frame_end_update_cdf) {
+    frame_scratch_buffer->symbol_decoder_context = saved_symbol_decoder_context;
   }
-  state_.current_frame->SetFrameContext(symbol_decoder_context_);
-  if (post_filter.DoDeblock()) {
-    loop_filter_mask_.Build(obu->sequence_header(), obu->frame_header(),
-                            obu->tile_groups().front().start,
-                            obu->tile_groups().back().end,
-                            block_parameters_holder, inter_transform_sizes_);
-  }
-  if (!post_filter.ApplyFiltering()) {
-    LIBGAV1_DLOG(ERROR, "Error applying in-loop filtering.");
-    return kLibgav1StatusUnknownError;
-  }
-  SetCurrentFrameSegmentationMap(obu->frame_header(), prev_segment_ids);
-  return kLibgav1StatusOk;
+  current_frame->SetFrameContext(frame_scratch_buffer->symbol_decoder_context);
+  SetSegmentationMap(frame_header, prev_segment_ids, current_frame);
+  return kStatusOk;
 }
 
-void DecoderImpl::SetCurrentFrameSegmentationMap(
+StatusCode DecoderImpl::ApplyFilmGrain(
+    const ObuSequenceHeader& sequence_header,
     const ObuFrameHeader& frame_header,
-    const SegmentationMap* prev_segment_ids) {
-  if (!frame_header.segmentation.enabled) {
-    // All segment_id's are 0.
-    state_.current_frame->segmentation_map()->Clear();
-  } else if (!frame_header.segmentation.update_map) {
-    // Copy from prev_segment_ids.
-    if (prev_segment_ids == nullptr) {
-      // Treat a null prev_segment_ids pointer as if it pointed to a
-      // segmentation map containing all 0s.
-      state_.current_frame->segmentation_map()->Clear();
-    } else {
-      state_.current_frame->segmentation_map()->CopyFrom(*prev_segment_ids);
-    }
+    const RefCountedBufferPtr& displayable_frame,
+    RefCountedBufferPtr* film_grain_frame, ThreadPool* thread_pool) {
+  if (!sequence_header.film_grain_params_present ||
+      !displayable_frame->film_grain_params().apply_grain ||
+      (settings_.post_filter_mask & 0x10) == 0) {
+    *film_grain_frame = displayable_frame;
+    return kStatusOk;
   }
+  if (!frame_header.show_existing_frame &&
+      frame_header.refresh_frame_flags == 0) {
+    // If show_existing_frame is true, then the current frame is a previously
+    // saved reference frame. If refresh_frame_flags is nonzero, then the
+    // state_.UpdateReferenceFrames() call above has saved the current frame as
+    // a reference frame. Therefore, if both of these conditions are false, then
+    // the current frame is not saved as a reference frame. displayable_frame
+    // should hold the only reference to the current frame.
+    assert(displayable_frame.use_count() == 1);
+    // Add film grain noise in place.
+    *film_grain_frame = displayable_frame;
+  } else {
+    *film_grain_frame = buffer_pool_.GetFreeBuffer();
+    if (*film_grain_frame == nullptr) {
+      LIBGAV1_DLOG(ERROR,
+                   "Could not get film_grain_frame from the buffer pool.");
+      return kStatusResourceExhausted;
+    }
+    if (!(*film_grain_frame)
+             ->Realloc(displayable_frame->buffer()->bitdepth(),
+                       displayable_frame->buffer()->is_monochrome(),
+                       displayable_frame->upscaled_width(),
+                       displayable_frame->frame_height(),
+                       displayable_frame->buffer()->subsampling_x(),
+                       displayable_frame->buffer()->subsampling_y(),
+                       kBorderPixelsFilmGrain, kBorderPixelsFilmGrain,
+                       kBorderPixelsFilmGrain, kBorderPixelsFilmGrain)) {
+      LIBGAV1_DLOG(ERROR, "film_grain_frame->Realloc() failed.");
+      return kStatusOutOfMemory;
+    }
+    (*film_grain_frame)
+        ->set_chroma_sample_position(
+            displayable_frame->chroma_sample_position());
+    (*film_grain_frame)->set_spatial_id(displayable_frame->spatial_id());
+    (*film_grain_frame)->set_temporal_id(displayable_frame->temporal_id());
+  }
+  const bool color_matrix_is_identity =
+      sequence_header.color_config.matrix_coefficients ==
+      kMatrixCoefficientsIdentity;
+  assert(displayable_frame->buffer()->stride(kPlaneU) ==
+         displayable_frame->buffer()->stride(kPlaneV));
+  const int input_stride_uv = displayable_frame->buffer()->stride(kPlaneU);
+  assert((*film_grain_frame)->buffer()->stride(kPlaneU) ==
+         (*film_grain_frame)->buffer()->stride(kPlaneV));
+  const int output_stride_uv = (*film_grain_frame)->buffer()->stride(kPlaneU);
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  if (displayable_frame->buffer()->bitdepth() > 8) {
+    FilmGrain<10> film_grain(displayable_frame->film_grain_params(),
+                             displayable_frame->buffer()->is_monochrome(),
+                             color_matrix_is_identity,
+                             displayable_frame->buffer()->subsampling_x(),
+                             displayable_frame->buffer()->subsampling_y(),
+                             displayable_frame->upscaled_width(),
+                             displayable_frame->frame_height(), thread_pool);
+    if (!film_grain.AddNoise(
+            displayable_frame->buffer()->data(kPlaneY),
+            displayable_frame->buffer()->stride(kPlaneY),
+            displayable_frame->buffer()->data(kPlaneU),
+            displayable_frame->buffer()->data(kPlaneV), input_stride_uv,
+            (*film_grain_frame)->buffer()->data(kPlaneY),
+            (*film_grain_frame)->buffer()->stride(kPlaneY),
+            (*film_grain_frame)->buffer()->data(kPlaneU),
+            (*film_grain_frame)->buffer()->data(kPlaneV), output_stride_uv)) {
+      LIBGAV1_DLOG(ERROR, "film_grain.AddNoise() failed.");
+      return kStatusOutOfMemory;
+    }
+    return kStatusOk;
+  }
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+  FilmGrain<8> film_grain(displayable_frame->film_grain_params(),
+                          displayable_frame->buffer()->is_monochrome(),
+                          color_matrix_is_identity,
+                          displayable_frame->buffer()->subsampling_x(),
+                          displayable_frame->buffer()->subsampling_y(),
+                          displayable_frame->upscaled_width(),
+                          displayable_frame->frame_height(), thread_pool);
+  if (!film_grain.AddNoise(
+          displayable_frame->buffer()->data(kPlaneY),
+          displayable_frame->buffer()->stride(kPlaneY),
+          displayable_frame->buffer()->data(kPlaneU),
+          displayable_frame->buffer()->data(kPlaneV), input_stride_uv,
+          (*film_grain_frame)->buffer()->data(kPlaneY),
+          (*film_grain_frame)->buffer()->stride(kPlaneY),
+          (*film_grain_frame)->buffer()->data(kPlaneU),
+          (*film_grain_frame)->buffer()->data(kPlaneV), output_stride_uv)) {
+    LIBGAV1_DLOG(ERROR, "film_grain.AddNoise() failed.");
+    return kStatusOutOfMemory;
+  }
+  return kStatusOk;
+}
+
+bool DecoderImpl::IsNewSequenceHeader(const ObuParser& obu) {
+  if (std::find_if(obu.obu_headers().begin(), obu.obu_headers().end(),
+                   [](const ObuHeader& obu_header) {
+                     return obu_header.type == kObuSequenceHeader;
+                   }) == obu.obu_headers().end()) {
+    return false;
+  }
+  const ObuSequenceHeader sequence_header = obu.sequence_header();
+  const bool sequence_header_changed =
+      !has_sequence_header_ ||
+      sequence_header_.color_config.bitdepth !=
+          sequence_header.color_config.bitdepth ||
+      sequence_header_.color_config.is_monochrome !=
+          sequence_header.color_config.is_monochrome ||
+      sequence_header_.color_config.subsampling_x !=
+          sequence_header.color_config.subsampling_x ||
+      sequence_header_.color_config.subsampling_y !=
+          sequence_header.color_config.subsampling_y ||
+      sequence_header_.max_frame_width != sequence_header.max_frame_width ||
+      sequence_header_.max_frame_height != sequence_header.max_frame_height;
+  sequence_header_ = sequence_header;
+  has_sequence_header_ = true;
+  return sequence_header_changed;
 }
 
 }  // namespace libgav1
diff --git a/libgav1/src/decoder_impl.h b/libgav1/src/decoder_impl.h
index 18026f7..df1b091 100644
--- a/libgav1/src/decoder_impl.h
+++ b/libgav1/src/decoder_impl.h
@@ -18,23 +18,26 @@
 #define LIBGAV1_SRC_DECODER_IMPL_H_
 
 #include <array>
+#include <condition_variable>  // NOLINT (unapproved c++11 header)
 #include <cstddef>
 #include <cstdint>
 #include <memory>
+#include <mutex>  // NOLINT (unapproved c++11 header)
 
 #include "src/buffer_pool.h"
-#include "src/decoder_buffer.h"
-#include "src/decoder_settings.h"
+#include "src/decoder_state.h"
 #include "src/dsp/constants.h"
-#include "src/loop_filter_mask.h"
+#include "src/frame_scratch_buffer.h"
+#include "src/gav1/decoder_buffer.h"
+#include "src/gav1/decoder_settings.h"
+#include "src/gav1/status_code.h"
 #include "src/obu_parser.h"
 #include "src/residual_buffer_pool.h"
-#include "src/status_code.h"
 #include "src/symbol_decoder_context.h"
-#include "src/threading_strategy.h"
 #include "src/tile.h"
 #include "src/utils/array_2d.h"
 #include "src/utils/block_parameters_holder.h"
+#include "src/utils/compiler_attributes.h"
 #include "src/utils/constants.h"
 #include "src/utils/memory.h"
 #include "src/utils/queue.h"
@@ -43,69 +46,85 @@
 
 namespace libgav1 {
 
-struct EncodedFrame : public Allocable {
-  // The default constructor is invoked by the Queue<EncodedFrame>::Init()
+struct TemporalUnit;
+
+struct EncodedFrame {
+  EncodedFrame(ObuParser* const obu, const DecoderState& state,
+               const RefCountedBufferPtr& frame, int position_in_temporal_unit)
+      : sequence_header(obu->sequence_header()),
+        frame_header(obu->frame_header()),
+        state(state),
+        temporal_unit(nullptr),
+        frame(frame),
+        position_in_temporal_unit(position_in_temporal_unit) {
+    obu->MoveTileBuffer(&tile_buffers);
+    frame->MarkFrameAsStarted();
+  }
+
+  const ObuSequenceHeader sequence_header;
+  const ObuFrameHeader frame_header;
+  Vector<TileBuffer> tile_buffers;
+  DecoderState state;
+  TemporalUnit* temporal_unit;
+  RefCountedBufferPtr frame;
+  const int position_in_temporal_unit;
+};
+
+struct TemporalUnit : public Allocable {
+  // The default constructor is invoked by the Queue<TemporalUnit>::Init()
   // method. Queue<> does not use the default-constructed elements, so it is
   // safe for the default constructor to not initialize the members.
-  EncodedFrame() = default;
-  EncodedFrame(const uint8_t* data, size_t size, int64_t user_private_data)
-      : data(data), size(size), user_private_data(user_private_data) {}
+  TemporalUnit() = default;
+  TemporalUnit(const uint8_t* data, size_t size, int64_t user_private_data,
+               void* buffer_private_data)
+      : data(data),
+        size(size),
+        user_private_data(user_private_data),
+        buffer_private_data(buffer_private_data),
+        decoded(false),
+        status(kStatusOk),
+        has_displayable_frame(false),
+        output_frame_position(-1),
+        decoded_count(0),
+        output_layer_count(0),
+        released_input_buffer(false) {}
 
   const uint8_t* data;
   size_t size;
   int64_t user_private_data;
-};
+  void* buffer_private_data;
 
-struct DecoderState {
-  // Section 7.20. Updates frames in the reference_frame array with
-  // current_frame, based on the refresh_frame_flags bitmask.
-  void UpdateReferenceFrames(int refresh_frame_flags);
+  // The following members are used only in frame parallel mode.
+  bool decoded;
+  StatusCode status;
+  bool has_displayable_frame;
+  int output_frame_position;
 
-  // Clears all the reference frames.
-  void ClearReferenceFrames();
+  Vector<EncodedFrame> frames;
+  size_t decoded_count;
 
-  ObuSequenceHeader sequence_header = {};
-  // If true, sequence_header is valid.
-  bool has_sequence_header = false;
-  // reference_valid and reference_frame_id are used only if
-  // sequence_header_.frame_id_numbers_present is true.
-  // The reference_valid array is indexed by a reference picture slot number.
-  // A value (boolean) in the array signifies whether the corresponding
-  // reference picture slot is valid for use as a reference picture.
-  std::array<bool, kNumReferenceFrameTypes> reference_valid = {};
-  std::array<uint16_t, kNumReferenceFrameTypes> reference_frame_id = {};
-  // A valid value of current_frame_id is an unsigned integer of at most 16
-  // bits. -1 indicates current_frame_id is not initialized.
-  int current_frame_id = -1;
-  // The RefOrderHint array variable in the spec.
-  std::array<uint8_t, kNumReferenceFrameTypes> reference_order_hint = {};
-  // The OrderHint variable in the spec. Its value comes from either the
-  // order_hint syntax element in the uncompressed header (if
-  // show_existing_frame is false) or RefOrderHint[ frame_to_show_map_idx ]
-  // (if show_existing_frame is true and frame_type is KEY_FRAME). See Section
-  // 5.9.2 and Section 7.4.
-  //
-  // NOTE: When show_existing_frame is false, it is often more convenient to
-  // just use the order_hint field of the frame header as OrderHint. So this
-  // field is mainly used to update the reference_order_hint array in
-  // UpdateReferenceFrames().
-  uint8_t order_hint = 0;
-  // reference_frame_sign_bias[i] (a boolean) specifies the intended direction
-  // of the motion vector in time for each reference frame.
-  // * |false| indicates that the reference frame is a forwards reference (i.e.
-  //   the reference frame is expected to be output before the current frame);
-  // * |true| indicates that the reference frame is a backwards reference.
-  // Note: reference_frame_sign_bias[0] (for kReferenceFrameIntra) is not used.
-  std::array<bool, kNumReferenceFrameTypes> reference_frame_sign_bias = {};
-  std::array<RefCountedBufferPtr, kNumReferenceFrameTypes> reference_frame;
-  RefCountedBufferPtr current_frame;
-  // wedge_master_mask has to be initialized to zero.
-  std::array<uint8_t, 6 * kWedgeMaskMasterSize* kWedgeMaskMasterSize>
-      wedge_master_mask = {};
-  // TODO(chengchen): It is possible to reduce the buffer size. Because wedge
-  // mask sizes are 8x8, 8x16, ..., 32x32. This buffer size can fit 32x32.
-  std::array<uint8_t, kWedgeMaskSize> wedge_masks = {};
-  Array2D<TemporalMotionVector> motion_field_mv;
+  // The struct (and the counter) is used to support output of multiple layers
+  // within a single temporal unit. The decoding process will store the output
+  // frames in |output_layers| in the order they are finished decoding. At the
+  // end of the decoding process, this array will be sorted in reverse order of
+  // |position_in_temporal_unit|. DequeueFrame() will then return the frames in
+  // reverse order (so that the entire process can run with a single counter
+  // variable).
+  struct OutputLayer {
+    // Used by std::sort to sort |output_layers| in reverse order of
+    // |position_in_temporal_unit|.
+    bool operator<(const OutputLayer& rhs) const {
+      return position_in_temporal_unit > rhs.position_in_temporal_unit;
+    }
+
+    RefCountedBufferPtr frame;
+    int position_in_temporal_unit = 0;
+  } output_layers[kMaxLayers];
+  // Number of entries in |output_layers|.
+  int output_layer_count;
+  // Flag to ensure that we release the input buffer only once if there are
+  // multiple output layers.
+  bool released_input_buffer;
 };
 
 class DecoderImpl : public Allocable {
@@ -118,51 +137,121 @@
                            std::unique_ptr<DecoderImpl>* output);
   ~DecoderImpl();
   StatusCode EnqueueFrame(const uint8_t* data, size_t size,
-                          int64_t user_private_data);
+                          int64_t user_private_data, void* buffer_private_data);
   StatusCode DequeueFrame(const DecoderBuffer** out_ptr);
   static constexpr int GetMaxBitdepth() {
-#if LIBGAV1_MAX_BITDEPTH >= 10
-    return 10;
-#else
-    return 8;
-#endif
+    static_assert(LIBGAV1_MAX_BITDEPTH == 8 || LIBGAV1_MAX_BITDEPTH == 10,
+                  "LIBGAV1_MAX_BITDEPTH must be 8 or 10.");
+    return LIBGAV1_MAX_BITDEPTH;
   }
 
  private:
   explicit DecoderImpl(const DecoderSettings* settings);
   StatusCode Init();
-  bool AllocateCurrentFrame(const ObuFrameHeader& frame_header);
+  // Called when the first frame is enqueued. It does the OBU parsing for one
+  // temporal unit to retrieve the tile configuration and sets up the frame
+  // threading if frame parallel mode is allowed. It also initializes the
+  // |temporal_units_| queue based on the number of frame threads.
+  //
+  // The following are the limitations of the current implementation:
+  //  * It assumes that all frames in the video have the same tile
+  //    configuration. The frame parallel threading model will not be updated
+  //    based on tile configuration changes mid-stream.
+  //  * The above assumption holds true even when there is a new coded video
+  //    sequence (i.e.) a new sequence header.
+  StatusCode InitializeFrameThreadPoolAndTemporalUnitQueue(const uint8_t* data,
+                                                           size_t size);
+  // Used only in frame parallel mode. Signals failure and waits until the
+  // worker threads are aborted if |status| is a failure status. If |status| is
+  // equal to kStatusOk or kStatusTryAgain, this function does not do anything.
+  // Always returns the input parameter |status| as the return value.
+  //
+  // This function is called only from the application thread (from
+  // EnqueueFrame() and DequeueFrame()).
+  StatusCode SignalFailure(StatusCode status);
+
   void ReleaseOutputFrame();
-  // Populates buffer_ with values from |frame|. Adds a reference to |frame|
-  // in output_frame_.
+
+  // Decodes all the frames contained in the given temporal unit. Used only in
+  // non frame parallel mode.
+  StatusCode DecodeTemporalUnit(const TemporalUnit& temporal_unit,
+                                const DecoderBuffer** out_ptr);
+  // Used only in frame parallel mode. Does the OBU parsing for |data| and
+  // schedules the individual frames for decoding in the |frame_thread_pool_|.
+  StatusCode ParseAndSchedule(const uint8_t* data, size_t size,
+                              int64_t user_private_data,
+                              void* buffer_private_data);
+  // Decodes the |encoded_frame| and updates the
+  // |encoded_frame->temporal_unit|'s parameters if the decoded frame is a
+  // displayable frame. Used only in frame parallel mode.
+  StatusCode DecodeFrame(EncodedFrame* encoded_frame);
+
+  // Populates |buffer_| with values from |frame|. Adds a reference to |frame|
+  // in |output_frame_|.
   StatusCode CopyFrameToOutputBuffer(const RefCountedBufferPtr& frame);
-  StatusCode DecodeTiles(const ObuParser* obu);
-  // Sets the current frame's segmentation map for two cases. The third case
-  // is handled in Tile::DecodeBlock().
-  void SetCurrentFrameSegmentationMap(const ObuFrameHeader& frame_header,
-                                      const SegmentationMap* prev_segment_ids);
+  StatusCode DecodeTiles(const ObuSequenceHeader& sequence_header,
+                         const ObuFrameHeader& frame_header,
+                         const Vector<TileBuffer>& tile_buffers,
+                         const DecoderState& state,
+                         FrameScratchBuffer* frame_scratch_buffer,
+                         RefCountedBuffer* current_frame);
+  // Applies film grain synthesis to the |displayable_frame| and stores the film
+  // grain applied frame into |film_grain_frame|. Returns kStatusOk on success.
+  StatusCode ApplyFilmGrain(const ObuSequenceHeader& sequence_header,
+                            const ObuFrameHeader& frame_header,
+                            const RefCountedBufferPtr& displayable_frame,
+                            RefCountedBufferPtr* film_grain_frame,
+                            ThreadPool* thread_pool);
 
-  Queue<EncodedFrame> encoded_frames_;
+  bool IsNewSequenceHeader(const ObuParser& obu);
+
+  bool HasFailure() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    return failure_status_ != kStatusOk;
+  }
+
+  // Elements in this queue cannot be moved with std::move since the
+  // |EncodedFrame.temporal_unit| stores a pointer to elements in this queue.
+  Queue<TemporalUnit> temporal_units_;
   DecoderState state_;
-  ThreadingStrategy threading_strategy_;
-  SymbolDecoderContext symbol_decoder_context_;
 
-  // TODO(vigneshv): Only support one buffer for now. Eventually this has to be
-  // a vector or an array.
   DecoderBuffer buffer_ = {};
-  // output_frame_ holds a reference to the output frame on behalf of buffer_.
+  // |output_frame_| holds a reference to the output frame on behalf of
+  // |buffer_|.
   RefCountedBufferPtr output_frame_;
 
-  BufferPool buffer_pool_;
-  std::unique_ptr<ResidualBufferPool> residual_buffer_pool_;
-  AlignedUniquePtr<uint8_t> threaded_window_buffer_;
-  size_t threaded_window_buffer_size_ = 0;
-  Array2D<TransformSize> inter_transform_sizes_;
-  DecoderScratchBufferPool decoder_scratch_buffer_pool_;
+  // Queue of output frames that are to be returned in the DequeueFrame() calls.
+  // If |settings_.output_all_layers| is false, this queue will never contain
+  // more than 1 element. This queue is used only when |is_frame_parallel_| is
+  // false.
+  Queue<RefCountedBufferPtr> output_frame_queue_;
 
-  LoopFilterMask loop_filter_mask_;
+  BufferPool buffer_pool_;
+  WedgeMaskArray wedge_masks_;
+  FrameScratchBufferPool frame_scratch_buffer_pool_;
+
+  // Used to synchronize the accesses into |temporal_units_| in order to update
+  // the "decoded" state of an temporal unit.
+  std::mutex mutex_;
+  std::condition_variable decoded_condvar_;
+  bool is_frame_parallel_;
+  std::unique_ptr<ThreadPool> frame_thread_pool_;
+
+  // In frame parallel mode, there are two primary points of failure:
+  //  1) ParseAndSchedule()
+  //  2) DecodeTiles()
+  // Both of these functions have to respond to the other one failing by
+  // aborting whatever they are doing. This variable is used to accomplish that.
+  // If |failure_status_| is not kStatusOk, then the two functions will try to
+  // abort as early as they can.
+  StatusCode failure_status_ = kStatusOk LIBGAV1_GUARDED_BY(mutex_);
+
+  ObuSequenceHeader sequence_header_ = {};
+  // If true, sequence_header is valid.
+  bool has_sequence_header_ = false;
 
   const DecoderSettings& settings_;
+  bool seen_first_frame_ = false;
 };
 
 }  // namespace libgav1
diff --git a/libgav1/src/decoder_scratch_buffer.h b/libgav1/src/decoder_scratch_buffer.h
deleted file mode 100644
index 54ee1b7..0000000
--- a/libgav1/src/decoder_scratch_buffer.h
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- * Copyright 2019 The libgav1 Authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBGAV1_SRC_DECODER_SCRATCH_BUFFER_H_
-#define LIBGAV1_SRC_DECODER_SCRATCH_BUFFER_H_
-
-#include <cstdint>
-#include <mutex>  // NOLINT (unapproved c++11 header)
-
-#include "src/dsp/constants.h"
-#include "src/utils/compiler_attributes.h"
-#include "src/utils/constants.h"
-#include "src/utils/memory.h"
-#include "src/utils/stack.h"
-
-namespace libgav1 {
-
-// Buffer to facilitate decoding a superblock.
-struct DecoderScratchBuffer : public Allocable {
-  static constexpr int kBlockDecodedStride = 34;
-
- private:
-#if LIBGAV1_MAX_BITDEPTH >= 10
-  static constexpr int kPixelSize = 2;
-#else
-  static constexpr int kPixelSize = 1;
-#endif
-
- public:
-  // The following prediction modes need a prediction mask:
-  // kCompoundPredictionTypeDiffWeighted, kCompoundPredictionTypeWedge,
-  // kCompoundPredictionTypeIntra. They are mutually exclusive. This buffer is
-  // used to store the prediction mask during the inter prediction process. The
-  // mask only needs to be created for the Y plane and is used for the U & V
-  // planes.
-  alignas(kMaxAlignment) uint8_t
-      prediction_mask[kMaxSuperBlockSizeSquareInPixels];
-
-  // For each instance of the DecoderScratchBuffer, only one of the following
-  // buffers will be used at any given time, so it is ok to share them in a
-  // union.
-  union {
-    // Union usage note: This is used only by functions in the "inter"
-    // prediction path.
-    //
-    // Buffers used for inter prediction process.
-    alignas(kMaxAlignment) uint16_t
-        prediction_buffer[2][kMaxSuperBlockSizeSquareInPixels];
-
-    struct {
-      // Union usage note: This is used only by functions in the "intra"
-      // prediction path.
-      //
-      // Buffer used for storing subsampled luma samples needed for CFL
-      // prediction. This buffer is used to avoid repetition of the subsampling
-      // for the V plane when it is already done for the U plane.
-      int16_t cfl_luma_buffer[kCflLumaBufferStride][kCflLumaBufferStride];
-
-      // Union usage note: This is used only by the
-      // Tile::ReadTransformCoefficients() function (and the helper functions
-      // that it calls). This cannot be shared with |cfl_luma_buffer| since
-      // |cfl_luma_buffer| has to live across the 3 plane loop in
-      // Tile::TransformBlock.
-      //
-      // Buffer used by Tile::ReadTransformCoefficients() to store the quantized
-      // coefficients until the dequantization process is performed.
-      int32_t quantized_buffer[kQuantizedCoefficientBufferSize];
-    };
-  };
-
-  // Buffer used for convolve. The maximum size required for this buffer is:
-  //  maximum block height (with scaling) = 2 * 128 = 256.
-  //  maximum block stride (with scaling and border aligned to 16) =
-  //     (2 * 128 + 7 + 9) * pixel_size = 272 * pixel_size.
-  alignas(kMaxAlignment) uint8_t
-      convolve_block_buffer[256 * 272 * DecoderScratchBuffer::kPixelSize];
-
-  // Flag indicating whether the data in |cfl_luma_buffer| is valid.
-  bool cfl_luma_buffer_valid;
-
-  // Equivalent to BlockDecoded array in the spec. This stores the decoded
-  // state of every 4x4 block in a superblock. It has 1 row/column border on
-  // all 4 sides (hence the 34x34 dimension instead of 32x32). Note that the
-  // spec uses "-1" as an index to access the left and top borders. In the
-  // code, we treat the index (1, 1) as equivalent to the spec's (0, 0). So
-  // all accesses into this array will be offset by +1 when compared with the
-  // spec.
-  bool block_decoded[kMaxPlanes][kBlockDecodedStride][kBlockDecodedStride];
-};
-
-class DecoderScratchBufferPool {
- public:
-  std::unique_ptr<DecoderScratchBuffer> Get() {
-    std::lock_guard<std::mutex> lock(mutex_);
-    if (buffers_.Empty()) {
-      std::unique_ptr<DecoderScratchBuffer> scratch_buffer(
-          new (std::nothrow) DecoderScratchBuffer);
-      return scratch_buffer;
-    }
-    return buffers_.Pop();
-  }
-
-  void Release(std::unique_ptr<DecoderScratchBuffer> scratch_buffer) {
-    std::lock_guard<std::mutex> lock(mutex_);
-    buffers_.Push(std::move(scratch_buffer));
-  }
-
- private:
-  std::mutex mutex_;
-  // We will never need more than kMaxThreads scratch buffers since that is the
-  // maximum amount of work that will be done at any given time.
-  Stack<std::unique_ptr<DecoderScratchBuffer>, kMaxThreads> buffers_
-      LIBGAV1_GUARDED_BY(mutex_);
-};
-
-}  // namespace libgav1
-
-#endif  // LIBGAV1_SRC_DECODER_SCRATCH_BUFFER_H_
diff --git a/libgav1/src/decoder_settings.cc b/libgav1/src/decoder_settings.cc
new file mode 100644
index 0000000..9399073
--- /dev/null
+++ b/libgav1/src/decoder_settings.cc
@@ -0,0 +1,33 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/gav1/decoder_settings.h"
+
+extern "C" {
+
+void Libgav1DecoderSettingsInitDefault(Libgav1DecoderSettings* settings) {
+  settings->threads = 1;
+  settings->frame_parallel = 0;    // false
+  settings->blocking_dequeue = 0;  // false
+  settings->on_frame_buffer_size_changed = nullptr;
+  settings->get_frame_buffer = nullptr;
+  settings->release_frame_buffer = nullptr;
+  settings->release_input_buffer = nullptr;
+  settings->callback_private_data = nullptr;
+  settings->output_all_layers = 0;  // false
+  settings->operating_point = 0;
+  settings->post_filter_mask = 0x1f;
+}
+
+}  // extern "C"
diff --git a/libgav1/src/decoder_settings.h b/libgav1/src/decoder_settings.h
deleted file mode 100644
index 6c6f21d..0000000
--- a/libgav1/src/decoder_settings.h
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Copyright 2019 The libgav1 Authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBGAV1_SRC_DECODER_SETTINGS_H_
-#define LIBGAV1_SRC_DECODER_SETTINGS_H_
-
-#include <cstdint>
-
-#include "src/frame_buffer.h"
-
-// All the declarations in this file are part of the public ABI.
-
-namespace libgav1 {
-
-// Applications must populate this structure before creating a decoder instance.
-struct DecoderSettings {
-  // Number of threads to use when decoding. Must be greater than 0. The
-  // library will create at most |threads|-1 new threads, the calling thread is
-  // considered part of the library's thread count. Defaults to 1 (no new
-  // threads will be created).
-  int threads = 1;
-  // Do frame parallel decoding.
-  bool frame_parallel = false;
-  // Get frame buffer callback.
-  GetFrameBufferCallback get = nullptr;
-  // Release frame buffer callback.
-  ReleaseFrameBufferCallback release = nullptr;
-  // Passed as the private_data argument to the callbacks.
-  void* callback_private_data = nullptr;
-  // Mask indicating the post processing filters that need to be applied to the
-  // reconstructed frame. From LSB:
-  //   Bit 0: Loop filter (deblocking filter).
-  //   Bit 1: Cdef.
-  //   Bit 2: Superres.
-  //   Bit 3: Loop restoration.
-  //   Bit 4: Film grain synthesis.
-  //   All the bits other than the last 5 are ignored.
-  uint8_t post_filter_mask = 0x1f;
-};
-
-}  // namespace libgav1
-#endif  // LIBGAV1_SRC_DECODER_SETTINGS_H_
diff --git a/libgav1/src/decoder_state.h b/libgav1/src/decoder_state.h
new file mode 100644
index 0000000..897c99f
--- /dev/null
+++ b/libgav1/src/decoder_state.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DECODER_STATE_H_
+#define LIBGAV1_SRC_DECODER_STATE_H_
+
+#include <array>
+#include <cstdint>
+
+#include "src/buffer_pool.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+
+struct DecoderState {
+  // Section 7.20. Updates frames in the reference_frame array with
+  // |current_frame|, based on the |refresh_frame_flags| bitmask.
+  void UpdateReferenceFrames(const RefCountedBufferPtr& current_frame,
+                             int refresh_frame_flags) {
+    for (int ref_index = 0, mask = refresh_frame_flags; mask != 0;
+         ++ref_index, mask >>= 1) {
+      if ((mask & 1) != 0) {
+        reference_valid[ref_index] = true;
+        reference_frame_id[ref_index] = current_frame_id;
+        reference_frame[ref_index] = current_frame;
+        reference_order_hint[ref_index] = order_hint;
+      }
+    }
+  }
+
+  // Clears all the reference frames.
+  void ClearReferenceFrames() {
+    reference_valid = {};
+    reference_frame_id = {};
+    reference_order_hint = {};
+    for (int ref_index = 0; ref_index < kNumReferenceFrameTypes; ++ref_index) {
+      reference_frame[ref_index] = nullptr;
+    }
+  }
+
+  // reference_valid and reference_frame_id are used only if
+  // sequence_header_.frame_id_numbers_present is true.
+  // The reference_valid array is indexed by a reference picture slot number.
+  // A value (boolean) in the array signifies whether the corresponding
+  // reference picture slot is valid for use as a reference picture.
+  std::array<bool, kNumReferenceFrameTypes> reference_valid = {};
+  std::array<uint16_t, kNumReferenceFrameTypes> reference_frame_id = {};
+  // A valid value of current_frame_id is an unsigned integer of at most 16
+  // bits. -1 indicates current_frame_id is not initialized.
+  int current_frame_id = -1;
+  // The RefOrderHint array variable in the spec.
+  std::array<uint8_t, kNumReferenceFrameTypes> reference_order_hint = {};
+  // The OrderHint variable in the spec. Its value comes from either the
+  // order_hint syntax element in the uncompressed header (if
+  // show_existing_frame is false) or RefOrderHint[ frame_to_show_map_idx ]
+  // (if show_existing_frame is true and frame_type is KEY_FRAME). See Section
+  // 5.9.2 and Section 7.4.
+  //
+  // NOTE: When show_existing_frame is false, it is often more convenient to
+  // just use the order_hint field of the frame header as OrderHint. So this
+  // field is mainly used to update the reference_order_hint array in
+  // UpdateReferenceFrames().
+  uint8_t order_hint = 0;
+  // reference_frame_sign_bias[i] (a boolean) specifies the intended direction
+  // of the motion vector in time for each reference frame.
+  // * |false| indicates that the reference frame is a forwards reference (i.e.
+  //   the reference frame is expected to be output before the current frame);
+  // * |true| indicates that the reference frame is a backwards reference.
+  // Note: reference_frame_sign_bias[0] (for kReferenceFrameIntra) is not used.
+  std::array<bool, kNumReferenceFrameTypes> reference_frame_sign_bias = {};
+  std::array<RefCountedBufferPtr, kNumReferenceFrameTypes> reference_frame;
+};
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_DECODER_STATE_H_
diff --git a/libgav1/src/dsp/arm/average_blend_neon.cc b/libgav1/src/dsp/arm/average_blend_neon.cc
index 94fad54..d946d70 100644
--- a/libgav1/src/dsp/arm/average_blend_neon.cc
+++ b/libgav1/src/dsp/arm/average_blend_neon.cc
@@ -13,7 +13,7 @@
 // limitations under the License.
 
 #include "src/dsp/average_blend.h"
-#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -24,83 +24,61 @@
 #include <cstdint>
 
 #include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/utils/common.h"
 
 namespace libgav1 {
 namespace dsp {
 namespace {
 
-constexpr int kBitdepth8 = 8;
-constexpr int kInterPostRoundBit = 4;
-// An offset to cancel offsets used in compound predictor generation that
-// make intermediate computations non negative.
-const int16x8_t kCompoundRoundOffset =
-    vdupq_n_s16((2 << (kBitdepth8 + 4)) + (2 << (kBitdepth8 + 3)));
+constexpr int kInterPostRoundBit =
+    kInterRoundBitsVertical - kInterRoundBitsCompoundVertical;
 
-inline void AverageBlend4Row(const uint16_t* prediction_0,
-                             const uint16_t* prediction_1, uint8_t* dest) {
-  const int16x4_t pred0 = vreinterpret_s16_u16(vld1_u16(prediction_0));
-  const int16x4_t pred1 = vreinterpret_s16_u16(vld1_u16(prediction_1));
-  int16x4_t res = vadd_s16(pred0, pred1);
-  res = vsub_s16(res, vget_low_s16(kCompoundRoundOffset));
-  StoreLo4(dest,
-           vqrshrun_n_s16(vcombine_s16(res, res), kInterPostRoundBit + 1));
+inline uint8x8_t AverageBlend8Row(const int16_t* prediction_0,
+                                  const int16_t* prediction_1) {
+  const int16x8_t pred0 = vld1q_s16(prediction_0);
+  const int16x8_t pred1 = vld1q_s16(prediction_1);
+  const int16x8_t res = vaddq_s16(pred0, pred1);
+  return vqrshrun_n_s16(res, kInterPostRoundBit + 1);
 }
 
-inline void AverageBlend8Row(const uint16_t* prediction_0,
-                             const uint16_t* prediction_1, uint8_t* dest) {
-  const int16x8_t pred0 = vreinterpretq_s16_u16(vld1q_u16(prediction_0));
-  const int16x8_t pred1 = vreinterpretq_s16_u16(vld1q_u16(prediction_1));
-  int16x8_t res = vaddq_s16(pred0, pred1);
-  res = vsubq_s16(res, kCompoundRoundOffset);
-  vst1_u8(dest, vqrshrun_n_s16(res, kInterPostRoundBit + 1));
-}
-
-inline void AverageBlendLargeRow(const uint16_t* prediction_0,
-                                 const uint16_t* prediction_1, const int width,
+inline void AverageBlendLargeRow(const int16_t* prediction_0,
+                                 const int16_t* prediction_1, const int width,
                                  uint8_t* dest) {
   int x = 0;
   do {
-    const int16x8_t pred_00 =
-        vreinterpretq_s16_u16(vld1q_u16(&prediction_0[x]));
-    const int16x8_t pred_01 =
-        vreinterpretq_s16_u16(vld1q_u16(&prediction_1[x]));
-    int16x8_t res0 = vaddq_s16(pred_00, pred_01);
-    res0 = vsubq_s16(res0, kCompoundRoundOffset);
+    const int16x8_t pred_00 = vld1q_s16(&prediction_0[x]);
+    const int16x8_t pred_01 = vld1q_s16(&prediction_1[x]);
+    const int16x8_t res0 = vaddq_s16(pred_00, pred_01);
     const uint8x8_t res_out0 = vqrshrun_n_s16(res0, kInterPostRoundBit + 1);
-    const int16x8_t pred_10 =
-        vreinterpretq_s16_u16(vld1q_u16(&prediction_0[x + 8]));
-    const int16x8_t pred_11 =
-        vreinterpretq_s16_u16(vld1q_u16(&prediction_1[x + 8]));
-    int16x8_t res1 = vaddq_s16(pred_10, pred_11);
-    res1 = vsubq_s16(res1, kCompoundRoundOffset);
+    const int16x8_t pred_10 = vld1q_s16(&prediction_0[x + 8]);
+    const int16x8_t pred_11 = vld1q_s16(&prediction_1[x + 8]);
+    const int16x8_t res1 = vaddq_s16(pred_10, pred_11);
     const uint8x8_t res_out1 = vqrshrun_n_s16(res1, kInterPostRoundBit + 1);
     vst1q_u8(dest + x, vcombine_u8(res_out0, res_out1));
     x += 16;
   } while (x < width);
 }
 
-void AverageBlend_NEON(const uint16_t* prediction_0,
-                       const ptrdiff_t prediction_stride_0,
-                       const uint16_t* prediction_1,
-                       const ptrdiff_t prediction_stride_1, const int width,
-                       const int height, void* const dest,
+void AverageBlend_NEON(const void* prediction_0, const void* prediction_1,
+                       const int width, const int height, void* const dest,
                        const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y = height;
 
   if (width == 4) {
     do {
-      AverageBlend4Row(prediction_0, prediction_1, dst);
-      dst += dest_stride;
-      prediction_0 += prediction_stride_0;
-      prediction_1 += prediction_stride_1;
+      const uint8x8_t result = AverageBlend8Row(pred_0, pred_1);
+      pred_0 += 8;
+      pred_1 += 8;
 
-      AverageBlend4Row(prediction_0, prediction_1, dst);
+      StoreLo4(dst, result);
       dst += dest_stride;
-      prediction_0 += prediction_stride_0;
-      prediction_1 += prediction_stride_1;
-
+      StoreHi4(dst, result);
+      dst += dest_stride;
       y -= 2;
     } while (y != 0);
     return;
@@ -108,15 +86,15 @@
 
   if (width == 8) {
     do {
-      AverageBlend8Row(prediction_0, prediction_1, dst);
+      vst1_u8(dst, AverageBlend8Row(pred_0, pred_1));
       dst += dest_stride;
-      prediction_0 += prediction_stride_0;
-      prediction_1 += prediction_stride_1;
+      pred_0 += 8;
+      pred_1 += 8;
 
-      AverageBlend8Row(prediction_0, prediction_1, dst);
+      vst1_u8(dst, AverageBlend8Row(pred_0, pred_1));
       dst += dest_stride;
-      prediction_0 += prediction_stride_0;
-      prediction_1 += prediction_stride_1;
+      pred_0 += 8;
+      pred_1 += 8;
 
       y -= 2;
     } while (y != 0);
@@ -124,22 +102,22 @@
   }
 
   do {
-    AverageBlendLargeRow(prediction_0, prediction_1, width, dst);
+    AverageBlendLargeRow(pred_0, pred_1, width, dst);
     dst += dest_stride;
-    prediction_0 += prediction_stride_0;
-    prediction_1 += prediction_stride_1;
+    pred_0 += width;
+    pred_1 += width;
 
-    AverageBlendLargeRow(prediction_0, prediction_1, width, dst);
+    AverageBlendLargeRow(pred_0, pred_1, width, dst);
     dst += dest_stride;
-    prediction_0 += prediction_stride_0;
-    prediction_1 += prediction_stride_1;
+    pred_0 += width;
+    pred_1 += width;
 
     y -= 2;
   } while (y != 0);
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   dsp->average_blend = AverageBlend_NEON;
 }
@@ -151,7 +129,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 
 namespace libgav1 {
 namespace dsp {
diff --git a/libgav1/src/dsp/arm/average_blend_neon.h b/libgav1/src/dsp/arm/average_blend_neon.h
index 569da64..d13bcd6 100644
--- a/libgav1/src/dsp/arm/average_blend_neon.h
+++ b/libgav1/src/dsp/arm/average_blend_neon.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_
 #define LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -30,7 +30,7 @@
 }  // namespace libgav1
 
 #if LIBGAV1_ENABLE_NEON
-#define LIBGAV1_Dsp8bpp_AverageBlend LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_AverageBlend LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_AVERAGE_BLEND_NEON_H_
diff --git a/libgav1/src/dsp/arm/cdef_neon.cc b/libgav1/src/dsp/arm/cdef_neon.cc
new file mode 100644
index 0000000..968b0ff
--- /dev/null
+++ b/libgav1/src/dsp/arm/cdef_neon.cc
@@ -0,0 +1,697 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/cdef.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+#include "src/dsp/cdef.inc"
+
+// ----------------------------------------------------------------------------
+// Refer to CdefDirection_C().
+//
+// int32_t partial[8][15] = {};
+// for (int i = 0; i < 8; ++i) {
+//   for (int j = 0; j < 8; ++j) {
+//     const int x = 1;
+//     partial[0][i + j] += x;
+//     partial[1][i + j / 2] += x;
+//     partial[2][i] += x;
+//     partial[3][3 + i - j / 2] += x;
+//     partial[4][7 + i - j] += x;
+//     partial[5][3 - i / 2 + j] += x;
+//     partial[6][j] += x;
+//     partial[7][i / 2 + j] += x;
+//   }
+// }
+//
+// Using the code above, generate the position count for partial[8][15].
+//
+// partial[0]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1
+// partial[1]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
+// partial[2]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0
+// partial[3]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
+// partial[4]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1
+// partial[5]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
+// partial[6]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0
+// partial[7]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
+//
+// The SIMD code shifts the input horizontally, then adds vertically to get the
+// correct partial value for the given position.
+// ----------------------------------------------------------------------------
+
+// ----------------------------------------------------------------------------
+// partial[0][i + j] += x;
+//
+// 00 01 02 03 04 05 06 07  00 00 00 00 00 00 00
+// 00 10 11 12 13 14 15 16  17 00 00 00 00 00 00
+// 00 00 20 21 22 23 24 25  26 27 00 00 00 00 00
+// 00 00 00 30 31 32 33 34  35 36 37 00 00 00 00
+// 00 00 00 00 40 41 42 43  44 45 46 47 00 00 00
+// 00 00 00 00 00 50 51 52  53 54 55 56 57 00 00
+// 00 00 00 00 00 00 60 61  62 63 64 65 66 67 00
+// 00 00 00 00 00 00 00 70  71 72 73 74 75 76 77
+//
+// partial[4] is the same except the source is reversed.
+LIBGAV1_ALWAYS_INLINE void AddPartial_D0_D4(uint8x8_t* v_src,
+                                            uint16x8_t* partial_lo,
+                                            uint16x8_t* partial_hi) {
+  const uint8x8_t v_zero = vdup_n_u8(0);
+  // 00 01 02 03 04 05 06 07
+  // 00 10 11 12 13 14 15 16
+  *partial_lo = vaddl_u8(v_src[0], vext_u8(v_zero, v_src[1], 7));
+
+  // 00 00 20 21 22 23 24 25
+  *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[2], 6));
+  // 17 00 00 00 00 00 00 00
+  // 26 27 00 00 00 00 00 00
+  *partial_hi =
+      vaddl_u8(vext_u8(v_src[1], v_zero, 7), vext_u8(v_src[2], v_zero, 6));
+
+  // 00 00 00 30 31 32 33 34
+  *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[3], 5));
+  // 35 36 37 00 00 00 00 00
+  *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[3], v_zero, 5));
+
+  // 00 00 00 00 40 41 42 43
+  *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[4], 4));
+  // 44 45 46 47 00 00 00 00
+  *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[4], v_zero, 4));
+
+  // 00 00 00 00 00 50 51 52
+  *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[5], 3));
+  // 53 54 55 56 57 00 00 00
+  *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[5], v_zero, 3));
+
+  // 00 00 00 00 00 00 60 61
+  *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[6], 2));
+  // 62 63 64 65 66 67 00 00
+  *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[6], v_zero, 2));
+
+  // 00 00 00 00 00 00 00 70
+  *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[7], 1));
+  // 71 72 73 74 75 76 77 00
+  *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[7], v_zero, 1));
+}
+
+// ----------------------------------------------------------------------------
+// partial[1][i + j / 2] += x;
+//
+// A0 = src[0] + src[1], A1 = src[2] + src[3], ...
+//
+// A0 A1 A2 A3 00 00 00 00  00 00 00 00 00 00 00
+// 00 B0 B1 B2 B3 00 00 00  00 00 00 00 00 00 00
+// 00 00 C0 C1 C2 C3 00 00  00 00 00 00 00 00 00
+// 00 00 00 D0 D1 D2 D3 00  00 00 00 00 00 00 00
+// 00 00 00 00 E0 E1 E2 E3  00 00 00 00 00 00 00
+// 00 00 00 00 00 F0 F1 F2  F3 00 00 00 00 00 00
+// 00 00 00 00 00 00 G0 G1  G2 G3 00 00 00 00 00
+// 00 00 00 00 00 00 00 H0  H1 H2 H3 00 00 00 00
+//
+// partial[3] is the same except the source is reversed.
+LIBGAV1_ALWAYS_INLINE void AddPartial_D1_D3(uint8x8_t* v_src,
+                                            uint16x8_t* partial_lo,
+                                            uint16x8_t* partial_hi) {
+  uint8x16_t v_d1_temp[8];
+  const uint8x8_t v_zero = vdup_n_u8(0);
+  const uint8x16_t v_zero_16 = vdupq_n_u8(0);
+
+  for (int i = 0; i < 8; ++i) {
+    v_d1_temp[i] = vcombine_u8(v_src[i], v_zero);
+  }
+
+  *partial_lo = *partial_hi = vdupq_n_u16(0);
+  // A0 A1 A2 A3 00 00 00 00
+  *partial_lo = vpadalq_u8(*partial_lo, v_d1_temp[0]);
+
+  // 00 B0 B1 B2 B3 00 00 00
+  *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[1], 14));
+
+  // 00 00 C0 C1 C2 C3 00 00
+  *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[2], 12));
+  // 00 00 00 D0 D1 D2 D3 00
+  *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[3], 10));
+  // 00 00 00 00 E0 E1 E2 E3
+  *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[4], 8));
+
+  // 00 00 00 00 00 F0 F1 F2
+  *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[5], 6));
+  // F3 00 00 00 00 00 00 00
+  *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[5], v_zero_16, 6));
+
+  // 00 00 00 00 00 00 G0 G1
+  *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[6], 4));
+  // G2 G3 00 00 00 00 00 00
+  *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[6], v_zero_16, 4));
+
+  // 00 00 00 00 00 00 00 H0
+  *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[7], 2));
+  // H1 H2 H3 00 00 00 00 00
+  *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[7], v_zero_16, 2));
+}
+
+// ----------------------------------------------------------------------------
+// partial[7][i / 2 + j] += x;
+//
+// 00 01 02 03 04 05 06 07  00 00 00 00 00 00 00
+// 10 11 12 13 14 15 16 17  00 00 00 00 00 00 00
+// 00 20 21 22 23 24 25 26  27 00 00 00 00 00 00
+// 00 30 31 32 33 34 35 36  37 00 00 00 00 00 00
+// 00 00 40 41 42 43 44 45  46 47 00 00 00 00 00
+// 00 00 50 51 52 53 54 55  56 57 00 00 00 00 00
+// 00 00 00 60 61 62 63 64  65 66 67 00 00 00 00
+// 00 00 00 70 71 72 73 74  75 76 77 00 00 00 00
+//
+// partial[5] is the same except the source is reversed.
+LIBGAV1_ALWAYS_INLINE void AddPartial_D5_D7(uint8x8_t* v_src,
+                                            uint16x8_t* partial_lo,
+                                            uint16x8_t* partial_hi) {
+  const uint16x8_t v_zero = vdupq_n_u16(0);
+  uint16x8_t v_pair_add[4];
+  // Add vertical source pairs.
+  v_pair_add[0] = vaddl_u8(v_src[0], v_src[1]);
+  v_pair_add[1] = vaddl_u8(v_src[2], v_src[3]);
+  v_pair_add[2] = vaddl_u8(v_src[4], v_src[5]);
+  v_pair_add[3] = vaddl_u8(v_src[6], v_src[7]);
+
+  // 00 01 02 03 04 05 06 07
+  // 10 11 12 13 14 15 16 17
+  *partial_lo = v_pair_add[0];
+  // 00 00 00 00 00 00 00 00
+  // 00 00 00 00 00 00 00 00
+  *partial_hi = vdupq_n_u16(0);
+
+  // 00 20 21 22 23 24 25 26
+  // 00 30 31 32 33 34 35 36
+  *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[1], 7));
+  // 27 00 00 00 00 00 00 00
+  // 37 00 00 00 00 00 00 00
+  *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[1], v_zero, 7));
+
+  // 00 00 40 41 42 43 44 45
+  // 00 00 50 51 52 53 54 55
+  *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[2], 6));
+  // 46 47 00 00 00 00 00 00
+  // 56 57 00 00 00 00 00 00
+  *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[2], v_zero, 6));
+
+  // 00 00 00 60 61 62 63 64
+  // 00 00 00 70 71 72 73 74
+  *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[3], 5));
+  // 65 66 67 00 00 00 00 00
+  // 75 76 77 00 00 00 00 00
+  *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[3], v_zero, 5));
+}
+
+LIBGAV1_ALWAYS_INLINE void AddPartial(const void* const source,
+                                      ptrdiff_t stride, uint16x8_t* partial_lo,
+                                      uint16x8_t* partial_hi) {
+  const auto* src = static_cast<const uint8_t*>(source);
+
+  // 8x8 input
+  // 00 01 02 03 04 05 06 07
+  // 10 11 12 13 14 15 16 17
+  // 20 21 22 23 24 25 26 27
+  // 30 31 32 33 34 35 36 37
+  // 40 41 42 43 44 45 46 47
+  // 50 51 52 53 54 55 56 57
+  // 60 61 62 63 64 65 66 67
+  // 70 71 72 73 74 75 76 77
+  uint8x8_t v_src[8];
+  for (int i = 0; i < 8; ++i) {
+    v_src[i] = vld1_u8(src);
+    src += stride;
+  }
+
+  // partial for direction 2
+  // --------------------------------------------------------------------------
+  // partial[2][i] += x;
+  // 00 10 20 30 40 50 60 70  00 00 00 00 00 00 00 00
+  // 01 11 21 33 41 51 61 71  00 00 00 00 00 00 00 00
+  // 02 12 22 33 42 52 62 72  00 00 00 00 00 00 00 00
+  // 03 13 23 33 43 53 63 73  00 00 00 00 00 00 00 00
+  // 04 14 24 34 44 54 64 74  00 00 00 00 00 00 00 00
+  // 05 15 25 35 45 55 65 75  00 00 00 00 00 00 00 00
+  // 06 16 26 36 46 56 66 76  00 00 00 00 00 00 00 00
+  // 07 17 27 37 47 57 67 77  00 00 00 00 00 00 00 00
+  partial_lo[2] = vsetq_lane_u16(SumVector(v_src[0]), partial_lo[2], 0);
+  partial_lo[2] = vsetq_lane_u16(SumVector(v_src[1]), partial_lo[2], 1);
+  partial_lo[2] = vsetq_lane_u16(SumVector(v_src[2]), partial_lo[2], 2);
+  partial_lo[2] = vsetq_lane_u16(SumVector(v_src[3]), partial_lo[2], 3);
+  partial_lo[2] = vsetq_lane_u16(SumVector(v_src[4]), partial_lo[2], 4);
+  partial_lo[2] = vsetq_lane_u16(SumVector(v_src[5]), partial_lo[2], 5);
+  partial_lo[2] = vsetq_lane_u16(SumVector(v_src[6]), partial_lo[2], 6);
+  partial_lo[2] = vsetq_lane_u16(SumVector(v_src[7]), partial_lo[2], 7);
+
+  // partial for direction 6
+  // --------------------------------------------------------------------------
+  // partial[6][j] += x;
+  // 00 01 02 03 04 05 06 07  00 00 00 00 00 00 00 00
+  // 10 11 12 13 14 15 16 17  00 00 00 00 00 00 00 00
+  // 20 21 22 23 24 25 26 27  00 00 00 00 00 00 00 00
+  // 30 31 32 33 34 35 36 37  00 00 00 00 00 00 00 00
+  // 40 41 42 43 44 45 46 47  00 00 00 00 00 00 00 00
+  // 50 51 52 53 54 55 56 57  00 00 00 00 00 00 00 00
+  // 60 61 62 63 64 65 66 67  00 00 00 00 00 00 00 00
+  // 70 71 72 73 74 75 76 77  00 00 00 00 00 00 00 00
+  const uint8x8_t v_zero = vdup_n_u8(0);
+  partial_lo[6] = vaddl_u8(v_zero, v_src[0]);
+  for (int i = 1; i < 8; ++i) {
+    partial_lo[6] = vaddw_u8(partial_lo[6], v_src[i]);
+  }
+
+  // partial for direction 0
+  AddPartial_D0_D4(v_src, &partial_lo[0], &partial_hi[0]);
+
+  // partial for direction 1
+  AddPartial_D1_D3(v_src, &partial_lo[1], &partial_hi[1]);
+
+  // partial for direction 7
+  AddPartial_D5_D7(v_src, &partial_lo[7], &partial_hi[7]);
+
+  uint8x8_t v_src_reverse[8];
+  for (int i = 0; i < 8; ++i) {
+    v_src_reverse[i] = vrev64_u8(v_src[i]);
+  }
+
+  // partial for direction 4
+  AddPartial_D0_D4(v_src_reverse, &partial_lo[4], &partial_hi[4]);
+
+  // partial for direction 3
+  AddPartial_D1_D3(v_src_reverse, &partial_lo[3], &partial_hi[3]);
+
+  // partial for direction 5
+  AddPartial_D5_D7(v_src_reverse, &partial_lo[5], &partial_hi[5]);
+}
+
+uint32x4_t Square(uint16x4_t a) { return vmull_u16(a, a); }
+
+uint32x4_t SquareAccumulate(uint32x4_t a, uint16x4_t b) {
+  return vmlal_u16(a, b, b);
+}
+
+// |cost[0]| and |cost[4]| square the input and sum with the corresponding
+// element from the other end of the vector:
+// |kCdefDivisionTable[]| element:
+// cost[0] += (Square(partial[0][i]) + Square(partial[0][14 - i])) *
+//             kCdefDivisionTable[i + 1];
+// cost[0] += Square(partial[0][7]) * kCdefDivisionTable[8];
+// Because everything is being summed into a single value the distributive
+// property allows us to mirror the division table and accumulate once.
+uint32_t Cost0Or4(const uint16x8_t a, const uint16x8_t b,
+                  const uint32x4_t division_table[4]) {
+  uint32x4_t c = vmulq_u32(Square(vget_low_u16(a)), division_table[0]);
+  c = vmlaq_u32(c, Square(vget_high_u16(a)), division_table[1]);
+  c = vmlaq_u32(c, Square(vget_low_u16(b)), division_table[2]);
+  c = vmlaq_u32(c, Square(vget_high_u16(b)), division_table[3]);
+  return SumVector(c);
+}
+
+// |cost[2]| and |cost[6]| square the input and accumulate:
+// cost[2] += Square(partial[2][i])
+uint32_t SquareAccumulate(const uint16x8_t a) {
+  uint32x4_t c = Square(vget_low_u16(a));
+  c = SquareAccumulate(c, vget_high_u16(a));
+  c = vmulq_n_u32(c, kCdefDivisionTable[7]);
+  return SumVector(c);
+}
+
+uint32_t CostOdd(const uint16x8_t a, const uint16x8_t b, const uint32x4_t mask,
+                 const uint32x4_t division_table[2]) {
+  // Remove elements 0-2.
+  uint32x4_t c = vandq_u32(mask, Square(vget_low_u16(a)));
+  c = vaddq_u32(c, Square(vget_high_u16(a)));
+  c = vmulq_n_u32(c, kCdefDivisionTable[7]);
+
+  c = vmlaq_u32(c, Square(vget_low_u16(a)), division_table[0]);
+  c = vmlaq_u32(c, Square(vget_low_u16(b)), division_table[1]);
+  return SumVector(c);
+}
+
+void CdefDirection_NEON(const void* const source, ptrdiff_t stride,
+                        int* const direction, int* const variance) {
+  assert(direction != nullptr);
+  assert(variance != nullptr);
+  const auto* src = static_cast<const uint8_t*>(source);
+  uint32_t cost[8];
+  uint16x8_t partial_lo[8], partial_hi[8];
+
+  AddPartial(src, stride, partial_lo, partial_hi);
+
+  cost[2] = SquareAccumulate(partial_lo[2]);
+  cost[6] = SquareAccumulate(partial_lo[6]);
+
+  const uint32x4_t division_table[4] = {
+      vld1q_u32(kCdefDivisionTable), vld1q_u32(kCdefDivisionTable + 4),
+      vld1q_u32(kCdefDivisionTable + 8), vld1q_u32(kCdefDivisionTable + 12)};
+
+  cost[0] = Cost0Or4(partial_lo[0], partial_hi[0], division_table);
+  cost[4] = Cost0Or4(partial_lo[4], partial_hi[4], division_table);
+
+  const uint32x4_t division_table_odd[2] = {
+      vld1q_u32(kCdefDivisionTableOdd), vld1q_u32(kCdefDivisionTableOdd + 4)};
+
+  const uint32x4_t element_3_mask = {0, 0, 0, static_cast<uint32_t>(-1)};
+
+  cost[1] =
+      CostOdd(partial_lo[1], partial_hi[1], element_3_mask, division_table_odd);
+  cost[3] =
+      CostOdd(partial_lo[3], partial_hi[3], element_3_mask, division_table_odd);
+  cost[5] =
+      CostOdd(partial_lo[5], partial_hi[5], element_3_mask, division_table_odd);
+  cost[7] =
+      CostOdd(partial_lo[7], partial_hi[7], element_3_mask, division_table_odd);
+
+  uint32_t best_cost = 0;
+  *direction = 0;
+  for (int i = 0; i < 8; ++i) {
+    if (cost[i] > best_cost) {
+      best_cost = cost[i];
+      *direction = i;
+    }
+  }
+  *variance = (best_cost - cost[(*direction + 4) & 7]) >> 10;
+}
+
+// -------------------------------------------------------------------------
+// CdefFilter
+
+// Load 4 vectors based on the given |direction|.
+void LoadDirection(const uint16_t* const src, const ptrdiff_t stride,
+                   uint16x8_t* output, const int direction) {
+  // Each |direction| describes a different set of source values. Expand this
+  // set by negating each set. For |direction| == 0 this gives a diagonal line
+  // from top right to bottom left. The first value is y, the second x. Negative
+  // y values move up.
+  //    a       b         c       d
+  // {-1, 1}, {1, -1}, {-2, 2}, {2, -2}
+  //         c
+  //       a
+  //     0
+  //   b
+  // d
+  const int y_0 = kCdefDirections[direction][0][0];
+  const int x_0 = kCdefDirections[direction][0][1];
+  const int y_1 = kCdefDirections[direction][1][0];
+  const int x_1 = kCdefDirections[direction][1][1];
+  output[0] = vld1q_u16(src + y_0 * stride + x_0);
+  output[1] = vld1q_u16(src - y_0 * stride - x_0);
+  output[2] = vld1q_u16(src + y_1 * stride + x_1);
+  output[3] = vld1q_u16(src - y_1 * stride - x_1);
+}
+
+// Load 4 vectors based on the given |direction|. Use when |block_width| == 4 to
+// do 2 rows at a time.
+void LoadDirection4(const uint16_t* const src, const ptrdiff_t stride,
+                    uint16x8_t* output, const int direction) {
+  const int y_0 = kCdefDirections[direction][0][0];
+  const int x_0 = kCdefDirections[direction][0][1];
+  const int y_1 = kCdefDirections[direction][1][0];
+  const int x_1 = kCdefDirections[direction][1][1];
+  output[0] = vcombine_u16(vld1_u16(src + y_0 * stride + x_0),
+                           vld1_u16(src + y_0 * stride + stride + x_0));
+  output[1] = vcombine_u16(vld1_u16(src - y_0 * stride - x_0),
+                           vld1_u16(src - y_0 * stride + stride - x_0));
+  output[2] = vcombine_u16(vld1_u16(src + y_1 * stride + x_1),
+                           vld1_u16(src + y_1 * stride + stride + x_1));
+  output[3] = vcombine_u16(vld1_u16(src - y_1 * stride - x_1),
+                           vld1_u16(src - y_1 * stride + stride - x_1));
+}
+
+int16x8_t Constrain(const uint16x8_t pixel, const uint16x8_t reference,
+                    const uint16x8_t threshold, const int16x8_t damping) {
+  // If reference > pixel, the difference will be negative, so covert to 0 or
+  // -1.
+  const uint16x8_t sign = vcgtq_u16(reference, pixel);
+  const uint16x8_t abs_diff = vabdq_u16(pixel, reference);
+  const uint16x8_t shifted_diff = vshlq_u16(abs_diff, damping);
+  // For bitdepth == 8, the threshold range is [0, 15] and the damping range is
+  // [3, 6]. If pixel == kCdefLargeValue(0x4000), shifted_diff will always be
+  // larger than threshold. Subtract using saturation will return 0 when pixel
+  // == kCdefLargeValue.
+  static_assert(kCdefLargeValue == 0x4000, "Invalid kCdefLargeValue");
+  const uint16x8_t thresh_minus_shifted_diff =
+      vqsubq_u16(threshold, shifted_diff);
+  const uint16x8_t clamp_abs_diff =
+      vminq_u16(thresh_minus_shifted_diff, abs_diff);
+  // Restore the sign.
+  return vreinterpretq_s16_u16(
+      vsubq_u16(veorq_u16(clamp_abs_diff, sign), sign));
+}
+
+template <int width, bool enable_primary = true, bool enable_secondary = true>
+void CdefFilter_NEON(const uint16_t* src, const ptrdiff_t src_stride,
+                     const int height, const int primary_strength,
+                     const int secondary_strength, const int damping,
+                     const int direction, void* dest,
+                     const ptrdiff_t dst_stride) {
+  static_assert(width == 8 || width == 4, "");
+  static_assert(enable_primary || enable_secondary, "");
+  constexpr bool clipping_required = enable_primary && enable_secondary;
+  auto* dst = static_cast<uint8_t*>(dest);
+  const uint16x8_t cdef_large_value_mask =
+      vdupq_n_u16(static_cast<uint16_t>(~kCdefLargeValue));
+  const uint16x8_t primary_threshold = vdupq_n_u16(primary_strength);
+  const uint16x8_t secondary_threshold = vdupq_n_u16(secondary_strength);
+
+  int16x8_t primary_damping_shift, secondary_damping_shift;
+
+  // FloorLog2() requires input to be > 0.
+  // 8-bit damping range: Y: [3, 6], UV: [2, 5].
+  if (enable_primary) {
+    // primary_strength: [0, 15] -> FloorLog2: [0, 3] so a clamp is necessary
+    // for UV filtering.
+    primary_damping_shift =
+        vdupq_n_s16(-std::max(0, damping - FloorLog2(primary_strength)));
+  }
+  if (enable_secondary) {
+    // secondary_strength: [0, 4] -> FloorLog2: [0, 2] so no clamp to 0 is
+    // necessary.
+    assert(damping - FloorLog2(secondary_strength) >= 0);
+    secondary_damping_shift =
+        vdupq_n_s16(-(damping - FloorLog2(secondary_strength)));
+  }
+
+  const int primary_tap_0 = kCdefPrimaryTaps[primary_strength & 1][0];
+  const int primary_tap_1 = kCdefPrimaryTaps[primary_strength & 1][1];
+
+  int y = height;
+  do {
+    uint16x8_t pixel;
+    if (width == 8) {
+      pixel = vld1q_u16(src);
+    } else {
+      pixel = vcombine_u16(vld1_u16(src), vld1_u16(src + src_stride));
+    }
+
+    uint16x8_t min = pixel;
+    uint16x8_t max = pixel;
+    int16x8_t sum;
+
+    if (enable_primary) {
+      // Primary |direction|.
+      uint16x8_t primary_val[4];
+      if (width == 8) {
+        LoadDirection(src, src_stride, primary_val, direction);
+      } else {
+        LoadDirection4(src, src_stride, primary_val, direction);
+      }
+
+      if (clipping_required) {
+        min = vminq_u16(min, primary_val[0]);
+        min = vminq_u16(min, primary_val[1]);
+        min = vminq_u16(min, primary_val[2]);
+        min = vminq_u16(min, primary_val[3]);
+
+        // The source is 16 bits, however, we only really care about the lower
+        // 8 bits.  The upper 8 bits contain the "large" flag.  After the final
+        // primary max has been calculated, zero out the upper 8 bits.  Use this
+        // to find the "16 bit" max.
+        const uint8x16_t max_p01 =
+            vmaxq_u8(vreinterpretq_u8_u16(primary_val[0]),
+                     vreinterpretq_u8_u16(primary_val[1]));
+        const uint8x16_t max_p23 =
+            vmaxq_u8(vreinterpretq_u8_u16(primary_val[2]),
+                     vreinterpretq_u8_u16(primary_val[3]));
+        const uint16x8_t max_p =
+            vreinterpretq_u16_u8(vmaxq_u8(max_p01, max_p23));
+        max = vmaxq_u16(max, vandq_u16(max_p, cdef_large_value_mask));
+      }
+
+      sum = Constrain(primary_val[0], pixel, primary_threshold,
+                      primary_damping_shift);
+      sum = vmulq_n_s16(sum, primary_tap_0);
+      sum = vmlaq_n_s16(sum,
+                        Constrain(primary_val[1], pixel, primary_threshold,
+                                  primary_damping_shift),
+                        primary_tap_0);
+      sum = vmlaq_n_s16(sum,
+                        Constrain(primary_val[2], pixel, primary_threshold,
+                                  primary_damping_shift),
+                        primary_tap_1);
+      sum = vmlaq_n_s16(sum,
+                        Constrain(primary_val[3], pixel, primary_threshold,
+                                  primary_damping_shift),
+                        primary_tap_1);
+    } else {
+      sum = vdupq_n_s16(0);
+    }
+
+    if (enable_secondary) {
+      // Secondary |direction| values (+/- 2). Clamp |direction|.
+      uint16x8_t secondary_val[8];
+      if (width == 8) {
+        LoadDirection(src, src_stride, secondary_val, direction + 2);
+        LoadDirection(src, src_stride, secondary_val + 4, direction - 2);
+      } else {
+        LoadDirection4(src, src_stride, secondary_val, direction + 2);
+        LoadDirection4(src, src_stride, secondary_val + 4, direction - 2);
+      }
+
+      if (clipping_required) {
+        min = vminq_u16(min, secondary_val[0]);
+        min = vminq_u16(min, secondary_val[1]);
+        min = vminq_u16(min, secondary_val[2]);
+        min = vminq_u16(min, secondary_val[3]);
+        min = vminq_u16(min, secondary_val[4]);
+        min = vminq_u16(min, secondary_val[5]);
+        min = vminq_u16(min, secondary_val[6]);
+        min = vminq_u16(min, secondary_val[7]);
+
+        const uint8x16_t max_s01 =
+            vmaxq_u8(vreinterpretq_u8_u16(secondary_val[0]),
+                     vreinterpretq_u8_u16(secondary_val[1]));
+        const uint8x16_t max_s23 =
+            vmaxq_u8(vreinterpretq_u8_u16(secondary_val[2]),
+                     vreinterpretq_u8_u16(secondary_val[3]));
+        const uint8x16_t max_s45 =
+            vmaxq_u8(vreinterpretq_u8_u16(secondary_val[4]),
+                     vreinterpretq_u8_u16(secondary_val[5]));
+        const uint8x16_t max_s67 =
+            vmaxq_u8(vreinterpretq_u8_u16(secondary_val[6]),
+                     vreinterpretq_u8_u16(secondary_val[7]));
+        const uint16x8_t max_s = vreinterpretq_u16_u8(
+            vmaxq_u8(vmaxq_u8(max_s01, max_s23), vmaxq_u8(max_s45, max_s67)));
+        max = vmaxq_u16(max, vandq_u16(max_s, cdef_large_value_mask));
+      }
+
+      sum = vmlaq_n_s16(sum,
+                        Constrain(secondary_val[0], pixel, secondary_threshold,
+                                  secondary_damping_shift),
+                        kCdefSecondaryTap0);
+      sum = vmlaq_n_s16(sum,
+                        Constrain(secondary_val[1], pixel, secondary_threshold,
+                                  secondary_damping_shift),
+                        kCdefSecondaryTap0);
+      sum = vmlaq_n_s16(sum,
+                        Constrain(secondary_val[2], pixel, secondary_threshold,
+                                  secondary_damping_shift),
+                        kCdefSecondaryTap1);
+      sum = vmlaq_n_s16(sum,
+                        Constrain(secondary_val[3], pixel, secondary_threshold,
+                                  secondary_damping_shift),
+                        kCdefSecondaryTap1);
+      sum = vmlaq_n_s16(sum,
+                        Constrain(secondary_val[4], pixel, secondary_threshold,
+                                  secondary_damping_shift),
+                        kCdefSecondaryTap0);
+      sum = vmlaq_n_s16(sum,
+                        Constrain(secondary_val[5], pixel, secondary_threshold,
+                                  secondary_damping_shift),
+                        kCdefSecondaryTap0);
+      sum = vmlaq_n_s16(sum,
+                        Constrain(secondary_val[6], pixel, secondary_threshold,
+                                  secondary_damping_shift),
+                        kCdefSecondaryTap1);
+      sum = vmlaq_n_s16(sum,
+                        Constrain(secondary_val[7], pixel, secondary_threshold,
+                                  secondary_damping_shift),
+                        kCdefSecondaryTap1);
+    }
+    // Clip3(pixel + ((8 + sum - (sum < 0)) >> 4), min, max))
+    const int16x8_t sum_lt_0 = vshrq_n_s16(sum, 15);
+    sum = vaddq_s16(sum, sum_lt_0);
+    int16x8_t result = vrsraq_n_s16(vreinterpretq_s16_u16(pixel), sum, 4);
+    if (clipping_required) {
+      result = vminq_s16(result, vreinterpretq_s16_u16(max));
+      result = vmaxq_s16(result, vreinterpretq_s16_u16(min));
+    }
+
+    const uint8x8_t dst_pixel = vqmovun_s16(result);
+    if (width == 8) {
+      src += src_stride;
+      vst1_u8(dst, dst_pixel);
+      dst += dst_stride;
+      --y;
+    } else {
+      src += src_stride << 1;
+      StoreLo4(dst, dst_pixel);
+      dst += dst_stride;
+      StoreHi4(dst, dst_pixel);
+      dst += dst_stride;
+      y -= 2;
+    }
+  } while (y != 0);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+  dsp->cdef_direction = CdefDirection_NEON;
+  dsp->cdef_filters[0][0] = CdefFilter_NEON<4>;
+  dsp->cdef_filters[0][1] =
+      CdefFilter_NEON<4, /*enable_primary=*/true, /*enable_secondary=*/false>;
+  dsp->cdef_filters[0][2] = CdefFilter_NEON<4, /*enable_primary=*/false>;
+  dsp->cdef_filters[1][0] = CdefFilter_NEON<8>;
+  dsp->cdef_filters[1][1] =
+      CdefFilter_NEON<8, /*enable_primary=*/true, /*enable_secondary=*/false>;
+  dsp->cdef_filters[1][2] = CdefFilter_NEON<8, /*enable_primary=*/false>;
+}
+
+}  // namespace
+}  // namespace low_bitdepth
+
+void CdefInit_NEON() { low_bitdepth::Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+#else  // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void CdefInit_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/cdef_neon.h b/libgav1/src/dsp/arm/cdef_neon.h
new file mode 100644
index 0000000..53d5f86
--- /dev/null
+++ b/libgav1/src/dsp/arm/cdef_neon.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_CDEF_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_CDEF_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::cdef_direction and Dsp::cdef_filters. This function is not
+// thread-safe.
+void CdefInit_NEON();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_CdefDirection LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_CdefFilters LIBGAV1_CPU_NEON
+#endif  // LIBGAV1_ENABLE_NEON
+
+#endif  // LIBGAV1_SRC_DSP_ARM_CDEF_NEON_H_
diff --git a/libgav1/src/dsp/arm/common_neon.h b/libgav1/src/dsp/arm/common_neon.h
index e0667f9..e8367ab 100644
--- a/libgav1/src/dsp/arm/common_neon.h
+++ b/libgav1/src/dsp/arm/common_neon.h
@@ -17,7 +17,7 @@
 #ifndef LIBGAV1_SRC_DSP_ARM_COMMON_NEON_H_
 #define LIBGAV1_SRC_DSP_ARM_COMMON_NEON_H_
 
-#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -29,6 +29,8 @@
 #if 0
 #include <cstdio>
 
+#include "absl/strings/str_cat.h"
+
 constexpr bool kEnablePrintRegs = true;
 
 union DebugRegister {
@@ -82,6 +84,16 @@
   }
 }
 
+inline void PrintReg(const int32x4x2_t val, const std::string& name) {
+  DebugRegisterQ r;
+  vst1q_u32(r.u32, val.val[0]);
+  const std::string name0 = absl::StrCat(name, ".val[0]").c_str();
+  PrintVectQ(r, name0.c_str(), 32);
+  vst1q_u32(r.u32, val.val[1]);
+  const std::string name1 = absl::StrCat(name, ".val[1]").c_str();
+  PrintVectQ(r, name1.c_str(), 32);
+}
+
 inline void PrintReg(const uint32x4_t val, const char* name) {
   DebugRegisterQ r;
   vst1q_u32(r.u32, val);
@@ -180,49 +192,89 @@
 //------------------------------------------------------------------------------
 // Load functions.
 
-// Load 4 uint8_t values into the low half of a uint8x8_t register.
-inline uint8x8_t LoadLo4(const uint8_t* const buf, uint8x8_t val) {
-  uint32_t temp;
-  memcpy(&temp, buf, 4);
-  return vreinterpret_u8_u32(vld1_lane_u32(&temp, vreinterpret_u32_u8(val), 0));
+// Load 2 uint8_t values into lanes 0 and 1. Zeros the register before loading
+// the values. Use caution when using this in loops because it will re-zero the
+// register before loading on every iteration.
+inline uint8x8_t Load2(const void* const buf) {
+  const uint16x4_t zero = vdup_n_u16(0);
+  uint16_t temp;
+  memcpy(&temp, buf, 2);
+  return vreinterpret_u8_u16(vld1_lane_u16(&temp, zero, 0));
 }
 
-// Load 4 uint8_t values into the high half of a uint8x8_t register.
-inline uint8x8_t LoadHi4(const uint8_t* const buf, uint8x8_t val) {
+// Load 2 uint8_t values into |lane| * 2 and |lane| * 2 + 1.
+template <int lane>
+inline uint8x8_t Load2(const void* const buf, uint8x8_t val) {
+  uint16_t temp;
+  memcpy(&temp, buf, 2);
+  return vreinterpret_u8_u16(
+      vld1_lane_u16(&temp, vreinterpret_u16_u8(val), lane));
+}
+
+// Load 4 uint8_t values into the low half of a uint8x8_t register. Zeros the
+// register before loading the values. Use caution when using this in loops
+// because it will re-zero the register before loading on every iteration.
+inline uint8x8_t Load4(const void* const buf) {
+  const uint32x2_t zero = vdup_n_u32(0);
   uint32_t temp;
   memcpy(&temp, buf, 4);
-  return vreinterpret_u8_u32(vld1_lane_u32(&temp, vreinterpret_u32_u8(val), 1));
+  return vreinterpret_u8_u32(vld1_lane_u32(&temp, zero, 0));
+}
+
+// Load 4 uint8_t values into 4 lanes staring with |lane| * 4.
+template <int lane>
+inline uint8x8_t Load4(const void* const buf, uint8x8_t val) {
+  uint32_t temp;
+  memcpy(&temp, buf, 4);
+  return vreinterpret_u8_u32(
+      vld1_lane_u32(&temp, vreinterpret_u32_u8(val), lane));
 }
 
 //------------------------------------------------------------------------------
 // Store functions.
 
 // Propagate type information to the compiler. Without this the compiler may
-// assume the required alignment of uint32_t (4 bytes) and add alignment hints
-// to the memory access.
-inline void Uint32ToMem(uint8_t* const buf, uint32_t val) {
-  memcpy(buf, &val, 4);
+// assume the required alignment of the type (4 bytes in the case of uint32_t)
+// and add alignment hints to the memory access.
+template <typename T>
+inline void ValueToMem(void* const buf, T val) {
+  memcpy(buf, &val, sizeof(val));
 }
 
-inline void Uint32ToMem(uint16_t* const buf, uint32_t val) {
-  memcpy(buf, &val, 4);
+// Store 4 int8_t values from the low half of an int8x8_t register.
+inline void StoreLo4(void* const buf, const int8x8_t val) {
+  ValueToMem<int32_t>(buf, vget_lane_s32(vreinterpret_s32_s8(val), 0));
 }
 
 // Store 4 uint8_t values from the low half of a uint8x8_t register.
-inline void StoreLo4(uint8_t* const buf, const uint8x8_t val) {
-  Uint32ToMem(buf, vget_lane_u32(vreinterpret_u32_u8(val), 0));
+inline void StoreLo4(void* const buf, const uint8x8_t val) {
+  ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u8(val), 0));
 }
 
 // Store 4 uint8_t values from the high half of a uint8x8_t register.
-inline void StoreHi4(uint8_t* const buf, const uint8x8_t val) {
-  Uint32ToMem(buf, vget_lane_u32(vreinterpret_u32_u8(val), 1));
+inline void StoreHi4(void* const buf, const uint8x8_t val) {
+  ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u8(val), 1));
+}
+
+// Store 2 uint8_t values from |lane| * 2 and |lane| * 2 + 1 of a uint8x8_t
+// register.
+template <int lane>
+inline void Store2(void* const buf, const uint8x8_t val) {
+  ValueToMem<uint16_t>(buf, vget_lane_u16(vreinterpret_u16_u8(val), lane));
 }
 
 // Store 2 uint16_t values from |lane| * 2 and |lane| * 2 + 1 of a uint16x8_t
 // register.
 template <int lane>
-inline void Store2(uint16_t* const buf, const uint16x8_t val) {
-  Uint32ToMem(buf, vgetq_lane_u32(vreinterpretq_u32_u16(val), lane));
+inline void Store2(void* const buf, const uint16x8_t val) {
+  ValueToMem<uint32_t>(buf, vgetq_lane_u32(vreinterpretq_u32_u16(val), lane));
+}
+
+// Store 2 uint16_t values from |lane| * 2 and |lane| * 2 + 1 of a uint16x4_t
+// register.
+template <int lane>
+inline void Store2(uint16_t* const buf, const uint16x4_t val) {
+  ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u16(val), lane));
 }
 
 //------------------------------------------------------------------------------
@@ -230,6 +282,11 @@
 
 // vshXX_n_XX() requires an immediate.
 template <int shift>
+inline uint8x8_t LeftShift(const uint8x8_t vector) {
+  return vreinterpret_u8_u64(vshl_n_u64(vreinterpret_u64_u8(vector), shift));
+}
+
+template <int shift>
 inline uint8x8_t RightShift(const uint8x8_t vector) {
   return vreinterpret_u8_u64(vshr_n_u64(vreinterpret_u64_u8(vector), shift));
 }
@@ -249,6 +306,16 @@
 #endif
 }
 
+// Shim vqtbl1_s8 for armv7.
+inline int8x8_t VQTbl1S8(const int8x16_t a, const uint8x8_t index) {
+#if defined(__aarch64__)
+  return vqtbl1_s8(a, index);
+#else
+  const int8x8x2_t b = {vget_low_s8(a), vget_high_s8(a)};
+  return vtbl2_s8(b, vreinterpret_s8_u8(index));
+#endif
+}
+
 //------------------------------------------------------------------------------
 // Interleave.
 
@@ -307,6 +374,30 @@
 }
 
 //------------------------------------------------------------------------------
+// Sum.
+
+inline uint16_t SumVector(const uint8x8_t a) {
+#if defined(__aarch64__)
+  return vaddlv_u8(a);
+#else
+  const uint16x4_t c = vpaddl_u8(a);
+  const uint32x2_t d = vpaddl_u16(c);
+  const uint64x1_t e = vpaddl_u32(d);
+  return static_cast<uint16_t>(vget_lane_u64(e, 0));
+#endif  // defined(__aarch64__)
+}
+
+inline uint32_t SumVector(const uint32x4_t a) {
+#if defined(__aarch64__)
+  return vaddvq_u32(a);
+#else
+  const uint64x2_t b = vpaddlq_u32(a);
+  const uint64x1_t c = vadd_u64(vget_low_u64(b), vget_high_u64(b));
+  return static_cast<uint32_t>(vget_lane_u64(c, 0));
+#endif
+}
+
+//------------------------------------------------------------------------------
 // Transpose.
 
 // Transpose 32 bit elements such that:
@@ -497,76 +588,24 @@
 }
 
 // Input:
-// a0: 00 01 02 03 04 05 06 07
-// a1: 10 11 12 13 14 15 16 17
-// a2: 20 21 22 23 24 25 26 27
-// a3: 30 31 32 33 34 35 36 37
-// a4: 40 41 42 43 44 45 46 47
-// a5: 50 51 52 53 54 55 56 57
-// a6: 60 61 62 63 64 65 66 67
-// a7: 70 71 72 73 74 75 76 77
+// a[0]: 00 01 02 03 04 05 06 07
+// a[1]: 10 11 12 13 14 15 16 17
+// a[2]: 20 21 22 23 24 25 26 27
+// a[3]: 30 31 32 33 34 35 36 37
+// a[4]: 40 41 42 43 44 45 46 47
+// a[5]: 50 51 52 53 54 55 56 57
+// a[6]: 60 61 62 63 64 65 66 67
+// a[7]: 70 71 72 73 74 75 76 77
 
 // Output:
-// a0: 00 10 20 30 40 50 60 70
-// a1: 01 11 21 31 41 51 61 71
-// a2: 02 12 22 32 42 52 62 72
-// a3: 03 13 23 33 43 53 63 73
-// a4: 04 14 24 34 44 54 64 74
-// a5: 05 15 25 35 45 55 65 75
-// a6: 06 16 26 36 46 56 66 76
-// a7: 07 17 27 37 47 57 67 77
-inline void Transpose8x8(int16x8_t* a0, int16x8_t* a1, int16x8_t* a2,
-                         int16x8_t* a3, int16x8_t* a4, int16x8_t* a5,
-                         int16x8_t* a6, int16x8_t* a7) {
-  const int16x8x2_t b0 = vtrnq_s16(*a0, *a1);
-  const int16x8x2_t b1 = vtrnq_s16(*a2, *a3);
-  const int16x8x2_t b2 = vtrnq_s16(*a4, *a5);
-  const int16x8x2_t b3 = vtrnq_s16(*a6, *a7);
-
-  const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]),
-                                   vreinterpretq_s32_s16(b1.val[0]));
-  const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]),
-                                   vreinterpretq_s32_s16(b1.val[1]));
-  const int32x4x2_t c2 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[0]),
-                                   vreinterpretq_s32_s16(b3.val[0]));
-  const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]),
-                                   vreinterpretq_s32_s16(b3.val[1]));
-
-  const int16x8x2_t d0 = VtrnqS64(c0.val[0], c2.val[0]);
-  const int16x8x2_t d1 = VtrnqS64(c1.val[0], c3.val[0]);
-  const int16x8x2_t d2 = VtrnqS64(c0.val[1], c2.val[1]);
-  const int16x8x2_t d3 = VtrnqS64(c1.val[1], c3.val[1]);
-
-  *a0 = d0.val[0];
-  *a1 = d1.val[0];
-  *a2 = d2.val[0];
-  *a3 = d3.val[0];
-  *a4 = d0.val[1];
-  *a5 = d1.val[1];
-  *a6 = d2.val[1];
-  *a7 = d3.val[1];
-}
-
-// Input:
-// a0: 00 01 02 03 04 05 06 07
-// a1: 10 11 12 13 14 15 16 17
-// a2: 20 21 22 23 24 25 26 27
-// a3: 30 31 32 33 34 35 36 37
-// a4: 40 41 42 43 44 45 46 47
-// a5: 50 51 52 53 54 55 56 57
-// a6: 60 61 62 63 64 65 66 67
-// a7: 70 71 72 73 74 75 76 77
-
-// Output:
-// a0: 00 10 20 30 40 50 60 70
-// a1: 01 11 21 31 41 51 61 71
-// a2: 02 12 22 32 42 52 62 72
-// a3: 03 13 23 33 43 53 63 73
-// a4: 04 14 24 34 44 54 64 74
-// a5: 05 15 25 35 45 55 65 75
-// a6: 06 16 26 36 46 56 66 76
-// a7: 07 17 27 37 47 57 67 77
-// TODO(johannkoenig): Switch users of the above transpose to this one.
+// a[0]: 00 10 20 30 40 50 60 70
+// a[1]: 01 11 21 31 41 51 61 71
+// a[2]: 02 12 22 32 42 52 62 72
+// a[3]: 03 13 23 33 43 53 63 73
+// a[4]: 04 14 24 34 44 54 64 74
+// a[5]: 05 15 25 35 45 55 65 75
+// a[6]: 06 16 26 36 46 56 66 76
+// a[7]: 07 17 27 37 47 57 67 77
 inline void Transpose8x8(int16x8_t a[8]) {
   const int16x8x2_t b0 = vtrnq_s16(a[0], a[1]);
   const int16x8x2_t b1 = vtrnq_s16(a[2], a[3]);
@@ -628,125 +667,8 @@
   a[7] = d3.val[1];
 }
 
-// Input:
-// i0: 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f
-// i1: 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f
-// i2: 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f
-// i3: 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f
-// i4: 40 41 42 43 44 45 46 47 48 49 4a 4b 4c 4d 4e 4f
-// i5: 50 51 52 53 54 55 56 57 58 59 5a 5b 5c 5d 5e 5f
-// i6: 60 61 62 63 64 65 66 67 68 69 6a 6b 6c 6d 6e 6f
-// i7: 70 71 72 73 74 75 76 77 78 79 7a 7b 7c 7d 7e 7f
-
-// Output:
-// o00: 00 10 20 30 40 50 60 70
-// o01: 01 11 21 31 41 51 61 71
-// o02: 02 12 22 32 42 52 62 72
-// o03: 03 13 23 33 43 53 63 73
-// o04: 04 14 24 34 44 54 64 74
-// o05: 05 15 25 35 45 55 65 75
-// o06: 06 16 26 36 46 56 66 76
-// o07: 07 17 27 37 47 57 67 77
-// o08: 08 18 28 38 48 58 68 78
-// o09: 09 19 29 39 49 59 69 79
-// o0a: 0a 1a 2a 3a 4a 5a 6a 7a
-// o0b: 0b 1b 2b 3b 4b 5b 6b 7b
-// o0c: 0c 1c 2c 3c 4c 5c 6c 7c
-// o0d: 0d 1d 2d 3d 4d 5d 6d 7d
-// o0e: 0e 1e 2e 3e 4e 5e 6e 7e
-// o0f: 0f 1f 2f 3f 4f 5f 6f 7f
-inline void Transpose16x8(const uint8x16_t i0, const uint8x16_t i1,
-                          const uint8x16_t i2, const uint8x16_t i3,
-                          const uint8x16_t i4, const uint8x16_t i5,
-                          const uint8x16_t i6, const uint8x16_t i7,
-                          uint8x8_t* o00, uint8x8_t* o01, uint8x8_t* o02,
-                          uint8x8_t* o03, uint8x8_t* o04, uint8x8_t* o05,
-                          uint8x8_t* o06, uint8x8_t* o07, uint8x8_t* o08,
-                          uint8x8_t* o09, uint8x8_t* o10, uint8x8_t* o11,
-                          uint8x8_t* o12, uint8x8_t* o13, uint8x8_t* o14,
-                          uint8x8_t* o15) {
-  const uint8x16x2_t b0 = vtrnq_u8(i0, i1);
-  const uint8x16x2_t b1 = vtrnq_u8(i2, i3);
-  const uint8x16x2_t b2 = vtrnq_u8(i4, i5);
-  const uint8x16x2_t b3 = vtrnq_u8(i6, i7);
-
-  const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]),
-                                    vreinterpretq_u16_u8(b1.val[0]));
-  const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]),
-                                    vreinterpretq_u16_u8(b1.val[1]));
-  const uint16x8x2_t c2 = vtrnq_u16(vreinterpretq_u16_u8(b2.val[0]),
-                                    vreinterpretq_u16_u8(b3.val[0]));
-  const uint16x8x2_t c3 = vtrnq_u16(vreinterpretq_u16_u8(b2.val[1]),
-                                    vreinterpretq_u16_u8(b3.val[1]));
-
-  const uint32x4x2_t d0 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[0]),
-                                    vreinterpretq_u32_u16(c2.val[0]));
-  const uint32x4x2_t d1 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[1]),
-                                    vreinterpretq_u32_u16(c2.val[1]));
-  const uint32x4x2_t d2 = vtrnq_u32(vreinterpretq_u32_u16(c1.val[0]),
-                                    vreinterpretq_u32_u16(c3.val[0]));
-  const uint32x4x2_t d3 = vtrnq_u32(vreinterpretq_u32_u16(c1.val[1]),
-                                    vreinterpretq_u32_u16(c3.val[1]));
-
-  *o00 = vget_low_u8(vreinterpretq_u8_u32(d0.val[0]));
-  *o01 = vget_low_u8(vreinterpretq_u8_u32(d2.val[0]));
-  *o02 = vget_low_u8(vreinterpretq_u8_u32(d1.val[0]));
-  *o03 = vget_low_u8(vreinterpretq_u8_u32(d3.val[0]));
-  *o04 = vget_low_u8(vreinterpretq_u8_u32(d0.val[1]));
-  *o05 = vget_low_u8(vreinterpretq_u8_u32(d2.val[1]));
-  *o06 = vget_low_u8(vreinterpretq_u8_u32(d1.val[1]));
-  *o07 = vget_low_u8(vreinterpretq_u8_u32(d3.val[1]));
-  *o08 = vget_high_u8(vreinterpretq_u8_u32(d0.val[0]));
-  *o09 = vget_high_u8(vreinterpretq_u8_u32(d2.val[0]));
-  *o10 = vget_high_u8(vreinterpretq_u8_u32(d1.val[0]));
-  *o11 = vget_high_u8(vreinterpretq_u8_u32(d3.val[0]));
-  *o12 = vget_high_u8(vreinterpretq_u8_u32(d0.val[1]));
-  *o13 = vget_high_u8(vreinterpretq_u8_u32(d2.val[1]));
-  *o14 = vget_high_u8(vreinterpretq_u8_u32(d1.val[1]));
-  *o15 = vget_high_u8(vreinterpretq_u8_u32(d3.val[1]));
-}
-
-// TODO(johannkoenig): Replace usage of the above transpose with this one.
-inline void Transpose16x8(const uint8x16_t input[8], uint8x8_t output[16]) {
-  const uint8x16x2_t b0 = vtrnq_u8(input[0], input[1]);
-  const uint8x16x2_t b1 = vtrnq_u8(input[2], input[3]);
-  const uint8x16x2_t b2 = vtrnq_u8(input[4], input[5]);
-  const uint8x16x2_t b3 = vtrnq_u8(input[6], input[7]);
-
-  const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]),
-                                    vreinterpretq_u16_u8(b1.val[0]));
-  const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]),
-                                    vreinterpretq_u16_u8(b1.val[1]));
-  const uint16x8x2_t c2 = vtrnq_u16(vreinterpretq_u16_u8(b2.val[0]),
-                                    vreinterpretq_u16_u8(b3.val[0]));
-  const uint16x8x2_t c3 = vtrnq_u16(vreinterpretq_u16_u8(b2.val[1]),
-                                    vreinterpretq_u16_u8(b3.val[1]));
-
-  const uint32x4x2_t d0 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[0]),
-                                    vreinterpretq_u32_u16(c2.val[0]));
-  const uint32x4x2_t d1 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[1]),
-                                    vreinterpretq_u32_u16(c2.val[1]));
-  const uint32x4x2_t d2 = vtrnq_u32(vreinterpretq_u32_u16(c1.val[0]),
-                                    vreinterpretq_u32_u16(c3.val[0]));
-  const uint32x4x2_t d3 = vtrnq_u32(vreinterpretq_u32_u16(c1.val[1]),
-                                    vreinterpretq_u32_u16(c3.val[1]));
-
-  output[0] = vget_low_u8(vreinterpretq_u8_u32(d0.val[0]));
-  output[1] = vget_low_u8(vreinterpretq_u8_u32(d2.val[0]));
-  output[2] = vget_low_u8(vreinterpretq_u8_u32(d1.val[0]));
-  output[3] = vget_low_u8(vreinterpretq_u8_u32(d3.val[0]));
-  output[4] = vget_low_u8(vreinterpretq_u8_u32(d0.val[1]));
-  output[5] = vget_low_u8(vreinterpretq_u8_u32(d2.val[1]));
-  output[6] = vget_low_u8(vreinterpretq_u8_u32(d1.val[1]));
-  output[7] = vget_low_u8(vreinterpretq_u8_u32(d3.val[1]));
-  output[8] = vget_high_u8(vreinterpretq_u8_u32(d0.val[0]));
-  output[9] = vget_high_u8(vreinterpretq_u8_u32(d2.val[0]));
-  output[10] = vget_high_u8(vreinterpretq_u8_u32(d1.val[0]));
-  output[11] = vget_high_u8(vreinterpretq_u8_u32(d3.val[0]));
-  output[12] = vget_high_u8(vreinterpretq_u8_u32(d0.val[1]));
-  output[13] = vget_high_u8(vreinterpretq_u8_u32(d2.val[1]));
-  output[14] = vget_high_u8(vreinterpretq_u8_u32(d1.val[1]));
-  output[15] = vget_high_u8(vreinterpretq_u8_u32(d3.val[1]));
+inline int16x8_t ZeroExtend(const uint8x8_t in) {
+  return vreinterpretq_s16_u16(vmovl_u8(in));
 }
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/arm/convolve_neon.cc b/libgav1/src/dsp/arm/convolve_neon.cc
index 5f7eef7..2c2557f 100644
--- a/libgav1/src/dsp/arm/convolve_neon.cc
+++ b/libgav1/src/dsp/arm/convolve_neon.cc
@@ -13,7 +13,7 @@
 // limitations under the License.
 
 #include "src/dsp/convolve.h"
-#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -25,325 +25,231 @@
 #include <cstdint>
 
 #include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
 
 namespace libgav1 {
 namespace dsp {
 namespace low_bitdepth {
 namespace {
 
-constexpr int kBitdepth8 = 8;
 constexpr int kIntermediateStride = kMaxSuperBlockSizeInPixels;
-constexpr int kSubPixelMask = (1 << kSubPixelBits) - 1;
 constexpr int kHorizontalOffset = 3;
-constexpr int kVerticalOffset = 3;
-constexpr int kInterRoundBitsVertical = 11;
+constexpr int kFilterIndexShift = 6;
 
-int GetFilterIndex(const int filter_index, const int length) {
-  if (length <= 4) {
-    if (filter_index == kInterpolationFilterEightTap ||
-        filter_index == kInterpolationFilterEightTapSharp) {
-      return 4;
-    }
-    if (filter_index == kInterpolationFilterEightTapSmooth) {
-      return 5;
-    }
+// Multiply every entry in |src[]| by the corresponding entry in |taps[]| and
+// sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final
+// sum from outranging int16_t.
+template <int filter_index, bool negative_outside_taps = false>
+int16x8_t SumOnePassTaps(const uint8x8_t* const src,
+                         const uint8x8_t* const taps) {
+  uint16x8_t sum;
+  if (filter_index == 0) {
+    // 6 taps. + - + + - +
+    sum = vmull_u8(src[0], taps[0]);
+    // Unsigned overflow will result in a valid int16_t value.
+    sum = vmlsl_u8(sum, src[1], taps[1]);
+    sum = vmlal_u8(sum, src[2], taps[2]);
+    sum = vmlal_u8(sum, src[3], taps[3]);
+    sum = vmlsl_u8(sum, src[4], taps[4]);
+    sum = vmlal_u8(sum, src[5], taps[5]);
+  } else if (filter_index == 1 && negative_outside_taps) {
+    // 6 taps. - + + + + -
+    // Set a base we can subtract from.
+    sum = vmull_u8(src[1], taps[1]);
+    sum = vmlsl_u8(sum, src[0], taps[0]);
+    sum = vmlal_u8(sum, src[2], taps[2]);
+    sum = vmlal_u8(sum, src[3], taps[3]);
+    sum = vmlal_u8(sum, src[4], taps[4]);
+    sum = vmlsl_u8(sum, src[5], taps[5]);
+  } else if (filter_index == 1) {
+    // 6 taps. All are positive.
+    sum = vmull_u8(src[0], taps[0]);
+    sum = vmlal_u8(sum, src[1], taps[1]);
+    sum = vmlal_u8(sum, src[2], taps[2]);
+    sum = vmlal_u8(sum, src[3], taps[3]);
+    sum = vmlal_u8(sum, src[4], taps[4]);
+    sum = vmlal_u8(sum, src[5], taps[5]);
+  } else if (filter_index == 2) {
+    // 8 taps. - + - + + - + -
+    sum = vmull_u8(src[1], taps[1]);
+    sum = vmlsl_u8(sum, src[0], taps[0]);
+    sum = vmlsl_u8(sum, src[2], taps[2]);
+    sum = vmlal_u8(sum, src[3], taps[3]);
+    sum = vmlal_u8(sum, src[4], taps[4]);
+    sum = vmlsl_u8(sum, src[5], taps[5]);
+    sum = vmlal_u8(sum, src[6], taps[6]);
+    sum = vmlsl_u8(sum, src[7], taps[7]);
+  } else if (filter_index == 3) {
+    // 2 taps. All are positive.
+    sum = vmull_u8(src[0], taps[0]);
+    sum = vmlal_u8(sum, src[1], taps[1]);
+  } else if (filter_index == 4) {
+    // 4 taps. - + + -
+    sum = vmull_u8(src[1], taps[1]);
+    sum = vmlsl_u8(sum, src[0], taps[0]);
+    sum = vmlal_u8(sum, src[2], taps[2]);
+    sum = vmlsl_u8(sum, src[3], taps[3]);
+  } else if (filter_index == 5) {
+    // 4 taps. All are positive.
+    sum = vmull_u8(src[0], taps[0]);
+    sum = vmlal_u8(sum, src[1], taps[1]);
+    sum = vmlal_u8(sum, src[2], taps[2]);
+    sum = vmlal_u8(sum, src[3], taps[3]);
   }
-  return filter_index;
+  return vreinterpretq_s16_u16(sum);
 }
 
-inline int16x8_t ZeroExtend(const uint8x8_t in) {
-  return vreinterpretq_s16_u16(vmovl_u8(in));
-}
-
-inline void Load8x8(const uint8_t* s, const ptrdiff_t p, int16x8_t* dst) {
-  dst[0] = ZeroExtend(vld1_u8(s));
-  s += p;
-  dst[1] = ZeroExtend(vld1_u8(s));
-  s += p;
-  dst[2] = ZeroExtend(vld1_u8(s));
-  s += p;
-  dst[3] = ZeroExtend(vld1_u8(s));
-  s += p;
-  dst[4] = ZeroExtend(vld1_u8(s));
-  s += p;
-  dst[5] = ZeroExtend(vld1_u8(s));
-  s += p;
-  dst[6] = ZeroExtend(vld1_u8(s));
-  s += p;
-  dst[7] = ZeroExtend(vld1_u8(s));
-}
-
-// Multiply every entry in |src[]| by the corresponding lane in |taps| and sum.
-// The sum of the entries in |taps| is always 128. In some situations negative
-// values are used. This creates a situation where the positive taps sum to more
-// than 128. An example is:
-// {-4, 10, -24, 100, 60, -20, 8, -2}
-// The negative taps never sum to < -128
-// The center taps are always positive. The remaining positive taps never sum
-// to > 128.
-// Summing these naively can overflow int16_t. This can be avoided by adding the
-// center taps last and saturating the result.
-// We do not need to expand to int32_t because later in the function the value
-// is shifted by |kFilterBits| (7) and saturated to uint8_t. This means any
-// value over 255 << 7 (32576 because of rounding) is clamped.
-template <int num_taps>
-int16x8_t SumTaps(const int16x8_t* const src, const int16x8_t taps) {
+template <int filter_index, bool negative_outside_taps>
+int16x8_t SumHorizontalTaps(const uint8_t* const src,
+                            const uint8x8_t* const v_tap) {
+  uint8x8_t v_src[8];
+  const uint8x16_t src_long = vld1q_u8(src);
   int16x8_t sum;
-  if (num_taps == 8) {
-    const int16x4_t taps_lo = vget_low_s16(taps);
-    const int16x4_t taps_hi = vget_high_s16(taps);
-    sum = vmulq_lane_s16(src[0], taps_lo, 0);
-    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 1);
-    sum = vmlaq_lane_s16(sum, src[2], taps_lo, 2);
-    sum = vmlaq_lane_s16(sum, src[5], taps_hi, 1);
-    sum = vmlaq_lane_s16(sum, src[6], taps_hi, 2);
-    sum = vmlaq_lane_s16(sum, src[7], taps_hi, 3);
 
-    // Center taps.
-    sum = vqaddq_s16(sum, vmulq_lane_s16(src[3], taps_lo, 3));
-    sum = vqaddq_s16(sum, vmulq_lane_s16(src[4], taps_hi, 0));
-  } else if (num_taps == 6) {
-    const int16x4_t taps_lo = vget_low_s16(taps);
-    const int16x4_t taps_hi = vget_high_s16(taps);
-    sum = vmulq_lane_s16(src[0], taps_lo, 1);
-    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 2);
-    sum = vmlaq_lane_s16(sum, src[4], taps_hi, 1);
-    sum = vmlaq_lane_s16(sum, src[5], taps_hi, 2);
-
-    // Center taps.
-    sum = vqaddq_s16(sum, vmulq_lane_s16(src[2], taps_lo, 3));
-    sum = vqaddq_s16(sum, vmulq_lane_s16(src[3], taps_hi, 0));
-  } else if (num_taps == 4) {
-    const int16x4_t taps_lo = vget_low_s16(taps);
-    sum = vmulq_lane_s16(src[0], taps_lo, 0);
-    sum = vmlaq_lane_s16(sum, src[3], taps_lo, 3);
-
-    // Center taps.
-    sum = vqaddq_s16(sum, vmulq_lane_s16(src[1], taps_lo, 1));
-    sum = vqaddq_s16(sum, vmulq_lane_s16(src[2], taps_lo, 2));
-  } else {
-    assert(num_taps == 2);
-    // All the taps are positive so there is no concern regarding saturation.
-    const int16x4_t taps_lo = vget_low_s16(taps);
-    sum = vmulq_lane_s16(src[0], taps_lo, 1);
-    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 2);
+  if (filter_index < 2) {
+    v_src[0] = vget_low_u8(vextq_u8(src_long, src_long, 1));
+    v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 2));
+    v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 3));
+    v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 4));
+    v_src[4] = vget_low_u8(vextq_u8(src_long, src_long, 5));
+    v_src[5] = vget_low_u8(vextq_u8(src_long, src_long, 6));
+    sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap + 1);
+  } else if (filter_index == 2) {
+    v_src[0] = vget_low_u8(src_long);
+    v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 1));
+    v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 2));
+    v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 3));
+    v_src[4] = vget_low_u8(vextq_u8(src_long, src_long, 4));
+    v_src[5] = vget_low_u8(vextq_u8(src_long, src_long, 5));
+    v_src[6] = vget_low_u8(vextq_u8(src_long, src_long, 6));
+    v_src[7] = vget_low_u8(vextq_u8(src_long, src_long, 7));
+    sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap);
+  } else if (filter_index == 3) {
+    v_src[0] = vget_low_u8(vextq_u8(src_long, src_long, 3));
+    v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 4));
+    sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap + 3);
+  } else if (filter_index > 3) {
+    v_src[0] = vget_low_u8(vextq_u8(src_long, src_long, 2));
+    v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 3));
+    v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 4));
+    v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 5));
+    sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src, v_tap + 2);
   }
-
   return sum;
 }
 
-// Add an offset to ensure the sum is positive and it fits within uint16_t.
-template <int num_taps>
-uint16x8_t SumTaps8To16(const int16x8_t* const src, const int16x8_t taps) {
-  // The worst case sum of negative taps is -56. The worst case sum of positive
-  // taps is 184. With the single pass versions of the Convolve we could safely
-  // saturate to int16_t because it outranged the final shift and narrow to
-  // uint8_t. For the 2D Convolve the intermediate values are 16 bits so we
-  // don't have that option.
-  // 184 * 255 = 46920 which is greater than int16_t can hold, but not uint16_t.
-  // The minimum value we need to handle is -56 * 255 = -14280.
-  // By offsetting the sum with 1 << 14 = 16384 we ensure that the sum is never
-  // negative and that 46920 + 16384 = 63304 fits comfortably in uint16_t. This
-  // allows us to use 16 bit registers instead of 32 bit registers.
-  // When considering the bit operations it is safe to ignore signedness. Due to
-  // the magic of 2's complement and well defined rollover rules the bit
-  // representations are equivalent.
-  const int16x4_t taps_lo = vget_low_s16(taps);
-  const int16x4_t taps_hi = vget_high_s16(taps);
-  // |offset| == 1 << (bitdepth + kFilterBits - 1);
-  int16x8_t sum = vdupq_n_s16(1 << 14);
-  if (num_taps == 8) {
-    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 0);
-    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 1);
-    sum = vmlaq_lane_s16(sum, src[2], taps_lo, 2);
-    sum = vmlaq_lane_s16(sum, src[3], taps_lo, 3);
-    sum = vmlaq_lane_s16(sum, src[4], taps_hi, 0);
-    sum = vmlaq_lane_s16(sum, src[5], taps_hi, 1);
-    sum = vmlaq_lane_s16(sum, src[6], taps_hi, 2);
-    sum = vmlaq_lane_s16(sum, src[7], taps_hi, 3);
-  } else if (num_taps == 6) {
-    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 1);
-    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 2);
-    sum = vmlaq_lane_s16(sum, src[2], taps_lo, 3);
-    sum = vmlaq_lane_s16(sum, src[3], taps_hi, 0);
-    sum = vmlaq_lane_s16(sum, src[4], taps_hi, 1);
-    sum = vmlaq_lane_s16(sum, src[5], taps_hi, 2);
-  } else if (num_taps == 4) {
-    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 2);
-    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 3);
-    sum = vmlaq_lane_s16(sum, src[2], taps_hi, 0);
-    sum = vmlaq_lane_s16(sum, src[3], taps_hi, 1);
-  } else if (num_taps == 2) {
-    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 3);
-    sum = vmlaq_lane_s16(sum, src[1], taps_hi, 0);
-  }
+template <int filter_index, bool negative_outside_taps>
+uint8x8_t SimpleHorizontalTaps(const uint8_t* const src,
+                               const uint8x8_t* const v_tap) {
+  int16x8_t sum =
+      SumHorizontalTaps<filter_index, negative_outside_taps>(src, v_tap);
 
-  // This is guaranteed to be positive. Convert it for the final shift.
-  return vreinterpretq_u16_s16(sum);
+  // Normally the Horizontal pass does the downshift in two passes:
+  // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
+  // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them
+  // requires adding the rounding offset from the skipped shift.
+  constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2);
+
+  sum = vaddq_s16(sum, vdupq_n_s16(first_shift_rounding_bit));
+  return vqrshrun_n_s16(sum, kFilterBits - 1);
 }
 
-template <int num_taps, int filter_index, bool negative_outside_taps = true>
-uint16x8_t SumCompoundHorizontalTaps(const uint8_t* const src,
-                                     const uint8x8_t* const v_tap) {
-  // Start with an offset to guarantee the sum is non negative.
-  uint16x8_t v_sum = vdupq_n_u16(1 << 14);
-  uint8x16_t v_src[8];
-  v_src[0] = vld1q_u8(&src[0]);
-  if (num_taps == 8) {
-    v_src[1] = vextq_u8(v_src[0], v_src[0], 1);
-    v_src[2] = vextq_u8(v_src[0], v_src[0], 2);
-    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
-    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
-    v_src[5] = vextq_u8(v_src[0], v_src[0], 5);
-    v_src[6] = vextq_u8(v_src[0], v_src[0], 6);
-    v_src[7] = vextq_u8(v_src[0], v_src[0], 7);
+template <int filter_index, bool negative_outside_taps>
+uint16x8_t HorizontalTaps8To16(const uint8_t* const src,
+                               const uint8x8_t* const v_tap) {
+  const int16x8_t sum =
+      SumHorizontalTaps<filter_index, negative_outside_taps>(src, v_tap);
 
-    // tap signs : - + - + + - + -
-    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[0]), v_tap[0]);
-    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
-    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
-    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
-    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
-    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
-    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
-    v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[7]), v_tap[7]);
-  } else if (num_taps == 6) {
-    v_src[1] = vextq_u8(v_src[0], v_src[0], 1);
-    v_src[2] = vextq_u8(v_src[0], v_src[0], 2);
-    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
-    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
-    v_src[5] = vextq_u8(v_src[0], v_src[0], 5);
-    v_src[6] = vextq_u8(v_src[0], v_src[0], 6);
-    if (filter_index == 0) {
-      // tap signs : + - + + - +
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
-      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
-      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
-    } else {
-      if (negative_outside_taps) {
-        // tap signs : - + + + + -
-        v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
-        v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
-      } else {
-        // tap signs : + + + + + +
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[1]), v_tap[1]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
-        v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[6]), v_tap[6]);
-      }
-    }
-  } else if (num_taps == 4) {
-    v_src[2] = vextq_u8(v_src[0], v_src[0], 2);
-    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
-    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
-    v_src[5] = vextq_u8(v_src[0], v_src[0], 5);
-    if (filter_index == 4) {
-      // tap signs : - + + -
-      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
-      v_sum = vmlsl_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
-    } else {
-      // tap signs : + + + +
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[2]), v_tap[2]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
-      v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[5]), v_tap[5]);
-    }
-  } else {
-    assert(num_taps == 2);
-    v_src[3] = vextq_u8(v_src[0], v_src[0], 3);
-    v_src[4] = vextq_u8(v_src[0], v_src[0], 4);
-    // tap signs : + +
-    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[3]), v_tap[3]);
-    v_sum = vmlal_u8(v_sum, vget_low_u8(v_src[4]), v_tap[4]);
-  }
-
-  return v_sum;
+  return vreinterpretq_u16_s16(
+      vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1));
 }
 
-template <int num_taps, int filter_index>
-uint16x8_t SumHorizontalTaps2xH(const uint8_t* src, const ptrdiff_t src_stride,
-                                const uint8x8_t* const v_tap) {
-  constexpr int positive_offset_bits = kBitdepth8 + kFilterBits - 1;
-  uint16x8_t sum = vdupq_n_u16(1 << positive_offset_bits);
-  uint8x8_t input0 = vld1_u8(src);
+template <int filter_index>
+int16x8_t SumHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride,
+                               const uint8x8_t* const v_tap) {
+  uint16x8_t sum;
+  const uint8x8_t input0 = vld1_u8(src);
   src += src_stride;
-  uint8x8_t input1 = vld1_u8(src);
+  const uint8x8_t input1 = vld1_u8(src);
   uint8x8x2_t input = vzip_u8(input0, input1);
 
-  if (num_taps == 2) {
+  if (filter_index == 3) {
     // tap signs : + +
-    sum = vmlal_u8(sum, vext_u8(input.val[0], input.val[1], 6), v_tap[3]);
+    sum = vmull_u8(vext_u8(input.val[0], input.val[1], 6), v_tap[3]);
     sum = vmlal_u8(sum, input.val[1], v_tap[4]);
   } else if (filter_index == 4) {
     // tap signs : - + + -
+    sum = vmull_u8(vext_u8(input.val[0], input.val[1], 6), v_tap[3]);
     sum = vmlsl_u8(sum, RightShift<4 * 8>(input.val[0]), v_tap[2]);
-    sum = vmlal_u8(sum, vext_u8(input.val[0], input.val[1], 6), v_tap[3]);
     sum = vmlal_u8(sum, input.val[1], v_tap[4]);
     sum = vmlsl_u8(sum, RightShift<2 * 8>(input.val[1]), v_tap[5]);
   } else {
     // tap signs : + + + +
-    sum = vmlal_u8(sum, RightShift<4 * 8>(input.val[0]), v_tap[2]);
+    sum = vmull_u8(RightShift<4 * 8>(input.val[0]), v_tap[2]);
     sum = vmlal_u8(sum, vext_u8(input.val[0], input.val[1], 6), v_tap[3]);
     sum = vmlal_u8(sum, input.val[1], v_tap[4]);
     sum = vmlal_u8(sum, RightShift<2 * 8>(input.val[1]), v_tap[5]);
   }
 
-  return vrshrq_n_u16(sum, kInterRoundBitsHorizontal);
+  return vreinterpretq_s16_u16(sum);
 }
 
-// TODO(johannkoenig): Rename this function. It works for more than just
-// compound convolutions.
+template <int filter_index>
+uint8x8_t SimpleHorizontalTaps2x2(const uint8_t* src,
+                                  const ptrdiff_t src_stride,
+                                  const uint8x8_t* const v_tap) {
+  int16x8_t sum = SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap);
+
+  // Normally the Horizontal pass does the downshift in two passes:
+  // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
+  // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them
+  // requires adding the rounding offset from the skipped shift.
+  constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2);
+
+  sum = vaddq_s16(sum, vdupq_n_s16(first_shift_rounding_bit));
+  return vqrshrun_n_s16(sum, kFilterBits - 1);
+}
+
+template <int filter_index>
+uint16x8_t HorizontalTaps8To16_2x2(const uint8_t* src,
+                                   const ptrdiff_t src_stride,
+                                   const uint8x8_t* const v_tap) {
+  const int16x8_t sum =
+      SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap);
+
+  return vreinterpretq_u16_s16(
+      vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1));
+}
+
 template <int num_taps, int step, int filter_index,
           bool negative_outside_taps = true, bool is_2d = false,
-          bool is_8bit = false>
-void ConvolveCompoundHorizontalBlock(const uint8_t* src,
-                                     const ptrdiff_t src_stride,
-                                     void* const dest,
-                                     const ptrdiff_t pred_stride,
-                                     const int width, const int height,
-                                     const uint8x8_t* const v_tap) {
-  const uint16x8_t v_compound_round_offset = vdupq_n_u16(1 << (kBitdepth8 + 4));
-  const int16x8_t v_inter_round_bits_0 =
-      vdupq_n_s16(-kInterRoundBitsHorizontal);
-
+          bool is_compound = false>
+void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride,
+                      void* const dest, const ptrdiff_t pred_stride,
+                      const int width, const int height,
+                      const uint8x8_t* const v_tap) {
   auto* dest8 = static_cast<uint8_t*>(dest);
   auto* dest16 = static_cast<uint16_t*>(dest);
 
-  if (width > 4) {
+  // 4 tap filters are never used when width > 4.
+  if (num_taps != 4 && width > 4) {
     int y = 0;
     do {
       int x = 0;
       do {
-        uint16x8_t v_sum =
-            SumCompoundHorizontalTaps<num_taps, filter_index,
-                                      negative_outside_taps>(&src[x], v_tap);
-        if (is_8bit) {
-          // Split shifts the way they are in C. They can be combined but that
-          // makes removing the 1 << 14 offset much more difficult.
-          v_sum = vrshrq_n_u16(v_sum, kInterRoundBitsHorizontal);
-          int16x8_t v_sum_signed = vreinterpretq_s16_u16(vsubq_u16(
-              v_sum, vdupq_n_u16(1 << (14 - kInterRoundBitsHorizontal))));
-          uint8x8_t result = vqrshrun_n_s16(
-              v_sum_signed, kFilterBits - kInterRoundBitsHorizontal);
-          vst1_u8(&dest8[x], result);
-        } else {
-          v_sum = vrshlq_u16(v_sum, v_inter_round_bits_0);
-          if (!is_2d) {
-            v_sum = vaddq_u16(v_sum, v_compound_round_offset);
-          }
+        if (is_2d || is_compound) {
+          const uint16x8_t v_sum =
+              HorizontalTaps8To16<filter_index, negative_outside_taps>(&src[x],
+                                                                       v_tap);
           vst1q_u16(&dest16[x], v_sum);
+        } else {
+          const uint8x8_t result =
+              SimpleHorizontalTaps<filter_index, negative_outside_taps>(&src[x],
+                                                                        v_tap);
+          vst1_u8(&dest8[x], result);
         }
         x += step;
       } while (x < width);
@@ -352,135 +258,142 @@
       dest16 += pred_stride;
     } while (++y < height);
     return;
-  } else if (width == 4) {
-    int y = 0;
-    do {
-      uint16x8_t v_sum =
-          SumCompoundHorizontalTaps<num_taps, filter_index,
-                                    negative_outside_taps>(&src[0], v_tap);
-      if (is_8bit) {
-        v_sum = vrshrq_n_u16(v_sum, kInterRoundBitsHorizontal);
-        int16x8_t v_sum_signed = vreinterpretq_s16_u16(vsubq_u16(
-            v_sum, vdupq_n_u16(1 << (14 - kInterRoundBitsHorizontal))));
-        uint8x8_t result = vqrshrun_n_s16(
-            v_sum_signed, kFilterBits - kInterRoundBitsHorizontal);
-        StoreLo4(&dest8[0], result);
-      } else {
-        v_sum = vrshlq_u16(v_sum, v_inter_round_bits_0);
-        if (!is_2d) {
-          v_sum = vaddq_u16(v_sum, v_compound_round_offset);
-        }
-        vst1_u16(&dest16[0], vget_low_u16(v_sum));
-      }
-      src += src_stride;
-      dest8 += pred_stride;
-      dest16 += pred_stride;
-    } while (++y < height);
-    return;
   }
 
   // Horizontal passes only needs to account for |num_taps| 2 and 4 when
-  // |width| == 2.
-  assert(width == 2);
+  // |width| <= 4.
+  assert(width <= 4);
   assert(num_taps <= 4);
-
-  constexpr int positive_offset_bits = kBitdepth8 + kFilterBits - 1;
-  // Leave off + 1 << (kBitdepth8 + 3).
-  constexpr int compound_round_offset = 1 << (kBitdepth8 + 4);
-
   if (num_taps <= 4) {
-    int y = 0;
-    do {
-      // TODO(johannkoenig): Re-order the values for storing.
-      uint16x8_t sum =
-          SumHorizontalTaps2xH<num_taps, filter_index>(src, src_stride, v_tap);
+    if (width == 4) {
+      int y = 0;
+      do {
+        if (is_2d || is_compound) {
+          const uint16x8_t v_sum =
+              HorizontalTaps8To16<filter_index, negative_outside_taps>(src,
+                                                                       v_tap);
+          vst1_u16(dest16, vget_low_u16(v_sum));
+        } else {
+          const uint8x8_t result =
+              SimpleHorizontalTaps<filter_index, negative_outside_taps>(src,
+                                                                        v_tap);
+          StoreLo4(&dest8[0], result);
+        }
+        src += src_stride;
+        dest8 += pred_stride;
+        dest16 += pred_stride;
+      } while (++y < height);
+      return;
+    }
 
+    if (!is_compound) {
+      int y = 0;
+      do {
+        if (is_2d) {
+          const uint16x8_t sum =
+              HorizontalTaps8To16_2x2<filter_index>(src, src_stride, v_tap);
+          dest16[0] = vgetq_lane_u16(sum, 0);
+          dest16[1] = vgetq_lane_u16(sum, 2);
+          dest16 += pred_stride;
+          dest16[0] = vgetq_lane_u16(sum, 1);
+          dest16[1] = vgetq_lane_u16(sum, 3);
+          dest16 += pred_stride;
+        } else {
+          const uint8x8_t sum =
+              SimpleHorizontalTaps2x2<filter_index>(src, src_stride, v_tap);
+
+          dest8[0] = vget_lane_u8(sum, 0);
+          dest8[1] = vget_lane_u8(sum, 2);
+          dest8 += pred_stride;
+
+          dest8[0] = vget_lane_u8(sum, 1);
+          dest8[1] = vget_lane_u8(sum, 3);
+          dest8 += pred_stride;
+        }
+
+        src += src_stride << 1;
+        y += 2;
+      } while (y < height - 1);
+
+      // The 2d filters have an odd |height| because the horizontal pass
+      // generates context for the vertical pass.
       if (is_2d) {
-        dest16[0] = vgetq_lane_u16(sum, 0);
-        dest16[1] = vgetq_lane_u16(sum, 2);
-        dest16 += pred_stride;
-        dest16[0] = vgetq_lane_u16(sum, 1);
-        dest16[1] = vgetq_lane_u16(sum, 3);
-        dest16 += pred_stride;
-      } else if (!is_8bit) {
-        // None of the test vectors hit this path but the unit tests do.
-        sum = vaddq_u16(sum, vdupq_n_u16(compound_round_offset));
-
-        dest16[0] = vgetq_lane_u16(sum, 0);
-        dest16[1] = vgetq_lane_u16(sum, 2);
-        dest16 += pred_stride;
-        dest16[0] = vgetq_lane_u16(sum, 1);
-        dest16[1] = vgetq_lane_u16(sum, 3);
-        dest16 += pred_stride;
-      } else {
-        // Split shifts the way they are in C. They can be combined but that
-        // makes removing the 1 << 14 offset much more difficult.
-        int16x8_t sum_signed = vreinterpretq_s16_u16(vsubq_u16(
-            sum, vdupq_n_u16(
-                     1 << (positive_offset_bits - kInterRoundBitsHorizontal))));
-        uint8x8_t result =
-            vqrshrun_n_s16(sum_signed, kFilterBits - kInterRoundBitsHorizontal);
-
-        // Could de-interleave and vst1_lane_u16().
-        dest8[0] = vget_lane_u8(result, 0);
-        dest8[1] = vget_lane_u8(result, 2);
-        dest8 += pred_stride;
-
-        dest8[0] = vget_lane_u8(result, 1);
-        dest8[1] = vget_lane_u8(result, 3);
-        dest8 += pred_stride;
+        assert(height % 2 == 1);
+        uint16x8_t sum;
+        const uint8x8_t input = vld1_u8(src);
+        if (filter_index == 3) {  // |num_taps| == 2
+          sum = vmull_u8(RightShift<3 * 8>(input), v_tap[3]);
+          sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]);
+        } else if (filter_index == 4) {
+          sum = vmull_u8(RightShift<3 * 8>(input), v_tap[3]);
+          sum = vmlsl_u8(sum, RightShift<2 * 8>(input), v_tap[2]);
+          sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]);
+          sum = vmlsl_u8(sum, RightShift<5 * 8>(input), v_tap[5]);
+        } else {
+          assert(filter_index == 5);
+          sum = vmull_u8(RightShift<2 * 8>(input), v_tap[2]);
+          sum = vmlal_u8(sum, RightShift<3 * 8>(input), v_tap[3]);
+          sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]);
+          sum = vmlal_u8(sum, RightShift<5 * 8>(input), v_tap[5]);
+        }
+        // |sum| contains an int16_t value.
+        sum = vreinterpretq_u16_s16(vrshrq_n_s16(
+            vreinterpretq_s16_u16(sum), kInterRoundBitsHorizontal - 1));
+        Store2<0>(dest16, sum);
       }
-
-      src += src_stride << 1;
-      y += 2;
-    } while (y < height - 1);
-
-    // The 2d filters have an odd |height| because the horizontal pass generates
-    // context for the vertical pass.
-    if (is_2d) {
-      assert(height % 2 == 1);
-      uint16x8_t sum = vdupq_n_u16(1 << positive_offset_bits);
-      uint8x8_t input = vld1_u8(src);
-      if (filter_index == 3) {  // |num_taps| == 2
-        sum = vmlal_u8(sum, RightShift<3 * 8>(input), v_tap[3]);
-        sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]);
-      } else if (filter_index == 4) {
-        sum = vmlsl_u8(sum, RightShift<2 * 8>(input), v_tap[2]);
-        sum = vmlal_u8(sum, RightShift<3 * 8>(input), v_tap[3]);
-        sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]);
-        sum = vmlsl_u8(sum, RightShift<5 * 8>(input), v_tap[5]);
-      } else {
-        assert(filter_index == 5);
-        sum = vmlal_u8(sum, RightShift<2 * 8>(input), v_tap[2]);
-        sum = vmlal_u8(sum, RightShift<3 * 8>(input), v_tap[3]);
-        sum = vmlal_u8(sum, RightShift<4 * 8>(input), v_tap[4]);
-        sum = vmlal_u8(sum, RightShift<5 * 8>(input), v_tap[5]);
-        sum = vrshrq_n_u16(sum, kInterRoundBitsHorizontal);
-      }
-      Store2<0>(dest16, sum);
     }
   }
 }
 
 // Process 16 bit inputs and output 32 bits.
-template <int num_taps>
-uint32x4x2_t Sum2DVerticalTaps(const int16x8_t* const src,
-                               const int16x8_t taps) {
-  // In order to get the rollover correct with the lengthening instruction we
-  // need to treat these as signed so that they sign extend properly.
+template <int num_taps, bool is_compound>
+inline int16x4_t Sum2DVerticalTaps4(const int16x4_t* const src,
+                                    const int16x8_t taps) {
   const int16x4_t taps_lo = vget_low_s16(taps);
   const int16x4_t taps_hi = vget_high_s16(taps);
-  // An offset to guarantee the sum is non negative. Captures 56 * -4590 =
-  // 257040 (worst case negative value from horizontal pass). It should be
-  // possible to use 1 << 18 (262144) instead of 1 << 19 but there probably
-  // isn't any benefit.
-  // |offset_bits| = bitdepth + 2 * kFilterBits - kInterRoundBitsHorizontal
-  // == 19.
-  int32x4_t sum_lo = vdupq_n_s32(1 << 19);
-  int32x4_t sum_hi = sum_lo;
+  int32x4_t sum;
   if (num_taps == 8) {
-    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[0]), taps_lo, 0);
-    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[0]), taps_lo, 0);
+    sum = vmull_lane_s16(src[0], taps_lo, 0);
+    sum = vmlal_lane_s16(sum, src[1], taps_lo, 1);
+    sum = vmlal_lane_s16(sum, src[2], taps_lo, 2);
+    sum = vmlal_lane_s16(sum, src[3], taps_lo, 3);
+    sum = vmlal_lane_s16(sum, src[4], taps_hi, 0);
+    sum = vmlal_lane_s16(sum, src[5], taps_hi, 1);
+    sum = vmlal_lane_s16(sum, src[6], taps_hi, 2);
+    sum = vmlal_lane_s16(sum, src[7], taps_hi, 3);
+  } else if (num_taps == 6) {
+    sum = vmull_lane_s16(src[0], taps_lo, 1);
+    sum = vmlal_lane_s16(sum, src[1], taps_lo, 2);
+    sum = vmlal_lane_s16(sum, src[2], taps_lo, 3);
+    sum = vmlal_lane_s16(sum, src[3], taps_hi, 0);
+    sum = vmlal_lane_s16(sum, src[4], taps_hi, 1);
+    sum = vmlal_lane_s16(sum, src[5], taps_hi, 2);
+  } else if (num_taps == 4) {
+    sum = vmull_lane_s16(src[0], taps_lo, 2);
+    sum = vmlal_lane_s16(sum, src[1], taps_lo, 3);
+    sum = vmlal_lane_s16(sum, src[2], taps_hi, 0);
+    sum = vmlal_lane_s16(sum, src[3], taps_hi, 1);
+  } else if (num_taps == 2) {
+    sum = vmull_lane_s16(src[0], taps_lo, 3);
+    sum = vmlal_lane_s16(sum, src[1], taps_hi, 0);
+  }
+
+  if (is_compound) {
+    return vqrshrn_n_s32(sum, kInterRoundBitsCompoundVertical - 1);
+  }
+
+  return vqrshrn_n_s32(sum, kInterRoundBitsVertical - 1);
+}
+
+template <int num_taps, bool is_compound>
+int16x8_t SimpleSum2DVerticalTaps(const int16x8_t* const src,
+                                  const int16x8_t taps) {
+  const int16x4_t taps_lo = vget_low_s16(taps);
+  const int16x4_t taps_hi = vget_high_s16(taps);
+  int32x4_t sum_lo, sum_hi;
+  if (num_taps == 8) {
+    sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 0);
+    sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 0);
     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 1);
     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 1);
     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 2);
@@ -497,8 +410,8 @@
     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[7]), taps_hi, 3);
     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[7]), taps_hi, 3);
   } else if (num_taps == 6) {
-    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[0]), taps_lo, 1);
-    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[0]), taps_lo, 1);
+    sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 1);
+    sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 1);
     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 2);
     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 2);
     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 3);
@@ -511,8 +424,8 @@
     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 2);
     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 2);
   } else if (num_taps == 4) {
-    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[0]), taps_lo, 2);
-    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[0]), taps_lo, 2);
+    sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 2);
+    sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 2);
     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 3);
     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 3);
 
@@ -521,384 +434,273 @@
     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 1);
     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 1);
   } else if (num_taps == 2) {
-    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[0]), taps_lo, 3);
-    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[0]), taps_lo, 3);
+    sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 3);
+    sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 3);
 
     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_hi, 0);
     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_hi, 0);
   }
 
-  // This is guaranteed to be positive. Convert it for the final shift.
-  const uint32x4x2_t return_val = {vreinterpretq_u32_s32(sum_lo),
-                                   vreinterpretq_u32_s32(sum_hi)};
-  return return_val;
-}
-
-// Process 16 bit inputs and output 32 bits.
-template <int num_taps>
-uint32x4_t Sum2DVerticalTaps(const int16x4_t* const src, const int16x8_t taps) {
-  // In order to get the rollover correct with the lengthening instruction we
-  // need to treat these as signed so that they sign extend properly.
-  const int16x4_t taps_lo = vget_low_s16(taps);
-  const int16x4_t taps_hi = vget_high_s16(taps);
-  // An offset to guarantee the sum is non negative. Captures 56 * -4590 =
-  // 257040 (worst case negative value from horizontal pass). It should be
-  // possible to use 1 << 18 (262144) instead of 1 << 19 but there probably
-  // isn't any benefit.
-  // |offset_bits| = bitdepth + 2 * kFilterBits - kInterRoundBitsHorizontal
-  // == 19.
-  int32x4_t sum = vdupq_n_s32(1 << 19);
-  if (num_taps == 8) {
-    sum = vmlal_lane_s16(sum, src[0], taps_lo, 0);
-    sum = vmlal_lane_s16(sum, src[1], taps_lo, 1);
-    sum = vmlal_lane_s16(sum, src[2], taps_lo, 2);
-    sum = vmlal_lane_s16(sum, src[3], taps_lo, 3);
-
-    sum = vmlal_lane_s16(sum, src[4], taps_hi, 0);
-    sum = vmlal_lane_s16(sum, src[5], taps_hi, 1);
-    sum = vmlal_lane_s16(sum, src[6], taps_hi, 2);
-    sum = vmlal_lane_s16(sum, src[7], taps_hi, 3);
-  } else if (num_taps == 6) {
-    sum = vmlal_lane_s16(sum, src[0], taps_lo, 1);
-    sum = vmlal_lane_s16(sum, src[1], taps_lo, 2);
-    sum = vmlal_lane_s16(sum, src[2], taps_lo, 3);
-
-    sum = vmlal_lane_s16(sum, src[3], taps_hi, 0);
-    sum = vmlal_lane_s16(sum, src[4], taps_hi, 1);
-    sum = vmlal_lane_s16(sum, src[5], taps_hi, 2);
-  } else if (num_taps == 4) {
-    sum = vmlal_lane_s16(sum, src[0], taps_lo, 2);
-    sum = vmlal_lane_s16(sum, src[1], taps_lo, 3);
-
-    sum = vmlal_lane_s16(sum, src[2], taps_hi, 0);
-    sum = vmlal_lane_s16(sum, src[3], taps_hi, 1);
-  } else if (num_taps == 2) {
-    sum = vmlal_lane_s16(sum, src[0], taps_lo, 3);
-
-    sum = vmlal_lane_s16(sum, src[1], taps_hi, 0);
+  if (is_compound) {
+    return vcombine_s16(
+        vqrshrn_n_s32(sum_lo, kInterRoundBitsCompoundVertical - 1),
+        vqrshrn_n_s32(sum_hi, kInterRoundBitsCompoundVertical - 1));
   }
 
-  // This is guaranteed to be positive. Convert it for the final shift.
-  return vreinterpretq_u32_s32(sum);
+  return vcombine_s16(vqrshrn_n_s32(sum_lo, kInterRoundBitsVertical - 1),
+                      vqrshrn_n_s32(sum_hi, kInterRoundBitsVertical - 1));
 }
 
 template <int num_taps, bool is_compound = false>
-void Filter2DVertical(const uint16_t* src, const ptrdiff_t src_stride,
-                      void* const dst, const ptrdiff_t dst_stride,
-                      const int width, const int height, const int16x8_t taps,
-                      const int inter_round_bits_vertical) {
+void Filter2DVertical(const uint16_t* src, void* const dst,
+                      const ptrdiff_t dst_stride, const int width,
+                      const int height, const int16x8_t taps) {
+  assert(width >= 8);
   constexpr int next_row = num_taps - 1;
-  const int32x4_t v_inter_round_bits_vertical =
-      vdupq_n_s32(-inter_round_bits_vertical);
+  // The Horizontal pass uses |width| as |stride| for the intermediate buffer.
+  const ptrdiff_t src_stride = width;
 
   auto* dst8 = static_cast<uint8_t*>(dst);
   auto* dst16 = static_cast<uint16_t*>(dst);
 
-  if (width > 4) {
-    int x = 0;
+  int x = 0;
+  do {
+    int16x8_t srcs[8];
+    const uint16_t* src_x = src + x;
+    srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+    src_x += src_stride;
+    if (num_taps >= 4) {
+      srcs[1] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+      src_x += src_stride;
+      srcs[2] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+      src_x += src_stride;
+      if (num_taps >= 6) {
+        srcs[3] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+        src_x += src_stride;
+        srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+        src_x += src_stride;
+        if (num_taps == 8) {
+          srcs[5] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+          src_x += src_stride;
+          srcs[6] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+          src_x += src_stride;
+        }
+      }
+    }
+
+    int y = 0;
     do {
-      int16x8_t srcs[8];
-      srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src + x));
+      srcs[next_row] = vreinterpretq_s16_u16(vld1q_u16(src_x));
+      src_x += src_stride;
+
+      const int16x8_t sum =
+          SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps);
+      if (is_compound) {
+        vst1q_u16(dst16 + x + y * dst_stride, vreinterpretq_u16_s16(sum));
+      } else {
+        vst1_u8(dst8 + x + y * dst_stride, vqmovun_s16(sum));
+      }
+
+      srcs[0] = srcs[1];
       if (num_taps >= 4) {
-        srcs[1] = vreinterpretq_s16_u16(vld1q_u16(src + x + src_stride));
-        srcs[2] = vreinterpretq_s16_u16(vld1q_u16(src + x + 2 * src_stride));
+        srcs[1] = srcs[2];
+        srcs[2] = srcs[3];
         if (num_taps >= 6) {
-          srcs[3] = vreinterpretq_s16_u16(vld1q_u16(src + x + 3 * src_stride));
-          srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src + x + 4 * src_stride));
+          srcs[3] = srcs[4];
+          srcs[4] = srcs[5];
           if (num_taps == 8) {
-            srcs[5] =
-                vreinterpretq_s16_u16(vld1q_u16(src + x + 5 * src_stride));
-            srcs[6] =
-                vreinterpretq_s16_u16(vld1q_u16(src + x + 6 * src_stride));
+            srcs[5] = srcs[6];
+            srcs[6] = srcs[7];
           }
         }
       }
+    } while (++y < height);
+    x += 8;
+  } while (x < width);
+}
 
-      int y = 0;
-      do {
-        srcs[next_row] = vreinterpretq_s16_u16(
-            vld1q_u16(src + x + (y + next_row) * src_stride));
+// Take advantage of |src_stride| == |width| to process two rows at a time.
+template <int num_taps, bool is_compound = false>
+void Filter2DVertical4xH(const uint16_t* src, void* const dst,
+                         const ptrdiff_t dst_stride, const int height,
+                         const int16x8_t taps) {
+  auto* dst8 = static_cast<uint8_t*>(dst);
+  auto* dst16 = static_cast<uint16_t*>(dst);
 
-        const uint32x4x2_t sums = Sum2DVerticalTaps<num_taps>(srcs, taps);
-        if (is_compound) {
-          const uint16x8_t results = vcombine_u16(
-              vmovn_u32(vqrshlq_u32(sums.val[0], v_inter_round_bits_vertical)),
-              vmovn_u32(vqrshlq_u32(sums.val[1], v_inter_round_bits_vertical)));
-          vst1q_u16(dst16 + x + y * dst_stride, results);
-        } else {
-          const uint16x8_t first_shift =
-              vcombine_u16(vqrshrn_n_u32(sums.val[0], kInterRoundBitsVertical),
-                           vqrshrn_n_u32(sums.val[1], kInterRoundBitsVertical));
-          // |single_round_offset| == (1 << bitdepth) + (1 << (bitdepth - 1)) ==
-          // 384
-          const uint8x8_t results =
-              vqmovn_u16(vqsubq_u16(first_shift, vdupq_n_u16(384)));
-
-          vst1_u8(dst8 + x + y * dst_stride, results);
-        }
-
-        srcs[0] = srcs[1];
-        if (num_taps >= 4) {
-          srcs[1] = srcs[2];
-          srcs[2] = srcs[3];
-          if (num_taps >= 6) {
-            srcs[3] = srcs[4];
-            srcs[4] = srcs[5];
-            if (num_taps == 8) {
-              srcs[5] = srcs[6];
-              srcs[6] = srcs[7];
-            }
-          }
-        }
-      } while (++y < height);
-      x += 8;
-    } while (x < width);
-    return;
-  }
-
-  assert(width == 4);
-  int16x4_t srcs[8];
-  srcs[0] = vreinterpret_s16_u16(vld1_u16(src));
-  src += src_stride;
+  int16x8_t srcs[9];
+  srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src));
+  src += 8;
   if (num_taps >= 4) {
-    srcs[1] = vreinterpret_s16_u16(vld1_u16(src));
-    src += src_stride;
-    srcs[2] = vreinterpret_s16_u16(vld1_u16(src));
-    src += src_stride;
+    srcs[2] = vreinterpretq_s16_u16(vld1q_u16(src));
+    src += 8;
+    srcs[1] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[2]));
     if (num_taps >= 6) {
-      srcs[3] = vreinterpret_s16_u16(vld1_u16(src));
-      src += src_stride;
-      srcs[4] = vreinterpret_s16_u16(vld1_u16(src));
-      src += src_stride;
+      srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src));
+      src += 8;
+      srcs[3] = vcombine_s16(vget_high_s16(srcs[2]), vget_low_s16(srcs[4]));
       if (num_taps == 8) {
-        srcs[5] = vreinterpret_s16_u16(vld1_u16(src));
-        src += src_stride;
-        srcs[6] = vreinterpret_s16_u16(vld1_u16(src));
-        src += src_stride;
+        srcs[6] = vreinterpretq_s16_u16(vld1q_u16(src));
+        src += 8;
+        srcs[5] = vcombine_s16(vget_high_s16(srcs[4]), vget_low_s16(srcs[6]));
       }
     }
   }
 
   int y = 0;
   do {
-    srcs[next_row] = vreinterpret_s16_u16(vld1_u16(src));
-    src += src_stride;
+    srcs[num_taps] = vreinterpretq_s16_u16(vld1q_u16(src));
+    src += 8;
+    srcs[num_taps - 1] = vcombine_s16(vget_high_s16(srcs[num_taps - 2]),
+                                      vget_low_s16(srcs[num_taps]));
 
-    const uint32x4_t sums = Sum2DVerticalTaps<num_taps>(srcs, taps);
+    const int16x8_t sum =
+        SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps);
     if (is_compound) {
-      const uint16x4_t results =
-          vmovn_u32(vqrshlq_u32(sums, v_inter_round_bits_vertical));
-      vst1_u16(dst16, results);
-      dst16 += dst_stride;
+      const uint16x8_t results = vreinterpretq_u16_s16(sum);
+      vst1q_u16(dst16, results);
+      dst16 += 4 << 1;
     } else {
-      const uint16x4_t first_shift =
-          vqrshrn_n_u32(sums, kInterRoundBitsVertical);
-      // |single_round_offset| == (1 << bitdepth) + (1 << (bitdepth - 1)) ==
-      // 384
-      const uint8x8_t results = vqmovn_u16(
-          vcombine_u16(vqsub_u16(first_shift, vdup_n_u16(384)), vdup_n_u16(0)));
+      const uint8x8_t results = vqmovun_s16(sum);
 
       StoreLo4(dst8, results);
       dst8 += dst_stride;
+      StoreHi4(dst8, results);
+      dst8 += dst_stride;
     }
 
-    srcs[0] = srcs[1];
+    srcs[0] = srcs[2];
     if (num_taps >= 4) {
-      srcs[1] = srcs[2];
-      srcs[2] = srcs[3];
+      srcs[1] = srcs[3];
+      srcs[2] = srcs[4];
       if (num_taps >= 6) {
-        srcs[3] = srcs[4];
-        srcs[4] = srcs[5];
+        srcs[3] = srcs[5];
+        srcs[4] = srcs[6];
         if (num_taps == 8) {
-          srcs[5] = srcs[6];
-          srcs[6] = srcs[7];
+          srcs[5] = srcs[7];
+          srcs[6] = srcs[8];
         }
       }
     }
-  } while (++y < height);
+    y += 2;
+  } while (y < height);
 }
 
-template <bool is_2d = false, bool is_8bit = false>
-void HorizontalPass(const uint8_t* const src, const ptrdiff_t src_stride,
-                    void* const dst, const ptrdiff_t dst_stride,
-                    const int width, const int height, const int subpixel,
-                    const int filter_index) {
+// Take advantage of |src_stride| == |width| to process four rows at a time.
+template <int num_taps>
+void Filter2DVertical2xH(const uint16_t* src, void* const dst,
+                         const ptrdiff_t dst_stride, const int height,
+                         const int16x8_t taps) {
+  constexpr int next_row = (num_taps < 6) ? 4 : 8;
+
+  auto* dst8 = static_cast<uint8_t*>(dst);
+
+  int16x8_t srcs[9];
+  srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src));
+  src += 8;
+  if (num_taps >= 6) {
+    srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src));
+    src += 8;
+    srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
+    if (num_taps == 8) {
+      srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
+      srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
+    }
+  }
+
+  int y = 0;
+  do {
+    srcs[next_row] = vreinterpretq_s16_u16(vld1q_u16(src));
+    src += 8;
+    if (num_taps == 2) {
+      srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
+    } else if (num_taps == 4) {
+      srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
+      srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
+      srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
+    } else if (num_taps == 6) {
+      srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
+      srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
+      srcs[5] = vextq_s16(srcs[4], srcs[8], 2);
+    } else if (num_taps == 8) {
+      srcs[5] = vextq_s16(srcs[4], srcs[8], 2);
+      srcs[6] = vcombine_s16(vget_high_s16(srcs[4]), vget_low_s16(srcs[8]));
+      srcs[7] = vextq_s16(srcs[4], srcs[8], 6);
+    }
+
+    const int16x8_t sum =
+        SimpleSum2DVerticalTaps<num_taps, /*is_compound=*/false>(srcs, taps);
+    const uint8x8_t results = vqmovun_s16(sum);
+
+    Store2<0>(dst8, results);
+    dst8 += dst_stride;
+    Store2<1>(dst8, results);
+    // When |height| <= 4 the taps are restricted to 2 and 4 tap variants.
+    // Therefore we don't need to check this condition when |height| > 4.
+    if (num_taps <= 4 && height == 2) return;
+    dst8 += dst_stride;
+    Store2<2>(dst8, results);
+    dst8 += dst_stride;
+    Store2<3>(dst8, results);
+    dst8 += dst_stride;
+
+    srcs[0] = srcs[4];
+    if (num_taps == 6) {
+      srcs[1] = srcs[5];
+      srcs[4] = srcs[8];
+    } else if (num_taps == 8) {
+      srcs[1] = srcs[5];
+      srcs[2] = srcs[6];
+      srcs[3] = srcs[7];
+      srcs[4] = srcs[8];
+    }
+
+    y += 4;
+  } while (y < height);
+}
+
+template <bool is_2d = false, bool is_compound = false>
+LIBGAV1_ALWAYS_INLINE void DoHorizontalPass(
+    const uint8_t* const src, const ptrdiff_t src_stride, void* const dst,
+    const ptrdiff_t dst_stride, const int width, const int height,
+    const int subpixel, const int filter_index) {
   // Duplicate the absolute value for each tap.  Negative taps are corrected
   // by using the vmlsl_u8 instruction.  Positive taps use vmlal_u8.
   uint8x8_t v_tap[kSubPixelTaps];
   const int filter_id = (subpixel >> 6) & kSubPixelMask;
+  assert(filter_id != 0);
+
   for (int k = 0; k < kSubPixelTaps; ++k) {
-    v_tap[k] = vreinterpret_u8_s8(
-        vabs_s8(vdup_n_s8(kSubPixelFilters[filter_index][filter_id][k])));
+    v_tap[k] = vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][filter_id][k]);
   }
 
   if (filter_index == 2) {  // 8 tap.
-    ConvolveCompoundHorizontalBlock<8, 8, 2, true, is_2d, is_8bit>(
+    FilterHorizontal<8, 8, 2, true, is_2d, is_compound>(
         src, src_stride, dst, dst_stride, width, height, v_tap);
   } else if (filter_index == 1) {  // 6 tap.
     // Check if outside taps are positive.
     if ((filter_id == 1) | (filter_id == 15)) {
-      ConvolveCompoundHorizontalBlock<6, 8, 1, false, is_2d, is_8bit>(
+      FilterHorizontal<6, 8, 1, false, is_2d, is_compound>(
           src, src_stride, dst, dst_stride, width, height, v_tap);
     } else {
-      ConvolveCompoundHorizontalBlock<6, 8, 1, true, is_2d, is_8bit>(
+      FilterHorizontal<6, 8, 1, true, is_2d, is_compound>(
           src, src_stride, dst, dst_stride, width, height, v_tap);
     }
   } else if (filter_index == 0) {  // 6 tap.
-    ConvolveCompoundHorizontalBlock<6, 8, 0, true, is_2d, is_8bit>(
+    FilterHorizontal<6, 8, 0, true, is_2d, is_compound>(
         src, src_stride, dst, dst_stride, width, height, v_tap);
   } else if (filter_index == 4) {  // 4 tap.
-    ConvolveCompoundHorizontalBlock<4, 8, 4, true, is_2d, is_8bit>(
+    FilterHorizontal<4, 8, 4, true, is_2d, is_compound>(
         src, src_stride, dst, dst_stride, width, height, v_tap);
   } else if (filter_index == 5) {  // 4 tap.
-    ConvolveCompoundHorizontalBlock<4, 8, 5, true, is_2d, is_8bit>(
+    FilterHorizontal<4, 8, 5, true, is_2d, is_compound>(
         src, src_stride, dst, dst_stride, width, height, v_tap);
   } else {  // 2 tap.
-    ConvolveCompoundHorizontalBlock<2, 8, 3, true, is_2d, is_8bit>(
+    FilterHorizontal<2, 8, 3, true, is_2d, is_compound>(
         src, src_stride, dst, dst_stride, width, height, v_tap);
   }
 }
 
-// There are three forms of this function:
-// 2D: input 8bit, output 16bit. |is_compound| has no effect.
-// 1D Horizontal: input 8bit, output 8bit.
-// 1D Compound Horizontal: input 8bit, output 16bit. Different rounding from 2D.
-// |width| is guaranteed to be 2 because all other cases are handled in neon.
-template <bool is_2d = true, bool is_compound = false>
-void HorizontalPass2xH(const uint8_t* src, const ptrdiff_t src_stride,
-                       void* const dst, const ptrdiff_t dst_stride,
-                       const int height, const int filter_index, const int taps,
-                       const int subpixel) {
-  // Even though |is_compound| has no effect when |is_2d| is true we block this
-  // combination in case the compiler gets confused.
-  static_assert(!is_2d || !is_compound, "|is_compound| is ignored.");
-  // Since this only handles |width| == 2, we only need to be concerned with
-  // 2 or 4 tap filters.
-  assert(taps == 2 || taps == 4);
-  auto* dst8 = static_cast<uint8_t*>(dst);
-  auto* dst16 = static_cast<uint16_t*>(dst);
-
-  const int compound_round_offset =
-      (1 << (kBitdepth8 + 4)) + (1 << (kBitdepth8 + 3));
-
-  const int filter_id = (subpixel >> 6) & kSubPixelMask;
-  const int taps_start = (kSubPixelTaps - taps) / 2;
-  int y = 0;
-  do {
-    int x = 0;
-    do {
-      int sum;
-      if (is_2d) {
-        // An offset to guarantee the sum is non negative.
-        sum = 1 << (kBitdepth8 + kFilterBits - 1);
-      } else if (is_compound) {
-        sum = 0;
-      } else {
-        // 1D non-Compound. The C uses a two stage shift with rounding. Here the
-        // shifts are combined and the rounding bit from the first stage is
-        // added in.
-        // (sum + 4 >> 3) + 8) >> 4 == (sum + 64 + 4) >> 7
-        sum = 4;
-      }
-      for (int k = 0; k < taps; ++k) {
-        const int tap = k + taps_start;
-        sum += kSubPixelFilters[filter_index][filter_id][tap] * src[x + k];
-      }
-      if (is_2d) {
-        dst16[x] = static_cast<int16_t>(
-            RightShiftWithRounding(sum, kInterRoundBitsHorizontal));
-      } else if (is_compound) {
-        sum = RightShiftWithRounding(sum, kInterRoundBitsHorizontal);
-        dst16[x] = sum + compound_round_offset;
-      } else {
-        // 1D non-Compound.
-        dst8[x] = static_cast<uint8_t>(
-            Clip3(RightShiftWithRounding(sum, kFilterBits), 0, 255));
-      }
-    } while (++x < 2);
-
-    src += src_stride;
-    dst8 += dst_stride;
-    dst16 += dst_stride;
-  } while (++y < height);
-}
-
-// This will always need to handle all |filter_index| values. Even with |width|
-// restricted to 2 the value of |height| can go up to at least 16.
-template <bool is_2d = true, bool is_compound = false>
-void VerticalPass2xH(const void* const src, const ptrdiff_t src_stride,
-                     void* const dst, const ptrdiff_t dst_stride,
-                     const int height, const int inter_round_bits_vertical,
-                     const int filter_index, const int taps,
-                     const int subpixel) {
-  const auto* src8 = static_cast<const uint8_t*>(src);
-  const auto* src16 = static_cast<const uint16_t*>(src);
-  auto* dst8 = static_cast<uint8_t*>(dst);
-  auto* dst16 = static_cast<uint16_t*>(dst);
-  const int filter_id = (subpixel >> 6) & kSubPixelMask;
-  const int taps_start = (kSubPixelTaps - taps) / 2;
-  constexpr int max_pixel_value = (1 << kBitdepth8) - 1;
-
-  int y = 0;
-  do {
-    int x = 0;
-    do {
-      int sum;
-      if (is_2d) {
-        sum = 1 << (kBitdepth8 + 2 * kFilterBits - kInterRoundBitsHorizontal);
-      } else if (is_compound) {
-        // TODO(johannkoenig): Keeping the sum positive is valuable for neon but
-        // may not actually help the C implementation. Investigate removing
-        // this.
-        // Use this offset to cancel out 1 << (kBitdepth8 + 3) >> 3 from
-        // |compound_round_offset|.
-        sum = (1 << (kBitdepth8 + 3)) << 3;
-      } else {
-        sum = 0;
-      }
-
-      for (int k = 0; k < taps; ++k) {
-        const int tap = k + taps_start;
-        if (is_2d) {
-          sum += kSubPixelFilters[filter_index][filter_id][tap] *
-                 src16[x + k * src_stride];
-        } else {
-          sum += kSubPixelFilters[filter_index][filter_id][tap] *
-                 src8[x + k * src_stride];
-        }
-      }
-
-      if (is_2d) {
-        if (is_compound) {
-          dst16[x] = static_cast<uint16_t>(
-              RightShiftWithRounding(sum, inter_round_bits_vertical));
-        } else {
-          constexpr int single_round_offset =
-              (1 << kBitdepth8) + (1 << (kBitdepth8 - 1));
-          dst8[x] = static_cast<uint8_t>(
-              Clip3(RightShiftWithRounding(sum, kInterRoundBitsVertical) -
-                        single_round_offset,
-                    0, max_pixel_value));
-        }
-      } else if (is_compound) {
-        // Leave off + 1 << (kBitdepth8 + 3).
-        constexpr int compound_round_offset = 1 << (kBitdepth8 + 4);
-        dst16[x] = RightShiftWithRounding(sum, 3) + compound_round_offset;
-      } else {
-        // 1D non-compound.
-        dst8[x] = static_cast<uint8_t>(Clip3(
-            RightShiftWithRounding(sum, kFilterBits), 0, max_pixel_value));
-      }
-    } while (++x < 2);
-
-    src8 += src_stride;
-    src16 += src_stride;
-    dst8 += dst_stride;
-    dst16 += dst_stride;
-  } while (++y < height);
-}
-
-int NumTapsInFilter(const int filter_index) {
+int GetNumTapsInFilter(const int filter_index) {
   if (filter_index < 2) {
     // Despite the names these only use 6 taps.
     // kInterpolationFilterEightTap
@@ -930,255 +732,135 @@
 void Convolve2D_NEON(const void* const reference,
                      const ptrdiff_t reference_stride,
                      const int horizontal_filter_index,
-                     const int vertical_filter_index,
-                     const int /*inter_round_bits_vertical*/,
-                     const int subpixel_x, const int subpixel_y,
-                     const int /*step_x*/, const int /*step_y*/,
-                     const int width, const int height, void* prediction,
-                     const ptrdiff_t pred_stride) {
+                     const int vertical_filter_index, const int subpixel_x,
+                     const int subpixel_y, const int width, const int height,
+                     void* prediction, const ptrdiff_t pred_stride) {
   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
-  const int horizontal_taps = NumTapsInFilter(horiz_filter_index);
-  const int vertical_taps = NumTapsInFilter(vert_filter_index);
+  const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
 
   // The output of the horizontal filter is guaranteed to fit in 16 bits.
   uint16_t
       intermediate_result[kMaxSuperBlockSizeInPixels *
                           (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
-  const int intermediate_stride = width;
   const int intermediate_height = height + vertical_taps - 1;
 
-  if (width >= 4) {
-    const ptrdiff_t src_stride = reference_stride;
-    const auto* src = static_cast<const uint8_t*>(reference) -
-                      (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset;
+  const ptrdiff_t src_stride = reference_stride;
+  const auto* src = static_cast<const uint8_t*>(reference) -
+                    (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset;
 
-    HorizontalPass<true>(src, src_stride, intermediate_result,
-                         intermediate_stride, width, intermediate_height,
-                         subpixel_x, horiz_filter_index);
+  DoHorizontalPass</*is_2d=*/true>(src, src_stride, intermediate_result, width,
+                                   width, intermediate_height, subpixel_x,
+                                   horiz_filter_index);
 
-    // Vertical filter.
-    auto* dest = static_cast<uint8_t*>(prediction);
-    const ptrdiff_t dest_stride = pred_stride;
-    const int filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
-    const int16x8_t taps =
-        vld1q_s16(kSubPixelFilters[vert_filter_index][filter_id]);
+  // Vertical filter.
+  auto* dest = static_cast<uint8_t*>(prediction);
+  const ptrdiff_t dest_stride = pred_stride;
+  const int filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
+  assert(filter_id != 0);
 
-    if (vertical_taps == 8) {
-      Filter2DVertical<8>(intermediate_result, intermediate_stride, dest,
-                          dest_stride, width, height, taps, 0);
-    } else if (vertical_taps == 6) {
-      Filter2DVertical<6>(intermediate_result, intermediate_stride, dest,
-                          dest_stride, width, height, taps, 0);
-    } else if (vertical_taps == 4) {
-      Filter2DVertical<4>(intermediate_result, intermediate_stride, dest,
-                          dest_stride, width, height, taps, 0);
-    } else {  // |vertical_taps| == 2
-      Filter2DVertical<2>(intermediate_result, intermediate_stride, dest,
-                          dest_stride, width, height, taps, 0);
+  const int16x8_t taps =
+      vmovl_s8(vld1_s8(kHalfSubPixelFilters[vert_filter_index][filter_id]));
+
+  if (vertical_taps == 8) {
+    if (width == 2) {
+      Filter2DVertical2xH<8>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else if (width == 4) {
+      Filter2DVertical4xH<8>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else {
+      Filter2DVertical<8>(intermediate_result, dest, dest_stride, width, height,
+                          taps);
     }
-  } else {
-    assert(width == 2);
-    // Horizontal filter.
-    const auto* const src = static_cast<const uint8_t*>(reference) -
-                            ((vertical_taps / 2) - 1) * reference_stride -
-                            ((horizontal_taps / 2) - 1);
-
-    HorizontalPass2xH(src, reference_stride, intermediate_result,
-                      intermediate_stride, intermediate_height,
-                      horiz_filter_index, horizontal_taps, subpixel_x);
-
-    // Vertical filter.
-    auto* dest = static_cast<uint8_t*>(prediction);
-    const ptrdiff_t dest_stride = pred_stride;
-
-    VerticalPass2xH(intermediate_result, intermediate_stride, dest, dest_stride,
-                    height, 0, vert_filter_index, vertical_taps, subpixel_y);
+  } else if (vertical_taps == 6) {
+    if (width == 2) {
+      Filter2DVertical2xH<6>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else if (width == 4) {
+      Filter2DVertical4xH<6>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else {
+      Filter2DVertical<6>(intermediate_result, dest, dest_stride, width, height,
+                          taps);
+    }
+  } else if (vertical_taps == 4) {
+    if (width == 2) {
+      Filter2DVertical2xH<4>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else if (width == 4) {
+      Filter2DVertical4xH<4>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else {
+      Filter2DVertical<4>(intermediate_result, dest, dest_stride, width, height,
+                          taps);
+    }
+  } else {  // |vertical_taps| == 2
+    if (width == 2) {
+      Filter2DVertical2xH<2>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else if (width == 4) {
+      Filter2DVertical4xH<2>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else {
+      Filter2DVertical<2>(intermediate_result, dest, dest_stride, width, height,
+                          taps);
+    }
   }
 }
 
-template <int tap_lane0, int tap_lane1>
-inline int16x8_t CombineFilterTapsLong(const int16x8_t sum,
-                                       const int16x8_t src0, int16x8_t src1,
-                                       int16x4_t taps0, int16x4_t taps1) {
-  int32x4_t sum_lo = vmovl_s16(vget_low_s16(sum));
-  int32x4_t sum_hi = vmovl_s16(vget_high_s16(sum));
-  const int16x8_t product0 = vmulq_lane_s16(src0, taps0, tap_lane0);
-  const int16x8_t product1 = vmulq_lane_s16(src1, taps1, tap_lane1);
-  const int32x4_t center_vals_lo =
-      vaddl_s16(vget_low_s16(product0), vget_low_s16(product1));
-  const int32x4_t center_vals_hi =
-      vaddl_s16(vget_high_s16(product0), vget_high_s16(product1));
-
-  sum_lo = vaddq_s32(sum_lo, center_vals_lo);
-  sum_hi = vaddq_s32(sum_hi, center_vals_hi);
-  return vcombine_s16(vrshrn_n_s32(sum_lo, 3), vrshrn_n_s32(sum_hi, 3));
-}
-
-// TODO(b/133525024): Replace usage of this function with version that uses
-// unsigned trick, once cl/263050071 is submitted.
-template <int num_taps>
-inline int16x8_t SumTapsCompound(const int16x8_t* const src,
-                                 const int16x8_t taps) {
-  int16x8_t sum = vdupq_n_s16(1 << (kBitdepth8 + kFilterBits - 1));
-  if (num_taps == 8) {
-    const int16x4_t taps_lo = vget_low_s16(taps);
-    const int16x4_t taps_hi = vget_high_s16(taps);
-    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 0);
-    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 1);
-    sum = vmlaq_lane_s16(sum, src[2], taps_lo, 2);
-    sum = vmlaq_lane_s16(sum, src[5], taps_hi, 1);
-    sum = vmlaq_lane_s16(sum, src[6], taps_hi, 2);
-    sum = vmlaq_lane_s16(sum, src[7], taps_hi, 3);
-
-    // Center taps may sum to as much as 160, which pollutes the sign bit in
-    // int16 types.
-    sum = CombineFilterTapsLong<3, 0>(sum, src[3], src[4], taps_lo, taps_hi);
-  } else if (num_taps == 6) {
-    const int16x4_t taps_lo = vget_low_s16(taps);
-    const int16x4_t taps_hi = vget_high_s16(taps);
-    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 0);
-    sum = vmlaq_lane_s16(sum, src[1], taps_lo, 1);
-    sum = vmlaq_lane_s16(sum, src[4], taps_hi, 0);
-    sum = vmlaq_lane_s16(sum, src[5], taps_hi, 1);
-
-    // Center taps in filter 0 may sum to as much as 148, which pollutes the
-    // sign bit in int16 types. This is not true of filter 1.
-    sum = CombineFilterTapsLong<2, 3>(sum, src[2], src[3], taps_lo, taps_lo);
-  } else if (num_taps == 4) {
-    const int16x4_t taps_lo = vget_low_s16(taps);
-    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 0);
-    sum = vmlaq_lane_s16(sum, src[3], taps_lo, 3);
-
-    // Center taps.
-    sum = vqaddq_s16(sum, vmulq_lane_s16(src[1], taps_lo, 1));
-    sum = vrshrq_n_s16(vqaddq_s16(sum, vmulq_lane_s16(src[2], taps_lo, 2)),
-                       kInterRoundBitsHorizontal);
-  } else {
-    assert(num_taps == 2);
-    // All the taps are positive so there is no concern regarding saturation.
-    const int16x4_t taps_lo = vget_low_s16(taps);
-    sum = vmlaq_lane_s16(sum, src[0], taps_lo, 0);
-    sum = vrshrq_n_s16(vmlaq_lane_s16(sum, src[1], taps_lo, 1),
-                       kInterRoundBitsHorizontal);
+// There are many opportunities for overreading in scaled convolve, because the
+// range of starting points for filter windows is anywhere from 0 to 16 for 8
+// destination pixels, and the window sizes range from 2 to 8. To accommodate
+// this range concisely, we use |grade_x| to mean the most steps in src that can
+// be traversed in a single |step_x| increment, i.e. 1 or 2. When grade_x is 2,
+// we are guaranteed to exceed 8 whole steps in src for every 8 |step_x|
+// increments. The first load covers the initial elements of src_x, while the
+// final load covers the taps.
+template <int grade_x>
+inline uint8x8x3_t LoadSrcVals(const uint8_t* src_x) {
+  uint8x8x3_t ret;
+  const uint8x16_t src_val = vld1q_u8(src_x);
+  ret.val[0] = vget_low_u8(src_val);
+  ret.val[1] = vget_high_u8(src_val);
+  if (grade_x > 1) {
+    ret.val[2] = vld1_u8(src_x + 16);
   }
-  return sum;
+  return ret;
 }
 
-// |grade_x| determines an upper limit on how many whole-pixel steps will be
-// realized with 8 |step_x| increments.
-template <int filter_index, int num_taps, int grade_x>
-inline void ConvolveHorizontalScaled_NEON(const uint8_t* src,
-                                          const ptrdiff_t src_stride,
-                                          const int width, const int subpixel_x,
-                                          const int step_x,
-                                          const int intermediate_height,
-                                          int16_t* dst) {
-  const int dst_stride = kMaxSuperBlockSizeInPixels;
-  const int kernel_offset = (8 - num_taps) / 2;
-  const int ref_x = subpixel_x >> kScaleSubPixelBits;
-  int y = intermediate_height;
-  do {  // y > 0
-    int p = subpixel_x;
-    int prev_p = p;
-    int x = 0;
-    int16x8_t s[(grade_x + 1) * 8];
-    const uint8_t* src_x =
-        &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
-    // TODO(petersonab,b/139707209): Fix source buffer overreads.
-    // For example, when |height| == 2 and |num_taps| == 8 then
-    // |intermediate_height| == 9. On the second pass this will load and
-    // transpose 7 rows past where |src| may end.
-    Load8x8(src_x, src_stride, s);
-    Transpose8x8(s);
-    if (grade_x > 1) {
-      Load8x8(src_x + 8, src_stride, &s[8]);
-      Transpose8x8(&s[8]);
-    }
-
-    do {  // x < width
-      int16x8_t result[8];
-      src_x = &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
-      // process 8 src_x steps
-      Load8x8(src_x + 8, src_stride, &s[8]);
-      Transpose8x8(&s[8]);
-      if (grade_x > 1) {
-        Load8x8(src_x + 16, src_stride, &s[16]);
-        Transpose8x8(&s[16]);
-      }
-      // Remainder after whole index increments.
-      int pixel_offset = p & ((1 << kScaleSubPixelBits) - 1);
-      for (int z = 0; z < 8; ++z) {
-        const int16x8_t filter = vld1q_s16(
-            &kSubPixelFilters[filter_index][(p >> 6) & 0xF][kernel_offset]);
-        result[z] = SumTapsCompound<num_taps>(
-            &s[pixel_offset >> kScaleSubPixelBits], filter);
-        pixel_offset += step_x;
-        p += step_x;
-      }
-
-      // Transpose the 8x8 filtered values back to dst.
-      Transpose8x8(result);
-
-      vst1q_s16(&dst[x + 0 * dst_stride], result[0]);
-      vst1q_s16(&dst[x + 1 * dst_stride], result[1]);
-      vst1q_s16(&dst[x + 2 * dst_stride], result[2]);
-      vst1q_s16(&dst[x + 3 * dst_stride], result[3]);
-      vst1q_s16(&dst[x + 4 * dst_stride], result[4]);
-      vst1q_s16(&dst[x + 5 * dst_stride], result[5]);
-      vst1q_s16(&dst[x + 6 * dst_stride], result[6]);
-      vst1q_s16(&dst[x + 7 * dst_stride], result[7]);
-
-      for (int i = 0; i < 8; ++i) {
-        s[i] =
-            s[(p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits) + i];
-        if (grade_x > 1) {
-          s[i + 8] = s[(p >> kScaleSubPixelBits) -
-                       (prev_p >> kScaleSubPixelBits) + i + 8];
-        }
-      }
-
-      prev_p = p;
-      x += 8;
-    } while (x < width);
-
-    src += src_stride * 8;
-    dst += dst_stride * 8;
-    y -= 8;
-  } while (y > 0);
-}
-
+// Pre-transpose the 2 tap filters in |kAbsHalfSubPixelFilters|[3]
 inline uint8x16_t GetPositive2TapFilter(const int tap_index) {
   assert(tap_index < 2);
-  constexpr uint8_t kSubPixel2TapFilterColumns[2][16] = {
-      {128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8},
-      {0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120}};
+  alignas(
+      16) static constexpr uint8_t kAbsHalfSubPixel2TapFilterColumns[2][16] = {
+      {64, 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4},
+      {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}};
 
-  return vld1q_u8(kSubPixel2TapFilterColumns[tap_index]);
+  return vld1q_u8(kAbsHalfSubPixel2TapFilterColumns[tap_index]);
 }
 
+template <int grade_x>
 inline void ConvolveKernelHorizontal2Tap(const uint8_t* src,
                                          const ptrdiff_t src_stride,
                                          const int width, const int subpixel_x,
                                          const int step_x,
                                          const int intermediate_height,
                                          int16_t* intermediate) {
-  const int kIntermediateStride = kMaxSuperBlockSizeInPixels;
   // Account for the 0-taps that precede the 2 nonzero taps.
   const int kernel_offset = 3;
   const int ref_x = subpixel_x >> kScaleSubPixelBits;
   const int step_x8 = step_x << 3;
   const uint8x16_t filter_taps0 = GetPositive2TapFilter(0);
   const uint8x16_t filter_taps1 = GetPositive2TapFilter(1);
-  const uint16x8_t sum = vdupq_n_u16(1 << (kBitdepth8 + kFilterBits - 1));
-  uint16x8_t index_steps = vmulq_n_u16(vmovl_u8(vcreate_u8(0x0706050403020100)),
-                                       static_cast<uint16_t>(step_x));
-
+  const uint16x8_t index_steps = vmulq_n_u16(
+      vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
-  for (int x = 0, p = subpixel_x; x < width; x += 8, p += step_x8) {
+
+  int p = subpixel_x;
+  if (width <= 4) {
     const uint8_t* src_x =
         &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
-    int16_t* intermediate_x = intermediate + x;
     // Only add steps to the 10-bit truncated p to avoid overflow.
     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
@@ -1189,45 +871,86 @@
     // For each x, a lane of tapsK has
     // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
     // on x.
-    const uint8x8_t taps0 = VQTbl1U8(filter_taps0, filter_indices);
-    const uint8x8_t taps1 = VQTbl1U8(filter_taps1, filter_indices);
-    for (int y = 0; y < intermediate_height; ++y) {
+    const uint8x8_t taps[2] = {VQTbl1U8(filter_taps0, filter_indices),
+                               VQTbl1U8(filter_taps1, filter_indices)};
+    int y = 0;
+    do {
       // Load a pool of samples to select from using stepped indices.
-      uint8x16_t src_vals = vld1q_u8(src_x);
+      const uint8x16_t src_vals = vld1q_u8(src_x);
       const uint8x8_t src_indices =
           vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
 
       // For each x, a lane of srcK contains src_x[k].
-      const uint8x8_t src0 = VQTbl1U8(src_vals, src_indices);
-      const uint8x8_t src1 =
-          VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)));
+      const uint8x8_t src[2] = {
+          VQTbl1U8(src_vals, src_indices),
+          VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)))};
 
-      const uint16x8_t product0 = vmlal_u8(sum, taps0, src0);
-      // product0 + product1
-      const uint16x8_t result = vmlal_u8(product0, taps1, src1);
+      vst1q_s16(intermediate,
+                vrshrq_n_s16(SumOnePassTaps</*filter_index=*/3>(src, taps),
+                             kInterRoundBitsHorizontal - 1));
+      src_x += src_stride;
+      intermediate += kIntermediateStride;
+    } while (++y < intermediate_height);
+    return;
+  }
 
-      vst1q_s16(intermediate_x, vreinterpretq_s16_u16(vrshrq_n_u16(result, 3)));
+  // |width| >= 8
+  int x = 0;
+  do {
+    const uint8_t* src_x =
+        &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+    int16_t* intermediate_x = intermediate + x;
+    // Only add steps to the 10-bit truncated p to avoid overflow.
+    const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+    const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+    const uint8x8_t filter_indices =
+        vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
+                filter_index_mask);
+    // This is a special case. The 2-tap filter has no negative taps, so we
+    // can use unsigned values.
+    // For each x, a lane of tapsK has
+    // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
+    // on x.
+    const uint8x8_t taps[2] = {VQTbl1U8(filter_taps0, filter_indices),
+                               VQTbl1U8(filter_taps1, filter_indices)};
+    int y = 0;
+    do {
+      // Load a pool of samples to select from using stepped indices.
+      const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x);
+      const uint8x8_t src_indices =
+          vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
+
+      // For each x, a lane of srcK contains src_x[k].
+      const uint8x8_t src[2] = {
+          vtbl3_u8(src_vals, src_indices),
+          vtbl3_u8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)))};
+
+      vst1q_s16(intermediate_x,
+                vrshrq_n_s16(SumOnePassTaps</*filter_index=*/3>(src, taps),
+                             kInterRoundBitsHorizontal - 1));
       src_x += src_stride;
       intermediate_x += kIntermediateStride;
-    }
-  }
+    } while (++y < intermediate_height);
+    x += 8;
+    p += step_x8;
+  } while (x < width);
 }
 
+// Pre-transpose the 4 tap filters in |kAbsHalfSubPixelFilters|[5].
 inline uint8x16_t GetPositive4TapFilter(const int tap_index) {
   assert(tap_index < 4);
-  constexpr uint8_t kSubPixel4TapPositiveFilterColumns[4][16] = {
-      {0, 30, 26, 22, 20, 18, 16, 14, 12, 12, 10, 8, 6, 4, 4, 2},
-      {128, 62, 62, 62, 60, 58, 56, 54, 52, 48, 46, 44, 42, 40, 36, 34},
-      {0, 34, 36, 40, 42, 44, 46, 48, 52, 54, 56, 58, 60, 62, 62, 62},
-      {0, 2, 4, 4, 6, 8, 10, 12, 12, 14, 16, 18, 20, 22, 26, 30}};
+  alignas(
+      16) static constexpr uint8_t kSubPixel4TapPositiveFilterColumns[4][16] = {
+      {0, 15, 13, 11, 10, 9, 8, 7, 6, 6, 5, 4, 3, 2, 2, 1},
+      {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
+      {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
+      {0, 1, 2, 2, 3, 4, 5, 6, 6, 7, 8, 9, 10, 11, 13, 15}};
 
-  uint8x16_t filter_taps =
-      vld1q_u8(kSubPixel4TapPositiveFilterColumns[tap_index]);
-  return filter_taps;
+  return vld1q_u8(kSubPixel4TapPositiveFilterColumns[tap_index]);
 }
 
 // This filter is only possible when width <= 4.
-inline void ConvolveKernelHorizontalPositive4Tap(
+void ConvolveKernelHorizontalPositive4Tap(
     const uint8_t* src, const ptrdiff_t src_stride, const int subpixel_x,
     const int step_x, const int intermediate_height, int16_t* intermediate) {
   const int kernel_offset = 2;
@@ -1237,69 +960,60 @@
   const uint8x16_t filter_taps1 = GetPositive4TapFilter(1);
   const uint8x16_t filter_taps2 = GetPositive4TapFilter(2);
   const uint8x16_t filter_taps3 = GetPositive4TapFilter(3);
-  uint16x8_t index_steps = vmulq_n_u16(vmovl_u8(vcreate_u8(0x0706050403020100)),
-                                       static_cast<uint16_t>(step_x));
-  int p = subpixel_x;
-  const uint16x8_t base = vdupq_n_u16(1 << (kBitdepth8 + kFilterBits - 1));
+  const uint16x8_t index_steps = vmulq_n_u16(
+      vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+  const int p = subpixel_x;
   // First filter is special, just a 128 tap on the center.
   const uint8_t* src_x =
       &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
   // Only add steps to the 10-bit truncated p to avoid overflow.
   const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
   const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
-  const uint8x8_t filter_indices =
-      vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask);
+  const uint8x8_t filter_indices = vand_u8(
+      vshrn_n_u16(subpel_index_offsets, kFilterIndexShift), filter_index_mask);
   // Note that filter_id depends on x.
   // For each x, tapsK has kSubPixelFilters[filter_index][filter_id][k].
-  const uint8x8_t taps0 = VQTbl1U8(filter_taps0, filter_indices);
-  const uint8x8_t taps1 = VQTbl1U8(filter_taps1, filter_indices);
-  const uint8x8_t taps2 = VQTbl1U8(filter_taps2, filter_indices);
-  const uint8x8_t taps3 = VQTbl1U8(filter_taps3, filter_indices);
+  const uint8x8_t taps[4] = {VQTbl1U8(filter_taps0, filter_indices),
+                             VQTbl1U8(filter_taps1, filter_indices),
+                             VQTbl1U8(filter_taps2, filter_indices),
+                             VQTbl1U8(filter_taps3, filter_indices)};
 
   const uint8x8_t src_indices =
       vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
-  for (int y = 0; y < intermediate_height; ++y) {
+  int y = 0;
+  do {
     // Load a pool of samples to select from using stepped index vectors.
-    uint8x16_t src_vals = vld1q_u8(src_x);
+    const uint8x16_t src_vals = vld1q_u8(src_x);
 
     // For each x, srcK contains src_x[k] where k=1.
     // Whereas taps come from different arrays, src pixels are drawn from the
     // same contiguous line.
-    const uint8x8_t src0 = VQTbl1U8(src_vals, src_indices);
-    const uint8x8_t src1 =
-        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)));
-    const uint8x8_t src2 =
-        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(2)));
-    const uint8x8_t src3 =
-        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(3)));
+    const uint8x8_t src[4] = {
+        VQTbl1U8(src_vals, src_indices),
+        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1))),
+        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(2))),
+        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(3)))};
 
-    uint16x8_t sum = vmlal_u8(base, taps0, src0);
-    sum = vmlal_u8(sum, taps1, src1);
-    sum = vmlal_u8(sum, taps2, src2);
-    sum = vmlal_u8(sum, taps3, src3);
-
-    vst1_s16(intermediate,
-             vreinterpret_s16_u16(vrshr_n_u16(vget_low_u16(sum), 3)));
+    vst1q_s16(intermediate,
+              vrshrq_n_s16(SumOnePassTaps</*filter_index=*/5>(src, taps),
+                           kInterRoundBitsHorizontal - 1));
 
     src_x += src_stride;
     intermediate += kIntermediateStride;
-  }
+  } while (++y < intermediate_height);
 }
 
+// Pre-transpose the 4 tap filters in |kAbsHalfSubPixelFilters|[4].
 inline uint8x16_t GetSigned4TapFilter(const int tap_index) {
   assert(tap_index < 4);
-  // The first and fourth taps of each filter are negative. However
-  // 128 does not fit in an 8-bit signed integer. Thus we use subtraction to
-  // keep everything unsigned.
-  constexpr uint8_t kSubPixel4TapSignedFilterColumns[4][16] = {
-      {0, 4, 8, 10, 12, 12, 14, 12, 12, 10, 10, 10, 8, 6, 4, 2},
-      {128, 126, 122, 116, 110, 102, 94, 84, 76, 66, 58, 48, 38, 28, 18, 8},
-      {0, 8, 18, 28, 38, 48, 58, 66, 76, 84, 94, 102, 110, 116, 122, 126},
-      {0, 2, 4, 6, 8, 10, 10, 10, 12, 12, 14, 12, 12, 10, 8, 4}};
+  alignas(16) static constexpr uint8_t
+      kAbsHalfSubPixel4TapSignedFilterColumns[4][16] = {
+          {0, 2, 4, 5, 6, 6, 7, 6, 6, 5, 5, 5, 4, 3, 2, 1},
+          {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
+          {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
+          {0, 1, 2, 3, 4, 5, 5, 5, 6, 6, 7, 6, 6, 5, 4, 2}};
 
-  uint8x16_t filter_taps =
-      vld1q_u8(kSubPixel4TapSignedFilterColumns[tap_index]);
-  return filter_taps;
+  return vld1q_u8(kAbsHalfSubPixel4TapSignedFilterColumns[tap_index]);
 }
 
 // This filter is only possible when width <= 4.
@@ -1313,66 +1027,480 @@
   const uint8x16_t filter_taps1 = GetSigned4TapFilter(1);
   const uint8x16_t filter_taps2 = GetSigned4TapFilter(2);
   const uint8x16_t filter_taps3 = GetSigned4TapFilter(3);
-  const uint16x8_t index_steps = vmulq_n_u16(vmovl_u8(vcreate_u8(0x03020100)),
-                                             static_cast<uint16_t>(step_x));
+  const uint16x4_t index_steps = vmul_n_u16(vcreate_u16(0x0003000200010000),
+                                            static_cast<uint16_t>(step_x));
 
-  const uint16x8_t base = vdupq_n_u16(1 << (kBitdepth8 + kFilterBits - 1));
-  int p = subpixel_x;
+  const int p = subpixel_x;
   const uint8_t* src_x =
       &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
   // Only add steps to the 10-bit truncated p to avoid overflow.
-  const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
-  const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+  const uint16x4_t p_fraction = vdup_n_u16(p & 1023);
+  const uint16x4_t subpel_index_offsets = vadd_u16(index_steps, p_fraction);
+  const uint8x8_t filter_index_offsets = vshrn_n_u16(
+      vcombine_u16(subpel_index_offsets, vdup_n_u16(0)), kFilterIndexShift);
   const uint8x8_t filter_indices =
-      vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask);
+      vand_u8(filter_index_offsets, filter_index_mask);
   // Note that filter_id depends on x.
   // For each x, tapsK has kSubPixelFilters[filter_index][filter_id][k].
-  const uint8x8_t taps0 = VQTbl1U8(filter_taps0, filter_indices);
-  const uint8x8_t taps1 = VQTbl1U8(filter_taps1, filter_indices);
-  const uint8x8_t taps2 = VQTbl1U8(filter_taps2, filter_indices);
-  const uint8x8_t taps3 = VQTbl1U8(filter_taps3, filter_indices);
-  for (int y = 0; y < intermediate_height; ++y) {
+  const uint8x8_t taps[4] = {VQTbl1U8(filter_taps0, filter_indices),
+                             VQTbl1U8(filter_taps1, filter_indices),
+                             VQTbl1U8(filter_taps2, filter_indices),
+                             VQTbl1U8(filter_taps3, filter_indices)};
+
+  const uint8x8_t src_indices_base =
+      vshr_n_u8(filter_index_offsets, kScaleSubPixelBits - kFilterIndexShift);
+
+  const uint8x8_t src_indices[4] = {src_indices_base,
+                                    vadd_u8(src_indices_base, vdup_n_u8(1)),
+                                    vadd_u8(src_indices_base, vdup_n_u8(2)),
+                                    vadd_u8(src_indices_base, vdup_n_u8(3))};
+
+  int y = 0;
+  do {
     // Load a pool of samples to select from using stepped indices.
-    uint8x16_t src_vals = vld1q_u8(src_x);
-    const uint8x8_t src_indices =
-        vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
+    const uint8x16_t src_vals = vld1q_u8(src_x);
 
     // For each x, srcK contains src_x[k] where k=1.
     // Whereas taps come from different arrays, src pixels are drawn from the
     // same contiguous line.
-    const uint8x8_t src0 = VQTbl1U8(src_vals, src_indices);
-    const uint8x8_t src1 =
-        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)));
-    const uint8x8_t src2 =
-        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(2)));
-    const uint8x8_t src3 =
-        VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(3)));
+    const uint8x8_t src[4] = {
+        VQTbl1U8(src_vals, src_indices[0]), VQTbl1U8(src_vals, src_indices[1]),
+        VQTbl1U8(src_vals, src_indices[2]), VQTbl1U8(src_vals, src_indices[3])};
 
-    // Offsetting by base permits a guaranteed positive.
-    uint16x8_t sum = vmlsl_u8(base, taps0, src0);
-    sum = vmlal_u8(sum, taps1, src1);
-    sum = vmlal_u8(sum, taps2, src2);
-    sum = vmlsl_u8(sum, taps3, src3);
-
-    vst1_s16(intermediate,
-             vreinterpret_s16_u16(vrshr_n_u16(vget_low_u16(sum), 3)));
+    vst1q_s16(intermediate,
+              vrshrq_n_s16(SumOnePassTaps</*filter_index=*/4>(src, taps),
+                           kInterRoundBitsHorizontal - 1));
     src_x += src_stride;
     intermediate += kIntermediateStride;
-  }
+  } while (++y < intermediate_height);
 }
 
-void ConvolveCompoundScale2D_NEON(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int horizontal_filter_index, const int vertical_filter_index,
-    const int inter_round_bits_vertical, const int subpixel_x,
-    const int subpixel_y, const int step_x, const int step_y, const int width,
-    const int height, void* prediction, const ptrdiff_t pred_stride) {
+// Pre-transpose the 6 tap filters in |kAbsHalfSubPixelFilters|[0].
+inline uint8x16_t GetSigned6TapFilter(const int tap_index) {
+  assert(tap_index < 6);
+  alignas(16) static constexpr uint8_t
+      kAbsHalfSubPixel6TapSignedFilterColumns[6][16] = {
+          {0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0},
+          {0, 3, 5, 6, 7, 7, 8, 7, 7, 6, 6, 6, 5, 4, 2, 1},
+          {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
+          {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
+          {0, 1, 2, 4, 5, 6, 6, 6, 7, 7, 8, 7, 7, 6, 5, 3},
+          {0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}};
+
+  return vld1q_u8(kAbsHalfSubPixel6TapSignedFilterColumns[tap_index]);
+}
+
+// This filter is only possible when width >= 8.
+template <int grade_x>
+inline void ConvolveKernelHorizontalSigned6Tap(
+    const uint8_t* src, const ptrdiff_t src_stride, const int width,
+    const int subpixel_x, const int step_x, const int intermediate_height,
+    int16_t* intermediate) {
+  const int kernel_offset = 1;
+  const uint8x8_t one = vdup_n_u8(1);
+  const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+  const int ref_x = subpixel_x >> kScaleSubPixelBits;
+  const int step_x8 = step_x << 3;
+  uint8x16_t filter_taps[6];
+  for (int i = 0; i < 6; ++i) {
+    filter_taps[i] = GetSigned6TapFilter(i);
+  }
+  const uint16x8_t index_steps = vmulq_n_u16(
+      vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+
+  int x = 0;
+  int p = subpixel_x;
+  do {
+    // Avoid overloading outside the reference boundaries. This means
+    // |trailing_width| can be up to 24.
+    const uint8_t* src_x =
+        &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+    int16_t* intermediate_x = intermediate + x;
+    // Only add steps to the 10-bit truncated p to avoid overflow.
+    const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+    const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+    const uint8x8_t src_indices =
+        vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
+    uint8x8_t src_lookup[6];
+    src_lookup[0] = src_indices;
+    for (int i = 1; i < 6; ++i) {
+      src_lookup[i] = vadd_u8(src_lookup[i - 1], one);
+    }
+
+    const uint8x8_t filter_indices =
+        vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
+                filter_index_mask);
+    // For each x, a lane of taps[k] has
+    // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
+    // on x.
+    uint8x8_t taps[6];
+    for (int i = 0; i < 6; ++i) {
+      taps[i] = VQTbl1U8(filter_taps[i], filter_indices);
+    }
+    int y = 0;
+    do {
+      // Load a pool of samples to select from using stepped indices.
+      const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x);
+
+      const uint8x8_t src[6] = {
+          vtbl3_u8(src_vals, src_lookup[0]), vtbl3_u8(src_vals, src_lookup[1]),
+          vtbl3_u8(src_vals, src_lookup[2]), vtbl3_u8(src_vals, src_lookup[3]),
+          vtbl3_u8(src_vals, src_lookup[4]), vtbl3_u8(src_vals, src_lookup[5])};
+
+      vst1q_s16(intermediate_x,
+                vrshrq_n_s16(SumOnePassTaps</*filter_index=*/0>(src, taps),
+                             kInterRoundBitsHorizontal - 1));
+      src_x += src_stride;
+      intermediate_x += kIntermediateStride;
+    } while (++y < intermediate_height);
+    x += 8;
+    p += step_x8;
+  } while (x < width);
+}
+
+// Pre-transpose the 6 tap filters in |kAbsHalfSubPixelFilters|[1]. This filter
+// has mixed positive and negative outer taps which are handled in
+// GetMixed6TapFilter().
+inline uint8x16_t GetPositive6TapFilter(const int tap_index) {
+  assert(tap_index < 6);
+  alignas(16) static constexpr uint8_t
+      kAbsHalfSubPixel6TapPositiveFilterColumns[4][16] = {
+          {0, 14, 13, 11, 10, 9, 8, 8, 7, 6, 5, 4, 3, 2, 2, 1},
+          {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
+          {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
+          {0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9, 10, 11, 13, 14}};
+
+  return vld1q_u8(kAbsHalfSubPixel6TapPositiveFilterColumns[tap_index]);
+}
+
+inline int8x16_t GetMixed6TapFilter(const int tap_index) {
+  assert(tap_index < 2);
+  alignas(
+      16) static constexpr int8_t kHalfSubPixel6TapMixedFilterColumns[2][16] = {
+      {0, 1, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0},
+      {0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 1}};
+
+  return vld1q_s8(kHalfSubPixel6TapMixedFilterColumns[tap_index]);
+}
+
+// This filter is only possible when width >= 8.
+template <int grade_x>
+inline void ConvolveKernelHorizontalMixed6Tap(
+    const uint8_t* src, const ptrdiff_t src_stride, const int width,
+    const int subpixel_x, const int step_x, const int intermediate_height,
+    int16_t* intermediate) {
+  const int kernel_offset = 1;
+  const uint8x8_t one = vdup_n_u8(1);
+  const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+  const int ref_x = subpixel_x >> kScaleSubPixelBits;
+  const int step_x8 = step_x << 3;
+  uint8x8_t taps[4];
+  int16x8_t mixed_taps[2];
+  uint8x16_t positive_filter_taps[4];
+  for (int i = 0; i < 4; ++i) {
+    positive_filter_taps[i] = GetPositive6TapFilter(i);
+  }
+  int8x16_t mixed_filter_taps[2];
+  mixed_filter_taps[0] = GetMixed6TapFilter(0);
+  mixed_filter_taps[1] = GetMixed6TapFilter(1);
+  const uint16x8_t index_steps = vmulq_n_u16(
+      vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+
+  int x = 0;
+  int p = subpixel_x;
+  do {
+    const uint8_t* src_x =
+        &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+    int16_t* intermediate_x = intermediate + x;
+    // Only add steps to the 10-bit truncated p to avoid overflow.
+    const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+    const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+    const uint8x8_t src_indices =
+        vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
+    uint8x8_t src_lookup[6];
+    src_lookup[0] = src_indices;
+    for (int i = 1; i < 6; ++i) {
+      src_lookup[i] = vadd_u8(src_lookup[i - 1], one);
+    }
+
+    const uint8x8_t filter_indices =
+        vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
+                filter_index_mask);
+    // For each x, a lane of taps[k] has
+    // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
+    // on x.
+    for (int i = 0; i < 4; ++i) {
+      taps[i] = VQTbl1U8(positive_filter_taps[i], filter_indices);
+    }
+    mixed_taps[0] = vmovl_s8(VQTbl1S8(mixed_filter_taps[0], filter_indices));
+    mixed_taps[1] = vmovl_s8(VQTbl1S8(mixed_filter_taps[1], filter_indices));
+
+    int y = 0;
+    do {
+      // Load a pool of samples to select from using stepped indices.
+      const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x);
+
+      int16x8_t sum_mixed = vmulq_s16(
+          mixed_taps[0], ZeroExtend(vtbl3_u8(src_vals, src_lookup[0])));
+      sum_mixed = vmlaq_s16(sum_mixed, mixed_taps[1],
+                            ZeroExtend(vtbl3_u8(src_vals, src_lookup[5])));
+      uint16x8_t sum = vreinterpretq_u16_s16(sum_mixed);
+      sum = vmlal_u8(sum, taps[0], vtbl3_u8(src_vals, src_lookup[1]));
+      sum = vmlal_u8(sum, taps[1], vtbl3_u8(src_vals, src_lookup[2]));
+      sum = vmlal_u8(sum, taps[2], vtbl3_u8(src_vals, src_lookup[3]));
+      sum = vmlal_u8(sum, taps[3], vtbl3_u8(src_vals, src_lookup[4]));
+
+      vst1q_s16(intermediate_x, vrshrq_n_s16(vreinterpretq_s16_u16(sum),
+                                             kInterRoundBitsHorizontal - 1));
+      src_x += src_stride;
+      intermediate_x += kIntermediateStride;
+    } while (++y < intermediate_height);
+    x += 8;
+    p += step_x8;
+  } while (x < width);
+}
+
+// Pre-transpose the 8 tap filters in |kAbsHalfSubPixelFilters|[2].
+inline uint8x16_t GetSigned8TapFilter(const int tap_index) {
+  assert(tap_index < 8);
+  alignas(16) static constexpr uint8_t
+      kAbsHalfSubPixel8TapSignedFilterColumns[8][16] = {
+          {0, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 0},
+          {0, 1, 3, 4, 5, 5, 5, 5, 6, 5, 4, 4, 3, 3, 2, 1},
+          {0, 3, 6, 9, 11, 11, 12, 12, 12, 11, 10, 9, 7, 5, 3, 1},
+          {64, 63, 62, 60, 58, 54, 50, 45, 40, 35, 30, 24, 19, 13, 8, 4},
+          {0, 4, 8, 13, 19, 24, 30, 35, 40, 45, 50, 54, 58, 60, 62, 63},
+          {0, 1, 3, 5, 7, 9, 10, 11, 12, 12, 12, 11, 11, 9, 6, 3},
+          {0, 1, 2, 3, 3, 4, 4, 5, 6, 5, 5, 5, 5, 4, 3, 1},
+          {0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1}};
+
+  return vld1q_u8(kAbsHalfSubPixel8TapSignedFilterColumns[tap_index]);
+}
+
+// This filter is only possible when width >= 8.
+template <int grade_x>
+inline void ConvolveKernelHorizontalSigned8Tap(
+    const uint8_t* src, const ptrdiff_t src_stride, const int width,
+    const int subpixel_x, const int step_x, const int intermediate_height,
+    int16_t* intermediate) {
+  const uint8x8_t one = vdup_n_u8(1);
+  const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+  const int ref_x = subpixel_x >> kScaleSubPixelBits;
+  const int step_x8 = step_x << 3;
+  uint8x8_t taps[8];
+  uint8x16_t filter_taps[8];
+  for (int i = 0; i < 8; ++i) {
+    filter_taps[i] = GetSigned8TapFilter(i);
+  }
+  const uint16x8_t index_steps = vmulq_n_u16(
+      vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+  int x = 0;
+  int p = subpixel_x;
+  do {
+    const uint8_t* src_x = &src[(p >> kScaleSubPixelBits) - ref_x];
+    int16_t* intermediate_x = intermediate + x;
+    // Only add steps to the 10-bit truncated p to avoid overflow.
+    const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+    const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+    const uint8x8_t src_indices =
+        vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
+    uint8x8_t src_lookup[8];
+    src_lookup[0] = src_indices;
+    for (int i = 1; i < 8; ++i) {
+      src_lookup[i] = vadd_u8(src_lookup[i - 1], one);
+    }
+
+    const uint8x8_t filter_indices =
+        vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
+                filter_index_mask);
+    // For each x, a lane of taps[k] has
+    // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
+    // on x.
+    for (int i = 0; i < 8; ++i) {
+      taps[i] = VQTbl1U8(filter_taps[i], filter_indices);
+    }
+
+    int y = 0;
+    do {
+      // Load a pool of samples to select from using stepped indices.
+      const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x);
+
+      const uint8x8_t src[8] = {
+          vtbl3_u8(src_vals, src_lookup[0]), vtbl3_u8(src_vals, src_lookup[1]),
+          vtbl3_u8(src_vals, src_lookup[2]), vtbl3_u8(src_vals, src_lookup[3]),
+          vtbl3_u8(src_vals, src_lookup[4]), vtbl3_u8(src_vals, src_lookup[5]),
+          vtbl3_u8(src_vals, src_lookup[6]), vtbl3_u8(src_vals, src_lookup[7])};
+
+      vst1q_s16(intermediate_x,
+                vrshrq_n_s16(SumOnePassTaps</*filter_index=*/2>(src, taps),
+                             kInterRoundBitsHorizontal - 1));
+      src_x += src_stride;
+      intermediate_x += kIntermediateStride;
+    } while (++y < intermediate_height);
+    x += 8;
+    p += step_x8;
+  } while (x < width);
+}
+
+// This function handles blocks of width 2 or 4.
+template <int num_taps, int grade_y, int width, bool is_compound>
+void ConvolveVerticalScale4xH(const int16_t* src, const int subpixel_y,
+                              const int filter_index, const int step_y,
+                              const int height, void* dest,
+                              const ptrdiff_t dest_stride) {
+  constexpr ptrdiff_t src_stride = kIntermediateStride;
+  const int16_t* src_y = src;
+  // |dest| is 16-bit in compound mode, Pixel otherwise.
+  uint16_t* dest16_y = static_cast<uint16_t*>(dest);
+  uint8_t* dest_y = static_cast<uint8_t*>(dest);
+  int16x4_t s[num_taps + grade_y];
+
+  int p = subpixel_y & 1023;
+  int prev_p = p;
+  int y = 0;
+  do {  // y < height
+    for (int i = 0; i < num_taps; ++i) {
+      s[i] = vld1_s16(src_y + i * src_stride);
+    }
+    int filter_id = (p >> 6) & kSubPixelMask;
+    int16x8_t filter =
+        vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
+    int16x4_t sums = Sum2DVerticalTaps4<num_taps, is_compound>(s, filter);
+    if (is_compound) {
+      assert(width != 2);
+      const uint16x4_t result = vreinterpret_u16_s16(sums);
+      vst1_u16(dest16_y, result);
+    } else {
+      const uint8x8_t result = vqmovun_s16(vcombine_s16(sums, sums));
+      if (width == 2) {
+        Store2<0>(dest_y, result);
+      } else {
+        StoreLo4(dest_y, result);
+      }
+    }
+    p += step_y;
+    const int p_diff =
+        (p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits);
+    prev_p = p;
+    // Here we load extra source in case it is needed. If |p_diff| == 0, these
+    // values will be unused, but it's faster to load than to branch.
+    s[num_taps] = vld1_s16(src_y + num_taps * src_stride);
+    if (grade_y > 1) {
+      s[num_taps + 1] = vld1_s16(src_y + (num_taps + 1) * src_stride);
+    }
+    dest16_y += dest_stride;
+    dest_y += dest_stride;
+
+    filter_id = (p >> 6) & kSubPixelMask;
+    filter = vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
+    sums = Sum2DVerticalTaps4<num_taps, is_compound>(&s[p_diff], filter);
+    if (is_compound) {
+      assert(width != 2);
+      const uint16x4_t result = vreinterpret_u16_s16(sums);
+      vst1_u16(dest16_y, result);
+    } else {
+      const uint8x8_t result = vqmovun_s16(vcombine_s16(sums, sums));
+      if (width == 2) {
+        Store2<0>(dest_y, result);
+      } else {
+        StoreLo4(dest_y, result);
+      }
+    }
+    p += step_y;
+    src_y = src + (p >> kScaleSubPixelBits) * src_stride;
+    prev_p = p;
+    dest16_y += dest_stride;
+    dest_y += dest_stride;
+
+    y += 2;
+  } while (y < height);
+}
+
+template <int num_taps, int grade_y, bool is_compound>
+inline void ConvolveVerticalScale(const int16_t* src, const int width,
+                                  const int subpixel_y, const int filter_index,
+                                  const int step_y, const int height,
+                                  void* dest, const ptrdiff_t dest_stride) {
+  constexpr ptrdiff_t src_stride = kIntermediateStride;
+  // A possible improvement is to use arithmetic to decide how many times to
+  // apply filters to same source before checking whether to load new srcs.
+  // However, this will only improve performance with very small step sizes.
+  int16x8_t s[num_taps + grade_y];
+  // |dest| is 16-bit in compound mode, Pixel otherwise.
+  uint16_t* dest16_y;
+  uint8_t* dest_y;
+
+  int x = 0;
+  do {  // x < width
+    const int16_t* src_x = src + x;
+    const int16_t* src_y = src_x;
+    dest16_y = static_cast<uint16_t*>(dest) + x;
+    dest_y = static_cast<uint8_t*>(dest) + x;
+    int p = subpixel_y & 1023;
+    int prev_p = p;
+    int y = 0;
+    do {  // y < height
+      for (int i = 0; i < num_taps; ++i) {
+        s[i] = vld1q_s16(src_y + i * src_stride);
+      }
+      int filter_id = (p >> 6) & kSubPixelMask;
+      int16x8_t filter =
+          vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
+      int16x8_t sum = SimpleSum2DVerticalTaps<num_taps, is_compound>(s, filter);
+      if (is_compound) {
+        vst1q_u16(dest16_y, vreinterpretq_u16_s16(sum));
+      } else {
+        vst1_u8(dest_y, vqmovun_s16(sum));
+      }
+      p += step_y;
+      const int p_diff =
+          (p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits);
+      // |grade_y| > 1 always means p_diff > 0, so load vectors that may be
+      // needed. Otherwise, we only need to load one vector because |p_diff|
+      // can't exceed 1.
+      s[num_taps] = vld1q_s16(src_y + num_taps * src_stride);
+      if (grade_y > 1) {
+        s[num_taps + 1] = vld1q_s16(src_y + (num_taps + 1) * src_stride);
+      }
+      dest16_y += dest_stride;
+      dest_y += dest_stride;
+
+      filter_id = (p >> 6) & kSubPixelMask;
+      filter = vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
+      sum = SimpleSum2DVerticalTaps<num_taps, is_compound>(&s[p_diff], filter);
+      if (is_compound) {
+        vst1q_u16(dest16_y, vreinterpretq_u16_s16(sum));
+      } else {
+        vst1_u8(dest_y, vqmovun_s16(sum));
+      }
+      p += step_y;
+      src_y = src_x + (p >> kScaleSubPixelBits) * src_stride;
+      prev_p = p;
+      dest16_y += dest_stride;
+      dest_y += dest_stride;
+
+      y += 2;
+    } while (y < height);
+    x += 8;
+  } while (x < width);
+}
+
+template <bool is_compound>
+void ConvolveScale2D_NEON(const void* const reference,
+                          const ptrdiff_t reference_stride,
+                          const int horizontal_filter_index,
+                          const int vertical_filter_index, const int subpixel_x,
+                          const int subpixel_y, const int step_x,
+                          const int step_y, const int width, const int height,
+                          void* prediction, const ptrdiff_t pred_stride) {
+  const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
+  const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
+  assert(step_x <= 2048);
+  const int num_vert_taps = GetNumTapsInFilter(vert_filter_index);
   const int intermediate_height =
       (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
        kScaleSubPixelBits) +
-      kSubPixelTaps;
-  // TODO(b/133525024): Decide whether it's worth branching to a special case
-  // when step_x or step_y is 1024.
+      num_vert_taps;
   assert(step_x <= 2048);
   // The output of the horizontal filter, i.e. the intermediate_result, is
   // guaranteed to fit in int16_t.
@@ -1384,49 +1512,71 @@
   // When width > 4, the valid filter index range is always [0, 3].
   // When width <= 4, the valid filter index range is always [3, 5].
   // Similarly for height.
-  const int kIntermediateStride = kMaxSuperBlockSizeInPixels;
   int filter_index = GetFilterIndex(horizontal_filter_index, width);
   int16_t* intermediate = intermediate_result;
-  const auto* src = static_cast<const uint8_t*>(reference);
   const ptrdiff_t src_stride = reference_stride;
-  auto* dest = static_cast<uint16_t*>(prediction);
+  const auto* src = static_cast<const uint8_t*>(reference);
+  const int vert_kernel_offset = (8 - num_vert_taps) / 2;
+  src += vert_kernel_offset * src_stride;
+
+  // Derive the maximum value of |step_x| at which all source values fit in one
+  // 16-byte load. Final index is src_x + |num_taps| - 1 < 16
+  // step_x*7 is the final base subpel index for the shuffle mask for filter
+  // inputs in each iteration on large blocks. When step_x is large, we need a
+  // larger structure and use a larger table lookup in order to gather all
+  // filter inputs.
+  // |num_taps| - 1 is the shuffle index of the final filter input.
+  const int num_horiz_taps = GetNumTapsInFilter(horiz_filter_index);
+  const int kernel_start_ceiling = 16 - num_horiz_taps;
+  // This truncated quotient |grade_x_threshold| selects |step_x| such that:
+  // (step_x * 7) >> kScaleSubPixelBits < single load limit
+  const int grade_x_threshold =
+      (kernel_start_ceiling << kScaleSubPixelBits) / 7;
   switch (filter_index) {
     case 0:
-      if (step_x < 1024) {
-        ConvolveHorizontalScaled_NEON<0, 6, 1>(
+      if (step_x > grade_x_threshold) {
+        ConvolveKernelHorizontalSigned6Tap<2>(
             src, src_stride, width, subpixel_x, step_x, intermediate_height,
             intermediate);
       } else {
-        ConvolveHorizontalScaled_NEON<0, 6, 2>(
+        ConvolveKernelHorizontalSigned6Tap<1>(
             src, src_stride, width, subpixel_x, step_x, intermediate_height,
             intermediate);
       }
       break;
     case 1:
-      if (step_x < 1024) {
-        ConvolveHorizontalScaled_NEON<1, 6, 1>(
-            src, src_stride, width, subpixel_x, step_x, intermediate_height,
-            intermediate);
+      if (step_x > grade_x_threshold) {
+        ConvolveKernelHorizontalMixed6Tap<2>(src, src_stride, width, subpixel_x,
+                                             step_x, intermediate_height,
+                                             intermediate);
+
       } else {
-        ConvolveHorizontalScaled_NEON<1, 6, 2>(
-            src, src_stride, width, subpixel_x, step_x, intermediate_height,
-            intermediate);
+        ConvolveKernelHorizontalMixed6Tap<1>(src, src_stride, width, subpixel_x,
+                                             step_x, intermediate_height,
+                                             intermediate);
       }
       break;
     case 2:
-      if (step_x <= 1024) {
-        ConvolveHorizontalScaled_NEON<2, 8, 1>(
+      if (step_x > grade_x_threshold) {
+        ConvolveKernelHorizontalSigned8Tap<2>(
             src, src_stride, width, subpixel_x, step_x, intermediate_height,
             intermediate);
       } else {
-        ConvolveHorizontalScaled_NEON<2, 8, 2>(
+        ConvolveKernelHorizontalSigned8Tap<1>(
             src, src_stride, width, subpixel_x, step_x, intermediate_height,
             intermediate);
       }
       break;
     case 3:
-      ConvolveKernelHorizontal2Tap(src, src_stride, width, subpixel_x, step_x,
-                                   intermediate_height, intermediate);
+      if (step_x > grade_x_threshold) {
+        ConvolveKernelHorizontal2Tap<2>(src, src_stride, width, subpixel_x,
+                                        step_x, intermediate_height,
+                                        intermediate);
+      } else {
+        ConvolveKernelHorizontal2Tap<1>(src, src_stride, width, subpixel_x,
+                                        step_x, intermediate_height,
+                                        intermediate);
+      }
       break;
     case 4:
       assert(width <= 4);
@@ -1441,23 +1591,135 @@
   // Vertical filter.
   filter_index = GetFilterIndex(vertical_filter_index, height);
   intermediate = intermediate_result;
-  const int offset_bits = kBitdepth8 + 2 * kFilterBits - 3;
-  for (int y = 0, p = subpixel_y & 1023; y < height; ++y, p += step_y) {
-    const int filter_id = (p >> 6) & kSubPixelMask;
-    for (int x = 0; x < width; ++x) {
-      // An offset to guarantee the sum is non negative.
-      int sum = 1 << offset_bits;
-      for (int k = 0; k < kSubPixelTaps; ++k) {
-        sum +=
-            kSubPixelFilters[filter_index][filter_id][k] *
-            intermediate[((p >> kScaleSubPixelBits) + k) * kIntermediateStride +
-                         x];
+
+  switch (filter_index) {
+    case 0:
+    case 1:
+      if (step_y <= 1024) {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale4xH<6, 1, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale4xH<6, 1, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<6, 1, is_compound>(
+              intermediate, width, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        }
+      } else {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale4xH<6, 2, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale4xH<6, 2, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<6, 2, is_compound>(
+              intermediate, width, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        }
       }
-      assert(sum >= 0 && sum < (1 << (offset_bits + 2)));
-      dest[x] = static_cast<uint16_t>(
-          RightShiftWithRounding(sum, inter_round_bits_vertical));
-    }
-    dest += pred_stride;
+      break;
+    case 2:
+      if (step_y <= 1024) {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale4xH<8, 1, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale4xH<8, 1, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<8, 1, is_compound>(
+              intermediate, width, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        }
+      } else {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale4xH<8, 2, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale4xH<8, 2, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<8, 2, is_compound>(
+              intermediate, width, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        }
+      }
+      break;
+    case 3:
+      if (step_y <= 1024) {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale4xH<2, 1, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale4xH<2, 1, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<2, 1, is_compound>(
+              intermediate, width, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        }
+      } else {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale4xH<2, 2, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale4xH<2, 2, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<2, 2, is_compound>(
+              intermediate, width, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        }
+      }
+      break;
+    case 4:
+    default:
+      assert(filter_index == 4 || filter_index == 5);
+      assert(height <= 4);
+      if (step_y <= 1024) {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale4xH<4, 1, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale4xH<4, 1, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<4, 1, is_compound>(
+              intermediate, width, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        }
+      } else {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale4xH<4, 2, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale4xH<4, 2, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<4, 2, is_compound>(
+              intermediate, width, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        }
+      }
   }
 }
 
@@ -1465,65 +1727,75 @@
                              const ptrdiff_t reference_stride,
                              const int horizontal_filter_index,
                              const int /*vertical_filter_index*/,
-                             const int /*inter_round_bits_vertical*/,
                              const int subpixel_x, const int /*subpixel_y*/,
-                             const int /*step_x*/, const int /*step_y*/,
                              const int width, const int height,
                              void* prediction, const ptrdiff_t pred_stride) {
-  // For 8 (and 10) bit calculations |inter_round_bits_horizontal| is 3.
   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
   // Set |src| to the outermost tap.
   const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
   auto* dest = static_cast<uint8_t*>(prediction);
 
-  HorizontalPass<false, true>(src, reference_stride, dest, pred_stride, width,
-                              height, subpixel_x, filter_index);
+  DoHorizontalPass(src, reference_stride, dest, pred_stride, width, height,
+                   subpixel_x, filter_index);
 }
 
-template <int min_width, int num_taps>
+// The 1D compound shift is always |kInterRoundBitsHorizontal|, even for 1D
+// Vertical calculations.
+uint16x8_t Compound1DShift(const int16x8_t sum) {
+  return vreinterpretq_u16_s16(
+      vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1));
+}
+
+template <int filter_index, bool is_compound = false,
+          bool negative_outside_taps = false>
 void FilterVertical(const uint8_t* src, const ptrdiff_t src_stride,
-                    uint8_t* dst, const ptrdiff_t dst_stride, const int width,
-                    const int height, const int16x8_t taps) {
-  constexpr int next_row = num_taps - 1;
-  // |src| points to the outermost tap of the first value. When doing fewer than
-  // 8 taps it needs to be adjusted.
-  if (num_taps == 6) {
-    src += src_stride;
-  } else if (num_taps == 4) {
-    src += 2 * src_stride;
-  } else if (num_taps == 2) {
-    src += 3 * src_stride;
-  }
+                    void* const dst, const ptrdiff_t dst_stride,
+                    const int width, const int height,
+                    const uint8x8_t* const taps) {
+  const int num_taps = GetNumTapsInFilter(filter_index);
+  const int next_row = num_taps - 1;
+  auto* dst8 = static_cast<uint8_t*>(dst);
+  auto* dst16 = static_cast<uint16_t*>(dst);
+  assert(width >= 8);
 
   int x = 0;
   do {
-    int16x8_t srcs[8];
-    srcs[0] = ZeroExtend(vld1_u8(src + x));
+    const uint8_t* src_x = src + x;
+    uint8x8_t srcs[8];
+    srcs[0] = vld1_u8(src_x);
+    src_x += src_stride;
     if (num_taps >= 4) {
-      srcs[1] = ZeroExtend(vld1_u8(src + x + src_stride));
-      srcs[2] = ZeroExtend(vld1_u8(src + x + 2 * src_stride));
+      srcs[1] = vld1_u8(src_x);
+      src_x += src_stride;
+      srcs[2] = vld1_u8(src_x);
+      src_x += src_stride;
       if (num_taps >= 6) {
-        srcs[3] = ZeroExtend(vld1_u8(src + x + 3 * src_stride));
-        srcs[4] = ZeroExtend(vld1_u8(src + x + 4 * src_stride));
+        srcs[3] = vld1_u8(src_x);
+        src_x += src_stride;
+        srcs[4] = vld1_u8(src_x);
+        src_x += src_stride;
         if (num_taps == 8) {
-          srcs[5] = ZeroExtend(vld1_u8(src + x + 5 * src_stride));
-          srcs[6] = ZeroExtend(vld1_u8(src + x + 6 * src_stride));
+          srcs[5] = vld1_u8(src_x);
+          src_x += src_stride;
+          srcs[6] = vld1_u8(src_x);
+          src_x += src_stride;
         }
       }
     }
 
     int y = 0;
     do {
-      srcs[next_row] =
-          ZeroExtend(vld1_u8(src + x + (y + next_row) * src_stride));
+      srcs[next_row] = vld1_u8(src_x);
+      src_x += src_stride;
 
-      const int16x8_t sums = SumTaps<num_taps>(srcs, taps);
-      const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits);
-
-      if (min_width == 4) {
-        StoreLo4(dst + x + y * dst_stride, results);
+      const int16x8_t sums =
+          SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+      if (is_compound) {
+        const uint16x8_t results = Compound1DShift(sums);
+        vst1q_u16(dst16 + x + y * dst_stride, results);
       } else {
-        vst1_u8(dst + x + y * dst_stride, results);
+        const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+        vst1_u8(dst8 + x + y * dst_stride, results);
       }
 
       srcs[0] = srcs[1];
@@ -1544,6 +1816,394 @@
   } while (x < width);
 }
 
+template <int filter_index, bool is_compound = false,
+          bool negative_outside_taps = false>
+void FilterVertical4xH(const uint8_t* src, const ptrdiff_t src_stride,
+                       void* const dst, const ptrdiff_t dst_stride,
+                       const int height, const uint8x8_t* const taps) {
+  const int num_taps = GetNumTapsInFilter(filter_index);
+  auto* dst8 = static_cast<uint8_t*>(dst);
+  auto* dst16 = static_cast<uint16_t*>(dst);
+
+  uint8x8_t srcs[9];
+
+  if (num_taps == 2) {
+    srcs[2] = vdup_n_u8(0);
+
+    srcs[0] = Load4(src);
+    src += src_stride;
+
+    int y = 0;
+    do {
+      srcs[0] = Load4<1>(src, srcs[0]);
+      src += src_stride;
+      srcs[2] = Load4<0>(src, srcs[2]);
+      src += src_stride;
+      srcs[1] = vext_u8(srcs[0], srcs[2], 4);
+
+      const int16x8_t sums =
+          SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+      if (is_compound) {
+        const uint16x8_t results = Compound1DShift(sums);
+
+        vst1q_u16(dst16, results);
+        dst16 += 4 << 1;
+      } else {
+        const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+        StoreLo4(dst8, results);
+        dst8 += dst_stride;
+        StoreHi4(dst8, results);
+        dst8 += dst_stride;
+      }
+
+      srcs[0] = srcs[2];
+      y += 2;
+    } while (y < height);
+  } else if (num_taps == 4) {
+    srcs[4] = vdup_n_u8(0);
+
+    srcs[0] = Load4(src);
+    src += src_stride;
+    srcs[0] = Load4<1>(src, srcs[0]);
+    src += src_stride;
+    srcs[2] = Load4(src);
+    src += src_stride;
+    srcs[1] = vext_u8(srcs[0], srcs[2], 4);
+
+    int y = 0;
+    do {
+      srcs[2] = Load4<1>(src, srcs[2]);
+      src += src_stride;
+      srcs[4] = Load4<0>(src, srcs[4]);
+      src += src_stride;
+      srcs[3] = vext_u8(srcs[2], srcs[4], 4);
+
+      const int16x8_t sums =
+          SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+      if (is_compound) {
+        const uint16x8_t results = Compound1DShift(sums);
+
+        vst1q_u16(dst16, results);
+        dst16 += 4 << 1;
+      } else {
+        const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+        StoreLo4(dst8, results);
+        dst8 += dst_stride;
+        StoreHi4(dst8, results);
+        dst8 += dst_stride;
+      }
+
+      srcs[0] = srcs[2];
+      srcs[1] = srcs[3];
+      srcs[2] = srcs[4];
+      y += 2;
+    } while (y < height);
+  } else if (num_taps == 6) {
+    srcs[6] = vdup_n_u8(0);
+
+    srcs[0] = Load4(src);
+    src += src_stride;
+    srcs[0] = Load4<1>(src, srcs[0]);
+    src += src_stride;
+    srcs[2] = Load4(src);
+    src += src_stride;
+    srcs[1] = vext_u8(srcs[0], srcs[2], 4);
+    srcs[2] = Load4<1>(src, srcs[2]);
+    src += src_stride;
+    srcs[4] = Load4(src);
+    src += src_stride;
+    srcs[3] = vext_u8(srcs[2], srcs[4], 4);
+
+    int y = 0;
+    do {
+      srcs[4] = Load4<1>(src, srcs[4]);
+      src += src_stride;
+      srcs[6] = Load4<0>(src, srcs[6]);
+      src += src_stride;
+      srcs[5] = vext_u8(srcs[4], srcs[6], 4);
+
+      const int16x8_t sums =
+          SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+      if (is_compound) {
+        const uint16x8_t results = Compound1DShift(sums);
+
+        vst1q_u16(dst16, results);
+        dst16 += 4 << 1;
+      } else {
+        const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+        StoreLo4(dst8, results);
+        dst8 += dst_stride;
+        StoreHi4(dst8, results);
+        dst8 += dst_stride;
+      }
+
+      srcs[0] = srcs[2];
+      srcs[1] = srcs[3];
+      srcs[2] = srcs[4];
+      srcs[3] = srcs[5];
+      srcs[4] = srcs[6];
+      y += 2;
+    } while (y < height);
+  } else if (num_taps == 8) {
+    srcs[8] = vdup_n_u8(0);
+
+    srcs[0] = Load4(src);
+    src += src_stride;
+    srcs[0] = Load4<1>(src, srcs[0]);
+    src += src_stride;
+    srcs[2] = Load4(src);
+    src += src_stride;
+    srcs[1] = vext_u8(srcs[0], srcs[2], 4);
+    srcs[2] = Load4<1>(src, srcs[2]);
+    src += src_stride;
+    srcs[4] = Load4(src);
+    src += src_stride;
+    srcs[3] = vext_u8(srcs[2], srcs[4], 4);
+    srcs[4] = Load4<1>(src, srcs[4]);
+    src += src_stride;
+    srcs[6] = Load4(src);
+    src += src_stride;
+    srcs[5] = vext_u8(srcs[4], srcs[6], 4);
+
+    int y = 0;
+    do {
+      srcs[6] = Load4<1>(src, srcs[6]);
+      src += src_stride;
+      srcs[8] = Load4<0>(src, srcs[8]);
+      src += src_stride;
+      srcs[7] = vext_u8(srcs[6], srcs[8], 4);
+
+      const int16x8_t sums =
+          SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+      if (is_compound) {
+        const uint16x8_t results = Compound1DShift(sums);
+
+        vst1q_u16(dst16, results);
+        dst16 += 4 << 1;
+      } else {
+        const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+        StoreLo4(dst8, results);
+        dst8 += dst_stride;
+        StoreHi4(dst8, results);
+        dst8 += dst_stride;
+      }
+
+      srcs[0] = srcs[2];
+      srcs[1] = srcs[3];
+      srcs[2] = srcs[4];
+      srcs[3] = srcs[5];
+      srcs[4] = srcs[6];
+      srcs[5] = srcs[7];
+      srcs[6] = srcs[8];
+      y += 2;
+    } while (y < height);
+  }
+}
+
+template <int filter_index, bool negative_outside_taps = false>
+void FilterVertical2xH(const uint8_t* src, const ptrdiff_t src_stride,
+                       void* const dst, const ptrdiff_t dst_stride,
+                       const int height, const uint8x8_t* const taps) {
+  const int num_taps = GetNumTapsInFilter(filter_index);
+  auto* dst8 = static_cast<uint8_t*>(dst);
+
+  uint8x8_t srcs[9];
+
+  if (num_taps == 2) {
+    srcs[2] = vdup_n_u8(0);
+
+    srcs[0] = Load2(src);
+    src += src_stride;
+
+    int y = 0;
+    do {
+      srcs[0] = Load2<1>(src, srcs[0]);
+      src += src_stride;
+      srcs[0] = Load2<2>(src, srcs[0]);
+      src += src_stride;
+      srcs[0] = Load2<3>(src, srcs[0]);
+      src += src_stride;
+      srcs[2] = Load2<0>(src, srcs[2]);
+      src += src_stride;
+      srcs[1] = vext_u8(srcs[0], srcs[2], 2);
+
+      // This uses srcs[0]..srcs[1].
+      const int16x8_t sums =
+          SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+      const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+      Store2<0>(dst8, results);
+      dst8 += dst_stride;
+      Store2<1>(dst8, results);
+      if (height == 2) return;
+      dst8 += dst_stride;
+      Store2<2>(dst8, results);
+      dst8 += dst_stride;
+      Store2<3>(dst8, results);
+      dst8 += dst_stride;
+
+      srcs[0] = srcs[2];
+      y += 4;
+    } while (y < height);
+  } else if (num_taps == 4) {
+    srcs[4] = vdup_n_u8(0);
+
+    srcs[0] = Load2(src);
+    src += src_stride;
+    srcs[0] = Load2<1>(src, srcs[0]);
+    src += src_stride;
+    srcs[0] = Load2<2>(src, srcs[0]);
+    src += src_stride;
+
+    int y = 0;
+    do {
+      srcs[0] = Load2<3>(src, srcs[0]);
+      src += src_stride;
+      srcs[4] = Load2<0>(src, srcs[4]);
+      src += src_stride;
+      srcs[1] = vext_u8(srcs[0], srcs[4], 2);
+      srcs[4] = Load2<1>(src, srcs[4]);
+      src += src_stride;
+      srcs[2] = vext_u8(srcs[0], srcs[4], 4);
+      srcs[4] = Load2<2>(src, srcs[4]);
+      src += src_stride;
+      srcs[3] = vext_u8(srcs[0], srcs[4], 6);
+
+      // This uses srcs[0]..srcs[3].
+      const int16x8_t sums =
+          SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+      const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+      Store2<0>(dst8, results);
+      dst8 += dst_stride;
+      Store2<1>(dst8, results);
+      if (height == 2) return;
+      dst8 += dst_stride;
+      Store2<2>(dst8, results);
+      dst8 += dst_stride;
+      Store2<3>(dst8, results);
+      dst8 += dst_stride;
+
+      srcs[0] = srcs[4];
+      y += 4;
+    } while (y < height);
+  } else if (num_taps == 6) {
+    // During the vertical pass the number of taps is restricted when
+    // |height| <= 4.
+    assert(height > 4);
+    srcs[8] = vdup_n_u8(0);
+
+    srcs[0] = Load2(src);
+    src += src_stride;
+    srcs[0] = Load2<1>(src, srcs[0]);
+    src += src_stride;
+    srcs[0] = Load2<2>(src, srcs[0]);
+    src += src_stride;
+    srcs[0] = Load2<3>(src, srcs[0]);
+    src += src_stride;
+    srcs[4] = Load2(src);
+    src += src_stride;
+    srcs[1] = vext_u8(srcs[0], srcs[4], 2);
+
+    int y = 0;
+    do {
+      srcs[4] = Load2<1>(src, srcs[4]);
+      src += src_stride;
+      srcs[2] = vext_u8(srcs[0], srcs[4], 4);
+      srcs[4] = Load2<2>(src, srcs[4]);
+      src += src_stride;
+      srcs[3] = vext_u8(srcs[0], srcs[4], 6);
+      srcs[4] = Load2<3>(src, srcs[4]);
+      src += src_stride;
+      srcs[8] = Load2<0>(src, srcs[8]);
+      src += src_stride;
+      srcs[5] = vext_u8(srcs[4], srcs[8], 2);
+
+      // This uses srcs[0]..srcs[5].
+      const int16x8_t sums =
+          SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+      const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+      Store2<0>(dst8, results);
+      dst8 += dst_stride;
+      Store2<1>(dst8, results);
+      dst8 += dst_stride;
+      Store2<2>(dst8, results);
+      dst8 += dst_stride;
+      Store2<3>(dst8, results);
+      dst8 += dst_stride;
+
+      srcs[0] = srcs[4];
+      srcs[1] = srcs[5];
+      srcs[4] = srcs[8];
+      y += 4;
+    } while (y < height);
+  } else if (num_taps == 8) {
+    // During the vertical pass the number of taps is restricted when
+    // |height| <= 4.
+    assert(height > 4);
+    srcs[8] = vdup_n_u8(0);
+
+    srcs[0] = Load2(src);
+    src += src_stride;
+    srcs[0] = Load2<1>(src, srcs[0]);
+    src += src_stride;
+    srcs[0] = Load2<2>(src, srcs[0]);
+    src += src_stride;
+    srcs[0] = Load2<3>(src, srcs[0]);
+    src += src_stride;
+    srcs[4] = Load2(src);
+    src += src_stride;
+    srcs[1] = vext_u8(srcs[0], srcs[4], 2);
+    srcs[4] = Load2<1>(src, srcs[4]);
+    src += src_stride;
+    srcs[2] = vext_u8(srcs[0], srcs[4], 4);
+    srcs[4] = Load2<2>(src, srcs[4]);
+    src += src_stride;
+    srcs[3] = vext_u8(srcs[0], srcs[4], 6);
+
+    int y = 0;
+    do {
+      srcs[4] = Load2<3>(src, srcs[4]);
+      src += src_stride;
+      srcs[8] = Load2<0>(src, srcs[8]);
+      src += src_stride;
+      srcs[5] = vext_u8(srcs[4], srcs[8], 2);
+      srcs[8] = Load2<1>(src, srcs[8]);
+      src += src_stride;
+      srcs[6] = vext_u8(srcs[4], srcs[8], 4);
+      srcs[8] = Load2<2>(src, srcs[8]);
+      src += src_stride;
+      srcs[7] = vext_u8(srcs[4], srcs[8], 6);
+
+      // This uses srcs[0]..srcs[7].
+      const int16x8_t sums =
+          SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
+      const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
+
+      Store2<0>(dst8, results);
+      dst8 += dst_stride;
+      Store2<1>(dst8, results);
+      dst8 += dst_stride;
+      Store2<2>(dst8, results);
+      dst8 += dst_stride;
+      Store2<3>(dst8, results);
+      dst8 += dst_stride;
+
+      srcs[0] = srcs[4];
+      srcs[1] = srcs[5];
+      srcs[2] = srcs[6];
+      srcs[3] = srcs[7];
+      srcs[4] = srcs[8];
+      y += 4;
+    } while (y < height);
+  }
+}
+
 // This function is a simplified version of Convolve2D_C.
 // It is called when it is single prediction mode, where only vertical
 // filtering is required.
@@ -1553,107 +2213,129 @@
                            const ptrdiff_t reference_stride,
                            const int /*horizontal_filter_index*/,
                            const int vertical_filter_index,
-                           const int /*inter_round_bits_vertical*/,
                            const int /*subpixel_x*/, const int subpixel_y,
-                           const int /*step_x*/, const int /*step_y*/,
                            const int width, const int height, void* prediction,
                            const ptrdiff_t pred_stride) {
   const int filter_index = GetFilterIndex(vertical_filter_index, height);
+  const int vertical_taps = GetNumTapsInFilter(filter_index);
   const ptrdiff_t src_stride = reference_stride;
-  const auto* src =
-      static_cast<const uint8_t*>(reference) - kVerticalOffset * src_stride;
+  const auto* src = static_cast<const uint8_t*>(reference) -
+                    (vertical_taps / 2 - 1) * src_stride;
   auto* dest = static_cast<uint8_t*>(prediction);
   const ptrdiff_t dest_stride = pred_stride;
   const int filter_id = (subpixel_y >> 6) & kSubPixelMask;
-  // First filter is always a copy.
-  if (filter_id == 0) {
-    // Move |src| down the actual values and not the start of the context.
-    src = static_cast<const uint8_t*>(reference);
-    int y = 0;
-    do {
-      memcpy(dest, src, width * sizeof(src[0]));
-      src += src_stride;
-      dest += dest_stride;
-    } while (++y < height);
-    return;
+  assert(filter_id != 0);
+
+  uint8x8_t taps[8];
+  for (int k = 0; k < kSubPixelTaps; ++k) {
+    taps[k] = vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][filter_id][k]);
   }
 
-  // Break up by # of taps
-  // |filter_index| taps  enum InterpolationFilter
-  //        0       6     kInterpolationFilterEightTap
-  //        1       6     kInterpolationFilterEightTapSmooth
-  //        2       8     kInterpolationFilterEightTapSharp
-  //        3       2     kInterpolationFilterBilinear
-  //        4       4     kInterpolationFilterSwitchable
-  //        5       4     !!! SECRET FILTER !!! only for Wx4.
-  if (width >= 4) {
-    if (filter_index == 2) {  // 8 tap.
-      const int16x8_t taps =
-          vld1q_s16(kSubPixelFilters[filter_index][filter_id]);
-      if (width == 4) {
-        FilterVertical<4, 8>(src, src_stride, dest, dest_stride, width, height,
-                             taps);
-      } else {
-        FilterVertical<8, 8>(src, src_stride, dest, dest_stride, width, height,
-                             taps);
-      }
-    } else if (filter_index < 2) {  // 6 tap.
-      const int16x8_t taps =
-          vld1q_s16(kSubPixelFilters[filter_index][filter_id]);
-      if (width == 4) {
-        FilterVertical<4, 6>(src, src_stride, dest, dest_stride, width, height,
-                             taps);
-      } else {
-        FilterVertical<8, 6>(src, src_stride, dest, dest_stride, width, height,
-                             taps);
-      }
-    } else if (filter_index > 3) {  // 4 tap.
-      // Store taps in vget_low_s16(taps).
-      const int16x8_t taps =
-          vld1q_s16(kSubPixelFilters[filter_index][filter_id] + 2);
-      if (width == 4) {
-        FilterVertical<4, 4>(src, src_stride, dest, dest_stride, width, height,
-                             taps);
-      } else {
-        FilterVertical<8, 4>(src, src_stride, dest, dest_stride, width, height,
-                             taps);
-      }
-    } else {  // 2 tap.
-      // Store taps in vget_low_s16(taps).
-      const int16x8_t taps =
-          vld1q_s16(kSubPixelFilters[filter_index][filter_id] + 2);
-      if (width == 4) {
-        FilterVertical<4, 2>(src, src_stride, dest, dest_stride, width, height,
-                             taps);
-      } else {
-        FilterVertical<8, 2>(src, src_stride, dest, dest_stride, width, height,
-                             taps);
-      }
+  if (filter_index == 0) {  // 6 tap.
+    if (width == 2) {
+      FilterVertical2xH<0>(src, src_stride, dest, dest_stride, height,
+                           taps + 1);
+    } else if (width == 4) {
+      FilterVertical4xH<0>(src, src_stride, dest, dest_stride, height,
+                           taps + 1);
+    } else {
+      FilterVertical<0>(src, src_stride, dest, dest_stride, width, height,
+                        taps + 1);
+    }
+  } else if ((filter_index == 1) &
+             ((filter_id == 1) | (filter_id == 15))) {  // 5 tap.
+    if (width == 2) {
+      FilterVertical2xH<1>(src, src_stride, dest, dest_stride, height,
+                           taps + 1);
+    } else if (width == 4) {
+      FilterVertical4xH<1>(src, src_stride, dest, dest_stride, height,
+                           taps + 1);
+    } else {
+      FilterVertical<1>(src, src_stride, dest, dest_stride, width, height,
+                        taps + 1);
+    }
+  } else if ((filter_index == 1) &
+             ((filter_id == 7) | (filter_id == 8) |
+              (filter_id == 9))) {  // 6 tap with weird negative taps.
+    if (width == 2) {
+      FilterVertical2xH<1,
+                        /*negative_outside_taps=*/true>(
+          src, src_stride, dest, dest_stride, height, taps + 1);
+    } else if (width == 4) {
+      FilterVertical4xH<1, /*is_compound=*/false,
+                        /*negative_outside_taps=*/true>(
+          src, src_stride, dest, dest_stride, height, taps + 1);
+    } else {
+      FilterVertical<1, /*is_compound=*/false, /*negative_outside_taps=*/true>(
+          src, src_stride, dest, dest_stride, width, height, taps + 1);
+    }
+  } else if (filter_index == 2) {  // 8 tap.
+    if (width == 2) {
+      FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps);
+    } else if (width == 4) {
+      FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps);
+    } else {
+      FilterVertical<2>(src, src_stride, dest, dest_stride, width, height,
+                        taps);
+    }
+  } else if (filter_index == 3) {  // 2 tap.
+    if (width == 2) {
+      FilterVertical2xH<3>(src, src_stride, dest, dest_stride, height,
+                           taps + 3);
+    } else if (width == 4) {
+      FilterVertical4xH<3>(src, src_stride, dest, dest_stride, height,
+                           taps + 3);
+    } else {
+      FilterVertical<3>(src, src_stride, dest, dest_stride, width, height,
+                        taps + 3);
+    }
+  } else if (filter_index == 4) {  // 4 tap.
+    // Outside taps are negative.
+    if (width == 2) {
+      FilterVertical2xH<4>(src, src_stride, dest, dest_stride, height,
+                           taps + 2);
+    } else if (width == 4) {
+      FilterVertical4xH<4>(src, src_stride, dest, dest_stride, height,
+                           taps + 2);
+    } else {
+      FilterVertical<4>(src, src_stride, dest, dest_stride, width, height,
+                        taps + 2);
     }
   } else {
-    assert(width == 2);
-    const int taps = NumTapsInFilter(filter_index);
-    src =
-        static_cast<const uint8_t*>(reference) - ((taps / 2) - 1) * src_stride;
-    VerticalPass2xH</*is_2d=*/false>(src, src_stride, dest, pred_stride, height,
-                                     0, filter_index, taps, subpixel_y);
+    // 4 tap. When |filter_index| == 1 the |filter_id| values listed below map
+    // to 4 tap filters.
+    assert(filter_index == 5 ||
+           (filter_index == 1 &&
+            (filter_id == 2 || filter_id == 3 || filter_id == 4 ||
+             filter_id == 5 || filter_id == 6 || filter_id == 10 ||
+             filter_id == 11 || filter_id == 12 || filter_id == 13 ||
+             filter_id == 14)));
+    // According to GetNumTapsInFilter() this has 6 taps but here we are
+    // treating it as though it has 4.
+    if (filter_index == 1) src += src_stride;
+    if (width == 2) {
+      FilterVertical2xH<5>(src, src_stride, dest, dest_stride, height,
+                           taps + 2);
+    } else if (width == 4) {
+      FilterVertical4xH<5>(src, src_stride, dest, dest_stride, height,
+                           taps + 2);
+    } else {
+      FilterVertical<5>(src, src_stride, dest, dest_stride, width, height,
+                        taps + 2);
+    }
   }
 }
 
 void ConvolveCompoundCopy_NEON(
     const void* const reference, const ptrdiff_t reference_stride,
     const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const int /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
-    const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
-    const int width, const int height, void* prediction,
-    const ptrdiff_t pred_stride) {
+    const int /*subpixel_x*/, const int /*subpixel_y*/, const int width,
+    const int height, void* prediction, const ptrdiff_t /*pred_stride*/) {
   const auto* src = static_cast<const uint8_t*>(reference);
   const ptrdiff_t src_stride = reference_stride;
   auto* dest = static_cast<uint16_t*>(prediction);
-  const int bitdepth = 8;
-  const int compound_round_offset =
-      (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
-  const uint16x8_t v_compound_round_offset = vdupq_n_u16(compound_round_offset);
+  constexpr int final_shift =
+      kInterRoundBitsVertical - kInterRoundBitsCompoundVertical;
 
   if (width >= 16) {
     int y = 0;
@@ -1661,226 +2343,161 @@
       int x = 0;
       do {
         const uint8x16_t v_src = vld1q_u8(&src[x]);
-        const uint16x8_t v_src_x16_lo = vshll_n_u8(vget_low_u8(v_src), 4);
-        const uint16x8_t v_src_x16_hi = vshll_n_u8(vget_high_u8(v_src), 4);
         const uint16x8_t v_dest_lo =
-            vaddq_u16(v_src_x16_lo, v_compound_round_offset);
+            vshll_n_u8(vget_low_u8(v_src), final_shift);
         const uint16x8_t v_dest_hi =
-            vaddq_u16(v_src_x16_hi, v_compound_round_offset);
+            vshll_n_u8(vget_high_u8(v_src), final_shift);
         vst1q_u16(&dest[x], v_dest_lo);
         x += 8;
         vst1q_u16(&dest[x], v_dest_hi);
         x += 8;
       } while (x < width);
       src += src_stride;
-      dest += pred_stride;
+      dest += width;
     } while (++y < height);
   } else if (width == 8) {
     int y = 0;
     do {
       const uint8x8_t v_src = vld1_u8(&src[0]);
-      const uint16x8_t v_src_x16 = vshll_n_u8(v_src, 4);
-      vst1q_u16(&dest[0], vaddq_u16(v_src_x16, v_compound_round_offset));
+      const uint16x8_t v_dest = vshll_n_u8(v_src, final_shift);
+      vst1q_u16(&dest[0], v_dest);
       src += src_stride;
-      dest += pred_stride;
+      dest += width;
     } while (++y < height);
-  } else if (width == 4) {
-    const uint8x8_t zero = vdup_n_u8(0);
+  } else { /* width == 4 */
+    uint8x8_t v_src = vdup_n_u8(0);
+
     int y = 0;
     do {
-      const uint8x8_t v_src = LoadLo4(&src[0], zero);
-      const uint16x8_t v_src_x16 = vshll_n_u8(v_src, 4);
-      const uint16x8_t v_dest = vaddq_u16(v_src_x16, v_compound_round_offset);
-      vst1_u16(&dest[0], vget_low_u16(v_dest));
+      v_src = Load4<0>(&src[0], v_src);
       src += src_stride;
-      dest += pred_stride;
-    } while (++y < height);
-  } else {  // width == 2
-    assert(width == 2);
-    int y = 0;
-    do {
-      dest[0] = (src[0] << 4) + compound_round_offset;
-      dest[1] = (src[1] << 4) + compound_round_offset;
+      v_src = Load4<1>(&src[0], v_src);
       src += src_stride;
-      dest += pred_stride;
-    } while (++y < height);
+      const uint16x8_t v_dest = vshll_n_u8(v_src, final_shift);
+      vst1q_u16(&dest[0], v_dest);
+      dest += 4 << 1;
+      y += 2;
+    } while (y < height);
   }
 }
 
-// Input 8 bits and output 16 bits.
-template <int min_width, int num_taps>
-void FilterCompoundVertical(const uint8_t* src, const ptrdiff_t src_stride,
-                            uint16_t* dst, const ptrdiff_t dst_stride,
-                            const int width, const int height,
-                            const int16x8_t taps) {
-  constexpr int next_row = num_taps - 1;
-  // |src| points to the outermost tap of the first value. When doing fewer than
-  // 8 taps it needs to be adjusted.
-  if (num_taps == 6) {
-    src += src_stride;
-  } else if (num_taps == 4) {
-    src += 2 * src_stride;
-  } else if (num_taps == 2) {
-    src += 3 * src_stride;
-  }
-
-  const uint16x8_t compound_round_offset = vdupq_n_u16(1 << 12);
-
-  int x = 0;
-  do {
-    int16x8_t srcs[8];
-    srcs[0] = ZeroExtend(vld1_u8(src + x));
-    if (num_taps >= 4) {
-      srcs[1] = ZeroExtend(vld1_u8(src + x + src_stride));
-      srcs[2] = ZeroExtend(vld1_u8(src + x + 2 * src_stride));
-      if (num_taps >= 6) {
-        srcs[3] = ZeroExtend(vld1_u8(src + x + 3 * src_stride));
-        srcs[4] = ZeroExtend(vld1_u8(src + x + 4 * src_stride));
-        if (num_taps == 8) {
-          srcs[5] = ZeroExtend(vld1_u8(src + x + 5 * src_stride));
-          srcs[6] = ZeroExtend(vld1_u8(src + x + 6 * src_stride));
-        }
-      }
-    }
-
-    int y = 0;
-    do {
-      srcs[next_row] =
-          ZeroExtend(vld1_u8(src + x + (y + next_row) * src_stride));
-
-      const uint16x8_t sums = SumTaps8To16<num_taps>(srcs, taps);
-      const uint16x8_t shifted = vrshrq_n_u16(sums, 3);
-      // In order to keep the sum in 16 bits we add an offset to the sum
-      // (1 << (bitdepth + kFilterBits - 1) == 1 << 14). This ensures that the
-      // results will never be negative.
-      // Normally ConvolveCompoundVertical would add |compound_round_offset| at
-      // the end. Instead we use that to compensate for the initial offset.
-      // (1 << (bitdepth + 4)) + (1 << (bitdepth + 3)) == (1 << 12) + (1 << 11)
-      // After taking into account the shift above:
-      // RightShiftWithRounding(LeftShift(sum, bits_shift),
-      //                        inter_round_bits_vertical)
-      // where bits_shift == kFilterBits - kInterRoundBitsHorizontal == 4
-      // and inter_round_bits_vertical == 7
-      // and simplifying it to RightShiftWithRounding(sum, 3)
-      // we see that the initial offset of 1 << 14 >> 3 == 1 << 11 and
-      // |compound_round_offset| can be simplified to 1 << 12.
-      const uint16x8_t offset = vaddq_u16(shifted, compound_round_offset);
-
-      if (min_width == 4) {
-        vst1_u16(dst + x + y * dst_stride, vget_low_u16(offset));
-      } else {
-        vst1q_u16(dst + x + y * dst_stride, offset);
-      }
-
-      srcs[0] = srcs[1];
-      if (num_taps >= 4) {
-        srcs[1] = srcs[2];
-        srcs[2] = srcs[3];
-        if (num_taps >= 6) {
-          srcs[3] = srcs[4];
-          srcs[4] = srcs[5];
-          if (num_taps == 8) {
-            srcs[5] = srcs[6];
-            srcs[6] = srcs[7];
-          }
-        }
-      }
-    } while (++y < height);
-    x += 8;
-  } while (x < width);
-}
-
 void ConvolveCompoundVertical_NEON(
     const void* const reference, const ptrdiff_t reference_stride,
     const int /*horizontal_filter_index*/, const int vertical_filter_index,
-    const int /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
-    const int subpixel_y, const int /*step_x*/, const int /*step_y*/,
-    const int width, const int height, void* prediction,
-    const ptrdiff_t pred_stride) {
+    const int /*subpixel_x*/, const int subpixel_y, const int width,
+    const int height, void* prediction, const ptrdiff_t /*pred_stride*/) {
   const int filter_index = GetFilterIndex(vertical_filter_index, height);
+  const int vertical_taps = GetNumTapsInFilter(filter_index);
   const ptrdiff_t src_stride = reference_stride;
-  const auto* src =
-      static_cast<const uint8_t*>(reference) - kVerticalOffset * src_stride;
+  const auto* src = static_cast<const uint8_t*>(reference) -
+                    (vertical_taps / 2 - 1) * src_stride;
   auto* dest = static_cast<uint16_t*>(prediction);
   const int filter_id = (subpixel_y >> 6) & kSubPixelMask;
+  assert(filter_id != 0);
 
-  if (width >= 4) {
-    const int16x8_t taps = vld1q_s16(kSubPixelFilters[filter_index][filter_id]);
+  uint8x8_t taps[8];
+  for (int k = 0; k < kSubPixelTaps; ++k) {
+    taps[k] = vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][filter_id][k]);
+  }
 
-    if (filter_index == 2) {  // 8 tap.
-      if (width == 4) {
-        FilterCompoundVertical<4, 8>(src, src_stride, dest, pred_stride, width,
-                                     height, taps);
-      } else {
-        FilterCompoundVertical<8, 8>(src, src_stride, dest, pred_stride, width,
-                                     height, taps);
-      }
-    } else if (filter_index < 2) {  // 6 tap.
-      if (width == 4) {
-        FilterCompoundVertical<4, 6>(src, src_stride, dest, pred_stride, width,
-                                     height, taps);
-      } else {
-        FilterCompoundVertical<8, 6>(src, src_stride, dest, pred_stride, width,
-                                     height, taps);
-      }
-    } else if (filter_index == 3) {  // 2 tap.
-      if (width == 4) {
-        FilterCompoundVertical<4, 2>(src, src_stride, dest, pred_stride, width,
-                                     height, taps);
-      } else {
-        FilterCompoundVertical<8, 2>(src, src_stride, dest, pred_stride, width,
-                                     height, taps);
-      }
-    } else if (filter_index > 3) {  // 4 tap.
-      if (width == 4) {
-        FilterCompoundVertical<4, 4>(src, src_stride, dest, pred_stride, width,
-                                     height, taps);
-      } else {
-        FilterCompoundVertical<8, 4>(src, src_stride, dest, pred_stride, width,
-                                     height, taps);
-      }
+  if (filter_index == 0) {  // 6 tap.
+    if (width == 4) {
+      FilterVertical4xH<0, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps + 1);
+    } else {
+      FilterVertical<0, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps + 1);
+    }
+  } else if ((filter_index == 1) &
+             ((filter_id == 1) | (filter_id == 15))) {  // 5 tap.
+    if (width == 4) {
+      FilterVertical4xH<1, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps + 1);
+    } else {
+      FilterVertical<1, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps + 1);
+    }
+  } else if ((filter_index == 1) &
+             ((filter_id == 7) | (filter_id == 8) |
+              (filter_id == 9))) {  // 6 tap with weird negative taps.
+    if (width == 4) {
+      FilterVertical4xH<1, /*is_compound=*/true,
+                        /*negative_outside_taps=*/true>(src, src_stride, dest,
+                                                        4, height, taps + 1);
+    } else {
+      FilterVertical<1, /*is_compound=*/true, /*negative_outside_taps=*/true>(
+          src, src_stride, dest, width, width, height, taps + 1);
+    }
+  } else if (filter_index == 2) {  // 8 tap.
+    if (width == 4) {
+      FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps);
+    } else {
+      FilterVertical<2, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps);
+    }
+  } else if (filter_index == 3) {  // 2 tap.
+    if (width == 4) {
+      FilterVertical4xH<3, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps + 3);
+    } else {
+      FilterVertical<3, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps + 3);
+    }
+  } else if (filter_index == 4) {  // 4 tap.
+    if (width == 4) {
+      FilterVertical4xH<4, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps + 2);
+    } else {
+      FilterVertical<4, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps + 2);
     }
   } else {
-    assert(width == 2);
-    const int taps = NumTapsInFilter(filter_index);
-    src =
-        static_cast<const uint8_t*>(reference) - ((taps / 2) - 1) * src_stride;
-    VerticalPass2xH</*is_2d=*/false, /*is_compound=*/true>(
-        src, src_stride, dest, pred_stride, height, 0, filter_index, taps,
-        subpixel_y);
+    // 4 tap. When |filter_index| == 1 the |filter_id| values listed below map
+    // to 4 tap filters.
+    assert(filter_index == 5 ||
+           (filter_index == 1 &&
+            (filter_id == 2 || filter_id == 3 || filter_id == 4 ||
+             filter_id == 5 || filter_id == 6 || filter_id == 10 ||
+             filter_id == 11 || filter_id == 12 || filter_id == 13 ||
+             filter_id == 14)));
+    // According to GetNumTapsInFilter() this has 6 taps but here we are
+    // treating it as though it has 4.
+    if (filter_index == 1) src += src_stride;
+    if (width == 4) {
+      FilterVertical4xH<5, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps + 2);
+    } else {
+      FilterVertical<5, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps + 2);
+    }
   }
 }
 
 void ConvolveCompoundHorizontal_NEON(
     const void* const reference, const ptrdiff_t reference_stride,
     const int horizontal_filter_index, const int /*vertical_filter_index*/,
-    const int /*inter_round_bits_vertical*/, const int subpixel_x,
-    const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
-    const int width, const int height, void* prediction,
-    const ptrdiff_t pred_stride) {
+    const int subpixel_x, const int /*subpixel_y*/, const int width,
+    const int height, void* prediction, const ptrdiff_t /*pred_stride*/) {
   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
   const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
   auto* dest = static_cast<uint16_t*>(prediction);
 
-  HorizontalPass(src, reference_stride, dest, pred_stride, width, height,
-                 subpixel_x, filter_index);
+  DoHorizontalPass</*is_2d=*/false, /*is_compound=*/true>(
+      src, reference_stride, dest, width, width, height, subpixel_x,
+      filter_index);
 }
 
-void ConvolveCompound2D_NEON(const void* const reference,
-                             const ptrdiff_t reference_stride,
-                             const int horizontal_filter_index,
-                             const int vertical_filter_index,
-                             const int inter_round_bits_vertical,
-                             const int subpixel_x, const int subpixel_y,
-                             const int /*step_x*/, const int /*step_y*/,
-                             const int width, const int height,
-                             void* prediction, const ptrdiff_t pred_stride) {
+void ConvolveCompound2D_NEON(
+    const void* const reference, const ptrdiff_t reference_stride,
+    const int horizontal_filter_index, const int vertical_filter_index,
+    const int subpixel_x, const int subpixel_y, const int width,
+    const int height, void* prediction, const ptrdiff_t /*pred_stride*/) {
   // The output of the horizontal filter, i.e. the intermediate_result, is
   // guaranteed to fit in int16_t.
   uint16_t
       intermediate_result[kMaxSuperBlockSizeInPixels *
                           (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
-  const int intermediate_stride = kMaxSuperBlockSizeInPixels;
 
   // Horizontal filter.
   // Filter types used for width <= 4 are different from those for width > 4.
@@ -1889,66 +2506,586 @@
   // Similarly for height.
   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
-  const int horizontal_taps = NumTapsInFilter(horiz_filter_index);
-  const int vertical_taps = NumTapsInFilter(vert_filter_index);
-  uint16_t* intermediate = intermediate_result;
+  const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
   const int intermediate_height = height + vertical_taps - 1;
   const ptrdiff_t src_stride = reference_stride;
-  const auto* src = static_cast<const uint8_t*>(reference) -
-                    kVerticalOffset * src_stride - kHorizontalOffset;
+  const auto* const src = static_cast<const uint8_t*>(reference) -
+                          (vertical_taps / 2 - 1) * src_stride -
+                          kHorizontalOffset;
+
+  DoHorizontalPass</*is_2d=*/true, /*is_compound=*/true>(
+      src, src_stride, intermediate_result, width, width, intermediate_height,
+      subpixel_x, horiz_filter_index);
+
+  // Vertical filter.
   auto* dest = static_cast<uint16_t*>(prediction);
-  int filter_id = (subpixel_x >> 6) & kSubPixelMask;
+  const int filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
+  assert(filter_id != 0);
 
-  if (width >= 4) {
-    // TODO(johannkoenig): Use |width| for |intermediate_stride|.
-    src = static_cast<const uint8_t*>(reference) -
-          (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset;
-    HorizontalPass<true>(src, src_stride, intermediate_result,
-                         intermediate_stride, width, intermediate_height,
-                         subpixel_x, horiz_filter_index);
+  const ptrdiff_t dest_stride = width;
+  const int16x8_t taps =
+      vmovl_s8(vld1_s8(kHalfSubPixelFilters[vert_filter_index][filter_id]));
 
-    // Vertical filter.
-    intermediate = intermediate_result;
-    filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
-
-    const ptrdiff_t dest_stride = pred_stride;
-    const int16x8_t taps =
-        vld1q_s16(kSubPixelFilters[vert_filter_index][filter_id]);
-
-    if (vertical_taps == 8) {
+  if (vertical_taps == 8) {
+    if (width == 4) {
+      Filter2DVertical4xH<8, /*is_compound=*/true>(intermediate_result, dest,
+                                                   dest_stride, height, taps);
+    } else {
       Filter2DVertical<8, /*is_compound=*/true>(
-          intermediate, intermediate_stride, dest, dest_stride, width, height,
-          taps, inter_round_bits_vertical);
-    } else if (vertical_taps == 6) {
-      Filter2DVertical<6, /*is_compound=*/true>(
-          intermediate, intermediate_stride, dest, dest_stride, width, height,
-          taps, inter_round_bits_vertical);
-    } else if (vertical_taps == 4) {
-      Filter2DVertical<4, /*is_compound=*/true>(
-          intermediate, intermediate_stride, dest, dest_stride, width, height,
-          taps, inter_round_bits_vertical);
-    } else {  // |vertical_taps| == 2
-      Filter2DVertical<2, /*is_compound=*/true>(
-          intermediate, intermediate_stride, dest, dest_stride, width, height,
-          taps, inter_round_bits_vertical);
+          intermediate_result, dest, dest_stride, width, height, taps);
     }
+  } else if (vertical_taps == 6) {
+    if (width == 4) {
+      Filter2DVertical4xH<6, /*is_compound=*/true>(intermediate_result, dest,
+                                                   dest_stride, height, taps);
+    } else {
+      Filter2DVertical<6, /*is_compound=*/true>(
+          intermediate_result, dest, dest_stride, width, height, taps);
+    }
+  } else if (vertical_taps == 4) {
+    if (width == 4) {
+      Filter2DVertical4xH<4, /*is_compound=*/true>(intermediate_result, dest,
+                                                   dest_stride, height, taps);
+    } else {
+      Filter2DVertical<4, /*is_compound=*/true>(
+          intermediate_result, dest, dest_stride, width, height, taps);
+    }
+  } else {  // |vertical_taps| == 2
+    if (width == 4) {
+      Filter2DVertical4xH<2, /*is_compound=*/true>(intermediate_result, dest,
+                                                   dest_stride, height, taps);
+    } else {
+      Filter2DVertical<2, /*is_compound=*/true>(
+          intermediate_result, dest, dest_stride, width, height, taps);
+    }
+  }
+}
+
+inline void HalfAddHorizontal(const uint8_t* src, uint8_t* dst) {
+  const uint8x16_t left = vld1q_u8(src);
+  const uint8x16_t right = vld1q_u8(src + 1);
+  vst1q_u8(dst, vrhaddq_u8(left, right));
+}
+
+template <int width>
+inline void IntraBlockCopyHorizontal(const uint8_t* src,
+                                     const ptrdiff_t src_stride,
+                                     const int height, uint8_t* dst,
+                                     const ptrdiff_t dst_stride) {
+  const ptrdiff_t src_remainder_stride = src_stride - (width - 16);
+  const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16);
+
+  int y = 0;
+  do {
+    HalfAddHorizontal(src, dst);
+    if (width >= 32) {
+      src += 16;
+      dst += 16;
+      HalfAddHorizontal(src, dst);
+      if (width >= 64) {
+        src += 16;
+        dst += 16;
+        HalfAddHorizontal(src, dst);
+        src += 16;
+        dst += 16;
+        HalfAddHorizontal(src, dst);
+        if (width == 128) {
+          src += 16;
+          dst += 16;
+          HalfAddHorizontal(src, dst);
+          src += 16;
+          dst += 16;
+          HalfAddHorizontal(src, dst);
+          src += 16;
+          dst += 16;
+          HalfAddHorizontal(src, dst);
+          src += 16;
+          dst += 16;
+          HalfAddHorizontal(src, dst);
+        }
+      }
+    }
+    src += src_remainder_stride;
+    dst += dst_remainder_stride;
+  } while (++y < height);
+}
+
+void ConvolveIntraBlockCopyHorizontal_NEON(
+    const void* const reference, const ptrdiff_t reference_stride,
+    const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
+    const int /*subpixel_x*/, const int /*subpixel_y*/, const int width,
+    const int height, void* const prediction, const ptrdiff_t pred_stride) {
+  const auto* src = static_cast<const uint8_t*>(reference);
+  auto* dest = static_cast<uint8_t*>(prediction);
+
+  if (width == 128) {
+    IntraBlockCopyHorizontal<128>(src, reference_stride, height, dest,
+                                  pred_stride);
+  } else if (width == 64) {
+    IntraBlockCopyHorizontal<64>(src, reference_stride, height, dest,
+                                 pred_stride);
+  } else if (width == 32) {
+    IntraBlockCopyHorizontal<32>(src, reference_stride, height, dest,
+                                 pred_stride);
+  } else if (width == 16) {
+    IntraBlockCopyHorizontal<16>(src, reference_stride, height, dest,
+                                 pred_stride);
+  } else if (width == 8) {
+    int y = 0;
+    do {
+      const uint8x8_t left = vld1_u8(src);
+      const uint8x8_t right = vld1_u8(src + 1);
+      vst1_u8(dest, vrhadd_u8(left, right));
+
+      src += reference_stride;
+      dest += pred_stride;
+    } while (++y < height);
+  } else if (width == 4) {
+    uint8x8_t left = vdup_n_u8(0);
+    uint8x8_t right = vdup_n_u8(0);
+    int y = 0;
+    do {
+      left = Load4<0>(src, left);
+      right = Load4<0>(src + 1, right);
+      src += reference_stride;
+      left = Load4<1>(src, left);
+      right = Load4<1>(src + 1, right);
+      src += reference_stride;
+
+      const uint8x8_t result = vrhadd_u8(left, right);
+
+      StoreLo4(dest, result);
+      dest += pred_stride;
+      StoreHi4(dest, result);
+      dest += pred_stride;
+      y += 2;
+    } while (y < height);
   } else {
-    src = static_cast<const uint8_t*>(reference) -
-          ((vertical_taps / 2) - 1) * src_stride - ((horizontal_taps / 2) - 1);
+    assert(width == 2);
+    uint8x8_t left = vdup_n_u8(0);
+    uint8x8_t right = vdup_n_u8(0);
+    int y = 0;
+    do {
+      left = Load2<0>(src, left);
+      right = Load2<0>(src + 1, right);
+      src += reference_stride;
+      left = Load2<1>(src, left);
+      right = Load2<1>(src + 1, right);
+      src += reference_stride;
 
-    HorizontalPass2xH(src, src_stride, intermediate_result, intermediate_stride,
-                      intermediate_height, horiz_filter_index, horizontal_taps,
-                      subpixel_x);
+      const uint8x8_t result = vrhadd_u8(left, right);
 
-    VerticalPass2xH</*is_2d=*/true, /*is_compound=*/true>(
-        intermediate_result, intermediate_stride, dest, pred_stride, height,
-        inter_round_bits_vertical, vert_filter_index, vertical_taps,
-        subpixel_y);
+      Store2<0>(dest, result);
+      dest += pred_stride;
+      Store2<1>(dest, result);
+      dest += pred_stride;
+      y += 2;
+    } while (y < height);
+  }
+}
+
+template <int width>
+inline void IntraBlockCopyVertical(const uint8_t* src,
+                                   const ptrdiff_t src_stride, const int height,
+                                   uint8_t* dst, const ptrdiff_t dst_stride) {
+  const ptrdiff_t src_remainder_stride = src_stride - (width - 16);
+  const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16);
+  uint8x16_t row[8], below[8];
+
+  row[0] = vld1q_u8(src);
+  if (width >= 32) {
+    src += 16;
+    row[1] = vld1q_u8(src);
+    if (width >= 64) {
+      src += 16;
+      row[2] = vld1q_u8(src);
+      src += 16;
+      row[3] = vld1q_u8(src);
+      if (width == 128) {
+        src += 16;
+        row[4] = vld1q_u8(src);
+        src += 16;
+        row[5] = vld1q_u8(src);
+        src += 16;
+        row[6] = vld1q_u8(src);
+        src += 16;
+        row[7] = vld1q_u8(src);
+      }
+    }
+  }
+  src += src_remainder_stride;
+
+  int y = 0;
+  do {
+    below[0] = vld1q_u8(src);
+    if (width >= 32) {
+      src += 16;
+      below[1] = vld1q_u8(src);
+      if (width >= 64) {
+        src += 16;
+        below[2] = vld1q_u8(src);
+        src += 16;
+        below[3] = vld1q_u8(src);
+        if (width == 128) {
+          src += 16;
+          below[4] = vld1q_u8(src);
+          src += 16;
+          below[5] = vld1q_u8(src);
+          src += 16;
+          below[6] = vld1q_u8(src);
+          src += 16;
+          below[7] = vld1q_u8(src);
+        }
+      }
+    }
+    src += src_remainder_stride;
+
+    vst1q_u8(dst, vrhaddq_u8(row[0], below[0]));
+    row[0] = below[0];
+    if (width >= 32) {
+      dst += 16;
+      vst1q_u8(dst, vrhaddq_u8(row[1], below[1]));
+      row[1] = below[1];
+      if (width >= 64) {
+        dst += 16;
+        vst1q_u8(dst, vrhaddq_u8(row[2], below[2]));
+        row[2] = below[2];
+        dst += 16;
+        vst1q_u8(dst, vrhaddq_u8(row[3], below[3]));
+        row[3] = below[3];
+        if (width >= 128) {
+          dst += 16;
+          vst1q_u8(dst, vrhaddq_u8(row[4], below[4]));
+          row[4] = below[4];
+          dst += 16;
+          vst1q_u8(dst, vrhaddq_u8(row[5], below[5]));
+          row[5] = below[5];
+          dst += 16;
+          vst1q_u8(dst, vrhaddq_u8(row[6], below[6]));
+          row[6] = below[6];
+          dst += 16;
+          vst1q_u8(dst, vrhaddq_u8(row[7], below[7]));
+          row[7] = below[7];
+        }
+      }
+    }
+    dst += dst_remainder_stride;
+  } while (++y < height);
+}
+
+void ConvolveIntraBlockCopyVertical_NEON(
+    const void* const reference, const ptrdiff_t reference_stride,
+    const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
+    const int /*subpixel_x*/, const int /*subpixel_y*/, const int width,
+    const int height, void* const prediction, const ptrdiff_t pred_stride) {
+  const auto* src = static_cast<const uint8_t*>(reference);
+  auto* dest = static_cast<uint8_t*>(prediction);
+
+  if (width == 128) {
+    IntraBlockCopyVertical<128>(src, reference_stride, height, dest,
+                                pred_stride);
+  } else if (width == 64) {
+    IntraBlockCopyVertical<64>(src, reference_stride, height, dest,
+                               pred_stride);
+  } else if (width == 32) {
+    IntraBlockCopyVertical<32>(src, reference_stride, height, dest,
+                               pred_stride);
+  } else if (width == 16) {
+    IntraBlockCopyVertical<16>(src, reference_stride, height, dest,
+                               pred_stride);
+  } else if (width == 8) {
+    uint8x8_t row, below;
+    row = vld1_u8(src);
+    src += reference_stride;
+
+    int y = 0;
+    do {
+      below = vld1_u8(src);
+      src += reference_stride;
+
+      vst1_u8(dest, vrhadd_u8(row, below));
+      dest += pred_stride;
+
+      row = below;
+    } while (++y < height);
+  } else if (width == 4) {
+    uint8x8_t row = Load4(src);
+    uint8x8_t below = vdup_n_u8(0);
+    src += reference_stride;
+
+    int y = 0;
+    do {
+      below = Load4<0>(src, below);
+      src += reference_stride;
+
+      StoreLo4(dest, vrhadd_u8(row, below));
+      dest += pred_stride;
+
+      row = below;
+    } while (++y < height);
+  } else {
+    assert(width == 2);
+    uint8x8_t row = Load2(src);
+    uint8x8_t below = vdup_n_u8(0);
+    src += reference_stride;
+
+    int y = 0;
+    do {
+      below = Load2<0>(src, below);
+      src += reference_stride;
+
+      Store2<0>(dest, vrhadd_u8(row, below));
+      dest += pred_stride;
+
+      row = below;
+    } while (++y < height);
+  }
+}
+
+template <int width>
+inline void IntraBlockCopy2D(const uint8_t* src, const ptrdiff_t src_stride,
+                             const int height, uint8_t* dst,
+                             const ptrdiff_t dst_stride) {
+  const ptrdiff_t src_remainder_stride = src_stride - (width - 8);
+  const ptrdiff_t dst_remainder_stride = dst_stride - (width - 8);
+  uint16x8_t row[16];
+  row[0] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+  if (width >= 16) {
+    src += 8;
+    row[1] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+    if (width >= 32) {
+      src += 8;
+      row[2] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+      src += 8;
+      row[3] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+      if (width >= 64) {
+        src += 8;
+        row[4] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+        src += 8;
+        row[5] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+        src += 8;
+        row[6] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+        src += 8;
+        row[7] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+        if (width == 128) {
+          src += 8;
+          row[8] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+          src += 8;
+          row[9] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+          src += 8;
+          row[10] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+          src += 8;
+          row[11] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+          src += 8;
+          row[12] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+          src += 8;
+          row[13] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+          src += 8;
+          row[14] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+          src += 8;
+          row[15] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+        }
+      }
+    }
+  }
+  src += src_remainder_stride;
+
+  int y = 0;
+  do {
+    const uint16x8_t below_0 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+    vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[0], below_0), 2));
+    row[0] = below_0;
+    if (width >= 16) {
+      src += 8;
+      dst += 8;
+
+      const uint16x8_t below_1 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+      vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[1], below_1), 2));
+      row[1] = below_1;
+      if (width >= 32) {
+        src += 8;
+        dst += 8;
+
+        const uint16x8_t below_2 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+        vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[2], below_2), 2));
+        row[2] = below_2;
+        src += 8;
+        dst += 8;
+
+        const uint16x8_t below_3 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+        vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[3], below_3), 2));
+        row[3] = below_3;
+        if (width >= 64) {
+          src += 8;
+          dst += 8;
+
+          const uint16x8_t below_4 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+          vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[4], below_4), 2));
+          row[4] = below_4;
+          src += 8;
+          dst += 8;
+
+          const uint16x8_t below_5 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+          vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[5], below_5), 2));
+          row[5] = below_5;
+          src += 8;
+          dst += 8;
+
+          const uint16x8_t below_6 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+          vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[6], below_6), 2));
+          row[6] = below_6;
+          src += 8;
+          dst += 8;
+
+          const uint16x8_t below_7 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+          vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[7], below_7), 2));
+          row[7] = below_7;
+          if (width == 128) {
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_8 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+            vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[8], below_8), 2));
+            row[8] = below_8;
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_9 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+            vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[9], below_9), 2));
+            row[9] = below_9;
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_10 =
+                vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+            vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[10], below_10), 2));
+            row[10] = below_10;
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_11 =
+                vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+            vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[11], below_11), 2));
+            row[11] = below_11;
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_12 =
+                vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+            vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[12], below_12), 2));
+            row[12] = below_12;
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_13 =
+                vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+            vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[13], below_13), 2));
+            row[13] = below_13;
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_14 =
+                vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+            vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[14], below_14), 2));
+            row[14] = below_14;
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_15 =
+                vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
+            vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[15], below_15), 2));
+            row[15] = below_15;
+          }
+        }
+      }
+    }
+    src += src_remainder_stride;
+    dst += dst_remainder_stride;
+  } while (++y < height);
+}
+
+void ConvolveIntraBlockCopy2D_NEON(
+    const void* const reference, const ptrdiff_t reference_stride,
+    const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
+    const int /*subpixel_x*/, const int /*subpixel_y*/, const int width,
+    const int height, void* const prediction, const ptrdiff_t pred_stride) {
+  const auto* src = static_cast<const uint8_t*>(reference);
+  auto* dest = static_cast<uint8_t*>(prediction);
+  // Note: allow vertical access to height + 1. Because this function is only
+  // for u/v plane of intra block copy, such access is guaranteed to be within
+  // the prediction block.
+
+  if (width == 128) {
+    IntraBlockCopy2D<128>(src, reference_stride, height, dest, pred_stride);
+  } else if (width == 64) {
+    IntraBlockCopy2D<64>(src, reference_stride, height, dest, pred_stride);
+  } else if (width == 32) {
+    IntraBlockCopy2D<32>(src, reference_stride, height, dest, pred_stride);
+  } else if (width == 16) {
+    IntraBlockCopy2D<16>(src, reference_stride, height, dest, pred_stride);
+  } else if (width == 8) {
+    IntraBlockCopy2D<8>(src, reference_stride, height, dest, pred_stride);
+  } else if (width == 4) {
+    uint8x8_t left = Load4(src);
+    uint8x8_t right = Load4(src + 1);
+    src += reference_stride;
+
+    uint16x4_t row = vget_low_u16(vaddl_u8(left, right));
+
+    int y = 0;
+    do {
+      left = Load4<0>(src, left);
+      right = Load4<0>(src + 1, right);
+      src += reference_stride;
+      left = Load4<1>(src, left);
+      right = Load4<1>(src + 1, right);
+      src += reference_stride;
+
+      const uint16x8_t below = vaddl_u8(left, right);
+
+      const uint8x8_t result = vrshrn_n_u16(
+          vaddq_u16(vcombine_u16(row, vget_low_u16(below)), below), 2);
+      StoreLo4(dest, result);
+      dest += pred_stride;
+      StoreHi4(dest, result);
+      dest += pred_stride;
+
+      row = vget_high_u16(below);
+      y += 2;
+    } while (y < height);
+  } else {
+    uint8x8_t left = Load2(src);
+    uint8x8_t right = Load2(src + 1);
+    src += reference_stride;
+
+    uint16x4_t row = vget_low_u16(vaddl_u8(left, right));
+
+    int y = 0;
+    do {
+      left = Load2<0>(src, left);
+      right = Load2<0>(src + 1, right);
+      src += reference_stride;
+      left = Load2<2>(src, left);
+      right = Load2<2>(src + 1, right);
+      src += reference_stride;
+
+      const uint16x8_t below = vaddl_u8(left, right);
+
+      const uint8x8_t result = vrshrn_n_u16(
+          vaddq_u16(vcombine_u16(row, vget_low_u16(below)), below), 2);
+      Store2<0>(dest, result);
+      dest += pred_stride;
+      Store2<2>(dest, result);
+      dest += pred_stride;
+
+      row = vget_high_u16(below);
+      y += 2;
+    } while (y < height);
   }
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   dsp->convolve[0][0][0][1] = ConvolveHorizontal_NEON;
   dsp->convolve[0][0][1][0] = ConvolveVertical_NEON;
@@ -1959,9 +3096,12 @@
   dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_NEON;
   dsp->convolve[0][1][1][1] = ConvolveCompound2D_NEON;
 
-  // TODO(petersonab,b/139707209): Fix source buffer overreads.
-  // dsp->convolve_scale[1] = ConvolveCompoundScale2D_NEON;
-  static_cast<void>(ConvolveCompoundScale2D_NEON);
+  dsp->convolve[1][0][0][1] = ConvolveIntraBlockCopyHorizontal_NEON;
+  dsp->convolve[1][0][1][0] = ConvolveIntraBlockCopyVertical_NEON;
+  dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_NEON;
+
+  dsp->convolve_scale[0] = ConvolveScale2D_NEON<false>;
+  dsp->convolve_scale[1] = ConvolveScale2D_NEON<true>;
 }
 
 }  // namespace
@@ -1972,7 +3112,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 
 namespace libgav1 {
 namespace dsp {
diff --git a/libgav1/src/dsp/arm/convolve_neon.h b/libgav1/src/dsp/arm/convolve_neon.h
index a537650..948ef4d 100644
--- a/libgav1/src/dsp/arm/convolve_neon.h
+++ b/libgav1/src/dsp/arm/convolve_neon.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_
 #define LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -30,17 +30,21 @@
 }  // namespace libgav1
 
 #if LIBGAV1_ENABLE_NEON
-#define LIBGAV1_Dsp8bpp_ConvolveHorizontal LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_ConvolveVertical LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_Convolve2D LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveHorizontal LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveVertical LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Convolve2D LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_ConvolveCompoundCopy LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_ConvolveCompoundVertical LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_ConvolveCompound2D LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundCopy LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundVertical LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveCompound2D LIBGAV1_CPU_NEON
 
-// TODO(petersonab,b/139707209): Fix source buffer overreads.
-// #define LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopyHorizontal LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopyVertical LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveIntraBlockCopy2D LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp8bpp_ConvolveScale2D LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_
diff --git a/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc b/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc
index 39b34a9..04952ab 100644
--- a/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc
+++ b/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc
@@ -13,7 +13,7 @@
 // limitations under the License.
 
 #include "src/dsp/distance_weighted_blend.h"
-#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -24,138 +24,93 @@
 #include <cstdint>
 
 #include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/utils/common.h"
 
 namespace libgav1 {
 namespace dsp {
 namespace {
 
-constexpr int kBitdepth8 = 8;
 constexpr int kInterPostRoundBit = 4;
 
-const int16x8_t kCompoundRoundOffset =
-    vdupq_n_s16((1 << kBitdepth8) + (1 << (kBitdepth8 - 1)));
+inline int16x8_t ComputeWeightedAverage8(const int16x8_t pred0,
+                                         const int16x8_t pred1,
+                                         const int16x4_t weights[2]) {
+  // TODO(https://issuetracker.google.com/issues/150325685): Investigate range.
+  const int32x4_t wpred0_lo = vmull_s16(weights[0], vget_low_s16(pred0));
+  const int32x4_t wpred0_hi = vmull_s16(weights[0], vget_high_s16(pred0));
+  const int32x4_t blended_lo =
+      vmlal_s16(wpred0_lo, weights[1], vget_low_s16(pred1));
+  const int32x4_t blended_hi =
+      vmlal_s16(wpred0_hi, weights[1], vget_high_s16(pred1));
 
-inline int16x8_t ComputeWeightedAverage8(const uint16x8_t pred0,
-                                         const uint16x8_t pred1,
-                                         const uint16x4_t weights[2]) {
-  const uint32x4_t wpred0_lo = vmull_u16(weights[0], vget_low_u16(pred0));
-  const uint32x4_t wpred0_hi = vmull_u16(weights[0], vget_high_u16(pred0));
-  const uint32x4_t blended_lo =
-      vmlal_u16(wpred0_lo, weights[1], vget_low_u16(pred1));
-  const uint32x4_t blended_hi =
-      vmlal_u16(wpred0_hi, weights[1], vget_high_u16(pred1));
-
-  const uint16x4_t result_lo =
-      vqrshrn_n_u32(blended_lo, kInterPostRoundBit + 4);
-  const uint16x4_t result_hi =
-      vqrshrn_n_u32(blended_hi, kInterPostRoundBit + 4);
-  return vsubq_s16(vreinterpretq_s16_u16(vcombine_u16(result_lo, result_hi)),
-                   kCompoundRoundOffset);
+  return vcombine_s16(vqrshrn_n_s32(blended_lo, kInterPostRoundBit + 4),
+                      vqrshrn_n_s32(blended_hi, kInterPostRoundBit + 4));
 }
 
-template <int height>
-inline void DistanceWeightedBlend4xH_NEON(const uint16_t* prediction_0,
-                                          const ptrdiff_t prediction_stride_0,
-                                          const uint16_t* prediction_1,
-                                          const ptrdiff_t prediction_stride_1,
-                                          const uint16x4_t weights[2],
-                                          void* const dest,
-                                          const ptrdiff_t dest_stride) {
+template <int width, int height>
+inline void DistanceWeightedBlendSmall_NEON(const int16_t* prediction_0,
+                                            const int16_t* prediction_1,
+                                            const int16x4_t weights[2],
+                                            void* const dest,
+                                            const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
-  const uint16_t* pred_0 = prediction_0;
-  const uint16_t* pred_1 = prediction_1;
+  constexpr int step = 16 / width;
 
-  for (int y = 0; y < height; y += 4) {
-    const uint16x4_t src_00 = vld1_u16(pred_0);
-    const uint16x4_t src_10 = vld1_u16(pred_1);
-    pred_0 += prediction_stride_0;
-    pred_1 += prediction_stride_1;
-    const uint16x4_t src_01 = vld1_u16(pred_0);
-    const uint16x4_t src_11 = vld1_u16(pred_1);
-    pred_0 += prediction_stride_0;
-    pred_1 += prediction_stride_1;
-    const int16x8_t res01 = ComputeWeightedAverage8(
-        vcombine_u16(src_00, src_01), vcombine_u16(src_10, src_11), weights);
-
-    const uint16x4_t src_02 = vld1_u16(pred_0);
-    const uint16x4_t src_12 = vld1_u16(pred_1);
-    pred_0 += prediction_stride_0;
-    pred_1 += prediction_stride_1;
-    const uint16x4_t src_03 = vld1_u16(pred_0);
-    const uint16x4_t src_13 = vld1_u16(pred_1);
-    pred_0 += prediction_stride_0;
-    pred_1 += prediction_stride_1;
-    const int16x8_t res23 = ComputeWeightedAverage8(
-        vcombine_u16(src_02, src_03), vcombine_u16(src_12, src_13), weights);
-
-    const uint8x8_t result_01 = vqmovun_s16(res01);
-    const uint8x8_t result_23 = vqmovun_s16(res23);
-    StoreLo4(dst, result_01);
-    dst += dest_stride;
-    StoreHi4(dst, result_01);
-    dst += dest_stride;
-    StoreLo4(dst, result_23);
-    dst += dest_stride;
-    StoreHi4(dst, result_23);
-    dst += dest_stride;
-  }
-}
-
-template <int height>
-inline void DistanceWeightedBlend8xH_NEON(const uint16_t* prediction_0,
-                                          const ptrdiff_t prediction_stride_0,
-                                          const uint16_t* prediction_1,
-                                          const ptrdiff_t prediction_stride_1,
-                                          const uint16x4_t weights[2],
-                                          void* const dest,
-                                          const ptrdiff_t dest_stride) {
-  auto* dst = static_cast<uint8_t*>(dest);
-  const uint16_t* pred_0 = prediction_0;
-  const uint16_t* pred_1 = prediction_1;
-
-  for (int y = 0; y < height; y += 2) {
-    const uint16x8_t src_00 = vld1q_u16(pred_0);
-    const uint16x8_t src_10 = vld1q_u16(pred_1);
-    pred_0 += prediction_stride_0;
-    pred_1 += prediction_stride_1;
+  for (int y = 0; y < height; y += step) {
+    const int16x8_t src_00 = vld1q_s16(prediction_0);
+    const int16x8_t src_10 = vld1q_s16(prediction_1);
+    prediction_0 += 8;
+    prediction_1 += 8;
     const int16x8_t res0 = ComputeWeightedAverage8(src_00, src_10, weights);
 
-    const uint16x8_t src_01 = vld1q_u16(pred_0);
-    const uint16x8_t src_11 = vld1q_u16(pred_1);
-    pred_0 += prediction_stride_0;
-    pred_1 += prediction_stride_1;
+    const int16x8_t src_01 = vld1q_s16(prediction_0);
+    const int16x8_t src_11 = vld1q_s16(prediction_1);
+    prediction_0 += 8;
+    prediction_1 += 8;
     const int16x8_t res1 = ComputeWeightedAverage8(src_01, src_11, weights);
 
     const uint8x8_t result0 = vqmovun_s16(res0);
     const uint8x8_t result1 = vqmovun_s16(res1);
-    vst1_u8(dst, result0);
-    dst += dest_stride;
-    vst1_u8(dst, result1);
-    dst += dest_stride;
+    if (width == 4) {
+      StoreLo4(dst, result0);
+      dst += dest_stride;
+      StoreHi4(dst, result0);
+      dst += dest_stride;
+      StoreLo4(dst, result1);
+      dst += dest_stride;
+      StoreHi4(dst, result1);
+      dst += dest_stride;
+    } else {
+      assert(width == 8);
+      vst1_u8(dst, result0);
+      dst += dest_stride;
+      vst1_u8(dst, result1);
+      dst += dest_stride;
+    }
   }
 }
 
-inline void DistanceWeightedBlendLarge_NEON(
-    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
-    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
-    const uint16x4_t weights[2], const int width, const int height,
-    void* const dest, const ptrdiff_t dest_stride) {
+inline void DistanceWeightedBlendLarge_NEON(const int16_t* prediction_0,
+                                            const int16_t* prediction_1,
+                                            const int16x4_t weights[2],
+                                            const int width, const int height,
+                                            void* const dest,
+                                            const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
-  const uint16_t* pred_0 = prediction_0;
-  const uint16_t* pred_1 = prediction_1;
 
   int y = height;
   do {
     int x = 0;
     do {
-      const uint16x8_t src0_lo = vld1q_u16(pred_0 + x);
-      const uint16x8_t src1_lo = vld1q_u16(pred_1 + x);
+      const int16x8_t src0_lo = vld1q_s16(prediction_0 + x);
+      const int16x8_t src1_lo = vld1q_s16(prediction_1 + x);
       const int16x8_t res_lo =
           ComputeWeightedAverage8(src0_lo, src1_lo, weights);
 
-      const uint16x8_t src0_hi = vld1q_u16(pred_0 + x + 8);
-      const uint16x8_t src1_hi = vld1q_u16(pred_1 + x + 8);
+      const int16x8_t src0_hi = vld1q_s16(prediction_0 + x + 8);
+      const int16x8_t src1_hi = vld1q_s16(prediction_1 + x + 8);
       const int16x8_t res_hi =
           ComputeWeightedAverage8(src0_hi, src1_hi, weights);
 
@@ -165,31 +120,33 @@
       x += 16;
     } while (x < width);
     dst += dest_stride;
-    pred_0 += prediction_stride_0;
-    pred_1 += prediction_stride_1;
+    prediction_0 += width;
+    prediction_1 += width;
   } while (--y != 0);
 }
 
-inline void DistanceWeightedBlend_NEON(
-    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
-    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
-    const uint8_t weight_0, const uint8_t weight_1, const int width,
-    const int height, void* const dest, const ptrdiff_t dest_stride) {
-  uint16x4_t weights[2] = {vdup_n_u16(weight_0), vdup_n_u16(weight_1)};
+inline void DistanceWeightedBlend_NEON(const void* prediction_0,
+                                       const void* prediction_1,
+                                       const uint8_t weight_0,
+                                       const uint8_t weight_1, const int width,
+                                       const int height, void* const dest,
+                                       const ptrdiff_t dest_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int16x4_t weights[2] = {vdup_n_s16(weight_0), vdup_n_s16(weight_1)};
+  // TODO(johannkoenig): Investigate the branching. May be fine to call with a
+  // variable height.
   if (width == 4) {
     if (height == 4) {
-      DistanceWeightedBlend4xH_NEON<4>(prediction_0, prediction_stride_0,
-                                       prediction_1, prediction_stride_1,
-                                       weights, dest, dest_stride);
+      DistanceWeightedBlendSmall_NEON<4, 4>(pred_0, pred_1, weights, dest,
+                                            dest_stride);
     } else if (height == 8) {
-      DistanceWeightedBlend4xH_NEON<8>(prediction_0, prediction_stride_0,
-                                       prediction_1, prediction_stride_1,
-                                       weights, dest, dest_stride);
+      DistanceWeightedBlendSmall_NEON<4, 8>(pred_0, pred_1, weights, dest,
+                                            dest_stride);
     } else {
       assert(height == 16);
-      DistanceWeightedBlend4xH_NEON<16>(prediction_0, prediction_stride_0,
-                                        prediction_1, prediction_stride_1,
-                                        weights, dest, dest_stride);
+      DistanceWeightedBlendSmall_NEON<4, 16>(pred_0, pred_1, weights, dest,
+                                             dest_stride);
     }
     return;
   }
@@ -197,37 +154,32 @@
   if (width == 8) {
     switch (height) {
       case 4:
-        DistanceWeightedBlend8xH_NEON<4>(prediction_0, prediction_stride_0,
-                                         prediction_1, prediction_stride_1,
-                                         weights, dest, dest_stride);
+        DistanceWeightedBlendSmall_NEON<8, 4>(pred_0, pred_1, weights, dest,
+                                              dest_stride);
         return;
       case 8:
-        DistanceWeightedBlend8xH_NEON<8>(prediction_0, prediction_stride_0,
-                                         prediction_1, prediction_stride_1,
-                                         weights, dest, dest_stride);
+        DistanceWeightedBlendSmall_NEON<8, 8>(pred_0, pred_1, weights, dest,
+                                              dest_stride);
         return;
       case 16:
-        DistanceWeightedBlend8xH_NEON<16>(prediction_0, prediction_stride_0,
-                                          prediction_1, prediction_stride_1,
-                                          weights, dest, dest_stride);
+        DistanceWeightedBlendSmall_NEON<8, 16>(pred_0, pred_1, weights, dest,
+                                               dest_stride);
         return;
       default:
         assert(height == 32);
-        DistanceWeightedBlend8xH_NEON<32>(prediction_0, prediction_stride_0,
-                                          prediction_1, prediction_stride_1,
-                                          weights, dest, dest_stride);
+        DistanceWeightedBlendSmall_NEON<8, 32>(pred_0, pred_1, weights, dest,
+                                               dest_stride);
 
         return;
     }
   }
 
-  DistanceWeightedBlendLarge_NEON(prediction_0, prediction_stride_0,
-                                  prediction_1, prediction_stride_1, weights,
-                                  width, height, dest, dest_stride);
+  DistanceWeightedBlendLarge_NEON(pred_0, pred_1, weights, width, height, dest,
+                                  dest_stride);
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   dsp->distance_weighted_blend = DistanceWeightedBlend_NEON;
 }
@@ -239,7 +191,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 
 namespace libgav1 {
 namespace dsp {
diff --git a/libgav1/src/dsp/arm/distance_weighted_blend_neon.h b/libgav1/src/dsp/arm/distance_weighted_blend_neon.h
index 6d35956..4d8824c 100644
--- a/libgav1/src/dsp/arm/distance_weighted_blend_neon.h
+++ b/libgav1/src/dsp/arm/distance_weighted_blend_neon.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_ARM_DISTANCE_WEIGHTED_BLEND_NEON_H_
 #define LIBGAV1_SRC_DSP_ARM_DISTANCE_WEIGHTED_BLEND_NEON_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -32,7 +32,7 @@
 // If NEON is enabled signal the NEON implementation should be used instead of
 // normal C.
 #if LIBGAV1_ENABLE_NEON
-#define LIBGAV1_Dsp8bpp_DistanceWeightedBlend LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_DistanceWeightedBlend LIBGAV1_CPU_NEON
 
 #endif  // LIBGAV1_ENABLE_NEON
 
diff --git a/libgav1/src/dsp/arm/film_grain_neon.cc b/libgav1/src/dsp/arm/film_grain_neon.cc
new file mode 100644
index 0000000..2612466
--- /dev/null
+++ b/libgav1/src/dsp/arm/film_grain_neon.cc
@@ -0,0 +1,1188 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/film_grain.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+#include <new>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/arm/film_grain_neon.h"
+#include "src/dsp/common.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/film_grain_common.h"
+#include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
+#include "src/utils/logging.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace film_grain {
+namespace {
+
+// These functions are overloaded for both possible sizes in order to simplify
+// loading and storing to and from intermediate value types from within a
+// template function.
+inline int16x8_t GetSignedSource8(const int8_t* src) {
+  return vmovl_s8(vld1_s8(src));
+}
+
+inline int16x8_t GetSignedSource8(const uint8_t* src) {
+  return ZeroExtend(vld1_u8(src));
+}
+
+inline void StoreUnsigned8(uint8_t* dest, const uint16x8_t data) {
+  vst1_u8(dest, vmovn_u16(data));
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+inline int16x8_t GetSignedSource8(const int16_t* src) { return vld1q_s16(src); }
+
+inline int16x8_t GetSignedSource8(const uint16_t* src) {
+  return vreinterpretq_s16_u16(vld1q_u16(src));
+}
+
+inline void StoreUnsigned8(uint16_t* dest, const uint16x8_t data) {
+  vst1q_u16(dest, data);
+}
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+// Each element in |sum| represents one destination value's running
+// autoregression formula. The fixed source values in |grain_lo| and |grain_hi|
+// allow for a sliding window in successive calls to this function.
+template <int position_offset>
+inline int32x4x2_t AccumulateWeightedGrain(const int16x8_t grain_lo,
+                                           const int16x8_t grain_hi,
+                                           int16_t coeff, int32x4x2_t sum) {
+  const int16x8_t grain = vextq_s16(grain_lo, grain_hi, position_offset);
+  sum.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(grain), coeff);
+  sum.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(grain), coeff);
+  return sum;
+}
+
+// Because the autoregressive filter requires the output of each pixel to
+// compute pixels that come after in the row, we have to finish the calculations
+// one at a time.
+template <int bitdepth, int auto_regression_coeff_lag, int lane>
+inline void WriteFinalAutoRegression(int8_t* grain_cursor, int32x4x2_t sum,
+                                     const int8_t* coeffs, int pos, int shift) {
+  int32_t result = vgetq_lane_s32(sum.val[lane >> 2], lane & 3);
+
+  for (int delta_col = -auto_regression_coeff_lag; delta_col < 0; ++delta_col) {
+    result += grain_cursor[lane + delta_col] * coeffs[pos];
+    ++pos;
+  }
+  grain_cursor[lane] =
+      Clip3(grain_cursor[lane] + RightShiftWithRounding(result, shift),
+            GetGrainMin<bitdepth>(), GetGrainMax<bitdepth>());
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+template <int bitdepth, int auto_regression_coeff_lag, int lane>
+inline void WriteFinalAutoRegression(int16_t* grain_cursor, int32x4x2_t sum,
+                                     const int8_t* coeffs, int pos, int shift) {
+  int32_t result = vgetq_lane_s32(sum.val[lane >> 2], lane & 3);
+
+  for (int delta_col = -auto_regression_coeff_lag; delta_col < 0; ++delta_col) {
+    result += grain_cursor[lane + delta_col] * coeffs[pos];
+    ++pos;
+  }
+  grain_cursor[lane] =
+      Clip3(grain_cursor[lane] + RightShiftWithRounding(result, shift),
+            GetGrainMin<bitdepth>(), GetGrainMax<bitdepth>());
+}
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+// Because the autoregressive filter requires the output of each pixel to
+// compute pixels that come after in the row, we have to finish the calculations
+// one at a time.
+template <int bitdepth, int auto_regression_coeff_lag, int lane>
+inline void WriteFinalAutoRegressionChroma(int8_t* u_grain_cursor,
+                                           int8_t* v_grain_cursor,
+                                           int32x4x2_t sum_u, int32x4x2_t sum_v,
+                                           const int8_t* coeffs_u,
+                                           const int8_t* coeffs_v, int pos,
+                                           int shift) {
+  WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>(
+      u_grain_cursor, sum_u, coeffs_u, pos, shift);
+  WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>(
+      v_grain_cursor, sum_v, coeffs_v, pos, shift);
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+template <int bitdepth, int auto_regression_coeff_lag, int lane>
+inline void WriteFinalAutoRegressionChroma(int16_t* u_grain_cursor,
+                                           int16_t* v_grain_cursor,
+                                           int32x4x2_t sum_u, int32x4x2_t sum_v,
+                                           const int8_t* coeffs_u,
+                                           const int8_t* coeffs_v, int pos,
+                                           int shift) {
+  WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>(
+      u_grain_cursor, sum_u, coeffs_u, pos, shift);
+  WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>(
+      v_grain_cursor, sum_v, coeffs_v, pos, shift);
+}
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+inline void SetZero(int32x4x2_t* v) {
+  v->val[0] = vdupq_n_s32(0);
+  v->val[1] = vdupq_n_s32(0);
+}
+
+// Computes subsampled luma for use with chroma, by averaging in the x direction
+// or y direction when applicable.
+int16x8_t GetSubsampledLuma(const int8_t* const luma, int subsampling_x,
+                            int subsampling_y, ptrdiff_t stride) {
+  if (subsampling_y != 0) {
+    assert(subsampling_x != 0);
+    const int8x16_t src0 = vld1q_s8(luma);
+    const int8x16_t src1 = vld1q_s8(luma + stride);
+    const int16x8_t ret0 = vcombine_s16(vpaddl_s8(vget_low_s8(src0)),
+                                        vpaddl_s8(vget_high_s8(src0)));
+    const int16x8_t ret1 = vcombine_s16(vpaddl_s8(vget_low_s8(src1)),
+                                        vpaddl_s8(vget_high_s8(src1)));
+    return vrshrq_n_s16(vaddq_s16(ret0, ret1), 2);
+  }
+  if (subsampling_x != 0) {
+    const int8x16_t src = vld1q_s8(luma);
+    return vrshrq_n_s16(
+        vcombine_s16(vpaddl_s8(vget_low_s8(src)), vpaddl_s8(vget_high_s8(src))),
+        1);
+  }
+  return vmovl_s8(vld1_s8(luma));
+}
+
+// For BlendNoiseWithImageChromaWithCfl, only |subsampling_x| is needed.
+inline uint16x8_t GetAverageLuma(const uint8_t* const luma, int subsampling_x) {
+  if (subsampling_x != 0) {
+    const uint8x16_t src = vld1q_u8(luma);
+    return vrshrq_n_u16(vpaddlq_u8(src), 1);
+  }
+  return vmovl_u8(vld1_u8(luma));
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+// Computes subsampled luma for use with chroma, by averaging in the x direction
+// or y direction when applicable.
+int16x8_t GetSubsampledLuma(const int16_t* const luma, int subsampling_x,
+                            int subsampling_y, ptrdiff_t stride) {
+  if (subsampling_y != 0) {
+    assert(subsampling_x != 0);
+    int16x8_t src0_lo = vld1q_s16(luma);
+    int16x8_t src0_hi = vld1q_s16(luma + 8);
+    const int16x8_t src1_lo = vld1q_s16(luma + stride);
+    const int16x8_t src1_hi = vld1q_s16(luma + stride + 8);
+    const int16x8_t src0 =
+        vcombine_s16(vpadd_s16(vget_low_s16(src0_lo), vget_high_s16(src0_lo)),
+                     vpadd_s16(vget_low_s16(src0_hi), vget_high_s16(src0_hi)));
+    const int16x8_t src1 =
+        vcombine_s16(vpadd_s16(vget_low_s16(src1_lo), vget_high_s16(src1_lo)),
+                     vpadd_s16(vget_low_s16(src1_hi), vget_high_s16(src1_hi)));
+    return vrshrq_n_s16(vaddq_s16(src0, src1), 2);
+  }
+  if (subsampling_x != 0) {
+    const int16x8_t src_lo = vld1q_s16(luma);
+    const int16x8_t src_hi = vld1q_s16(luma + 8);
+    const int16x8_t ret =
+        vcombine_s16(vpadd_s16(vget_low_s16(src_lo), vget_high_s16(src_lo)),
+                     vpadd_s16(vget_low_s16(src_hi), vget_high_s16(src_hi)));
+    return vrshrq_n_s16(ret, 1);
+  }
+  return vld1q_s16(luma);
+}
+
+// For BlendNoiseWithImageChromaWithCfl, only |subsampling_x| is needed.
+inline uint16x8_t GetAverageLuma(const uint16_t* const luma,
+                                 int subsampling_x) {
+  if (subsampling_x != 0) {
+    const uint16x8x2_t src = vld2q_u16(luma);
+    return vrhaddq_u16(src.val[0], src.val[1]);
+  }
+  return vld1q_u16(luma);
+}
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+template <int bitdepth, typename GrainType, int auto_regression_coeff_lag,
+          bool use_luma>
+void ApplyAutoRegressiveFilterToChromaGrains_NEON(const FilmGrainParams& params,
+                                                  const void* luma_grain_buffer,
+                                                  int subsampling_x,
+                                                  int subsampling_y,
+                                                  void* u_grain_buffer,
+                                                  void* v_grain_buffer) {
+  static_assert(auto_regression_coeff_lag <= 3, "Invalid autoregression lag.");
+  const auto* luma_grain = static_cast<const GrainType*>(luma_grain_buffer);
+  auto* u_grain = static_cast<GrainType*>(u_grain_buffer);
+  auto* v_grain = static_cast<GrainType*>(v_grain_buffer);
+  const int auto_regression_shift = params.auto_regression_shift;
+  const int chroma_width =
+      (subsampling_x == 0) ? kMaxChromaWidth : kMinChromaWidth;
+  const int chroma_height =
+      (subsampling_y == 0) ? kMaxChromaHeight : kMinChromaHeight;
+  // When |chroma_width| == 44, we write 8 at a time from x in [3, 34],
+  // leaving [35, 40] to write at the end.
+  const int chroma_width_remainder =
+      (chroma_width - 2 * kAutoRegressionBorder) & 7;
+
+  int y = kAutoRegressionBorder;
+  luma_grain += kLumaWidth * y;
+  u_grain += chroma_width * y;
+  v_grain += chroma_width * y;
+  do {
+    // Each row is computed 8 values at a time in the following loop. At the
+    // end of the loop, 4 values remain to write. They are given a special
+    // reduced iteration at the end.
+    int x = kAutoRegressionBorder;
+    int luma_x = kAutoRegressionBorder;
+    do {
+      int pos = 0;
+      int32x4x2_t sum_u;
+      int32x4x2_t sum_v;
+      SetZero(&sum_u);
+      SetZero(&sum_v);
+
+      if (auto_regression_coeff_lag > 0) {
+        for (int delta_row = -auto_regression_coeff_lag; delta_row < 0;
+             ++delta_row) {
+          // These loads may overflow to the next row, but they are never called
+          // on the final row of a grain block. Therefore, they will never
+          // exceed the block boundaries.
+          // Note: this could be slightly optimized to a single load in 8bpp,
+          // but requires making a special first iteration and accumulate
+          // function that takes an int8x16_t.
+          const int16x8_t u_grain_lo =
+              GetSignedSource8(u_grain + x + delta_row * chroma_width -
+                               auto_regression_coeff_lag);
+          const int16x8_t u_grain_hi =
+              GetSignedSource8(u_grain + x + delta_row * chroma_width -
+                               auto_regression_coeff_lag + 8);
+          const int16x8_t v_grain_lo =
+              GetSignedSource8(v_grain + x + delta_row * chroma_width -
+                               auto_regression_coeff_lag);
+          const int16x8_t v_grain_hi =
+              GetSignedSource8(v_grain + x + delta_row * chroma_width -
+                               auto_regression_coeff_lag + 8);
+#define ACCUMULATE_WEIGHTED_GRAIN(offset)                                  \
+  sum_u = AccumulateWeightedGrain<offset>(                                 \
+      u_grain_lo, u_grain_hi, params.auto_regression_coeff_u[pos], sum_u); \
+  sum_v = AccumulateWeightedGrain<offset>(                                 \
+      v_grain_lo, v_grain_hi, params.auto_regression_coeff_v[pos++], sum_v)
+
+          ACCUMULATE_WEIGHTED_GRAIN(0);
+          ACCUMULATE_WEIGHTED_GRAIN(1);
+          ACCUMULATE_WEIGHTED_GRAIN(2);
+          // The horizontal |auto_regression_coeff_lag| loop is replaced with
+          // if-statements to give vextq_s16 an immediate param.
+          if (auto_regression_coeff_lag > 1) {
+            ACCUMULATE_WEIGHTED_GRAIN(3);
+            ACCUMULATE_WEIGHTED_GRAIN(4);
+          }
+          if (auto_regression_coeff_lag > 2) {
+            assert(auto_regression_coeff_lag == 3);
+            ACCUMULATE_WEIGHTED_GRAIN(5);
+            ACCUMULATE_WEIGHTED_GRAIN(6);
+          }
+        }
+      }
+
+      if (use_luma) {
+        const int16x8_t luma = GetSubsampledLuma(
+            luma_grain + luma_x, subsampling_x, subsampling_y, kLumaWidth);
+
+        // Luma samples get the final coefficient in the formula, but are best
+        // computed all at once before the final row.
+        const int coeff_u =
+            params.auto_regression_coeff_u[pos + auto_regression_coeff_lag];
+        const int coeff_v =
+            params.auto_regression_coeff_v[pos + auto_regression_coeff_lag];
+
+        sum_u.val[0] = vmlal_n_s16(sum_u.val[0], vget_low_s16(luma), coeff_u);
+        sum_u.val[1] = vmlal_n_s16(sum_u.val[1], vget_high_s16(luma), coeff_u);
+        sum_v.val[0] = vmlal_n_s16(sum_v.val[0], vget_low_s16(luma), coeff_v);
+        sum_v.val[1] = vmlal_n_s16(sum_v.val[1], vget_high_s16(luma), coeff_v);
+      }
+      // At this point in the filter, the source addresses and destination
+      // addresses overlap. Because this is an auto-regressive filter, the
+      // higher lanes cannot be computed without the results of the lower lanes.
+      // Each call to WriteFinalAutoRegression incorporates preceding values
+      // on the final row, and writes a single sample. This allows the next
+      // pixel's value to be computed in the next call.
+#define WRITE_AUTO_REGRESSION_RESULT(lane)                                    \
+  WriteFinalAutoRegressionChroma<bitdepth, auto_regression_coeff_lag, lane>(  \
+      u_grain + x, v_grain + x, sum_u, sum_v, params.auto_regression_coeff_u, \
+      params.auto_regression_coeff_v, pos, auto_regression_shift)
+
+      WRITE_AUTO_REGRESSION_RESULT(0);
+      WRITE_AUTO_REGRESSION_RESULT(1);
+      WRITE_AUTO_REGRESSION_RESULT(2);
+      WRITE_AUTO_REGRESSION_RESULT(3);
+      WRITE_AUTO_REGRESSION_RESULT(4);
+      WRITE_AUTO_REGRESSION_RESULT(5);
+      WRITE_AUTO_REGRESSION_RESULT(6);
+      WRITE_AUTO_REGRESSION_RESULT(7);
+
+      x += 8;
+      luma_x += 8 << subsampling_x;
+    } while (x < chroma_width - kAutoRegressionBorder - chroma_width_remainder);
+
+    // This is the "final iteration" of the above loop over width. We fill in
+    // the remainder of the width, which is less than 8.
+    int pos = 0;
+    int32x4x2_t sum_u;
+    int32x4x2_t sum_v;
+    SetZero(&sum_u);
+    SetZero(&sum_v);
+
+    for (int delta_row = -auto_regression_coeff_lag; delta_row < 0;
+         ++delta_row) {
+      // These loads may overflow to the next row, but they are never called on
+      // the final row of a grain block. Therefore, they will never exceed the
+      // block boundaries.
+      const int16x8_t u_grain_lo = GetSignedSource8(
+          u_grain + x + delta_row * chroma_width - auto_regression_coeff_lag);
+      const int16x8_t u_grain_hi =
+          GetSignedSource8(u_grain + x + delta_row * chroma_width -
+                           auto_regression_coeff_lag + 8);
+      const int16x8_t v_grain_lo = GetSignedSource8(
+          v_grain + x + delta_row * chroma_width - auto_regression_coeff_lag);
+      const int16x8_t v_grain_hi =
+          GetSignedSource8(v_grain + x + delta_row * chroma_width -
+                           auto_regression_coeff_lag + 8);
+
+      ACCUMULATE_WEIGHTED_GRAIN(0);
+      ACCUMULATE_WEIGHTED_GRAIN(1);
+      ACCUMULATE_WEIGHTED_GRAIN(2);
+      // The horizontal |auto_regression_coeff_lag| loop is replaced with
+      // if-statements to give vextq_s16 an immediate param.
+      if (auto_regression_coeff_lag > 1) {
+        ACCUMULATE_WEIGHTED_GRAIN(3);
+        ACCUMULATE_WEIGHTED_GRAIN(4);
+      }
+      if (auto_regression_coeff_lag > 2) {
+        assert(auto_regression_coeff_lag == 3);
+        ACCUMULATE_WEIGHTED_GRAIN(5);
+        ACCUMULATE_WEIGHTED_GRAIN(6);
+      }
+    }
+
+    if (use_luma) {
+      const int16x8_t luma = GetSubsampledLuma(
+          luma_grain + luma_x, subsampling_x, subsampling_y, kLumaWidth);
+
+      // Luma samples get the final coefficient in the formula, but are best
+      // computed all at once before the final row.
+      const int coeff_u =
+          params.auto_regression_coeff_u[pos + auto_regression_coeff_lag];
+      const int coeff_v =
+          params.auto_regression_coeff_v[pos + auto_regression_coeff_lag];
+
+      sum_u.val[0] = vmlal_n_s16(sum_u.val[0], vget_low_s16(luma), coeff_u);
+      sum_u.val[1] = vmlal_n_s16(sum_u.val[1], vget_high_s16(luma), coeff_u);
+      sum_v.val[0] = vmlal_n_s16(sum_v.val[0], vget_low_s16(luma), coeff_v);
+      sum_v.val[1] = vmlal_n_s16(sum_v.val[1], vget_high_s16(luma), coeff_v);
+    }
+
+    WRITE_AUTO_REGRESSION_RESULT(0);
+    WRITE_AUTO_REGRESSION_RESULT(1);
+    WRITE_AUTO_REGRESSION_RESULT(2);
+    WRITE_AUTO_REGRESSION_RESULT(3);
+    if (chroma_width_remainder == 6) {
+      WRITE_AUTO_REGRESSION_RESULT(4);
+      WRITE_AUTO_REGRESSION_RESULT(5);
+    }
+
+    luma_grain += kLumaWidth << subsampling_y;
+    u_grain += chroma_width;
+    v_grain += chroma_width;
+  } while (++y < chroma_height);
+#undef ACCUMULATE_WEIGHTED_GRAIN
+#undef WRITE_AUTO_REGRESSION_RESULT
+}
+
+// Applies an auto-regressive filter to the white noise in luma_grain.
+template <int bitdepth, typename GrainType, int auto_regression_coeff_lag>
+void ApplyAutoRegressiveFilterToLumaGrain_NEON(const FilmGrainParams& params,
+                                               void* luma_grain_buffer) {
+  static_assert(auto_regression_coeff_lag > 0, "");
+  const int8_t* const auto_regression_coeff_y = params.auto_regression_coeff_y;
+  const uint8_t auto_regression_shift = params.auto_regression_shift;
+
+  int y = kAutoRegressionBorder;
+  auto* luma_grain =
+      static_cast<GrainType*>(luma_grain_buffer) + kLumaWidth * y;
+  do {
+    // Each row is computed 8 values at a time in the following loop. At the
+    // end of the loop, 4 values remain to write. They are given a special
+    // reduced iteration at the end.
+    int x = kAutoRegressionBorder;
+    do {
+      int pos = 0;
+      int32x4x2_t sum;
+      SetZero(&sum);
+      for (int delta_row = -auto_regression_coeff_lag; delta_row < 0;
+           ++delta_row) {
+        // These loads may overflow to the next row, but they are never called
+        // on the final row of a grain block. Therefore, they will never exceed
+        // the block boundaries.
+        const int16x8_t src_grain_lo =
+            GetSignedSource8(luma_grain + x + delta_row * kLumaWidth -
+                             auto_regression_coeff_lag);
+        const int16x8_t src_grain_hi =
+            GetSignedSource8(luma_grain + x + delta_row * kLumaWidth -
+                             auto_regression_coeff_lag + 8);
+
+        // A pictorial representation of the auto-regressive filter for
+        // various values of params.auto_regression_coeff_lag. The letter 'O'
+        // represents the current sample. (The filter always operates on the
+        // current sample with filter coefficient 1.) The letters 'X'
+        // represent the neighboring samples that the filter operates on, below
+        // their corresponding "offset" number.
+        //
+        // params.auto_regression_coeff_lag == 3:
+        //   0 1 2 3 4 5 6
+        //   X X X X X X X
+        //   X X X X X X X
+        //   X X X X X X X
+        //   X X X O
+        // params.auto_regression_coeff_lag == 2:
+        //     0 1 2 3 4
+        //     X X X X X
+        //     X X X X X
+        //     X X O
+        // params.auto_regression_coeff_lag == 1:
+        //       0 1 2
+        //       X X X
+        //       X O
+        // params.auto_regression_coeff_lag == 0:
+        //         O
+        // The function relies on the caller to skip the call in the 0 lag
+        // case.
+
+#define ACCUMULATE_WEIGHTED_GRAIN(offset)                           \
+  sum = AccumulateWeightedGrain<offset>(src_grain_lo, src_grain_hi, \
+                                        auto_regression_coeff_y[pos++], sum)
+        ACCUMULATE_WEIGHTED_GRAIN(0);
+        ACCUMULATE_WEIGHTED_GRAIN(1);
+        ACCUMULATE_WEIGHTED_GRAIN(2);
+        // The horizontal |auto_regression_coeff_lag| loop is replaced with
+        // if-statements to give vextq_s16 an immediate param.
+        if (auto_regression_coeff_lag > 1) {
+          ACCUMULATE_WEIGHTED_GRAIN(3);
+          ACCUMULATE_WEIGHTED_GRAIN(4);
+        }
+        if (auto_regression_coeff_lag > 2) {
+          assert(auto_regression_coeff_lag == 3);
+          ACCUMULATE_WEIGHTED_GRAIN(5);
+          ACCUMULATE_WEIGHTED_GRAIN(6);
+        }
+      }
+      // At this point in the filter, the source addresses and destination
+      // addresses overlap. Because this is an auto-regressive filter, the
+      // higher lanes cannot be computed without the results of the lower lanes.
+      // Each call to WriteFinalAutoRegression incorporates preceding values
+      // on the final row, and writes a single sample. This allows the next
+      // pixel's value to be computed in the next call.
+#define WRITE_AUTO_REGRESSION_RESULT(lane)                             \
+  WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>( \
+      luma_grain + x, sum, auto_regression_coeff_y, pos,               \
+      auto_regression_shift)
+
+      WRITE_AUTO_REGRESSION_RESULT(0);
+      WRITE_AUTO_REGRESSION_RESULT(1);
+      WRITE_AUTO_REGRESSION_RESULT(2);
+      WRITE_AUTO_REGRESSION_RESULT(3);
+      WRITE_AUTO_REGRESSION_RESULT(4);
+      WRITE_AUTO_REGRESSION_RESULT(5);
+      WRITE_AUTO_REGRESSION_RESULT(6);
+      WRITE_AUTO_REGRESSION_RESULT(7);
+      x += 8;
+      // Leave the final four pixels for the special iteration below.
+    } while (x < kLumaWidth - kAutoRegressionBorder - 4);
+
+    // Final 4 pixels in the row.
+    int pos = 0;
+    int32x4x2_t sum;
+    SetZero(&sum);
+    for (int delta_row = -auto_regression_coeff_lag; delta_row < 0;
+         ++delta_row) {
+      const int16x8_t src_grain_lo = GetSignedSource8(
+          luma_grain + x + delta_row * kLumaWidth - auto_regression_coeff_lag);
+      const int16x8_t src_grain_hi =
+          GetSignedSource8(luma_grain + x + delta_row * kLumaWidth -
+                           auto_regression_coeff_lag + 8);
+
+      ACCUMULATE_WEIGHTED_GRAIN(0);
+      ACCUMULATE_WEIGHTED_GRAIN(1);
+      ACCUMULATE_WEIGHTED_GRAIN(2);
+      // The horizontal |auto_regression_coeff_lag| loop is replaced with
+      // if-statements to give vextq_s16 an immediate param.
+      if (auto_regression_coeff_lag > 1) {
+        ACCUMULATE_WEIGHTED_GRAIN(3);
+        ACCUMULATE_WEIGHTED_GRAIN(4);
+      }
+      if (auto_regression_coeff_lag > 2) {
+        assert(auto_regression_coeff_lag == 3);
+        ACCUMULATE_WEIGHTED_GRAIN(5);
+        ACCUMULATE_WEIGHTED_GRAIN(6);
+      }
+    }
+    // delta_row == 0
+    WRITE_AUTO_REGRESSION_RESULT(0);
+    WRITE_AUTO_REGRESSION_RESULT(1);
+    WRITE_AUTO_REGRESSION_RESULT(2);
+    WRITE_AUTO_REGRESSION_RESULT(3);
+    luma_grain += kLumaWidth;
+  } while (++y < kLumaHeight);
+
+#undef WRITE_AUTO_REGRESSION_RESULT
+#undef ACCUMULATE_WEIGHTED_GRAIN
+}
+
+void InitializeScalingLookupTable_NEON(
+    int num_points, const uint8_t point_value[], const uint8_t point_scaling[],
+    uint8_t scaling_lut[kScalingLookupTableSize]) {
+  if (num_points == 0) {
+    memset(scaling_lut, 0, sizeof(scaling_lut[0]) * kScalingLookupTableSize);
+    return;
+  }
+  static_assert(sizeof(scaling_lut[0]) == 1, "");
+  memset(scaling_lut, point_scaling[0], point_value[0]);
+  const uint32x4_t steps = vmovl_u16(vcreate_u16(0x0003000200010000));
+  const uint32x4_t offset = vdupq_n_u32(32768);
+  for (int i = 0; i < num_points - 1; ++i) {
+    const int delta_y = point_scaling[i + 1] - point_scaling[i];
+    const int delta_x = point_value[i + 1] - point_value[i];
+    const int delta = delta_y * ((65536 + (delta_x >> 1)) / delta_x);
+    const int delta4 = delta << 2;
+    const uint8x8_t base_point = vdup_n_u8(point_scaling[i]);
+    uint32x4_t upscaled_points0 = vmlaq_n_u32(offset, steps, delta);
+    const uint32x4_t line_increment4 = vdupq_n_u32(delta4);
+    // Get the second set of 4 points by adding 4 steps to the first set.
+    uint32x4_t upscaled_points1 = vaddq_u32(upscaled_points0, line_increment4);
+    // We obtain the next set of 8 points by adding 8 steps to each of the
+    // current 8 points.
+    const uint32x4_t line_increment8 = vshlq_n_u32(line_increment4, 1);
+    int x = 0;
+    do {
+      const uint16x4_t interp_points0 = vshrn_n_u32(upscaled_points0, 16);
+      const uint16x4_t interp_points1 = vshrn_n_u32(upscaled_points1, 16);
+      const uint8x8_t interp_points =
+          vmovn_u16(vcombine_u16(interp_points0, interp_points1));
+      // The spec guarantees that the max value of |point_value[i]| + x is 255.
+      // Writing 8 bytes starting at the final table byte, leaves 7 bytes of
+      // required padding.
+      vst1_u8(&scaling_lut[point_value[i] + x],
+              vadd_u8(interp_points, base_point));
+      upscaled_points0 = vaddq_u32(upscaled_points0, line_increment8);
+      upscaled_points1 = vaddq_u32(upscaled_points1, line_increment8);
+      x += 8;
+    } while (x < delta_x);
+  }
+  const uint8_t last_point_value = point_value[num_points - 1];
+  memset(&scaling_lut[last_point_value], point_scaling[num_points - 1],
+         kScalingLookupTableSize - last_point_value);
+}
+
+inline int16x8_t Clip3(const int16x8_t value, const int16x8_t low,
+                       const int16x8_t high) {
+  const int16x8_t clipped_to_ceiling = vminq_s16(high, value);
+  return vmaxq_s16(low, clipped_to_ceiling);
+}
+
+template <int bitdepth, typename Pixel>
+inline int16x8_t GetScalingFactors(
+    const uint8_t scaling_lut[kScalingLookupTableSize], const Pixel* source) {
+  int16_t start_vals[8];
+  if (bitdepth == 8) {
+    start_vals[0] = scaling_lut[source[0]];
+    start_vals[1] = scaling_lut[source[1]];
+    start_vals[2] = scaling_lut[source[2]];
+    start_vals[3] = scaling_lut[source[3]];
+    start_vals[4] = scaling_lut[source[4]];
+    start_vals[5] = scaling_lut[source[5]];
+    start_vals[6] = scaling_lut[source[6]];
+    start_vals[7] = scaling_lut[source[7]];
+    return vld1q_s16(start_vals);
+  }
+  int16_t end_vals[8];
+  // TODO(petersonab): Precompute this into a larger table for direct lookups.
+  int index = source[0] >> 2;
+  start_vals[0] = scaling_lut[index];
+  end_vals[0] = scaling_lut[index + 1];
+  index = source[1] >> 2;
+  start_vals[1] = scaling_lut[index];
+  end_vals[1] = scaling_lut[index + 1];
+  index = source[2] >> 2;
+  start_vals[2] = scaling_lut[index];
+  end_vals[2] = scaling_lut[index + 1];
+  index = source[3] >> 2;
+  start_vals[3] = scaling_lut[index];
+  end_vals[3] = scaling_lut[index + 1];
+  index = source[4] >> 2;
+  start_vals[4] = scaling_lut[index];
+  end_vals[4] = scaling_lut[index + 1];
+  index = source[5] >> 2;
+  start_vals[5] = scaling_lut[index];
+  end_vals[5] = scaling_lut[index + 1];
+  index = source[6] >> 2;
+  start_vals[6] = scaling_lut[index];
+  end_vals[6] = scaling_lut[index + 1];
+  index = source[7] >> 2;
+  start_vals[7] = scaling_lut[index];
+  end_vals[7] = scaling_lut[index + 1];
+  const int16x8_t start = vld1q_s16(start_vals);
+  const int16x8_t end = vld1q_s16(end_vals);
+  int16x8_t remainder = GetSignedSource8(source);
+  remainder = vandq_s16(remainder, vdupq_n_s16(3));
+  const int16x8_t delta = vmulq_s16(vsubq_s16(end, start), remainder);
+  return vaddq_s16(start, vrshrq_n_s16(delta, 2));
+}
+
+inline int16x8_t ScaleNoise(const int16x8_t noise, const int16x8_t scaling,
+                            const int16x8_t scaling_shift_vect) {
+  const int16x8_t upscaled_noise = vmulq_s16(noise, scaling);
+  return vrshlq_s16(upscaled_noise, scaling_shift_vect);
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+inline int16x8_t ScaleNoise(const int16x8_t noise, const int16x8_t scaling,
+                            const int32x4_t scaling_shift_vect) {
+  // TODO(petersonab): Try refactoring scaling lookup table to int16_t and
+  // upscaling by 7 bits to permit high half multiply. This would eliminate
+  // the intermediate 32x4 registers. Also write the averaged values directly
+  // into the table so it doesn't have to be done for every pixel in
+  // the frame.
+  const int32x4_t upscaled_noise_lo =
+      vmull_s16(vget_low_s16(noise), vget_low_s16(scaling));
+  const int32x4_t upscaled_noise_hi =
+      vmull_s16(vget_high_s16(noise), vget_high_s16(scaling));
+  const int16x4_t noise_lo =
+      vmovn_s32(vrshlq_s32(upscaled_noise_lo, scaling_shift_vect));
+  const int16x4_t noise_hi =
+      vmovn_s32(vrshlq_s32(upscaled_noise_hi, scaling_shift_vect));
+  return vcombine_s16(noise_lo, noise_hi);
+}
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+template <int bitdepth, typename GrainType, typename Pixel>
+void BlendNoiseWithImageLuma_NEON(
+    const void* noise_image_ptr, int min_value, int max_luma, int scaling_shift,
+    int width, int height, int start_height,
+    const uint8_t scaling_lut_y[kScalingLookupTableSize],
+    const void* source_plane_y, ptrdiff_t source_stride_y, void* dest_plane_y,
+    ptrdiff_t dest_stride_y) {
+  const auto* noise_image =
+      static_cast<const Array2D<GrainType>*>(noise_image_ptr);
+  const auto* in_y_row = static_cast<const Pixel*>(source_plane_y);
+  source_stride_y /= sizeof(Pixel);
+  auto* out_y_row = static_cast<Pixel*>(dest_plane_y);
+  dest_stride_y /= sizeof(Pixel);
+  const int16x8_t floor = vdupq_n_s16(min_value);
+  const int16x8_t ceiling = vdupq_n_s16(max_luma);
+  // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe
+  // for 16 bit signed integers. In higher bitdepths, however, we have to
+  // expand to 32 to protect the sign bit.
+  const int16x8_t scaling_shift_vect16 = vdupq_n_s16(-scaling_shift);
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  const int32x4_t scaling_shift_vect32 = vdupq_n_s32(-scaling_shift);
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+  int y = 0;
+  do {
+    int x = 0;
+    do {
+      // This operation on the unsigned input is safe in 8bpp because the vector
+      // is widened before it is reinterpreted.
+      const int16x8_t orig = GetSignedSource8(&in_y_row[x]);
+      const int16x8_t scaling =
+          GetScalingFactors<bitdepth, Pixel>(scaling_lut_y, &in_y_row[x]);
+      int16x8_t noise =
+          GetSignedSource8(&(noise_image[kPlaneY][y + start_height][x]));
+
+      if (bitdepth == 8) {
+        noise = ScaleNoise(noise, scaling, scaling_shift_vect16);
+      } else {
+#if LIBGAV1_MAX_BITDEPTH >= 10
+        noise = ScaleNoise(noise, scaling, scaling_shift_vect32);
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+      }
+      const int16x8_t combined = vaddq_s16(orig, noise);
+      // In 8bpp, when params_.clip_to_restricted_range == false, we can replace
+      // clipping with vqmovun_s16, but it's not likely to be worth copying the
+      // function for just that case, though the gain would be very small.
+      StoreUnsigned8(&out_y_row[x],
+                     vreinterpretq_u16_s16(Clip3(combined, floor, ceiling)));
+      x += 8;
+    } while (x < width);
+    in_y_row += source_stride_y;
+    out_y_row += dest_stride_y;
+  } while (++y < height);
+}
+
+template <int bitdepth, typename GrainType, typename Pixel>
+inline int16x8_t BlendChromaValsWithCfl(
+    const Pixel* average_luma_buffer,
+    const uint8_t scaling_lut[kScalingLookupTableSize],
+    const Pixel* chroma_cursor, const GrainType* noise_image_cursor,
+    const int16x8_t scaling_shift_vect16,
+    const int32x4_t scaling_shift_vect32) {
+  const int16x8_t scaling =
+      GetScalingFactors<bitdepth, Pixel>(scaling_lut, average_luma_buffer);
+  const int16x8_t orig = GetSignedSource8(chroma_cursor);
+  int16x8_t noise = GetSignedSource8(noise_image_cursor);
+  if (bitdepth == 8) {
+    noise = ScaleNoise(noise, scaling, scaling_shift_vect16);
+  } else {
+    noise = ScaleNoise(noise, scaling, scaling_shift_vect32);
+  }
+  return vaddq_s16(orig, noise);
+}
+
+template <int bitdepth, typename GrainType, typename Pixel>
+LIBGAV1_ALWAYS_INLINE void BlendChromaPlaneWithCfl_NEON(
+    const Array2D<GrainType>& noise_image, int min_value, int max_chroma,
+    int width, int height, int start_height, int subsampling_x,
+    int subsampling_y, int scaling_shift,
+    const uint8_t scaling_lut[kScalingLookupTableSize], const Pixel* in_y_row,
+    ptrdiff_t source_stride_y, const Pixel* in_chroma_row,
+    ptrdiff_t source_stride_chroma, Pixel* out_chroma_row,
+    ptrdiff_t dest_stride) {
+  const int16x8_t floor = vdupq_n_s16(min_value);
+  const int16x8_t ceiling = vdupq_n_s16(max_chroma);
+  Pixel luma_buffer[16];
+  memset(luma_buffer, 0, sizeof(luma_buffer));
+  // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe
+  // for 16 bit signed integers. In higher bitdepths, however, we have to
+  // expand to 32 to protect the sign bit.
+  const int16x8_t scaling_shift_vect16 = vdupq_n_s16(-scaling_shift);
+  const int32x4_t scaling_shift_vect32 = vdupq_n_s32(-scaling_shift);
+
+  const int chroma_height = (height + subsampling_y) >> subsampling_y;
+  const int chroma_width = (width + subsampling_x) >> subsampling_x;
+  const int safe_chroma_width = chroma_width & ~7;
+
+  // Writing to this buffer avoids the cost of doing 8 lane lookups in a row
+  // in GetScalingFactors.
+  Pixel average_luma_buffer[8];
+  assert(start_height % 2 == 0);
+  start_height >>= subsampling_y;
+  int y = 0;
+  do {
+    int x = 0;
+    do {
+      const int luma_x = x << subsampling_x;
+      // TODO(petersonab): Consider specializing by subsampling_x. In the 444
+      // case &in_y_row[x] can be passed to GetScalingFactors directly.
+      const uint16x8_t average_luma =
+          GetAverageLuma(&in_y_row[luma_x], subsampling_x);
+      StoreUnsigned8(average_luma_buffer, average_luma);
+
+      const int16x8_t blended =
+          BlendChromaValsWithCfl<bitdepth, GrainType, Pixel>(
+              average_luma_buffer, scaling_lut, &in_chroma_row[x],
+              &(noise_image[y + start_height][x]), scaling_shift_vect16,
+              scaling_shift_vect32);
+
+      // In 8bpp, when params_.clip_to_restricted_range == false, we can replace
+      // clipping with vqmovun_s16, but it's not likely to be worth copying the
+      // function for just that case.
+      StoreUnsigned8(&out_chroma_row[x],
+                     vreinterpretq_u16_s16(Clip3(blended, floor, ceiling)));
+      x += 8;
+    } while (x < safe_chroma_width);
+
+    if (x < chroma_width) {
+      const int luma_x = x << subsampling_x;
+      const int valid_range = width - luma_x;
+      memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0]));
+      luma_buffer[valid_range] = in_y_row[width - 1];
+      const uint16x8_t average_luma =
+          GetAverageLuma(luma_buffer, subsampling_x);
+      StoreUnsigned8(average_luma_buffer, average_luma);
+
+      const int16x8_t blended =
+          BlendChromaValsWithCfl<bitdepth, GrainType, Pixel>(
+              average_luma_buffer, scaling_lut, &in_chroma_row[x],
+              &(noise_image[y + start_height][x]), scaling_shift_vect16,
+              scaling_shift_vect32);
+      // In 8bpp, when params_.clip_to_restricted_range == false, we can replace
+      // clipping with vqmovun_s16, but it's not likely to be worth copying the
+      // function for just that case.
+      StoreUnsigned8(&out_chroma_row[x],
+                     vreinterpretq_u16_s16(Clip3(blended, floor, ceiling)));
+    }
+
+    in_y_row += source_stride_y << subsampling_y;
+    in_chroma_row += source_stride_chroma;
+    out_chroma_row += dest_stride;
+  } while (++y < chroma_height);
+}
+
+// This function is for the case params_.chroma_scaling_from_luma == true.
+// This further implies that scaling_lut_u == scaling_lut_v == scaling_lut_y.
+template <int bitdepth, typename GrainType, typename Pixel>
+void BlendNoiseWithImageChromaWithCfl_NEON(
+    Plane plane, const FilmGrainParams& params, const void* noise_image_ptr,
+    int min_value, int max_chroma, int width, int height, int start_height,
+    int subsampling_x, int subsampling_y,
+    const uint8_t scaling_lut[kScalingLookupTableSize],
+    const void* source_plane_y, ptrdiff_t source_stride_y,
+    const void* source_plane_uv, ptrdiff_t source_stride_uv,
+    void* dest_plane_uv, ptrdiff_t dest_stride_uv) {
+  const auto* noise_image =
+      static_cast<const Array2D<GrainType>*>(noise_image_ptr);
+  const auto* in_y = static_cast<const Pixel*>(source_plane_y);
+  source_stride_y /= sizeof(Pixel);
+
+  const auto* in_uv = static_cast<const Pixel*>(source_plane_uv);
+  source_stride_uv /= sizeof(Pixel);
+  auto* out_uv = static_cast<Pixel*>(dest_plane_uv);
+  dest_stride_uv /= sizeof(Pixel);
+  // Looping over one plane at a time is faster in higher resolutions, despite
+  // re-computing luma.
+  BlendChromaPlaneWithCfl_NEON<bitdepth, GrainType, Pixel>(
+      noise_image[plane], min_value, max_chroma, width, height, start_height,
+      subsampling_x, subsampling_y, params.chroma_scaling, scaling_lut, in_y,
+      source_stride_y, in_uv, source_stride_uv, out_uv, dest_stride_uv);
+}
+
+}  // namespace
+
+namespace low_bitdepth {
+namespace {
+
+inline int16x8_t BlendChromaValsNoCfl(
+    const uint8_t scaling_lut[kScalingLookupTableSize],
+    const uint8_t* chroma_cursor, const int8_t* noise_image_cursor,
+    const int16x8_t& average_luma, const int16x8_t& scaling_shift_vect,
+    const int16x8_t& offset, int luma_multiplier, int chroma_multiplier) {
+  uint8_t merged_buffer[8];
+  const int16x8_t orig = GetSignedSource8(chroma_cursor);
+  const int16x8_t weighted_luma = vmulq_n_s16(average_luma, luma_multiplier);
+  const int16x8_t weighted_chroma = vmulq_n_s16(orig, chroma_multiplier);
+  // Maximum value of |combined_u| is 127*255 = 0x7E81.
+  const int16x8_t combined = vhaddq_s16(weighted_luma, weighted_chroma);
+  // Maximum value of u_offset is (255 << 5) = 0x1FE0.
+  // 0x7E81 + 0x1FE0 = 0x9E61, therefore another halving add is required.
+  const uint8x8_t merged = vqshrun_n_s16(vhaddq_s16(offset, combined), 4);
+  vst1_u8(merged_buffer, merged);
+  const int16x8_t scaling =
+      GetScalingFactors<8, uint8_t>(scaling_lut, merged_buffer);
+  int16x8_t noise = GetSignedSource8(noise_image_cursor);
+  noise = ScaleNoise(noise, scaling, scaling_shift_vect);
+  return vaddq_s16(orig, noise);
+}
+
+LIBGAV1_ALWAYS_INLINE void BlendChromaPlane8bpp_NEON(
+    const Array2D<int8_t>& noise_image, int min_value, int max_chroma,
+    int width, int height, int start_height, int subsampling_x,
+    int subsampling_y, int scaling_shift, int chroma_offset,
+    int chroma_multiplier, int luma_multiplier,
+    const uint8_t scaling_lut[kScalingLookupTableSize], const uint8_t* in_y_row,
+    ptrdiff_t source_stride_y, const uint8_t* in_chroma_row,
+    ptrdiff_t source_stride_chroma, uint8_t* out_chroma_row,
+    ptrdiff_t dest_stride) {
+  const int16x8_t floor = vdupq_n_s16(min_value);
+  const int16x8_t ceiling = vdupq_n_s16(max_chroma);
+  // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe
+  // for 16 bit signed integers. In higher bitdepths, however, we have to
+  // expand to 32 to protect the sign bit.
+  const int16x8_t scaling_shift_vect = vdupq_n_s16(-scaling_shift);
+
+  const int chroma_height = (height + subsampling_y) >> subsampling_y;
+  const int chroma_width = (width + subsampling_x) >> subsampling_x;
+  const int safe_chroma_width = chroma_width & ~7;
+  uint8_t luma_buffer[16];
+  const int16x8_t offset = vdupq_n_s16(chroma_offset << 5);
+
+  start_height >>= subsampling_y;
+  int y = 0;
+  do {
+    int x = 0;
+    do {
+      const int luma_x = x << subsampling_x;
+      const int16x8_t average_luma = vreinterpretq_s16_u16(
+          GetAverageLuma(&in_y_row[luma_x], subsampling_x));
+      const int16x8_t blended = BlendChromaValsNoCfl(
+          scaling_lut, &in_chroma_row[x], &(noise_image[y + start_height][x]),
+          average_luma, scaling_shift_vect, offset, luma_multiplier,
+          chroma_multiplier);
+      // In 8bpp, when params_.clip_to_restricted_range == false, we can
+      // replace clipping with vqmovun_s16, but the gain would be small.
+      StoreUnsigned8(&out_chroma_row[x],
+                     vreinterpretq_u16_s16(Clip3(blended, floor, ceiling)));
+
+      x += 8;
+    } while (x < safe_chroma_width);
+
+    if (x < chroma_width) {
+      // Begin right edge iteration. Same as the normal iterations, but the
+      // |average_luma| computation requires a duplicated luma value at the
+      // end.
+      const int luma_x = x << subsampling_x;
+      const int valid_range = width - luma_x;
+      memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0]));
+      luma_buffer[valid_range] = in_y_row[width - 1];
+
+      const int16x8_t average_luma =
+          vreinterpretq_s16_u16(GetAverageLuma(luma_buffer, subsampling_x));
+      const int16x8_t blended = BlendChromaValsNoCfl(
+          scaling_lut, &in_chroma_row[x], &(noise_image[y + start_height][x]),
+          average_luma, scaling_shift_vect, offset, luma_multiplier,
+          chroma_multiplier);
+      StoreUnsigned8(&out_chroma_row[x],
+                     vreinterpretq_u16_s16(Clip3(blended, floor, ceiling)));
+      // End of right edge iteration.
+    }
+
+    in_y_row += source_stride_y << subsampling_y;
+    in_chroma_row += source_stride_chroma;
+    out_chroma_row += dest_stride;
+  } while (++y < chroma_height);
+}
+
+// This function is for the case params_.chroma_scaling_from_luma == false.
+void BlendNoiseWithImageChroma8bpp_NEON(
+    Plane plane, const FilmGrainParams& params, const void* noise_image_ptr,
+    int min_value, int max_chroma, int width, int height, int start_height,
+    int subsampling_x, int subsampling_y,
+    const uint8_t scaling_lut[kScalingLookupTableSize],
+    const void* source_plane_y, ptrdiff_t source_stride_y,
+    const void* source_plane_uv, ptrdiff_t source_stride_uv,
+    void* dest_plane_uv, ptrdiff_t dest_stride_uv) {
+  assert(plane == kPlaneU || plane == kPlaneV);
+  const auto* noise_image =
+      static_cast<const Array2D<int8_t>*>(noise_image_ptr);
+  const auto* in_y = static_cast<const uint8_t*>(source_plane_y);
+  const auto* in_uv = static_cast<const uint8_t*>(source_plane_uv);
+  auto* out_uv = static_cast<uint8_t*>(dest_plane_uv);
+
+  const int offset = (plane == kPlaneU) ? params.u_offset : params.v_offset;
+  const int luma_multiplier =
+      (plane == kPlaneU) ? params.u_luma_multiplier : params.v_luma_multiplier;
+  const int multiplier =
+      (plane == kPlaneU) ? params.u_multiplier : params.v_multiplier;
+  BlendChromaPlane8bpp_NEON(noise_image[plane], min_value, max_chroma, width,
+                            height, start_height, subsampling_x, subsampling_y,
+                            params.chroma_scaling, offset, multiplier,
+                            luma_multiplier, scaling_lut, in_y, source_stride_y,
+                            in_uv, source_stride_uv, out_uv, dest_stride_uv);
+}
+
+inline void WriteOverlapLine8bpp_NEON(const int8_t* noise_stripe_row,
+                                      const int8_t* noise_stripe_row_prev,
+                                      int plane_width,
+                                      const int8x8_t grain_coeff,
+                                      const int8x8_t old_coeff,
+                                      int8_t* noise_image_row) {
+  int x = 0;
+  do {
+    // Note that these reads may exceed noise_stripe_row's width by up to 7
+    // bytes.
+    const int8x8_t source_grain = vld1_s8(noise_stripe_row + x);
+    const int8x8_t source_old = vld1_s8(noise_stripe_row_prev + x);
+    const int16x8_t weighted_grain = vmull_s8(grain_coeff, source_grain);
+    const int16x8_t grain = vmlal_s8(weighted_grain, old_coeff, source_old);
+    // Note that this write may exceed noise_image_row's width by up to 7 bytes.
+    vst1_s8(noise_image_row + x, vqrshrn_n_s16(grain, 5));
+    x += 8;
+  } while (x < plane_width);
+}
+
+void ConstructNoiseImageOverlap8bpp_NEON(const void* noise_stripes_buffer,
+                                         int width, int height,
+                                         int subsampling_x, int subsampling_y,
+                                         void* noise_image_buffer) {
+  const auto* noise_stripes =
+      static_cast<const Array2DView<int8_t>*>(noise_stripes_buffer);
+  auto* noise_image = static_cast<Array2D<int8_t>*>(noise_image_buffer);
+  const int plane_width = (width + subsampling_x) >> subsampling_x;
+  const int plane_height = (height + subsampling_y) >> subsampling_y;
+  const int stripe_height = 32 >> subsampling_y;
+  const int stripe_mask = stripe_height - 1;
+  int y = stripe_height;
+  int luma_num = 1;
+  if (subsampling_y == 0) {
+    const int8x8_t first_row_grain_coeff = vdup_n_s8(17);
+    const int8x8_t first_row_old_coeff = vdup_n_s8(27);
+    const int8x8_t second_row_grain_coeff = first_row_old_coeff;
+    const int8x8_t second_row_old_coeff = first_row_grain_coeff;
+    for (; y < (plane_height & ~stripe_mask); ++luma_num, y += stripe_height) {
+      const int8_t* noise_stripe = (*noise_stripes)[luma_num];
+      const int8_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1];
+      WriteOverlapLine8bpp_NEON(
+          noise_stripe, &noise_stripe_prev[32 * plane_width], plane_width,
+          first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]);
+
+      WriteOverlapLine8bpp_NEON(&noise_stripe[plane_width],
+                                &noise_stripe_prev[(32 + 1) * plane_width],
+                                plane_width, second_row_grain_coeff,
+                                second_row_old_coeff, (*noise_image)[y + 1]);
+    }
+    // Either one partial stripe remains (remaining_height  > 0),
+    // OR image is less than one stripe high (remaining_height < 0),
+    // OR all stripes are completed (remaining_height == 0).
+    const int remaining_height = plane_height - y;
+    if (remaining_height <= 0) {
+      return;
+    }
+    const int8_t* noise_stripe = (*noise_stripes)[luma_num];
+    const int8_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1];
+    WriteOverlapLine8bpp_NEON(
+        noise_stripe, &noise_stripe_prev[32 * plane_width], plane_width,
+        first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]);
+
+    if (remaining_height > 1) {
+      WriteOverlapLine8bpp_NEON(&noise_stripe[plane_width],
+                                &noise_stripe_prev[(32 + 1) * plane_width],
+                                plane_width, second_row_grain_coeff,
+                                second_row_old_coeff, (*noise_image)[y + 1]);
+    }
+  } else {  // subsampling_y == 1
+    const int8x8_t first_row_grain_coeff = vdup_n_s8(22);
+    const int8x8_t first_row_old_coeff = vdup_n_s8(23);
+    for (; y < plane_height; ++luma_num, y += stripe_height) {
+      const int8_t* noise_stripe = (*noise_stripes)[luma_num];
+      const int8_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1];
+      WriteOverlapLine8bpp_NEON(
+          noise_stripe, &noise_stripe_prev[16 * plane_width], plane_width,
+          first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]);
+    }
+  }
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+
+  // LumaAutoRegressionFunc
+  dsp->film_grain.luma_auto_regression[0] =
+      ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 1>;
+  dsp->film_grain.luma_auto_regression[1] =
+      ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 2>;
+  dsp->film_grain.luma_auto_regression[2] =
+      ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 3>;
+
+  // ChromaAutoRegressionFunc[use_luma][auto_regression_coeff_lag]
+  // Chroma autoregression should never be called when lag is 0 and use_luma
+  // is false.
+  dsp->film_grain.chroma_auto_regression[0][0] = nullptr;
+  dsp->film_grain.chroma_auto_regression[0][1] =
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 1, false>;
+  dsp->film_grain.chroma_auto_regression[0][2] =
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 2, false>;
+  dsp->film_grain.chroma_auto_regression[0][3] =
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 3, false>;
+  dsp->film_grain.chroma_auto_regression[1][0] =
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 0, true>;
+  dsp->film_grain.chroma_auto_regression[1][1] =
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 1, true>;
+  dsp->film_grain.chroma_auto_regression[1][2] =
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 2, true>;
+  dsp->film_grain.chroma_auto_regression[1][3] =
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 3, true>;
+
+  dsp->film_grain.construct_noise_image_overlap =
+      ConstructNoiseImageOverlap8bpp_NEON;
+
+  dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_NEON;
+
+  dsp->film_grain.blend_noise_luma =
+      BlendNoiseWithImageLuma_NEON<8, int8_t, uint8_t>;
+  dsp->film_grain.blend_noise_chroma[0] = BlendNoiseWithImageChroma8bpp_NEON;
+  dsp->film_grain.blend_noise_chroma[1] =
+      BlendNoiseWithImageChromaWithCfl_NEON<8, int8_t, uint8_t>;
+}
+
+}  // namespace
+}  // namespace low_bitdepth
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+namespace high_bitdepth {
+namespace {
+
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+
+  // LumaAutoRegressionFunc
+  dsp->film_grain.luma_auto_regression[0] =
+      ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 1>;
+  dsp->film_grain.luma_auto_regression[1] =
+      ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 2>;
+  dsp->film_grain.luma_auto_regression[2] =
+      ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 3>;
+
+  // ChromaAutoRegressionFunc[use_luma][auto_regression_coeff_lag][subsampling]
+  // Chroma autoregression should never be called when lag is 0 and use_luma
+  // is false.
+  dsp->film_grain.chroma_auto_regression[0][0] = nullptr;
+  dsp->film_grain.chroma_auto_regression[0][1] =
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 1, false>;
+  dsp->film_grain.chroma_auto_regression[0][2] =
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 2, false>;
+  dsp->film_grain.chroma_auto_regression[0][3] =
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 3, false>;
+  dsp->film_grain.chroma_auto_regression[1][0] =
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 0, true>;
+  dsp->film_grain.chroma_auto_regression[1][1] =
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 1, true>;
+  dsp->film_grain.chroma_auto_regression[1][2] =
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 2, true>;
+  dsp->film_grain.chroma_auto_regression[1][3] =
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 3, true>;
+
+  dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_NEON;
+
+  dsp->film_grain.blend_noise_luma =
+      BlendNoiseWithImageLuma_NEON<10, int16_t, uint16_t>;
+  dsp->film_grain.blend_noise_chroma[1] =
+      BlendNoiseWithImageChromaWithCfl_NEON<10, int16_t, uint16_t>;
+}
+
+}  // namespace
+}  // namespace high_bitdepth
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+}  // namespace film_grain
+
+void FilmGrainInit_NEON() {
+  film_grain::low_bitdepth::Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  film_grain::high_bitdepth::Init10bpp();
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+}
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else  // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void FilmGrainInit_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/film_grain_neon.h b/libgav1/src/dsp/arm/film_grain_neon.h
new file mode 100644
index 0000000..44b3d1d
--- /dev/null
+++ b/libgav1/src/dsp/arm/film_grain_neon.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_FILM_GRAIN_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_FILM_GRAIN_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initialize members of Dsp::film_grain. This function is not thread-safe.
+void FilmGrainInit_NEON();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_FilmGrainAutoregressionLuma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_FilmGrainAutoregressionLuma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_FilmGrainAutoregressionChroma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_FilmGrainAutoregressionChroma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_FilmGrainConstructNoiseImageOverlap LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_FilmGrainInitializeScalingLutFunc LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_FilmGrainInitializeScalingLutFunc LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseLuma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseLuma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChroma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChromaWithCfl LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseChromaWithCfl LIBGAV1_DSP_NEON
+#endif  // LIBGAV1_ENABLE_NEON
+
+#endif  // LIBGAV1_SRC_DSP_ARM_FILM_GRAIN_NEON_H_
diff --git a/libgav1/src/dsp/arm/intra_edge_neon.cc b/libgav1/src/dsp/arm/intra_edge_neon.cc
index eff5a23..00b186a 100644
--- a/libgav1/src/dsp/arm/intra_edge_neon.cc
+++ b/libgav1/src/dsp/arm/intra_edge_neon.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/intra_edge.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -23,6 +23,8 @@
 #include <cassert>
 
 #include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/utils/common.h"  // RightShiftWithRounding()
 
 namespace libgav1 {
@@ -275,7 +277,7 @@
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   dsp->intra_edge_filter = IntraEdgeFilter_NEON;
   dsp->intra_edge_upsampler = IntraEdgeUpsampler_NEON;
@@ -288,7 +290,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/arm/intra_edge_neon.h b/libgav1/src/dsp/arm/intra_edge_neon.h
index 5ecba8a..d3bb243 100644
--- a/libgav1/src/dsp/arm/intra_edge_neon.h
+++ b/libgav1/src/dsp/arm/intra_edge_neon.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_ARM_INTRA_EDGE_NEON_H_
 #define LIBGAV1_SRC_DSP_ARM_INTRA_EDGE_NEON_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -31,8 +31,8 @@
 }  // namespace libgav1
 
 #if LIBGAV1_ENABLE_NEON
-#define LIBGAV1_Dsp8bpp_IntraEdgeFilter LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_IntraEdgeUpsampler LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_IntraEdgeFilter LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_IntraEdgeUpsampler LIBGAV1_CPU_NEON
 
 #endif  // LIBGAV1_ENABLE_NEON
 
diff --git a/libgav1/src/dsp/arm/intrapred_cfl_neon.cc b/libgav1/src/dsp/arm/intrapred_cfl_neon.cc
index a4e4f05..45fe33b 100644
--- a/libgav1/src/dsp/arm/intrapred_cfl_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_cfl_neon.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/intrapred.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -24,6 +24,8 @@
 #include <cstdint>
 
 #include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/utils/common.h"
 
 namespace libgav1 {
@@ -36,35 +38,36 @@
   return vreinterpretq_u8_u16(vdupq_n_u16(combined_values));
 }
 
-int SumVector(uint32x2_t a) {
+uint32_t SumVector(uint32x2_t a) {
 #if defined(__aarch64__)
   return vaddv_u32(a);
 #else
   const uint64x1_t b = vpaddl_u32(a);
-  return vget_lane_u64(b, 0);
+  return vget_lane_u32(vreinterpret_u32_u64(b), 0);
 #endif  // defined(__aarch64__)
 }
 
-int SumVector(uint32x4_t a) {
+uint32_t SumVector(uint32x4_t a) {
 #if defined(__aarch64__)
   return vaddvq_u32(a);
 #else
   const uint64x2_t b = vpaddlq_u32(a);
   const uint64x1_t c = vadd_u64(vget_low_u64(b), vget_high_u64(b));
-  return vget_lane_u64(c, 0);
+  return vget_lane_u32(vreinterpret_u32_u64(c), 0);
 #endif  // defined(__aarch64__)
 }
 
 // Divide by the number of elements.
-int Average(const int sum, const int width, const int height) {
+uint32_t Average(const uint32_t sum, const int width, const int height) {
   return RightShiftWithRounding(sum, FloorLog2(width) + FloorLog2(height));
 }
 
 // Subtract |val| from every element in |a|.
-void BlockSubtract(const int val,
+void BlockSubtract(const uint32_t val,
                    int16_t a[kCflLumaBufferStride][kCflLumaBufferStride],
                    const int width, const int height) {
-  const int16x8_t val_v = vdupq_n_s16(val);
+  assert(val <= INT16_MAX);
+  const int16x8_t val_v = vdupq_n_s16(static_cast<int16_t>(val));
 
   for (int y = 0; y < height; ++y) {
     if (width == 4) {
@@ -97,7 +100,7 @@
     const int max_luma_width, const int max_luma_height,
     const void* const source, const ptrdiff_t stride) {
   const auto* src = static_cast<const uint8_t*>(source);
-  int sum;
+  uint32_t sum;
   if (block_width == 4) {
     assert(max_luma_width >= 8);
     uint32x2_t running_sum = vdup_n_u32(0);
@@ -193,7 +196,7 @@
     sum = SumVector(running_sum);
   }
 
-  const int average = Average(sum, block_width, block_height);
+  const uint32_t average = Average(sum, block_width, block_height);
   BlockSubtract(average, luma, block_width, block_height);
 }
 
@@ -203,15 +206,15 @@
     const int max_luma_width, const int max_luma_height,
     const void* const source, const ptrdiff_t stride) {
   const auto* src = static_cast<const uint8_t*>(source);
-  int sum;
+  uint32_t sum;
   if (block_width == 4) {
     assert(max_luma_width >= 4);
     uint32x4_t running_sum = vdupq_n_u32(0);
+    uint8x8_t row = vdup_n_u8(0);
 
     for (int y = 0; y < block_height; y += 2) {
-      uint8x8_t row = vdup_n_u8(0);
-      row = LoadLo4(src, row);
-      row = LoadHi4(src + stride, row);
+      row = Load4<0>(src, row);
+      row = Load4<1>(src + stride, row);
       if (y < (max_luma_height - 1)) {
         src += stride << 1;
       }
@@ -272,7 +275,7 @@
     sum = SumVector(running_sum);
   }
 
-  const int average = Average(sum, block_width, block_height);
+  const uint32_t average = Average(sum, block_width, block_height);
   BlockSubtract(average, luma, block_width, block_height);
 }
 
@@ -281,8 +284,7 @@
                           const int16x8_t dc) {
   const int16x8_t la = vmulq_n_s16(luma, alpha);
   // Subtract the sign bit to round towards zero.
-  const int16x8_t sub_sign = vsubq_s16(
-      la, vreinterpretq_s16_u16(vshrq_n_u16(vreinterpretq_u16_s16(la), 15)));
+  const int16x8_t sub_sign = vsraq_n_s16(la, la, 15);
   // Shift and accumulate.
   const int16x8_t result = vrsraq_n_s16(dc, sub_sign, 6);
   return vqmovun_s16(result);
@@ -367,7 +369,7 @@
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
 
   dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType420] =
@@ -466,7 +468,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/arm/intrapred_directional_neon.cc b/libgav1/src/dsp/arm/intrapred_directional_neon.cc
index a4714d6..805ba81 100644
--- a/libgav1/src/dsp/arm/intrapred_directional_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_directional_neon.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/intrapred.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -26,6 +26,8 @@
 #include <cstring>  // memset
 
 #include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/utils/common.h"
 
 namespace libgav1 {
@@ -898,7 +900,7 @@
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   dsp->directional_intra_predictor_zone1 = DirectionalIntraPredictorZone1_NEON;
   dsp->directional_intra_predictor_zone2 = DirectionalIntraPredictorZone2_NEON;
@@ -913,7 +915,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/arm/intrapred_filter_intra_neon.cc b/libgav1/src/dsp/arm/intrapred_filter_intra_neon.cc
index 0887c85..411708e 100644
--- a/libgav1/src/dsp/arm/intrapred_filter_intra_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_filter_intra_neon.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/intrapred.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -24,6 +24,8 @@
 #include <cstdint>
 
 #include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/utils/common.h"
 
 namespace libgav1 {
@@ -46,42 +48,42 @@
 //
 // We take this into account when summing the values by subtracting the product
 // of the first row.
-const uint8_t kTransposedTaps[kNumFilterIntraPredictors][7][8] = {
-    {{6, 5, 3, 3, 4, 3, 3, 3},  // Original values are negative.
-     {10, 2, 1, 1, 6, 2, 2, 1},
-     {0, 10, 1, 1, 0, 6, 2, 2},
-     {0, 0, 10, 2, 0, 0, 6, 2},
-     {0, 0, 0, 10, 0, 0, 0, 6},
-     {12, 9, 7, 5, 2, 2, 2, 3},
-     {0, 0, 0, 0, 12, 9, 7, 5}},
-    {{10, 6, 4, 2, 10, 6, 4, 2},  // Original values are negative.
-     {16, 0, 0, 0, 16, 0, 0, 0},
-     {0, 16, 0, 0, 0, 16, 0, 0},
-     {0, 0, 16, 0, 0, 0, 16, 0},
-     {0, 0, 0, 16, 0, 0, 0, 16},
-     {10, 6, 4, 2, 0, 0, 0, 0},
-     {0, 0, 0, 0, 10, 6, 4, 2}},
-    {{8, 8, 8, 8, 4, 4, 4, 4},  // Original values are negative.
-     {8, 0, 0, 0, 4, 0, 0, 0},
-     {0, 8, 0, 0, 0, 4, 0, 0},
-     {0, 0, 8, 0, 0, 0, 4, 0},
-     {0, 0, 0, 8, 0, 0, 0, 4},
-     {16, 16, 16, 16, 0, 0, 0, 0},
-     {0, 0, 0, 0, 16, 16, 16, 16}},
-    {{2, 1, 1, 0, 1, 1, 1, 1},  // Original values are negative.
-     {8, 3, 2, 1, 4, 3, 2, 2},
-     {0, 8, 3, 2, 0, 4, 3, 2},
-     {0, 0, 8, 3, 0, 0, 4, 3},
-     {0, 0, 0, 8, 0, 0, 0, 4},
-     {10, 6, 4, 2, 3, 4, 4, 3},
-     {0, 0, 0, 0, 10, 6, 4, 3}},
-    {{12, 10, 9, 8, 10, 9, 8, 7},  // Original values are negative.
-     {14, 0, 0, 0, 12, 1, 0, 0},
-     {0, 14, 0, 0, 0, 12, 0, 0},
-     {0, 0, 14, 0, 0, 0, 12, 1},
-     {0, 0, 0, 14, 0, 0, 0, 12},
-     {14, 12, 11, 10, 0, 0, 1, 1},
-     {0, 0, 0, 0, 14, 12, 11, 9}}};
+alignas(8) constexpr uint8_t kTransposedTaps[kNumFilterIntraPredictors][7][8] =
+    {{{6, 5, 3, 3, 4, 3, 3, 3},  // Original values are negative.
+      {10, 2, 1, 1, 6, 2, 2, 1},
+      {0, 10, 1, 1, 0, 6, 2, 2},
+      {0, 0, 10, 2, 0, 0, 6, 2},
+      {0, 0, 0, 10, 0, 0, 0, 6},
+      {12, 9, 7, 5, 2, 2, 2, 3},
+      {0, 0, 0, 0, 12, 9, 7, 5}},
+     {{10, 6, 4, 2, 10, 6, 4, 2},  // Original values are negative.
+      {16, 0, 0, 0, 16, 0, 0, 0},
+      {0, 16, 0, 0, 0, 16, 0, 0},
+      {0, 0, 16, 0, 0, 0, 16, 0},
+      {0, 0, 0, 16, 0, 0, 0, 16},
+      {10, 6, 4, 2, 0, 0, 0, 0},
+      {0, 0, 0, 0, 10, 6, 4, 2}},
+     {{8, 8, 8, 8, 4, 4, 4, 4},  // Original values are negative.
+      {8, 0, 0, 0, 4, 0, 0, 0},
+      {0, 8, 0, 0, 0, 4, 0, 0},
+      {0, 0, 8, 0, 0, 0, 4, 0},
+      {0, 0, 0, 8, 0, 0, 0, 4},
+      {16, 16, 16, 16, 0, 0, 0, 0},
+      {0, 0, 0, 0, 16, 16, 16, 16}},
+     {{2, 1, 1, 0, 1, 1, 1, 1},  // Original values are negative.
+      {8, 3, 2, 1, 4, 3, 2, 2},
+      {0, 8, 3, 2, 0, 4, 3, 2},
+      {0, 0, 8, 3, 0, 0, 4, 3},
+      {0, 0, 0, 8, 0, 0, 0, 4},
+      {10, 6, 4, 2, 3, 4, 4, 3},
+      {0, 0, 0, 0, 10, 6, 4, 3}},
+     {{12, 10, 9, 8, 10, 9, 8, 7},  // Original values are negative.
+      {14, 0, 0, 0, 12, 1, 0, 0},
+      {0, 14, 0, 0, 0, 12, 0, 0},
+      {0, 0, 14, 0, 0, 0, 12, 1},
+      {0, 0, 0, 14, 0, 0, 0, 12},
+      {14, 12, 11, 10, 0, 0, 1, 1},
+      {0, 0, 0, 0, 14, 12, 11, 9}}};
 
 void FilterIntraPredictor_NEON(void* const dest, ptrdiff_t stride,
                                const void* const top_row,
@@ -150,7 +152,7 @@
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   dsp->filter_intra_predictor = FilterIntraPredictor_NEON;
 }
@@ -163,7 +165,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/arm/intrapred_neon.cc b/libgav1/src/dsp/arm/intrapred_neon.cc
index 14ca346..c967d82 100644
--- a/libgav1/src/dsp/arm/intrapred_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_neon.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/intrapred.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -25,6 +25,7 @@
 
 #include "src/dsp/arm/common_neon.h"
 #include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -158,11 +159,10 @@
   const auto* const ref_0_u8 = static_cast<const uint8_t*>(ref_0);
   const auto* const ref_1_u8 = static_cast<const uint8_t*>(ref_1);
   if (ref_0_size_log2 == 2) {
-    uint8x8_t val = vdup_n_u8(0);
-    val = LoadLo4(ref_0_u8, val);
+    uint8x8_t val = Load4(ref_0_u8);
     if (use_ref_1) {
       if (ref_1_size_log2 == 2) {  // 4x4
-        val = LoadHi4(ref_1_u8, val);
+        val = Load4<1>(ref_1_u8, val);
         return Sum(vpaddl_u8(val));
       } else if (ref_1_size_log2 == 3) {  // 4x8
         const uint8x8_t val_1 = vld1_u8(ref_1_u8);
@@ -171,9 +171,7 @@
         return Sum(vadd_u16(sum_0, sum_1));
       } else if (ref_1_size_log2 == 4) {  // 4x16
         const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
-        const uint16x8_t sum_0 = vmovl_u8(val);
-        const uint16x8_t sum_1 = vpaddlq_u8(val_1);
-        return Sum(vaddq_u16(sum_0, sum_1));
+        return Sum(vaddw_u8(vpaddlq_u8(val_1), val));
       }
     }
     // 4x1
@@ -183,8 +181,7 @@
     const uint8x8_t val_0 = vld1_u8(ref_0_u8);
     if (use_ref_1) {
       if (ref_1_size_log2 == 2) {  // 8x4
-        uint8x8_t val_1 = vdup_n_u8(0);
-        val_1 = LoadLo4(ref_1_u8, val_1);
+        const uint8x8_t val_1 = Load4(ref_1_u8);
         const uint16x4_t sum_0 = vpaddl_u8(val_0);
         const uint16x4_t sum_1 = vpaddl_u8(val_1);
         return Sum(vadd_u16(sum_0, sum_1));
@@ -195,12 +192,9 @@
         return Sum(vadd_u16(sum_0, sum_1));
       } else if (ref_1_size_log2 == 4) {  // 8x16
         const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
-        const uint16x8_t sum_0 = vmovl_u8(val_0);
-        const uint16x8_t sum_1 = vpaddlq_u8(val_1);
-        return Sum(vaddq_u16(sum_0, sum_1));
+        return Sum(vaddw_u8(vpaddlq_u8(val_1), val_0));
       } else if (ref_1_size_log2 == 5) {  // 8x32
-        const uint16x8_t sum_0 = vmovl_u8(val_0);
-        return Sum(vaddq_u16(sum_0, LoadAndAdd32(ref_1_u8)));
+        return Sum(vaddw_u8(LoadAndAdd32(ref_1_u8), val_0));
       }
     }
     // 8x1
@@ -209,16 +203,11 @@
     const uint8x16_t val_0 = vld1q_u8(ref_0_u8);
     if (use_ref_1) {
       if (ref_1_size_log2 == 2) {  // 16x4
-        uint8x8_t val_1 = vdup_n_u8(0);
-        val_1 = LoadLo4(ref_1_u8, val_1);
-        const uint16x8_t sum_0 = vmovl_u8(val_1);
-        const uint16x8_t sum_u16 = vpaddlq_u8(val_0);
-        return Sum(vaddq_u16(sum_0, sum_u16));
+        const uint8x8_t val_1 = Load4(ref_1_u8);
+        return Sum(vaddw_u8(vpaddlq_u8(val_0), val_1));
       } else if (ref_1_size_log2 == 3) {  // 16x8
         const uint8x8_t val_1 = vld1_u8(ref_1_u8);
-        const uint16x8_t sum_0 = vpaddlq_u8(val_0);
-        const uint16x8_t sum_1 = vmovl_u8(val_1);
-        return Sum(vaddq_u16(sum_0, sum_1));
+        return Sum(vaddw_u8(vpaddlq_u8(val_0), val_1));
       } else if (ref_1_size_log2 == 4) {  // 16x16
         const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
         return Sum(Add(val_0, val_1));
@@ -239,8 +228,7 @@
     if (use_ref_1) {
       if (ref_1_size_log2 == 3) {  // 32x8
         const uint8x8_t val_1 = vld1_u8(ref_1_u8);
-        const uint16x8_t sum_1 = vmovl_u8(val_1);
-        return Sum(vaddq_u16(sum_0, sum_1));
+        return Sum(vaddw_u8(sum_0, val_1));
       } else if (ref_1_size_log2 == 4) {  // 32x16
         const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
         const uint16x8_t sum_1 = vpaddlq_u8(val_1);
@@ -340,8 +328,7 @@
   const uint16x8_t top_left_x2 = vdupq_n_u16(top_row_u8[-1] + top_row_u8[-1]);
   uint8x8_t top;
   if (width == 4) {
-    top = vdup_n_u8(0);
-    top = LoadLo4(top_row_u8, top);
+    top = Load4(top_row_u8);
   } else {  // width == 8
     top = vld1_u8(top_row_u8);
   }
@@ -388,6 +375,8 @@
 inline uint8x16_t XLeTopLeft(const uint8x16_t x_dist,
                              const uint16x8_t top_left_dist_low,
                              const uint16x8_t top_left_dist_high) {
+  // TODO(johannkoenig): cle() should work with vmovn(top_left_dist) instead of
+  // using movl(x_dist).
   const uint8x8_t x_le_top_left_low =
       vmovn_u16(vcleq_u16(vmovl_u8(vget_low_u8(x_dist)), top_left_dist_low));
   const uint8x8_t x_le_top_left_high =
@@ -536,7 +525,7 @@
 };
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   // 4x4
   dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcTop] =
@@ -976,7 +965,7 @@
 };
 
 void Init10bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
   assert(dsp != nullptr);
   dsp->intra_predictors[kTransformSize4x4][kIntraPredictorDcTop] =
       DcDefs::_4x4::DcTop;
@@ -1144,7 +1133,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/arm/intrapred_neon.h b/libgav1/src/dsp/arm/intrapred_neon.h
index fa56228..16f858c 100644
--- a/libgav1/src/dsp/arm/intrapred_neon.h
+++ b/libgav1/src/dsp/arm/intrapred_neon.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_ARM_INTRAPRED_NEON_H_
 #define LIBGAV1_SRC_DSP_ARM_INTRAPRED_NEON_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -38,381 +38,381 @@
 
 #if LIBGAV1_ENABLE_NEON
 // 8 bit
-#define LIBGAV1_Dsp8bpp_FilterIntraPredictor LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_FilterIntraPredictor LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 LIBGAV1_CPU_NEON
 
 // 4x4
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 LIBGAV1_CPU_NEON
 
 // 4x8
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 LIBGAV1_CPU_NEON
 
 // 4x16
-#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 LIBGAV1_CPU_NEON
 
 // 8x4
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 LIBGAV1_CPU_NEON
 
 // 8x8
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 LIBGAV1_CPU_NEON
 
 // 8x16
-#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 LIBGAV1_CPU_NEON
 
 // 8x32
-#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 LIBGAV1_CPU_NEON
 
 // 16x4
-#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 LIBGAV1_CPU_NEON
 
 // 16x8
-#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 LIBGAV1_CPU_NEON
 
 // 16x16
-#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 LIBGAV1_CPU_NEON
 
 // 16x32
-#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 LIBGAV1_CPU_NEON
 
 // 16x64
-#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
 // 32x8
-#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 LIBGAV1_CPU_NEON
 
 // 32x16
-#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 LIBGAV1_CPU_NEON
 
 // 32x32
-#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 LIBGAV1_CPU_NEON
 
 // 32x64
-#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
 // 64x16
-#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
 // 64x32
-#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
 // 64x64
-#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDc LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorPaeth LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmooth LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDc LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorPaeth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmooth LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
 
 // 10 bit
 // 4x4
-#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 4x8
-#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDc LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 4x16
-#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDc LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 8x4
-#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDc LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 8x8
-#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDc LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 8x16
-#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDc LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 8x32
-#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDc LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 16x4
-#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDc LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 16x8
-#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDc LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 16x16
-#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDcTop LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDcLeft \
-  LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDc LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 16x32
-#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDcTop LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDcLeft \
-  LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDc LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 16x64
-#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDcTop LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDcLeft \
-  LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDc LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 32x8
-#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDcTop LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDcLeft LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDc LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDcLeft LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 32x16
-#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDcTop LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDcLeft \
-  LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDc LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 32x32
-#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDcTop LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDcLeft \
-  LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDc LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 32x64
-#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDcTop LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDcLeft \
-  LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDc LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 64x16
-#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDcTop LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDcLeft \
-  LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDc LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 64x32
-#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDcTop LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDcLeft \
-  LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDc LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDc LIBGAV1_CPU_NEON
 
 // 64x64
-#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDcTop LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDcLeft \
-  LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDc LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDc LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_INTRAPRED_NEON_H_
diff --git a/libgav1/src/dsp/arm/intrapred_smooth_neon.cc b/libgav1/src/dsp/arm/intrapred_smooth_neon.cc
index 242587a..abc93e8 100644
--- a/libgav1/src/dsp/arm/intrapred_smooth_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_smooth_neon.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/intrapred.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -25,6 +25,7 @@
 
 #include "src/dsp/arm/common_neon.h"
 #include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -51,6 +52,9 @@
     69, 65, 61, 57, 54, 50, 47, 44, 41, 38, 35, 32, 29, 27, 25, 22, 20, 18, 16,
     15, 13, 12, 10, 9, 8, 7, 6, 6, 5, 5, 4, 4, 4};
 
+// TODO(b/150459137): Keeping the intermediate values in uint16_t would allow
+// processing more values at once. At the high end, it could do 4x4 or 8x2 at a
+// time.
 inline uint16x4_t CalculatePred(const uint16x4_t weighted_top,
                                 const uint16x4_t weighted_left,
                                 const uint16x4_t weighted_bl,
@@ -73,10 +77,8 @@
   uint8_t* dst = static_cast<uint8_t*>(dest);
 
   uint8x8_t top_v;
-  // TODO(johannkoenig): Process 16 values (4x4 / 8x2) at a time.
   if (width == 4) {
-    top_v = vdup_n_u8(0);
-    top_v = LoadLo4(top, top_v);
+    top_v = Load4(top);
   } else {  // width == 8
     top_v = vld1_u8(top);
   }
@@ -237,8 +239,7 @@
 
   uint8x8_t top_v;
   if (width == 4) {
-    top_v = vdup_n_u8(0);
-    top_v = LoadLo4(top, top_v);
+    top_v = Load4(top);
   } else {  // width == 8
     top_v = vld1_u8(top);
   }
@@ -441,7 +442,7 @@
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   // 4x4
   dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmooth] =
@@ -604,7 +605,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/arm/inverse_transform_neon.cc b/libgav1/src/dsp/arm/inverse_transform_neon.cc
index 70b7ff6..5ad53f6 100644
--- a/libgav1/src/dsp/arm/inverse_transform_neon.cc
+++ b/libgav1/src/dsp/arm/inverse_transform_neon.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/inverse_transform.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -24,6 +24,8 @@
 #include <cstdint>
 
 #include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/utils/array_2d.h"
 #include "src/utils/common.h"
 #include "src/utils/compiler_attributes.h"
@@ -386,12 +388,11 @@
                                                          const bool flip) {
   const int16_t cos128 = Cos128(angle);
   const int16_t sin128 = Sin128(angle);
-  const int32x4_t x0 = vmlsl_n_s16(vdupq_n_s32(0), vget_low_s16(*b), sin128);
-  const int32x4_t x0_hi =
-      vmlsl_n_s16(vdupq_n_s32(0), vget_high_s16(*b), sin128);
-  const int16x4_t x1 = vqrshrn_n_s32(x0, 12);
-  const int16x4_t x1_hi = vqrshrn_n_s32(x0_hi, 12);
-  const int16x8_t x = vcombine_s16(x1, x1_hi);
+  // For this function, the max value returned by Sin128() is 4091, which fits
+  // inside 12 bits.  This leaves room for the sign bit and the 3 left shifted
+  // bits.
+  assert(sin128 <= 0xfff);
+  const int16x8_t x = vqrdmulhq_s16(*b, vdupq_n_s16(-sin128 << 3));
   const int16x8_t y = vqrdmulhq_s16(*b, vdupq_n_s16(cos128 << 3));
   if (flip) {
     *a = y;
@@ -471,6 +472,40 @@
   return true;
 }
 
+template <int height>
+LIBGAV1_ALWAYS_INLINE bool DctDcOnlyColumn(void* dest, const void* source,
+                                           int non_zero_coeff_count,
+                                           int width) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  const int16_t cos128 = Cos128(32);
+
+  // Calculate dc values for first row.
+  if (width == 4) {
+    const int16x4_t v_src = vld1_s16(src);
+    const int16x4_t xy = vqrdmulh_s16(v_src, vdup_n_s16(cos128 << 3));
+    vst1_s16(dst, xy);
+  } else {
+    int i = 0;
+    do {
+      const int16x8_t v_src = vld1q_s16(&src[i]);
+      const int16x8_t xy = vqrdmulhq_s16(v_src, vdupq_n_s16(cos128 << 3));
+      vst1q_s16(&dst[i], xy);
+      i += 8;
+    } while (i < width);
+  }
+
+  // Copy first row to the rest of the block.
+  for (int y = 1; y < height; ++y) {
+    memcpy(&dst[y * width], &src[(y - 1) * width], width * sizeof(dst[0]));
+  }
+  return true;
+}
+
 template <ButterflyRotationFunc bufferfly_rotation,
           bool is_fast_bufferfly = false>
 LIBGAV1_ALWAYS_INLINE void Dct4Stages(int16x8_t* s) {
@@ -666,13 +701,14 @@
 // Process dct16 rows or columns, depending on the transpose flag.
 template <ButterflyRotationFunc bufferfly_rotation, bool stage_is_rectangular>
 LIBGAV1_ALWAYS_INLINE void Dct16_NEON(void* dest, const void* source,
-                                      int32_t step, bool transpose) {
+                                      int32_t step, bool is_row,
+                                      int row_shift) {
   auto* const dst = static_cast<int16_t*>(dest);
   const auto* const src = static_cast<const int16_t*>(source);
   int16x8_t s[16], x[16];
 
   if (stage_is_rectangular) {
-    if (transpose) {
+    if (is_row) {
       int16x8_t input[4];
       LoadSrc<16, 4>(src, step, 0, input);
       Transpose8x4To4x8(input, x);
@@ -681,7 +717,7 @@
     } else {
       LoadSrc<8, 16>(src, step, 0, x);
     }
-  } else if (transpose) {
+  } else if (is_row) {
     for (int idx = 0; idx < 16; idx += 8) {
       int16x8_t input[8];
       LoadSrc<16, 8>(src, step, idx, input);
@@ -714,8 +750,15 @@
   Dct8Stages<bufferfly_rotation>(s);
   Dct16Stages<bufferfly_rotation>(s);
 
+  if (is_row) {
+    const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
+    for (int i = 0; i < 16; ++i) {
+      s[i] = vqrshlq_s16(s[i], v_row_shift);
+    }
+  }
+
   if (stage_is_rectangular) {
-    if (transpose) {
+    if (is_row) {
       int16x8_t output[4];
       Transpose4x8To8x4(s, output);
       StoreDst<16, 4>(dst, step, 0, output);
@@ -724,7 +767,7 @@
     } else {
       StoreDst<8, 16>(dst, step, 0, s);
     }
-  } else if (transpose) {
+  } else if (is_row) {
     for (int idx = 0; idx < 16; idx += 8) {
       int16x8_t output[8];
       Transpose8x8(&s[idx], output);
@@ -827,13 +870,13 @@
 
 // Process dct32 rows or columns, depending on the transpose flag.
 LIBGAV1_ALWAYS_INLINE void Dct32_NEON(void* dest, const void* source,
-                                      const int32_t step,
-                                      const bool transpose) {
+                                      const int32_t step, const bool is_row,
+                                      int row_shift) {
   auto* const dst = static_cast<int16_t*>(dest);
   const auto* const src = static_cast<const int16_t*>(source);
   int16x8_t s[32], x[32];
 
-  if (transpose) {
+  if (is_row) {
     for (int idx = 0; idx < 32; idx += 8) {
       int16x8_t input[8];
       LoadSrc<16, 8>(src, step, idx, input);
@@ -886,10 +929,14 @@
   Dct16Stages<ButterflyRotation_8>(s);
   Dct32Stages<ButterflyRotation_8>(s);
 
-  if (transpose) {
+  if (is_row) {
+    const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
     for (int idx = 0; idx < 32; idx += 8) {
       int16x8_t output[8];
       Transpose8x8(&s[idx], output);
+      for (int i = 0; i < 8; ++i) {
+        output[i] = vqrshlq_s16(output[i], v_row_shift);
+      }
       StoreDst<16, 8>(dst, step, idx, output);
     }
   } else {
@@ -899,12 +946,13 @@
 
 // Allow the compiler to call this function instead of force inlining. Tests
 // show the performance is slightly faster.
-void Dct64_NEON(void* dest, const void* source, int32_t step, bool transpose) {
+void Dct64_NEON(void* dest, const void* source, int32_t step, bool is_row,
+                int row_shift) {
   auto* const dst = static_cast<int16_t*>(dest);
   const auto* const src = static_cast<const int16_t*>(source);
   int16x8_t s[64], x[32];
 
-  if (transpose) {
+  if (is_row) {
     // The last 32 values of every row are always zero if the |tx_width| is
     // 64.
     for (int idx = 0; idx < 32; idx += 8) {
@@ -1105,10 +1153,14 @@
   }
   //-- end dct 64 stages
 
-  if (transpose) {
+  if (is_row) {
+    const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
     for (int idx = 0; idx < 64; idx += 8) {
       int16x8_t output[8];
       Transpose8x8(&s[idx], output);
+      for (int i = 0; i < 8; ++i) {
+        output[i] = vqrshlq_s16(output[i], v_row_shift);
+      }
       StoreDst<16, 8>(dst, step, idx, output);
     }
   } else {
@@ -1141,29 +1193,24 @@
     }
   }
 
-  const int16x4_t kAdst4Multiplier_0 = vdup_n_s16(kAdst4Multiplier[0]);
-  const int16x4_t kAdst4Multiplier_1 = vdup_n_s16(kAdst4Multiplier[1]);
-  const int16x4_t kAdst4Multiplier_2 = vdup_n_s16(kAdst4Multiplier[2]);
-  const int16x4_t kAdst4Multiplier_3 = vdup_n_s16(kAdst4Multiplier[3]);
-
   // stage 1.
-  s[5] = vmull_s16(kAdst4Multiplier_1, vget_low_s16(x[3]));
-  s[6] = vmull_s16(kAdst4Multiplier_3, vget_low_s16(x[3]));
+  s[5] = vmull_n_s16(vget_low_s16(x[3]), kAdst4Multiplier[1]);
+  s[6] = vmull_n_s16(vget_low_s16(x[3]), kAdst4Multiplier[3]);
 
   // stage 2.
   const int32x4_t a7 = vsubl_s16(vget_low_s16(x[0]), vget_low_s16(x[2]));
   const int32x4_t b7 = vaddw_s16(a7, vget_low_s16(x[3]));
 
   // stage 3.
-  s[0] = vmull_s16(kAdst4Multiplier_0, vget_low_s16(x[0]));
-  s[1] = vmull_s16(kAdst4Multiplier_1, vget_low_s16(x[0]));
+  s[0] = vmull_n_s16(vget_low_s16(x[0]), kAdst4Multiplier[0]);
+  s[1] = vmull_n_s16(vget_low_s16(x[0]), kAdst4Multiplier[1]);
   // s[0] = s[0] + s[3]
-  s[0] = vmlal_s16(s[0], kAdst4Multiplier_3, vget_low_s16(x[2]));
+  s[0] = vmlal_n_s16(s[0], vget_low_s16(x[2]), kAdst4Multiplier[3]);
   // s[1] = s[1] - s[4]
-  s[1] = vmlsl_s16(s[1], kAdst4Multiplier_0, vget_low_s16(x[2]));
+  s[1] = vmlsl_n_s16(s[1], vget_low_s16(x[2]), kAdst4Multiplier[0]);
 
-  s[3] = vmull_s16(kAdst4Multiplier_2, vget_low_s16(x[1]));
-  s[2] = vmulq_s32(vmovl_s16(kAdst4Multiplier_2), b7);
+  s[3] = vmull_n_s16(vget_low_s16(x[1]), kAdst4Multiplier[2]);
+  s[2] = vmulq_n_s32(b7, kAdst4Multiplier[2]);
 
   // stage 4.
   s[0] = vaddq_s32(s[0], s[5]);
@@ -1200,6 +1247,82 @@
   }
 }
 
+alignas(8) constexpr int16_t kAdst4DcOnlyMultiplier[4] = {1321, 2482, 3344,
+                                                          2482};
+
+LIBGAV1_ALWAYS_INLINE bool Adst4DcOnly(void* dest, const void* source,
+                                       int non_zero_coeff_count,
+                                       bool should_round, int row_shift) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  int32x4_t s[2];
+
+  const int16x4_t v_src0 = vdup_n_s16(src[0]);
+  const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0);
+  const int16x4_t v_src_round =
+      vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
+  const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0);
+  const int16x4_t kAdst4DcOnlyMultipliers = vld1_s16(kAdst4DcOnlyMultiplier);
+  s[1] = vdupq_n_s32(0);
+
+  // s0*k0 s0*k1 s0*k2 s0*k1
+  s[0] = vmull_s16(kAdst4DcOnlyMultipliers, v_src);
+  // 0     0     0     s0*k0
+  s[1] = vextq_s32(s[1], s[0], 1);
+
+  const int32x4_t x3 = vaddq_s32(s[0], s[1]);
+  const int16x4_t dst_0 = vqrshrn_n_s32(x3, 12);
+
+  // vqrshlq_s16 will shift right if shift value is negative.
+  vst1_s16(dst, vqrshl_s16(dst_0, vdup_n_s16(-row_shift)));
+
+  return true;
+}
+
+LIBGAV1_ALWAYS_INLINE bool Adst4DcOnlyColumn(void* dest, const void* source,
+                                             int non_zero_coeff_count,
+                                             int width) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  int32x4_t s[4];
+
+  int i = 0;
+  do {
+    const int16x4_t v_src = vld1_s16(&src[i]);
+
+    s[0] = vmull_n_s16(v_src, kAdst4Multiplier[0]);
+    s[1] = vmull_n_s16(v_src, kAdst4Multiplier[1]);
+    s[2] = vmull_n_s16(v_src, kAdst4Multiplier[2]);
+
+    const int32x4_t x0 = s[0];
+    const int32x4_t x1 = s[1];
+    const int32x4_t x2 = s[2];
+    const int32x4_t x3 = vaddq_s32(s[0], s[1]);
+    const int16x4_t dst_0 = vqrshrn_n_s32(x0, 12);
+    const int16x4_t dst_1 = vqrshrn_n_s32(x1, 12);
+    const int16x4_t dst_2 = vqrshrn_n_s32(x2, 12);
+    const int16x4_t dst_3 = vqrshrn_n_s32(x3, 12);
+
+    vst1_s16(&dst[i], dst_0);
+    vst1_s16(&dst[i + width * 1], dst_1);
+    vst1_s16(&dst[i + width * 2], dst_2);
+    vst1_s16(&dst[i + width * 3], dst_3);
+
+    i += 4;
+  } while (i < width);
+
+  return true;
+}
+
 template <ButterflyRotationFunc bufferfly_rotation, bool stage_is_rectangular>
 LIBGAV1_ALWAYS_INLINE void Adst8_NEON(void* dest, const void* source,
                                       int32_t step, bool transpose) {
@@ -1290,15 +1413,132 @@
   }
 }
 
+LIBGAV1_ALWAYS_INLINE bool Adst8DcOnly(void* dest, const void* source,
+                                       int non_zero_coeff_count,
+                                       bool should_round, int row_shift) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  int16x8_t s[8];
+
+  const int16x8_t v_src = vdupq_n_s16(src[0]);
+  const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0);
+  const int16x8_t v_src_round =
+      vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3);
+  // stage 1.
+  s[1] = vbslq_s16(v_mask, v_src_round, v_src);
+
+  // stage 2.
+  ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true);
+
+  // stage 3.
+  s[4] = s[0];
+  s[5] = s[1];
+
+  // stage 4.
+  ButterflyRotation_4(&s[4], &s[5], 48, true);
+
+  // stage 5.
+  s[2] = s[0];
+  s[3] = s[1];
+  s[6] = s[4];
+  s[7] = s[5];
+
+  // stage 6.
+  ButterflyRotation_4(&s[2], &s[3], 32, true);
+  ButterflyRotation_4(&s[6], &s[7], 32, true);
+
+  // stage 7.
+  int16x8_t x[8];
+  x[0] = s[0];
+  x[1] = vqnegq_s16(s[4]);
+  x[2] = s[6];
+  x[3] = vqnegq_s16(s[2]);
+  x[4] = s[3];
+  x[5] = vqnegq_s16(s[7]);
+  x[6] = s[5];
+  x[7] = vqnegq_s16(s[1]);
+
+  for (int i = 0; i < 8; ++i) {
+    // vqrshlq_s16 will shift right if shift value is negative.
+    x[i] = vqrshlq_s16(x[i], vdupq_n_s16(-row_shift));
+    vst1q_lane_s16(&dst[i], x[i], 0);
+  }
+
+  return true;
+}
+
+LIBGAV1_ALWAYS_INLINE bool Adst8DcOnlyColumn(void* dest, const void* source,
+                                             int non_zero_coeff_count,
+                                             int width) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  int16x8_t s[8];
+
+  int i = 0;
+  do {
+    const int16x8_t v_src = vld1q_s16(&src[i]);
+    // stage 1.
+    s[1] = v_src;
+
+    // stage 2.
+    ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true);
+
+    // stage 3.
+    s[4] = s[0];
+    s[5] = s[1];
+
+    // stage 4.
+    ButterflyRotation_4(&s[4], &s[5], 48, true);
+
+    // stage 5.
+    s[2] = s[0];
+    s[3] = s[1];
+    s[6] = s[4];
+    s[7] = s[5];
+
+    // stage 6.
+    ButterflyRotation_4(&s[2], &s[3], 32, true);
+    ButterflyRotation_4(&s[6], &s[7], 32, true);
+
+    // stage 7.
+    int16x8_t x[8];
+    x[0] = s[0];
+    x[1] = vqnegq_s16(s[4]);
+    x[2] = s[6];
+    x[3] = vqnegq_s16(s[2]);
+    x[4] = s[3];
+    x[5] = vqnegq_s16(s[7]);
+    x[6] = s[5];
+    x[7] = vqnegq_s16(s[1]);
+
+    for (int j = 0; j < 8; ++j) {
+      vst1_s16(&dst[j * width], vget_low_s16(x[j]));
+    }
+    i += 4;
+    dst += 4;
+  } while (i < width);
+
+  return true;
+}
+
 template <ButterflyRotationFunc bufferfly_rotation, bool stage_is_rectangular>
 LIBGAV1_ALWAYS_INLINE void Adst16_NEON(void* dest, const void* source,
-                                       int32_t step, bool transpose) {
+                                       int32_t step, bool is_row,
+                                       int row_shift) {
   auto* const dst = static_cast<int16_t*>(dest);
   const auto* const src = static_cast<const int16_t*>(source);
   int16x8_t s[16], x[16];
 
   if (stage_is_rectangular) {
-    if (transpose) {
+    if (is_row) {
       int16x8_t input[4];
       LoadSrc<16, 4>(src, step, 0, input);
       Transpose8x4To4x8(input, x);
@@ -1308,7 +1548,7 @@
       LoadSrc<8, 16>(src, step, 0, x);
     }
   } else {
-    if (transpose) {
+    if (is_row) {
       for (int idx = 0; idx < 16; idx += 8) {
         int16x8_t input[8];
         LoadSrc<16, 8>(src, step, idx, input);
@@ -1414,20 +1654,31 @@
   x[15] = vqnegq_s16(s[1]);
 
   if (stage_is_rectangular) {
-    if (transpose) {
+    if (is_row) {
+      const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
       int16x8_t output[4];
       Transpose4x8To8x4(x, output);
+      for (int i = 0; i < 4; ++i) {
+        output[i] = vqrshlq_s16(output[i], v_row_shift);
+      }
       StoreDst<16, 4>(dst, step, 0, output);
       Transpose4x8To8x4(&x[8], output);
+      for (int i = 0; i < 4; ++i) {
+        output[i] = vqrshlq_s16(output[i], v_row_shift);
+      }
       StoreDst<16, 4>(dst, step, 8, output);
     } else {
       StoreDst<8, 16>(dst, step, 0, x);
     }
   } else {
-    if (transpose) {
+    if (is_row) {
+      const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
       for (int idx = 0; idx < 16; idx += 8) {
         int16x8_t output[8];
         Transpose8x8(&x[idx], output);
+        for (int i = 0; i < 8; ++i) {
+          output[i] = vqrshlq_s16(output[i], v_row_shift);
+        }
         StoreDst<16, 8>(dst, step, idx, output);
       }
     } else {
@@ -1436,6 +1687,122 @@
   }
 }
 
+LIBGAV1_ALWAYS_INLINE void Adst16DcOnlyInternal(int16x8_t* s, int16x8_t* x) {
+  // stage 2.
+  ButterflyRotation_FirstIsZero(&s[0], &s[1], 62, true);
+
+  // stage 3.
+  s[8] = s[0];
+  s[9] = s[1];
+
+  // stage 4.
+  ButterflyRotation_4(&s[8], &s[9], 56, true);
+
+  // stage 5.
+  s[4] = s[0];
+  s[12] = s[8];
+  s[5] = s[1];
+  s[13] = s[9];
+
+  // stage 6.
+  ButterflyRotation_4(&s[4], &s[5], 48, true);
+  ButterflyRotation_4(&s[12], &s[13], 48, true);
+
+  // stage 7.
+  s[2] = s[0];
+  s[6] = s[4];
+  s[10] = s[8];
+  s[14] = s[12];
+  s[3] = s[1];
+  s[7] = s[5];
+  s[11] = s[9];
+  s[15] = s[13];
+
+  // stage 8.
+  ButterflyRotation_4(&s[2], &s[3], 32, true);
+  ButterflyRotation_4(&s[6], &s[7], 32, true);
+  ButterflyRotation_4(&s[10], &s[11], 32, true);
+  ButterflyRotation_4(&s[14], &s[15], 32, true);
+
+  // stage 9.
+  x[0] = s[0];
+  x[1] = vqnegq_s16(s[8]);
+  x[2] = s[12];
+  x[3] = vqnegq_s16(s[4]);
+  x[4] = s[6];
+  x[5] = vqnegq_s16(s[14]);
+  x[6] = s[10];
+  x[7] = vqnegq_s16(s[2]);
+  x[8] = s[3];
+  x[9] = vqnegq_s16(s[11]);
+  x[10] = s[15];
+  x[11] = vqnegq_s16(s[7]);
+  x[12] = s[5];
+  x[13] = vqnegq_s16(s[13]);
+  x[14] = s[9];
+  x[15] = vqnegq_s16(s[1]);
+}
+
+LIBGAV1_ALWAYS_INLINE bool Adst16DcOnly(void* dest, const void* source,
+                                        int non_zero_coeff_count,
+                                        bool should_round, int row_shift) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  int16x8_t s[16];
+  int16x8_t x[16];
+
+  const int16x8_t v_src = vdupq_n_s16(src[0]);
+  const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0);
+  const int16x8_t v_src_round =
+      vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3);
+  // stage 1.
+  s[1] = vbslq_s16(v_mask, v_src_round, v_src);
+
+  Adst16DcOnlyInternal(s, x);
+
+  for (int i = 0; i < 16; ++i) {
+    // vqrshlq_s16 will shift right if shift value is negative.
+    x[i] = vqrshlq_s16(x[i], vdupq_n_s16(-row_shift));
+    vst1q_lane_s16(&dst[i], x[i], 0);
+  }
+
+  return true;
+}
+
+LIBGAV1_ALWAYS_INLINE bool Adst16DcOnlyColumn(void* dest, const void* source,
+                                              int non_zero_coeff_count,
+                                              int width) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  int i = 0;
+  do {
+    int16x8_t s[16];
+    int16x8_t x[16];
+    const int16x8_t v_src = vld1q_s16(&src[i]);
+    // stage 1.
+    s[1] = v_src;
+
+    Adst16DcOnlyInternal(s, x);
+
+    for (int j = 0; j < 16; ++j) {
+      vst1_s16(&dst[j * width], vget_low_s16(x[j]));
+    }
+    i += 4;
+    dst += 4;
+  } while (i < width);
+
+  return true;
+}
+
 //------------------------------------------------------------------------------
 // Identity Transforms.
 
@@ -1472,46 +1839,115 @@
   }
 }
 
-LIBGAV1_ALWAYS_INLINE void Identity4ColumnStoreToFrame(
+LIBGAV1_ALWAYS_INLINE bool Identity4DcOnly(void* dest, const void* source,
+                                           int non_zero_coeff_count,
+                                           bool should_round, int tx_height) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  const int16x4_t v_src0 = vdup_n_s16(src[0]);
+  const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0);
+  const int16x4_t v_src_round =
+      vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
+  const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0);
+  const int shift = tx_height < 16 ? 0 : 1;
+  const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
+  const int16x4_t v_multiplier = vdup_n_s16(kIdentity4Multiplier);
+  const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
+  const int32x4_t v_src_mult_lo = vmlal_s16(v_dual_round, v_src, v_multiplier);
+  const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, v_shift);
+  vst1_lane_s16(&dst[0], vqmovn_s32(dst_0), 0);
+  return true;
+}
+
+template <int identity_size>
+LIBGAV1_ALWAYS_INLINE void IdentityColumnStoreToFrame(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
     const int tx_width, const int tx_height, const int16_t* source) {
   const int stride = frame.columns();
   uint8_t* dst = frame[start_y] + start_x;
 
-  if (tx_width == 4) {
-    const uint8x8_t zero = vdup_n_u8(0);
-    for (int i = 0; i < tx_height; ++i) {
-      const int16x4_t residual = vld1_s16(&source[i * tx_width]);
-      const int16x4_t residual_fraction =
-          vqrdmulh_n_s16(residual, kIdentity4MultiplierFraction << 3);
-      const int16x4_t v_dst_i = vqadd_s16(residual, residual_fraction);
-      const int16x8_t frame_data =
-          vreinterpretq_s16_u16(vmovl_u8(LoadLo4(dst, zero)));
-      const int16x4_t a = vrshr_n_s16(v_dst_i, 4);
-      const int16x4_t b = vqadd_s16(a, vget_low_s16(frame_data));
-      const uint8x8_t d = vqmovun_s16(vcombine_s16(b, b));
-      StoreLo4(dst, d);
-      dst += stride;
+  if (identity_size < 32) {
+    if (tx_width == 4) {
+      uint8x8_t frame_data = vdup_n_u8(0);
+      int i = 0;
+      do {
+        const int16x4_t v_src = vld1_s16(&source[i * tx_width]);
+
+        int16x4_t v_dst_i;
+        if (identity_size == 4) {
+          const int16x4_t v_src_fraction =
+              vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 3);
+          v_dst_i = vqadd_s16(v_src, v_src_fraction);
+        } else if (identity_size == 8) {
+          v_dst_i = vqadd_s16(v_src, v_src);
+        } else {  // identity_size == 16
+          const int16x4_t v_src_mult =
+              vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 4);
+          const int16x4_t v_srcx2 = vqadd_s16(v_src, v_src);
+          v_dst_i = vqadd_s16(v_srcx2, v_src_mult);
+        }
+
+        frame_data = Load4<0>(dst, frame_data);
+        const int16x4_t a = vrshr_n_s16(v_dst_i, 4);
+        const uint16x8_t b =
+            vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data);
+        const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
+        StoreLo4(dst, d);
+        dst += stride;
+      } while (++i < tx_height);
+    } else {
+      int i = 0;
+      do {
+        const int row = i * tx_width;
+        int j = 0;
+        do {
+          const int16x8_t v_src = vld1q_s16(&source[row + j]);
+
+          int16x8_t v_dst_i;
+          if (identity_size == 4) {
+            const int16x8_t v_src_fraction =
+                vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 3);
+            v_dst_i = vqaddq_s16(v_src, v_src_fraction);
+          } else if (identity_size == 8) {
+            v_dst_i = vqaddq_s16(v_src, v_src);
+          } else {  // identity_size == 16
+            const int16x8_t v_src_mult =
+                vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 4);
+            const int16x8_t v_srcx2 = vqaddq_s16(v_src, v_src);
+            v_dst_i = vqaddq_s16(v_src_mult, v_srcx2);
+          }
+
+          const uint8x8_t frame_data = vld1_u8(dst + j);
+          const int16x8_t a = vrshrq_n_s16(v_dst_i, 4);
+          const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data);
+          const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
+          vst1_u8(dst + j, d);
+          j += 8;
+        } while (j < tx_width);
+        dst += stride;
+      } while (++i < tx_height);
     }
   } else {
-    for (int i = 0; i < tx_height; ++i) {
+    int i = 0;
+    do {
       const int row = i * tx_width;
       int j = 0;
       do {
-        const int16x8_t residual = vld1q_s16(&source[row + j]);
-        const int16x8_t residual_fraction =
-            vqrdmulhq_n_s16(residual, kIdentity4MultiplierFraction << 3);
-        const int16x8_t v_dst_i = vqaddq_s16(residual, residual_fraction);
-        const int16x8_t frame_data =
-            vreinterpretq_s16_u16(vmovl_u8(vld1_u8(dst + j)));
-        const int16x8_t a = vrshrq_n_s16(v_dst_i, 4);
-        const int16x8_t b = vqaddq_s16(a, frame_data);
-        const uint8x8_t d = vqmovun_s16(b);
+        const int16x8_t v_dst_i = vld1q_s16(&source[row + j]);
+        const uint8x8_t frame_data = vld1_u8(dst + j);
+        const int16x8_t a = vrshrq_n_s16(v_dst_i, 2);
+        const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data);
+        const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
         vst1_u8(dst + j, d);
         j += 8;
       } while (j < tx_width);
       dst += stride;
-    }
+    } while (++i < tx_height);
   }
 }
 
@@ -1522,25 +1958,27 @@
   uint8_t* dst = frame[start_y] + start_x;
 
   if (tx_width == 4) {
-    const uint8x8_t zero = vdup_n_u8(0);
-    for (int i = 0; i < tx_height; ++i) {
+    uint8x8_t frame_data = vdup_n_u8(0);
+    int i = 0;
+    do {
       const int16x4_t v_src = vld1_s16(&source[i * tx_width]);
       const int16x4_t v_src_mult =
           vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 3);
-      const int16x8_t frame_data =
-          vreinterpretq_s16_u16(vmovl_u8(LoadLo4(dst, zero)));
       const int16x4_t v_dst_row = vqadd_s16(v_src, v_src_mult);
       const int16x4_t v_src_mult2 =
           vqrdmulh_n_s16(v_dst_row, kIdentity4MultiplierFraction << 3);
       const int16x4_t v_dst_col = vqadd_s16(v_dst_row, v_src_mult2);
+      frame_data = Load4<0>(dst, frame_data);
       const int16x4_t a = vrshr_n_s16(v_dst_col, 4);
-      const int16x4_t b = vqadd_s16(a, vget_low_s16(frame_data));
-      const uint8x8_t d = vqmovun_s16(vcombine_s16(b, b));
+      const uint16x8_t b =
+          vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data);
+      const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
       StoreLo4(dst, d);
       dst += stride;
-    }
+    } while (++i < tx_height);
   } else {
-    for (int i = 0; i < tx_height; ++i) {
+    int i = 0;
+    do {
       const int row = i * tx_width;
       int j = 0;
       do {
@@ -1550,17 +1988,16 @@
         const int16x8_t v_dst_row = vqaddq_s16(v_src_round, v_src_round);
         const int16x8_t v_src_mult2 =
             vqrdmulhq_n_s16(v_dst_row, kIdentity4MultiplierFraction << 3);
-        const int16x8_t frame_data =
-            vreinterpretq_s16_u16(vmovl_u8(vld1_u8(dst + j)));
         const int16x8_t v_dst_col = vqaddq_s16(v_dst_row, v_src_mult2);
+        const uint8x8_t frame_data = vld1_u8(dst + j);
         const int16x8_t a = vrshrq_n_s16(v_dst_col, 4);
-        const int16x8_t b = vqaddq_s16(a, frame_data);
-        const uint8x8_t d = vqmovun_s16(b);
+        const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data);
+        const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
         vst1_u8(dst + j, d);
         j += 8;
       } while (j < tx_width);
       dst += stride;
-    }
+    } while (++i < tx_height);
   }
 }
 
@@ -1593,41 +2030,24 @@
   }
 }
 
-LIBGAV1_ALWAYS_INLINE void Identity8ColumnStoreToFrame_NEON(
-    Array2DView<uint8_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int tx_height, const int16_t* source) {
-  const int stride = frame.columns();
-  uint8_t* dst = frame[start_y] + start_x;
-
-  if (tx_width == 4) {
-    const uint8x8_t zero = vdup_n_u8(0);
-    for (int i = 0; i < tx_height; ++i) {
-      const int16x4_t residual = vld1_s16(&source[i * tx_width]);
-      const int16x4_t v_dst_i = vqadd_s16(residual, residual);
-      const int16x8_t frame_data =
-          vreinterpretq_s16_u16(vmovl_u8(LoadLo4(dst, zero)));
-      const int16x4_t a = vrshr_n_s16(v_dst_i, 4);
-      const int16x4_t b = vqadd_s16(a, vget_low_s16(frame_data));
-      const uint8x8_t d = vqmovun_s16(vcombine_s16(b, b));
-      StoreLo4(dst, d);
-      dst += stride;
-    }
-  } else {
-    for (int i = 0; i < tx_height; ++i) {
-      const int row = i * tx_width;
-      for (int j = 0; j < tx_width; j += 8) {
-        const int16x8_t residual = vld1q_s16(&source[row + j]);
-        const int16x8_t v_dst_i = vqaddq_s16(residual, residual);
-        const int16x8_t frame_data =
-            vreinterpretq_s16_u16(vmovl_u8(vld1_u8(dst + j)));
-        const int16x8_t a = vrshrq_n_s16(v_dst_i, 4);
-        const int16x8_t b = vqaddq_s16(a, frame_data);
-        const uint8x8_t d = vqmovun_s16(b);
-        vst1_u8(dst + j, d);
-      }
-      dst += stride;
-    }
+LIBGAV1_ALWAYS_INLINE bool Identity8DcOnly(void* dest, const void* source,
+                                           int non_zero_coeff_count,
+                                           bool should_round, int row_shift) {
+  if (non_zero_coeff_count > 1) {
+    return false;
   }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  const int16x4_t v_src0 = vdup_n_s16(src[0]);
+  const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0);
+  const int16x4_t v_src_round =
+      vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
+  const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0);
+  const int32x4_t v_srcx2 = vaddl_s16(v_src, v_src);
+  const int32x4_t dst_0 = vqrshlq_s32(v_srcx2, vdupq_n_s32(-row_shift));
+  vst1_lane_s16(&dst[0], vqmovn_s32(dst_0), 0);
+  return true;
 }
 
 LIBGAV1_ALWAYS_INLINE void Identity16Row_NEON(void* dest, const void* source,
@@ -1635,16 +2055,15 @@
   auto* const dst = static_cast<int16_t*>(dest);
   const auto* const src = static_cast<const int16_t*>(source);
   const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
-  const int16x4_t v_multiplier = vdup_n_s16(kIdentity16Multiplier);
   const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
 
   for (int i = 0; i < 4; ++i) {
     for (int j = 0; j < 2; ++j) {
       const int16x8_t v_src = vld1q_s16(&src[i * step + j * 8]);
       const int32x4_t v_src_mult_lo =
-          vmlal_s16(v_dual_round, vget_low_s16(v_src), v_multiplier);
-      const int32x4_t v_src_mult_hi =
-          vmlal_s16(v_dual_round, vget_high_s16(v_src), v_multiplier);
+          vmlal_n_s16(v_dual_round, vget_low_s16(v_src), kIdentity16Multiplier);
+      const int32x4_t v_src_mult_hi = vmlal_n_s16(
+          v_dual_round, vget_high_s16(v_src), kIdentity16Multiplier);
       const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift);
       const int32x4_t shift_hi = vqshlq_s32(v_src_mult_hi, v_shift);
       vst1q_s16(&dst[i * step + j * 8],
@@ -1653,49 +2072,29 @@
   }
 }
 
-LIBGAV1_ALWAYS_INLINE void Identity16ColumnStoreToFrame_NEON(
-    Array2DView<uint8_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int tx_height, const int16_t* source) {
-  const int stride = frame.columns();
-  uint8_t* dst = frame[start_y] + start_x;
-
-  if (tx_width == 4) {
-    const uint8x8_t zero = vdup_n_u8(0);
-    for (int i = 0; i < tx_height; ++i) {
-      const int16x4_t v_src = vld1_s16(&source[i * tx_width]);
-      const int16x4_t v_src_mult =
-          vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 4);
-      const int16x8_t frame_data =
-          vreinterpretq_s16_u16(vmovl_u8(LoadLo4(dst, zero)));
-
-      const int16x4_t v_srcx2 = vqadd_s16(v_src, v_src);
-      const int16x4_t v_dst_i = vqadd_s16(v_srcx2, v_src_mult);
-
-      const int16x4_t a = vrshr_n_s16(v_dst_i, 4);
-      const int16x4_t b = vqadd_s16(a, vget_low_s16(frame_data));
-      const uint8x8_t d = vqmovun_s16(vcombine_s16(b, b));
-      StoreLo4(dst, d);
-      dst += stride;
-    }
-  } else {
-    for (int i = 0; i < tx_height; ++i) {
-      const int row = i * tx_width;
-      for (int j = 0; j < tx_width; j += 8) {
-        const int16x8_t v_src = vld1q_s16(&source[row + j]);
-        const int16x8_t v_src_mult =
-            vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 4);
-        const int16x8_t frame_data =
-            vreinterpretq_s16_u16(vmovl_u8(vld1_u8(dst + j)));
-        const int16x8_t v_srcx2 = vqaddq_s16(v_src, v_src);
-        const int16x8_t v_dst_i = vqaddq_s16(v_src_mult, v_srcx2);
-        const int16x8_t a = vrshrq_n_s16(v_dst_i, 4);
-        const int16x8_t b = vqaddq_s16(a, frame_data);
-        const uint8x8_t d = vqmovun_s16(b);
-        vst1_u8(dst + j, d);
-      }
-      dst += stride;
-    }
+LIBGAV1_ALWAYS_INLINE bool Identity16DcOnly(void* dest, const void* source,
+                                            int non_zero_coeff_count,
+                                            bool should_round, int shift) {
+  if (non_zero_coeff_count > 1) {
+    return false;
   }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  const int16x4_t v_src0 = vdup_n_s16(src[0]);
+  const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0);
+  const int16x4_t v_src_round =
+      vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
+  const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0);
+  const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
+  const int16x4_t v_multiplier = vdup_n_s16(kIdentity16Multiplier);
+  const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
+  const int32x4_t v_src_mult_lo =
+      vmlal_s16(v_dual_round, (v_src), v_multiplier);
+  const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, v_shift);
+  vst1_lane_s16(&dst[0], vqmovn_s32(dst_0), 0);
+  return true;
 }
 
 LIBGAV1_ALWAYS_INLINE void Identity32Row16_NEON(void* dest, const void* source,
@@ -1717,27 +2116,22 @@
   }
 }
 
-LIBGAV1_ALWAYS_INLINE void Identity32ColumnStoreToFrame(
-    Array2DView<uint8_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int tx_height, const int16_t* source) {
-  const int stride = frame.columns();
-  uint8_t* dst = frame[start_y] + start_x;
-
-  for (int i = 0; i < tx_height; ++i) {
-    const int row = i * tx_width;
-    int j = 0;
-    do {
-      const int16x8_t v_dst_i = vld1q_s16(&source[row + j]);
-      const int16x8_t frame_data =
-          vreinterpretq_s16_u16(vmovl_u8(vld1_u8(dst + j)));
-      const int16x8_t a = vrshrq_n_s16(v_dst_i, 2);
-      const int16x8_t b = vqaddq_s16(a, frame_data);
-      const uint8x8_t d = vqmovun_s16(b);
-      vst1_u8(dst + j, d);
-      j += 8;
-    } while (j < tx_width);
-    dst += stride;
+LIBGAV1_ALWAYS_INLINE bool Identity32DcOnly(void* dest, const void* source,
+                                            int non_zero_coeff_count) {
+  if (non_zero_coeff_count > 1) {
+    return false;
   }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  const int16x4_t v_src0 = vdup_n_s16(src[0]);
+  const int16x4_t v_src = vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
+  // When combining the identity32 multiplier with the row shift, the
+  // calculation for tx_height equal to 16 can be simplified from
+  // ((A * 4) + 1) >> 1) to (A * 2).
+  const int16x4_t v_dst_0 = vqadd_s16(v_src, v_src);
+  vst1_lane_s16(&dst[0], v_dst_0, 0);
+  return true;
 }
 
 //------------------------------------------------------------------------------
@@ -1857,13 +2251,11 @@
   // Store to frame.
   uint8x8_t frame_data = vdup_n_u8(0);
   for (int row = 0; row < 4; row += 2) {
-    frame_data = LoadLo4(dst, frame_data);
-    frame_data = LoadHi4(dst + dst_stride, frame_data);
-    const int16x8_t a = vreinterpretq_s16_u16(vmovl_u8(frame_data));
+    frame_data = Load4<0>(dst, frame_data);
+    frame_data = Load4<1>(dst + dst_stride, frame_data);
     const int16x8_t residual = vcombine_s16(s[row], s[row + 1]);
-    // Saturate to prevent overflowing int16_t
-    const int16x8_t b = vqaddq_s16(a, residual);
-    frame_data = vqmovun_s16(b);
+    const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(residual), frame_data);
+    frame_data = vqmovun_s16(vreinterpretq_s16_u16(b));
     StoreLo4(dst, frame_data);
     dst += dst_stride;
     StoreHi4(dst, frame_data);
@@ -1958,41 +2350,42 @@
   }
 }
 
-template <bool enable_flip_rows = false>
+template <int tx_height, bool enable_flip_rows = false>
 LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int tx_height, const int16_t* source,
-    TransformType tx_type) {
+    const int tx_width, const int16_t* source, TransformType tx_type) {
   const bool flip_rows =
       enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false;
   const int stride = frame.columns();
   uint8_t* dst = frame[start_y] + start_x;
 
-  if (tx_width == 4) {
-    const uint8x8_t zero = vdup_n_u8(0);
+  // Enable for 4x4, 4x8, 4x16
+  if (tx_height < 32 && tx_width == 4) {
+    uint8x8_t frame_data = vdup_n_u8(0);
     for (int i = 0; i < tx_height; ++i) {
       const int row = flip_rows ? (tx_height - i - 1) * 4 : i * 4;
       const int16x4_t residual = vld1_s16(&source[row]);
-      const int16x8_t frame_data =
-          vreinterpretq_s16_u16(vmovl_u8(LoadLo4(dst, zero)));
+      frame_data = Load4<0>(dst, frame_data);
       const int16x4_t a = vrshr_n_s16(residual, 4);
-      const int16x4_t b = vqadd_s16(a, vget_low_s16(frame_data));
-      const uint8x8_t d = vqmovun_s16(vcombine_s16(b, b));
+      const uint16x8_t b =
+          vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data);
+      const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
       StoreLo4(dst, d);
       dst += stride;
     }
-  } else if (tx_width == 8) {
+    // Enable for 8x4, 8x8, 8x16, 8x32
+  } else if (tx_height < 64 && tx_width == 8) {
     for (int i = 0; i < tx_height; ++i) {
       const int row = flip_rows ? (tx_height - i - 1) * 8 : i * 8;
       const int16x8_t residual = vld1q_s16(&source[row]);
-      const int16x8_t frame_data =
-          vreinterpretq_s16_u16(vmovl_u8(vld1_u8(dst)));
+      const uint8x8_t frame_data = vld1_u8(dst);
       const int16x8_t a = vrshrq_n_s16(residual, 4);
-      const int16x8_t b = vqaddq_s16(a, frame_data);
-      const uint8x8_t d = vqmovun_s16(b);
+      const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data);
+      const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
       vst1_u8(dst, d);
       dst += stride;
     }
+    // Remaining widths >= 16.
   } else {
     for (int i = 0; i < tx_height; ++i) {
       const int y = start_y + i;
@@ -2005,13 +2398,13 @@
         const uint8x16_t frame_data = vld1q_u8(frame[y] + x);
         const int16x8_t a = vrshrq_n_s16(residual, 4);
         const int16x8_t a_hi = vrshrq_n_s16(residual_hi, 4);
-        const int16x8_t d =
-            vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(frame_data)));
-        const int16x8_t d_hi =
-            vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(frame_data)));
-        const int16x8_t e = vqaddq_s16(a, d);
-        const int16x8_t e_hi = vqaddq_s16(a_hi, d_hi);
-        vst1q_u8(frame[y] + x, vcombine_u8(vqmovun_s16(e), vqmovun_s16(e_hi)));
+        const uint16x8_t b =
+            vaddw_u8(vreinterpretq_u16_s16(a), vget_low_u8(frame_data));
+        const uint16x8_t b_hi =
+            vaddw_u8(vreinterpretq_u16_s16(a_hi), vget_high_u8(frame_data));
+        vst1q_u8(frame[y] + x,
+                 vcombine_u8(vqmovun_s16(vreinterpretq_s16_u16(b)),
+                             vqmovun_s16(vreinterpretq_s16_u16(b_hi))));
         j += 16;
       } while (j < tx_width);
     }
@@ -2022,7 +2415,7 @@
                             void* src_buffer, int start_x, int start_y,
                             void* dst_frame, bool is_row,
                             int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
@@ -2036,7 +2429,8 @@
       return;
     }
 
-    const int num_rows = tx_height;
+    const int num_rows =
+        GetNumRows<4>(tx_type, tx_height, non_zero_coeff_count);
     if (should_round) {
       ApplyRounding<4>(src, num_rows);
     }
@@ -2065,27 +2459,29 @@
     FlipColumns<4>(src, tx_width);
   }
 
-  if (tx_width == 4) {
-    // Process 4 1d dct4 columns in parallel.
-    Dct4_NEON<ButterflyRotation_4, false>(&src[0], &src[0], tx_width,
-                                          /*transpose=*/false);
-  } else {
-    // Process 8 1d dct4 columns in parallel per iteration.
-    int i = 0;
-    do {
-      Dct4_NEON<ButterflyRotation_8, true>(&src[i], &src[i], tx_width,
-                                           /*transpose=*/false);
-      i += 8;
-    } while (i < tx_width);
+  if (!DctDcOnlyColumn<4>(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    if (tx_width == 4) {
+      // Process 4 1d dct4 columns in parallel.
+      Dct4_NEON<ButterflyRotation_4, false>(&src[0], &src[0], tx_width,
+                                            /*transpose=*/false);
+    } else {
+      // Process 8 1d dct4 columns in parallel per iteration.
+      int i = 0;
+      do {
+        Dct4_NEON<ButterflyRotation_8, true>(&src[i], &src[i], tx_width,
+                                             /*transpose=*/false);
+        i += 8;
+      } while (i < tx_width);
+    }
   }
-  StoreToFrameWithRound(frame, start_x, start_y, tx_width, 4, src, tx_type);
+  StoreToFrameWithRound<4>(frame, start_x, start_y, tx_width, src, tx_type);
 }
 
 void Dct8TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
                             void* src_buffer, int start_x, int start_y,
                             void* dst_frame, bool is_row,
                             int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
@@ -2099,7 +2495,8 @@
       return;
     }
 
-    const int num_rows = tx_height;
+    const int num_rows =
+        GetNumRows<8>(tx_type, tx_height, non_zero_coeff_count);
     if (should_round) {
       ApplyRounding<8>(src, num_rows);
     }
@@ -2128,27 +2525,29 @@
     FlipColumns<8>(src, tx_width);
   }
 
-  if (tx_width == 4) {
-    // Process 4 1d dct8 columns in parallel.
-    Dct8_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 4,
-                                         /*transpose=*/false);
-  } else {
-    // Process 8 1d dct8 columns in parallel per iteration.
-    int i = 0;
-    do {
-      Dct8_NEON<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
-                                            /*transpose=*/false);
-      i += 8;
-    } while (i < tx_width);
+  if (!DctDcOnlyColumn<8>(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    if (tx_width == 4) {
+      // Process 4 1d dct8 columns in parallel.
+      Dct8_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 4,
+                                           /*transpose=*/false);
+    } else {
+      // Process 8 1d dct8 columns in parallel per iteration.
+      int i = 0;
+      do {
+        Dct8_NEON<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
+                                              /*transpose=*/false);
+        i += 8;
+      } while (i < tx_width);
+    }
   }
-  StoreToFrameWithRound(frame, start_x, start_y, tx_width, 8, src, tx_type);
+  StoreToFrameWithRound<8>(frame, start_x, start_y, tx_width, src, tx_type);
 }
 
 void Dct16TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
                              void* src_buffer, int start_x, int start_y,
                              void* dst_frame, bool is_row,
                              int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
@@ -2162,7 +2561,8 @@
       return;
     }
 
-    const int num_rows = std::min(tx_height, 32);
+    const int num_rows =
+        GetNumRows<16>(tx_type, std::min(tx_height, 32), non_zero_coeff_count);
     if (should_round) {
       ApplyRounding<16>(src, num_rows);
     }
@@ -2170,18 +2570,16 @@
     if (num_rows <= 4) {
       // Process 4 1d dct16 rows in parallel.
       Dct16_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 16,
-                                            /*transpose=*/true);
+                                            /*is_row=*/true, row_shift);
     } else {
       int i = 0;
       do {
         // Process 8 1d dct16 rows in parallel per iteration.
         Dct16_NEON<ButterflyRotation_8, false>(&src[i * 16], &src[i * 16], 16,
-                                               /*transpose=*/true);
+                                               /*is_row=*/true, row_shift);
         i += 8;
       } while (i < num_rows);
     }
-    // row_shift is always non zero here.
-    RowShift<16>(src, num_rows, row_shift);
 
     return;
   }
@@ -2191,27 +2589,30 @@
     FlipColumns<16>(src, tx_width);
   }
 
-  if (tx_width == 4) {
-    // Process 4 1d dct16 columns in parallel.
-    Dct16_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 4,
-                                          /*transpose=*/false);
-  } else {
-    int i = 0;
-    do {
-      // Process 8 1d dct16 columns in parallel per iteration.
-      Dct16_NEON<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
-                                             /*transpose=*/false);
-      i += 8;
-    } while (i < tx_width);
+  if (!DctDcOnlyColumn<16>(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    if (tx_width == 4) {
+      // Process 4 1d dct16 columns in parallel.
+      Dct16_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 4,
+                                            /*is_row=*/false, /*row_shift=*/0);
+    } else {
+      int i = 0;
+      do {
+        // Process 8 1d dct16 columns in parallel per iteration.
+        Dct16_NEON<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
+                                               /*is_row=*/false,
+                                               /*row_shift=*/0);
+        i += 8;
+      } while (i < tx_width);
+    }
   }
-  StoreToFrameWithRound(frame, start_x, start_y, tx_width, 16, src, tx_type);
+  StoreToFrameWithRound<16>(frame, start_x, start_y, tx_width, src, tx_type);
 }
 
 void Dct32TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
                              void* src_buffer, int start_x, int start_y,
                              void* dst_frame, bool is_row,
                              int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
@@ -2225,38 +2626,38 @@
       return;
     }
 
-    const int num_rows = std::min(tx_height, 32);
+    const int num_rows =
+        GetNumRows<32>(tx_type, std::min(tx_height, 32), non_zero_coeff_count);
     if (should_round) {
       ApplyRounding<32>(src, num_rows);
     }
     // Process 8 1d dct32 rows in parallel per iteration.
     int i = 0;
     do {
-      Dct32_NEON(&src[i * 32], &src[i * 32], 32, /*transpose=*/true);
+      Dct32_NEON(&src[i * 32], &src[i * 32], 32, /*is_row=*/true, row_shift);
       i += 8;
     } while (i < num_rows);
 
-    // row_shift is always non zero here.
-    RowShift<32>(src, num_rows, row_shift);
-
     return;
   }
 
   assert(!is_row);
-  // Process 8 1d dct32 columns in parallel per iteration.
-  int i = 0;
-  do {
-    Dct32_NEON(&src[i], &src[i], tx_width, /*transpose=*/false);
-    i += 8;
-  } while (i < tx_width);
-  StoreToFrameWithRound(frame, start_x, start_y, tx_width, 32, src, tx_type);
+  if (!DctDcOnlyColumn<32>(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    // Process 8 1d dct32 columns in parallel per iteration.
+    int i = 0;
+    do {
+      Dct32_NEON(&src[i], &src[i], tx_width, /*is_row=*/false, /*row_shift=*/0);
+      i += 8;
+    } while (i < tx_width);
+  }
+  StoreToFrameWithRound<32>(frame, start_x, start_y, tx_width, src, tx_type);
 }
 
 void Dct64TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
                              void* src_buffer, int start_x, int start_y,
                              void* dst_frame, bool is_row,
                              int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
@@ -2270,44 +2671,53 @@
       return;
     }
 
-    const int num_rows = std::min(tx_height, 32);
+    const int num_rows =
+        GetNumRows<32>(tx_type, std::min(tx_height, 32), non_zero_coeff_count);
     if (should_round) {
       ApplyRounding<64>(src, num_rows);
     }
     // Process 8 1d dct64 rows in parallel per iteration.
     int i = 0;
     do {
-      Dct64_NEON(&src[i * 64], &src[i * 64], 64, /*transpose=*/true);
+      Dct64_NEON(&src[i * 64], &src[i * 64], 64, /*is_row=*/true, row_shift);
       i += 8;
     } while (i < num_rows);
-    // row_shift is always non zero here.
-    RowShift<64>(src, num_rows, row_shift);
 
     return;
   }
 
   assert(!is_row);
-  // Process 8 1d dct64 columns in parallel per iteration.
-  int i = 0;
-  do {
-    Dct64_NEON(&src[i], &src[i], tx_width, /*transpose=*/false);
-    i += 8;
-  } while (i < tx_width);
-  StoreToFrameWithRound(frame, start_x, start_y, tx_width, 64, src, tx_type);
+  if (!DctDcOnlyColumn<64>(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    // Process 8 1d dct64 columns in parallel per iteration.
+    int i = 0;
+    do {
+      Dct64_NEON(&src[i], &src[i], tx_width, /*is_row=*/false, /*row_shift=*/0);
+      i += 8;
+    } while (i < tx_width);
+  }
+  StoreToFrameWithRound<64>(frame, start_x, start_y, tx_width, src, tx_type);
 }
 
 void Adst4TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
                              void* src_buffer, int start_x, int start_y,
                              void* dst_frame, bool is_row,
                              int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
-    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
+    const uint8_t row_shift = static_cast<uint8_t>(tx_height == 16);
     const bool should_round = (tx_height == 8);
+
+    if (Adst4DcOnly(&src[0], &src[0], non_zero_coeff_count, should_round,
+                    row_shift)) {
+      return;
+    }
+
+    const int num_rows =
+        GetNumRows<4>(tx_type, tx_height, non_zero_coeff_count);
     if (should_round) {
       ApplyRounding<4>(src, num_rows);
     }
@@ -2331,29 +2741,40 @@
     FlipColumns<4>(src, tx_width);
   }
 
-  // Process 4 1d adst4 columns in parallel per iteration.
-  int i = 0;
-  do {
-    Adst4_NEON<false>(&src[i], &src[i], tx_width, /*transpose=*/false);
-    i += 4;
-  } while (i < tx_width);
+  if (!Adst4DcOnlyColumn(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    // Process 4 1d adst4 columns in parallel per iteration.
+    int i = 0;
+    do {
+      Adst4_NEON<false>(&src[i], &src[i], tx_width, /*transpose=*/false);
+      i += 4;
+    } while (i < tx_width);
+  }
 
-  StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
-                                                   tx_width, 4, src, tx_type);
+  StoreToFrameWithRound<4, /*enable_flip_rows=*/true>(frame, start_x, start_y,
+                                                      tx_width, src, tx_type);
 }
 
 void Adst8TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
                              void* src_buffer, int start_x, int start_y,
                              void* dst_frame, bool is_row,
                              int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
-    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
-    if (kShouldRound[tx_size]) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+
+    if (Adst8DcOnly(&src[0], &src[0], non_zero_coeff_count, should_round,
+                    row_shift)) {
+      return;
+    }
+
+    const int num_rows =
+        GetNumRows<8>(tx_type, tx_height, non_zero_coeff_count);
+    if (should_round) {
       ApplyRounding<8>(src, num_rows);
     }
 
@@ -2371,7 +2792,6 @@
         i += 8;
       } while (i < num_rows);
     }
-    const uint8_t row_shift = kTransformRowShift[tx_size];
     if (row_shift > 0) {
       RowShift<8>(src, num_rows, row_shift);
     }
@@ -2383,56 +2803,62 @@
     FlipColumns<8>(src, tx_width);
   }
 
-  if (tx_width == 4) {
-    // Process 4 1d adst8 columns in parallel.
-    Adst8_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 4,
-                                          /*transpose=*/false);
-  } else {
-    // Process 8 1d adst8 columns in parallel per iteration.
-    int i = 0;
-    do {
-      Adst8_NEON<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
-                                             /*transpose=*/false);
-      i += 8;
-    } while (i < tx_width);
+  if (!Adst8DcOnlyColumn(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    if (tx_width == 4) {
+      // Process 4 1d adst8 columns in parallel.
+      Adst8_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 4,
+                                            /*transpose=*/false);
+    } else {
+      // Process 8 1d adst8 columns in parallel per iteration.
+      int i = 0;
+      do {
+        Adst8_NEON<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
+                                               /*transpose=*/false);
+        i += 8;
+      } while (i < tx_width);
+    }
   }
-  StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
-                                                   tx_width, 8, src, tx_type);
+  StoreToFrameWithRound<8, /*enable_flip_rows=*/true>(frame, start_x, start_y,
+                                                      tx_width, src, tx_type);
 }
 
 void Adst16TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
                               void* src_buffer, int start_x, int start_y,
                               void* dst_frame, bool is_row,
                               int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+
+    if (Adst16DcOnly(&src[0], &src[0], non_zero_coeff_count, should_round,
+                     row_shift)) {
+      return;
+    }
+
     const int num_rows =
-        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
-    if (kShouldRound[tx_size]) {
+        GetNumRows<16>(tx_type, std::min(tx_height, 32), non_zero_coeff_count);
+    if (should_round) {
       ApplyRounding<16>(src, num_rows);
     }
 
     if (num_rows <= 4) {
       // Process 4 1d adst16 rows in parallel.
       Adst16_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 16,
-                                             /*transpose=*/true);
+                                             /*is_row=*/true, row_shift);
     } else {
       int i = 0;
       do {
         // Process 8 1d adst16 rows in parallel per iteration.
         Adst16_NEON<ButterflyRotation_8, false>(&src[i * 16], &src[i * 16], 16,
-                                                /*transpose=*/true);
+                                                /*is_row=*/true, row_shift);
         i += 8;
       } while (i < num_rows);
     }
-    const uint8_t row_shift = kTransformRowShift[tx_size];
-    // row_shift is always non zero here.
-    RowShift<16>(src, num_rows, row_shift);
-
     return;
   }
 
@@ -2441,28 +2867,31 @@
     FlipColumns<16>(src, tx_width);
   }
 
-  if (tx_width == 4) {
-    // Process 4 1d adst16 columns in parallel.
-    Adst16_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 4,
-                                           /*transpose=*/false);
-  } else {
-    int i = 0;
-    do {
-      // Process 8 1d adst16 columns in parallel per iteration.
-      Adst16_NEON<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
-                                              /*transpose=*/false);
-      i += 8;
-    } while (i < tx_width);
+  if (!Adst16DcOnlyColumn(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    if (tx_width == 4) {
+      // Process 4 1d adst16 columns in parallel.
+      Adst16_NEON<ButterflyRotation_4, true>(&src[0], &src[0], 4,
+                                             /*is_row=*/false, /*row_shift=*/0);
+    } else {
+      int i = 0;
+      do {
+        // Process 8 1d adst16 columns in parallel per iteration.
+        Adst16_NEON<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
+                                                /*is_row=*/false,
+                                                /*row_shift=*/0);
+        i += 8;
+      } while (i < tx_width);
+    }
   }
-  StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
-                                                   tx_width, 16, src, tx_type);
+  StoreToFrameWithRound<16, /*enable_flip_rows=*/true>(frame, start_x, start_y,
+                                                       tx_width, src, tx_type);
 }
 
 void Identity4TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
                                  void* src_buffer, int start_x, int start_y,
                                  void* dst_frame, bool is_row,
                                  int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
@@ -2475,8 +2904,15 @@
       return;
     }
 
-    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
     const bool should_round = (tx_height == 8);
+
+    if (Identity4DcOnly(&src[0], &src[0], non_zero_coeff_count, should_round,
+                        tx_height)) {
+      return;
+    }
+
+    const int num_rows =
+        GetNumRows<4>(tx_type, tx_height, non_zero_coeff_count);
     if (should_round) {
       ApplyRounding<4>(src, num_rows);
     }
@@ -2496,11 +2932,12 @@
     return;
   }
   assert(!is_row);
+  const int height = (non_zero_coeff_count == 1) ? 1 : tx_height;
   // Special case: Process row calculations during column transform call.
   if (tx_type == kTransformTypeIdentityIdentity &&
       (tx_size == kTransformSize4x4 || tx_size == kTransformSize8x4)) {
-    Identity4RowColumnStoreToFrame(frame, start_x, start_y, tx_width,
-                                   /*tx_height=*/4, src);
+    Identity4RowColumnStoreToFrame(frame, start_x, start_y, tx_width, height,
+                                   src);
     return;
   }
 
@@ -2508,15 +2945,14 @@
     FlipColumns<4>(src, tx_width);
   }
 
-  Identity4ColumnStoreToFrame(frame, start_x, start_y, tx_width,
-                              /*tx_height=*/4, src);
+  IdentityColumnStoreToFrame<4>(frame, start_x, start_y, tx_width, height, src);
 }
 
 void Identity8TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
                                  void* src_buffer, int start_x, int start_y,
                                  void* dst_frame, bool is_row,
                                  int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
@@ -2528,8 +2964,17 @@
         tx_size == kTransformSize8x4) {
       return;
     }
-    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
-    if (kShouldRound[tx_size]) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+
+    if (Identity8DcOnly(&src[0], &src[0], non_zero_coeff_count, should_round,
+                        row_shift)) {
+      return;
+    }
+
+    const int num_rows =
+        GetNumRows<8>(tx_type, tx_height, non_zero_coeff_count);
+    if (should_round) {
       ApplyRounding<8>(src, num_rows);
     }
 
@@ -2558,24 +3003,31 @@
   if (kTransformFlipColumnsMask.Contains(tx_type)) {
     FlipColumns<8>(src, tx_width);
   }
-
-  Identity8ColumnStoreToFrame_NEON(frame, start_x, start_y, tx_width,
-                                   /*tx_height=*/8, src);
+  const int height = (non_zero_coeff_count == 1) ? 1 : tx_height;
+  IdentityColumnStoreToFrame<8>(frame, start_x, start_y, tx_width, height, src);
 }
 
 void Identity16TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
                                   void* src_buffer, int start_x, int start_y,
                                   void* dst_frame, bool is_row,
                                   int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+
+    if (Identity16DcOnly(&src[0], &src[0], non_zero_coeff_count, should_round,
+                         row_shift)) {
+      return;
+    }
+
     const int num_rows =
-        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
-    if (kShouldRound[tx_size]) {
+        GetNumRows<16>(tx_type, std::min(tx_height, 32), non_zero_coeff_count);
+    if (should_round) {
       ApplyRounding<16>(src, num_rows);
     }
     for (int i = 0; i < num_rows; i += 4) {
@@ -2589,22 +3041,21 @@
   if (kTransformFlipColumnsMask.Contains(tx_type)) {
     FlipColumns<16>(src, tx_width);
   }
-  Identity16ColumnStoreToFrame_NEON(frame, start_x, start_y, tx_width,
-                                    /*tx_height=*/16, src);
+  const int height = (non_zero_coeff_count == 1) ? 1 : tx_height;
+  IdentityColumnStoreToFrame<16>(frame, start_x, start_y, tx_width, height,
+                                 src);
 }
 
-void Identity32TransformLoop_NEON(TransformType /*tx_type*/,
-                                  TransformSize tx_size, void* src_buffer,
-                                  int start_x, int start_y, void* dst_frame,
-                                  bool is_row, int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+void Identity32TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
+                                  void* src_buffer, int start_x, int start_y,
+                                  void* dst_frame, bool is_row,
+                                  int non_zero_coeff_count) {
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
-    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
-
     // When combining the identity32 multiplier with the row shift, the
     // calculations for tx_height == 8 and tx_height == 32 can be simplified
     // from ((A * 4) + 2) >> 2) to A.
@@ -2612,7 +3063,16 @@
       return;
     }
 
-    // Process kTransformSize32x16
+    // Process kTransformSize32x16.  The src is always rounded before the
+    // identity transform and shifted by 1 afterwards.
+
+    if (Identity32DcOnly(&src[0], &src[0], non_zero_coeff_count)) {
+      return;
+    }
+
+    const int num_rows =
+        GetNumRows<32>(tx_type, tx_height, non_zero_coeff_count);
+
     assert(tx_size == kTransformSize32x16);
     ApplyRounding<32>(src, num_rows);
     for (int i = 0; i < num_rows; i += 4) {
@@ -2622,8 +3082,9 @@
   }
 
   assert(!is_row);
-  Identity32ColumnStoreToFrame(frame, start_x, start_y, tx_width,
-                               /*tx_height=*/32, src);
+  const int height = (non_zero_coeff_count == 1) ? 1 : tx_height;
+  IdentityColumnStoreToFrame<32>(frame, start_x, start_y, tx_width, height,
+                                 src);
 }
 
 void Wht4TransformLoop_NEON(TransformType tx_type, TransformSize tx_size,
@@ -2651,7 +3112,7 @@
 //------------------------------------------------------------------------------
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   // Maximum transform size for Dct is 64.
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformDct] =
@@ -2695,7 +3156,7 @@
 
 }  // namespace dsp
 }  // namespace libgav1
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/arm/inverse_transform_neon.h b/libgav1/src/dsp/arm/inverse_transform_neon.h
index 3c0e051..af647e8 100644
--- a/libgav1/src/dsp/arm/inverse_transform_neon.h
+++ b/libgav1/src/dsp/arm/inverse_transform_neon.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_
 #define LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -31,22 +31,22 @@
 }  // namespace libgav1
 
 #if LIBGAV1_ENABLE_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_INVERSE_TRANSFORM_NEON_H_
diff --git a/libgav1/src/dsp/arm/loop_filter_neon.cc b/libgav1/src/dsp/arm/loop_filter_neon.cc
index ab013b1..146c983 100644
--- a/libgav1/src/dsp/arm/loop_filter_neon.cc
+++ b/libgav1/src/dsp/arm/loop_filter_neon.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/loop_filter.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -24,6 +24,8 @@
 #include <cstdint>
 
 #include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -131,12 +133,10 @@
                       const int hev_thresh) {
   uint8_t* dst = static_cast<uint8_t*>(dest);
 
-  const uint8x8_t zero = vdup_n_u8(0);
-
-  const uint8x8_t p1_v = LoadLo4(dst - 2 * stride, zero);
-  const uint8x8_t p0_v = LoadLo4(dst - stride, zero);
-  const uint8x8_t p0q0 = LoadHi4(dst, p0_v);
-  const uint8x8_t p1q1 = LoadHi4(dst + stride, p1_v);
+  const uint8x8_t p1_v = Load4(dst - 2 * stride);
+  const uint8x8_t p0_v = Load4(dst - stride);
+  const uint8x8_t p0q0 = Load4<1>(dst, p0_v);
+  const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v);
 
   uint8x8_t hev_mask;
   uint8x8_t needs_filter4_mask;
@@ -187,11 +187,10 @@
 
   // |p1q0| and |p0q1| are named for the values they will contain after the
   // transpose.
-  const uint8x8_t zero = vdup_n_u8(0);
-  const uint8x8_t row0 = LoadLo4(dst, zero);
-  uint8x8_t p1q0 = LoadHi4(dst + stride, row0);
-  const uint8x8_t row2 = LoadLo4(dst + 2 * stride, zero);
-  uint8x8_t p0q1 = LoadHi4(dst + 3 * stride, row2);
+  const uint8x8_t row0 = Load4(dst);
+  uint8x8_t p1q0 = Load4<1>(dst + stride, row0);
+  const uint8x8_t row2 = Load4(dst + 2 * stride);
+  uint8x8_t p0q1 = Load4<1>(dst + 3 * stride, row2);
 
   Transpose4x4(&p1q0, &p0q1);
   // Rearrange.
@@ -283,39 +282,35 @@
                                      inner_thresh, outer_thresh);
 }
 
-inline void Filter6(const uint8x8_t p2q2_u8, const uint8x8_t p1q1_u8,
-                    const uint8x8_t p0q0_u8, uint8x8_t* const p1q1_output,
+inline void Filter6(const uint8x8_t p2q2, const uint8x8_t p1q1,
+                    const uint8x8_t p0q0, uint8x8_t* const p1q1_output,
                     uint8x8_t* const p0q0_output) {
-  const uint16x8_t p2q2 = vmovl_u8(p2q2_u8);
-  const uint16x8_t p1q1 = vmovl_u8(p1q1_u8);
-  const uint16x8_t p0q0 = vmovl_u8(p0q0_u8);
-
   // Sum p1 and q1 output from opposite directions
   // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0
   //      ^^^^^^^^
   // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3)
   //                                 ^^^^^^^^
-  const uint16x8_t p2q2_double = vaddq_u16(p2q2, p2q2);
-  uint16x8_t sum = vaddq_u16(p2q2_double, p2q2);
+  const uint16x8_t p2q2_double = vaddl_u8(p2q2, p2q2);
+  uint16x8_t sum = vaddw_u8(p2q2_double, p2q2);
 
   // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0
   //                 ^^^^^^^^
   // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3)
   //                      ^^^^^^^^
-  sum = vaddq_u16(vaddq_u16(p1q1, p1q1), sum);
+  sum = vaddq_u16(vaddl_u8(p1q1, p1q1), sum);
 
   // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0
   //                            ^^^^^^^^
   // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3)
   //           ^^^^^^^^
-  sum = vaddq_u16(vaddq_u16(p0q0, p0q0), sum);
+  sum = vaddq_u16(vaddl_u8(p0q0, p0q0), sum);
 
   // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0
   //                                       ^^
   // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3)
   //      ^^
-  const uint16x8_t q0p0 = vcombine_u16(vget_high_u16(p0q0), vget_low_u16(p0q0));
-  sum = vaddq_u16(sum, q0p0);
+  const uint8x8_t q0p0 = Transpose32(p0q0);
+  sum = vaddw_u8(sum, q0p0);
 
   *p1q1_output = vrshrn_n_u16(sum, 3);
 
@@ -323,8 +318,8 @@
   // p0 = p1 - (2 * p2) + q0 + q1
   // q0 = q1 - (2 * q2) + p0 + p1
   sum = vsubq_u16(sum, p2q2_double);
-  const uint16x8_t q1p1 = vcombine_u16(vget_high_u16(p1q1), vget_low_u16(p1q1));
-  sum = vaddq_u16(vaddq_u16(q0p0, q1p1), sum);
+  const uint8x8_t q1p1 = Transpose32(p1q1);
+  sum = vaddq_u16(vaddl_u8(q0p0, q1p1), sum);
 
   *p0q0_output = vrshrn_n_u16(sum, 3);
 }
@@ -334,13 +329,12 @@
                       const int hev_thresh) {
   auto* dst = static_cast<uint8_t*>(dest);
 
-  const uint8x8_t zero = vdup_n_u8(0);
-  const uint8x8_t p2_v = LoadLo4(dst - 3 * stride, zero);
-  const uint8x8_t p1_v = LoadLo4(dst - 2 * stride, zero);
-  const uint8x8_t p0_v = LoadLo4(dst - stride, zero);
-  const uint8x8_t p0q0 = LoadHi4(dst, p0_v);
-  const uint8x8_t p1q1 = LoadHi4(dst + stride, p1_v);
-  const uint8x8_t p2q2 = LoadHi4(dst + 2 * stride, p2_v);
+  const uint8x8_t p2_v = Load4(dst - 3 * stride);
+  const uint8x8_t p1_v = Load4(dst - 2 * stride);
+  const uint8x8_t p0_v = Load4(dst - stride);
+  const uint8x8_t p0q0 = Load4<1>(dst, p0_v);
+  const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v);
+  const uint8x8_t p2q2 = Load4<1>(dst + 2 * stride, p2_v);
 
   uint8x8_t needs_filter6_mask, is_flat3_mask, hev_mask;
   Filter6Masks(p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh,
@@ -372,6 +366,7 @@
 #if defined(__aarch64__)
   if (vaddv_u8(vand_u8(is_flat3_mask, needs_filter6_mask)) == 0) {
     // Filter6() does not apply.
+    const uint8x8_t zero = vdup_n_u8(0);
     f6_p1q1 = zero;
     f6_p0q0 = zero;
   } else {
@@ -523,58 +518,53 @@
                    p1q1, inner_thresh, outer_thresh);
 }
 
-inline void Filter8(const uint8x8_t p3q3_u8, const uint8x8_t p2q2_u8,
-                    const uint8x8_t p1q1_u8, const uint8x8_t p0q0_u8,
+inline void Filter8(const uint8x8_t p3q3, const uint8x8_t p2q2,
+                    const uint8x8_t p1q1, const uint8x8_t p0q0,
                     uint8x8_t* const p2q2_output, uint8x8_t* const p1q1_output,
                     uint8x8_t* const p0q0_output) {
-  const uint16x8_t p3q3 = vmovl_u8(p3q3_u8);
-  const uint16x8_t p2q2 = vmovl_u8(p2q2_u8);
-  const uint16x8_t p1q1 = vmovl_u8(p1q1_u8);
-  const uint16x8_t p0q0 = vmovl_u8(p0q0_u8);
-
   // Sum p2 and q2 output from opposite directions
   // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0
   //      ^^^^^^^^
   // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3)
   //                                ^^^^^^^^
-  uint16x8_t sum = vaddq_u16(vaddq_u16(p3q3, p3q3), p3q3);
+  uint16x8_t sum = vaddw_u8(vaddl_u8(p3q3, p3q3), p3q3);
 
   // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0
   //                 ^^^^^^^^
   // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3)
   //                     ^^^^^^^^
-  sum = vaddq_u16(vaddq_u16(p2q2, p2q2), sum);
+  sum = vaddq_u16(vaddl_u8(p2q2, p2q2), sum);
 
   // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0
   //                            ^^^^^^^
   // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3)
   //           ^^^^^^^
-  sum = vaddq_u16(vaddq_u16(p1q1, p0q0), sum);
+  sum = vaddq_u16(vaddl_u8(p1q1, p0q0), sum);
 
   // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0
   //                                      ^^
   // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3)
   //      ^^
-  const uint16x8_t q0p0 = vcombine_u16(vget_high_u16(p0q0), vget_low_u16(p0q0));
-  sum = vaddq_u16(q0p0, sum);
+  const uint8x8_t q0p0 = Transpose32(p0q0);
+  sum = vaddw_u8(sum, q0p0);
 
   *p2q2_output = vrshrn_n_u16(sum, 3);
 
   // Convert to p1 and q1 output:
   // p1 = p2 - p3 - p2 + p1 + q1
   // q1 = q2 - q3 - q2 + q0 + p1
-  sum = vsubq_u16(sum, vaddq_u16(p3q3, p2q2));
-  const uint16x8_t q1p1 = vcombine_u16(vget_high_u16(p1q1), vget_low_u16(p1q1));
-  sum = vaddq_u16(vaddq_u16(p1q1, q1p1), sum);
+  sum = vsubq_u16(sum, vaddl_u8(p3q3, p2q2));
+  const uint8x8_t q1p1 = Transpose32(p1q1);
+  sum = vaddq_u16(vaddl_u8(p1q1, q1p1), sum);
 
   *p1q1_output = vrshrn_n_u16(sum, 3);
 
   // Convert to p0 and q0 output:
   // p0 = p1 - p3 - p1 + p0 + q2
   // q0 = q1 - q3 - q1 + q0 + p2
-  sum = vsubq_u16(sum, vaddq_u16(p3q3, p1q1));
-  const uint16x8_t q2p2 = vcombine_u16(vget_high_u16(p2q2), vget_low_u16(p2q2));
-  sum = vaddq_u16(vaddq_u16(p0q0, q2p2), sum);
+  sum = vsubq_u16(sum, vaddl_u8(p3q3, p1q1));
+  const uint8x8_t q2p2 = Transpose32(p2q2);
+  sum = vaddq_u16(vaddl_u8(p0q0, q2p2), sum);
 
   *p0q0_output = vrshrn_n_u16(sum, 3);
 }
@@ -584,15 +574,14 @@
                       const int hev_thresh) {
   auto* dst = static_cast<uint8_t*>(dest);
 
-  const uint8x8_t zero = vdup_n_u8(0);
-  const uint8x8_t p3_v = LoadLo4(dst - 4 * stride, zero);
-  const uint8x8_t p2_v = LoadLo4(dst - 3 * stride, zero);
-  const uint8x8_t p1_v = LoadLo4(dst - 2 * stride, zero);
-  const uint8x8_t p0_v = LoadLo4(dst - stride, zero);
-  const uint8x8_t p0q0 = LoadHi4(dst, p0_v);
-  const uint8x8_t p1q1 = LoadHi4(dst + stride, p1_v);
-  const uint8x8_t p2q2 = LoadHi4(dst + 2 * stride, p2_v);
-  const uint8x8_t p3q3 = LoadHi4(dst + 3 * stride, p3_v);
+  const uint8x8_t p3_v = Load4(dst - 4 * stride);
+  const uint8x8_t p2_v = Load4(dst - 3 * stride);
+  const uint8x8_t p1_v = Load4(dst - 2 * stride);
+  const uint8x8_t p0_v = Load4(dst - stride);
+  const uint8x8_t p0q0 = Load4<1>(dst, p0_v);
+  const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v);
+  const uint8x8_t p2q2 = Load4<1>(dst + 2 * stride, p2_v);
+  const uint8x8_t p3q3 = Load4<1>(dst + 3 * stride, p3_v);
 
   uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask;
   Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh,
@@ -625,6 +614,7 @@
 #if defined(__aarch64__)
   if (vaddv_u8(is_flat4_mask) == 0) {
     // Filter8() does not apply.
+    const uint8x8_t zero = vdup_n_u8(0);
     f8_p2q2 = zero;
     f8_p1q1 = zero;
     f8_p0q0 = zero;
@@ -743,103 +733,95 @@
   vst1_u8(dst + 3 * stride, p0q3_output);
 }
 
-inline void Filter14(const uint8x8_t p6q6_u8, const uint8x8_t p5q5_u8,
-                     const uint8x8_t p4q4_u8, const uint8x8_t p3q3_u8,
-                     const uint8x8_t p2q2_u8, const uint8x8_t p1q1_u8,
-                     const uint8x8_t p0q0_u8, uint8x8_t* const p5q5_output,
+inline void Filter14(const uint8x8_t p6q6, const uint8x8_t p5q5,
+                     const uint8x8_t p4q4, const uint8x8_t p3q3,
+                     const uint8x8_t p2q2, const uint8x8_t p1q1,
+                     const uint8x8_t p0q0, uint8x8_t* const p5q5_output,
                      uint8x8_t* const p4q4_output, uint8x8_t* const p3q3_output,
                      uint8x8_t* const p2q2_output, uint8x8_t* const p1q1_output,
                      uint8x8_t* const p0q0_output) {
-  const uint16x8_t p6q6 = vmovl_u8(p6q6_u8);
-  const uint16x8_t p5q5 = vmovl_u8(p5q5_u8);
-  const uint16x8_t p4q4 = vmovl_u8(p4q4_u8);
-  const uint16x8_t p3q3 = vmovl_u8(p3q3_u8);
-  const uint16x8_t p2q2 = vmovl_u8(p2q2_u8);
-  const uint16x8_t p1q1 = vmovl_u8(p1q1_u8);
-  const uint16x8_t p0q0 = vmovl_u8(p0q0_u8);
-
   // Sum p5 and q5 output from opposite directions
   // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
   //      ^^^^^^^^
   // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
   //                                                     ^^^^^^^^
-  uint16x8_t sum = vsubq_u16(vshlq_n_u16(p6q6, 3), p6q6);
+  uint16x8_t sum = vsubw_u8(vshll_n_u8(p6q6, 3), p6q6);
 
   // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
   //                 ^^^^^^^^
   // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
   //                                          ^^^^^^^^
-  sum = vaddq_u16(vaddq_u16(p5q5, p5q5), sum);
+  sum = vaddq_u16(vaddl_u8(p5q5, p5q5), sum);
 
   // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
   //                            ^^^^^^^^
   // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
   //                               ^^^^^^^^
-  sum = vaddq_u16(vaddq_u16(p4q4, p4q4), sum);
+  sum = vaddq_u16(vaddl_u8(p4q4, p4q4), sum);
 
   // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
   //                                       ^^^^^^^
   // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
   //                     ^^^^^^^
-  sum = vaddq_u16(vaddq_u16(p3q3, p2q2), sum);
+  sum = vaddq_u16(vaddl_u8(p3q3, p2q2), sum);
 
   // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
   //                                                 ^^^^^^^
   // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
   //           ^^^^^^^
-  sum = vaddq_u16(vaddq_u16(p1q1, p0q0), sum);
+  sum = vaddq_u16(vaddl_u8(p1q1, p0q0), sum);
 
   // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
   //                                                           ^^
   // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
   //      ^^
-  const uint16x8_t q0p0 = vcombine_u16(vget_high_u16(p0q0), vget_low_u16(p0q0));
-  sum = vaddq_u16(q0p0, sum);
+  const uint8x8_t q0p0 = Transpose32(p0q0);
+  sum = vaddw_u8(sum, q0p0);
 
   *p5q5_output = vrshrn_n_u16(sum, 4);
 
   // Convert to p4 and q4 output:
   // p4 = p5 - (2 * p6) + p3 + q1
   // q4 = q5 - (2 * q6) + q3 + p1
-  sum = vsubq_u16(sum, vaddq_u16(p6q6, p6q6));
-  const uint16x8_t q1p1 = vcombine_u16(vget_high_u16(p1q1), vget_low_u16(p1q1));
-  sum = vaddq_u16(vaddq_u16(p3q3, q1p1), sum);
+  sum = vsubq_u16(sum, vaddl_u8(p6q6, p6q6));
+  const uint8x8_t q1p1 = Transpose32(p1q1);
+  sum = vaddq_u16(vaddl_u8(p3q3, q1p1), sum);
 
   *p4q4_output = vrshrn_n_u16(sum, 4);
 
   // Convert to p3 and q3 output:
   // p3 = p4 - p6 - p5 + p2 + q2
   // q3 = q4 - q6 - q5 + q2 + p2
-  sum = vsubq_u16(sum, vaddq_u16(p6q6, p5q5));
-  const uint16x8_t q2p2 = vcombine_u16(vget_high_u16(p2q2), vget_low_u16(p2q2));
-  sum = vaddq_u16(vaddq_u16(p2q2, q2p2), sum);
+  sum = vsubq_u16(sum, vaddl_u8(p6q6, p5q5));
+  const uint8x8_t q2p2 = Transpose32(p2q2);
+  sum = vaddq_u16(vaddl_u8(p2q2, q2p2), sum);
 
   *p3q3_output = vrshrn_n_u16(sum, 4);
 
   // Convert to p2 and q2 output:
   // p2 = p3 - p6 - p4 + p1 + q3
   // q2 = q3 - q6 - q4 + q1 + p3
-  sum = vsubq_u16(sum, vaddq_u16(p6q6, p4q4));
-  const uint16x8_t q3p3 = vcombine_u16(vget_high_u16(p3q3), vget_low_u16(p3q3));
-  sum = vaddq_u16(vaddq_u16(p1q1, q3p3), sum);
+  sum = vsubq_u16(sum, vaddl_u8(p6q6, p4q4));
+  const uint8x8_t q3p3 = Transpose32(p3q3);
+  sum = vaddq_u16(vaddl_u8(p1q1, q3p3), sum);
 
   *p2q2_output = vrshrn_n_u16(sum, 4);
 
   // Convert to p1 and q1 output:
   // p1 = p2 - p6 - p3 + p0 + q4
   // q1 = q2 - q6 - q3 + q0 + p4
-  sum = vsubq_u16(sum, vaddq_u16(p6q6, p3q3));
-  const uint16x8_t q4p4 = vcombine_u16(vget_high_u16(p4q4), vget_low_u16(p4q4));
-  sum = vaddq_u16(vaddq_u16(p0q0, q4p4), sum);
+  sum = vsubq_u16(sum, vaddl_u8(p6q6, p3q3));
+  const uint8x8_t q4p4 = Transpose32(p4q4);
+  sum = vaddq_u16(vaddl_u8(p0q0, q4p4), sum);
 
   *p1q1_output = vrshrn_n_u16(sum, 4);
 
   // Convert to p0 and q0 output:
   // p0 = p1 - p6 - p2 + q0 + q5
   // q0 = q1 - q6 - q2 + p0 + p5
-  sum = vsubq_u16(sum, vaddq_u16(p6q6, p2q2));
-  const uint16x8_t q5p5 = vcombine_u16(vget_high_u16(p5q5), vget_low_u16(p5q5));
-  sum = vaddq_u16(vaddq_u16(q0p0, q5p5), sum);
+  sum = vsubq_u16(sum, vaddl_u8(p6q6, p2q2));
+  const uint8x8_t q5p5 = Transpose32(p5q5);
+  sum = vaddq_u16(vaddl_u8(q0p0, q5p5), sum);
 
   *p0q0_output = vrshrn_n_u16(sum, 4);
 }
@@ -849,21 +831,20 @@
                        const int hev_thresh) {
   auto* dst = static_cast<uint8_t*>(dest);
 
-  const uint8x8_t zero = vdup_n_u8(0);
-  const uint8x8_t p6_v = LoadLo4(dst - 7 * stride, zero);
-  const uint8x8_t p5_v = LoadLo4(dst - 6 * stride, zero);
-  const uint8x8_t p4_v = LoadLo4(dst - 5 * stride, zero);
-  const uint8x8_t p3_v = LoadLo4(dst - 4 * stride, zero);
-  const uint8x8_t p2_v = LoadLo4(dst - 3 * stride, zero);
-  const uint8x8_t p1_v = LoadLo4(dst - 2 * stride, zero);
-  const uint8x8_t p0_v = LoadLo4(dst - stride, zero);
-  const uint8x8_t p0q0 = LoadHi4(dst, p0_v);
-  const uint8x8_t p1q1 = LoadHi4(dst + stride, p1_v);
-  const uint8x8_t p2q2 = LoadHi4(dst + 2 * stride, p2_v);
-  const uint8x8_t p3q3 = LoadHi4(dst + 3 * stride, p3_v);
-  const uint8x8_t p4q4 = LoadHi4(dst + 4 * stride, p4_v);
-  const uint8x8_t p5q5 = LoadHi4(dst + 5 * stride, p5_v);
-  const uint8x8_t p6q6 = LoadHi4(dst + 6 * stride, p6_v);
+  const uint8x8_t p6_v = Load4(dst - 7 * stride);
+  const uint8x8_t p5_v = Load4(dst - 6 * stride);
+  const uint8x8_t p4_v = Load4(dst - 5 * stride);
+  const uint8x8_t p3_v = Load4(dst - 4 * stride);
+  const uint8x8_t p2_v = Load4(dst - 3 * stride);
+  const uint8x8_t p1_v = Load4(dst - 2 * stride);
+  const uint8x8_t p0_v = Load4(dst - stride);
+  const uint8x8_t p0q0 = Load4<1>(dst, p0_v);
+  const uint8x8_t p1q1 = Load4<1>(dst + stride, p1_v);
+  const uint8x8_t p2q2 = Load4<1>(dst + 2 * stride, p2_v);
+  const uint8x8_t p3q3 = Load4<1>(dst + 3 * stride, p3_v);
+  const uint8x8_t p4q4 = Load4<1>(dst + 4 * stride, p4_v);
+  const uint8x8_t p5q5 = Load4<1>(dst + 5 * stride, p5_v);
+  const uint8x8_t p6q6 = Load4<1>(dst + 6 * stride, p6_v);
 
   uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask;
   Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh,
@@ -904,6 +885,7 @@
 #if defined(__aarch64__)
   if (vaddv_u8(is_flat4_mask) == 0) {
     // Filter8() and Filter14() do not apply.
+    const uint8x8_t zero = vdup_n_u8(0);
     f8_p1q1 = zero;
     f8_p0q0 = zero;
     f14_p1q1 = zero;
@@ -916,6 +898,7 @@
 #if defined(__aarch64__)
     if (vaddv_u8(is_flat_outer4_mask) == 0) {
       // Filter14() does not apply.
+      const uint8x8_t zero = vdup_n_u8(0);
       f14_p2q2 = zero;
       f14_p1q1 = zero;
       f14_p0q0 = zero;
@@ -968,48 +951,68 @@
                      const int outer_thresh, const int inner_thresh,
                      const int hev_thresh) {
   auto* dst = static_cast<uint8_t*>(dest);
-  // Move |dst| to the left side of the filter window.
-  dst -= 7;
+  dst -= 8;
+  // input
+  // p7 p6 p5 p4 p3 p2 p1 p0  q0 q1 q2 q3 q4 q5 q6 q7
+  const uint8x16_t x0 = vld1q_u8(dst);
+  dst += stride;
+  const uint8x16_t x1 = vld1q_u8(dst);
+  dst += stride;
+  const uint8x16_t x2 = vld1q_u8(dst);
+  dst += stride;
+  const uint8x16_t x3 = vld1q_u8(dst);
+  dst -= (stride * 3);
 
-  // The filter only need 14 bytes. Over read 2 bytes.
-  const uint8x16_t input_0 = vld1q_u8(dst);
-  const uint8x16_t input_1 = vld1q_u8(dst + stride);
-  const uint8x16_t input_2 = vld1q_u8(dst + 2 * stride);
-  const uint8x16_t input_3 = vld1q_u8(dst + 3 * stride);
+  // re-order input
+#if defined(__aarch64__)
+  const uint8x8_t index_qp3toqp0 = vcreate_u8(0x0b0a090804050607);
+  const uint8x8_t index_qp7toqp4 = vcreate_u8(0x0f0e0d0c00010203);
+  const uint8x16_t index_qp7toqp0 = vcombine_u8(index_qp3toqp0, index_qp7toqp4);
 
-  // Transpose 16x4. Just like an 8x4 transpose but with Q registers.
+  uint8x16_t input_0 = vqtbl1q_u8(x0, index_qp7toqp0);
+  uint8x16_t input_1 = vqtbl1q_u8(x1, index_qp7toqp0);
+  uint8x16_t input_2 = vqtbl1q_u8(x2, index_qp7toqp0);
+  uint8x16_t input_3 = vqtbl1q_u8(x3, index_qp7toqp0);
+#else
+  const uint8x8_t index_qp3toqp0 = vcreate_u8(0x0b0a090804050607);
+  const uint8x8_t index_qp7toqp4 = vcreate_u8(0x0f0e0d0c00010203);
+
+  const uint8x8_t x0_qp3qp0 = VQTbl1U8(x0, index_qp3toqp0);
+  const uint8x8_t x1_qp3qp0 = VQTbl1U8(x1, index_qp3toqp0);
+  const uint8x8_t x2_qp3qp0 = VQTbl1U8(x2, index_qp3toqp0);
+  const uint8x8_t x3_qp3qp0 = VQTbl1U8(x3, index_qp3toqp0);
+
+  const uint8x8_t x0_qp7qp4 = VQTbl1U8(x0, index_qp7toqp4);
+  const uint8x8_t x1_qp7qp4 = VQTbl1U8(x1, index_qp7toqp4);
+  const uint8x8_t x2_qp7qp4 = VQTbl1U8(x2, index_qp7toqp4);
+  const uint8x8_t x3_qp7qp4 = VQTbl1U8(x3, index_qp7toqp4);
+
+  const uint8x16_t input_0 = vcombine_u8(x0_qp3qp0, x0_qp7qp4);
+  const uint8x16_t input_1 = vcombine_u8(x1_qp3qp0, x1_qp7qp4);
+  const uint8x16_t input_2 = vcombine_u8(x2_qp3qp0, x2_qp7qp4);
+  const uint8x16_t input_3 = vcombine_u8(x3_qp3qp0, x3_qp7qp4);
+#endif
+  // input after re-order
+  // p0 p1 p2 p3 q0 q1 q2 q3  p4 p5 p6 p7 q4 q5 q6 q7
+
   const uint8x16x2_t in01 = vtrnq_u8(input_0, input_1);
   const uint8x16x2_t in23 = vtrnq_u8(input_2, input_3);
-
   const uint16x8x2_t in02 = vtrnq_u16(vreinterpretq_u16_u8(in01.val[0]),
                                       vreinterpretq_u16_u8(in23.val[0]));
   const uint16x8x2_t in13 = vtrnq_u16(vreinterpretq_u16_u8(in01.val[1]),
                                       vreinterpretq_u16_u8(in23.val[1]));
 
-  // This a very verbose way of renaming the registers.
-  const uint8x8_t p6p2 = vget_low_u8(vreinterpretq_u8_u16(in02.val[0]));
-  const uint8x8_t q1q5 = vget_high_u8(vreinterpretq_u8_u16(in02.val[0]));
+  const uint8x8_t p0q0 = vget_low_u8(vreinterpretq_u8_u16(in02.val[0]));
+  const uint8x8_t p1q1 = vget_low_u8(vreinterpretq_u8_u16(in13.val[0]));
 
-  const uint8x8_t p4p0 = vget_low_u8(vreinterpretq_u8_u16(in02.val[1]));
-  const uint8x8_t q3x0 = vget_high_u8(vreinterpretq_u8_u16(in02.val[1]));
+  const uint8x8_t p2q2 = vget_low_u8(vreinterpretq_u8_u16(in02.val[1]));
+  const uint8x8_t p3q3 = vget_low_u8(vreinterpretq_u8_u16(in13.val[1]));
 
-  const uint8x8_t p5p1 = vget_low_u8(vreinterpretq_u8_u16(in13.val[0]));
-  const uint8x8_t q2q6 = vget_high_u8(vreinterpretq_u8_u16(in13.val[0]));
+  const uint8x8_t p4q4 = vget_high_u8(vreinterpretq_u8_u16(in02.val[0]));
+  const uint8x8_t p5q5 = vget_high_u8(vreinterpretq_u8_u16(in13.val[0]));
 
-  const uint8x8_t p3q0 = vget_low_u8(vreinterpretq_u8_u16(in13.val[1]));
-  const uint8x8_t q4x1 = vget_high_u8(vreinterpretq_u8_u16(in13.val[1]));
-
-  const uint8x8x2_t p6q6xp2q2 = Interleave32(p6p2, Transpose32(q2q6));
-  const uint8x8_t p6q6 = p6q6xp2q2.val[0];
-  const uint8x8_t p2q2 = p6q6xp2q2.val[1];
-
-  const uint8x8x2_t p5q5xp1q1 = Interleave32(p5p1, Transpose32(q1q5));
-  const uint8x8_t p5q5 = p5q5xp1q1.val[0];
-  const uint8x8_t p1q1 = p5q5xp1q1.val[1];
-
-  const uint8x8_t p4q4 = InterleaveLow32(p4p0, q4x1);
-  const uint8x8_t p3q3 = InterleaveLow32(p3q0, q3x0);
-  const uint8x8_t p0q0 = InterleaveHigh32(p4p0, p3q0);
+  const uint8x8_t p6q6 = vget_high_u8(vreinterpretq_u8_u16(in02.val[1]));
+  const uint8x8_t p7q7 = vget_high_u8(vreinterpretq_u8_u16(in13.val[1]));
 
   uint8x8_t needs_filter8_mask, is_flat4_mask, hev_mask;
   Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_thresh, inner_thresh,
@@ -1038,28 +1041,30 @@
   is_flat_outer4_mask =
       InterleaveLow32(is_flat_outer4_mask, is_flat_outer4_mask);
 
-  uint8x8_t f_p1q1;
-  uint8x8_t f_p0q0;
+  uint8x8_t f_p0q0, f_p1q1;
   const uint8x8x2_t q0p1xp0q1 = Interleave32(Transpose32(p0q0), p1q1);
   Filter4(q0p1xp0q1.val[0], q0p1xp0q1.val[1], hev_mask, &f_p1q1, &f_p0q0);
   // Reset the outer values if only a Hev() mask was required.
   f_p1q1 = vbsl_u8(hev_mask, p1q1, f_p1q1);
 
-  // Input is 14 taps but output is only 12. Move |dst| from p6 to p5.
-  dst += 1;
-
   uint8x8_t p1q1_output, p0q0_output;
+  uint8x8_t p5q5_output, p4q4_output, p3q3_output, p2q2_output;
+
 #if defined(__aarch64__)
   if (vaddv_u8(is_flat4_mask) == 0) {
     // Filter8() and Filter14() do not apply.
     p1q1_output = p1q1;
     p0q0_output = p0q0;
+
+    p5q5_output = p5q5;
+    p4q4_output = p4q4;
+    p3q3_output = p3q3;
+    p2q2_output = p2q2;
   } else {
 #endif  // defined(__aarch64__)
     uint8x8_t f8_p2q2, f8_p1q1, f8_p0q0;
     Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0);
 
-    uint8x8_t p5q5_output, p4q4_output, p3q3_output, p2q2_output;
 #if defined(__aarch64__)
     if (vaddv_u8(is_flat_outer4_mask) == 0) {
       // Filter14() does not apply.
@@ -1085,27 +1090,6 @@
     }
 #endif  // defined(__aarch64__)
     p2q2_output = vbsl_u8(is_flat4_mask, p2q2_output, p2q2);
-
-    // Transpose and write the Filter14/Filter8 exclusive values.
-    const uint8x8x2_t p5q2xq5p2_output =
-        Interleave32(p5q5_output, Transpose32(p2q2_output));
-    uint8x8_t p5q2_output = p5q2xq5p2_output.val[0];
-    uint8x8_t p2q5_output = Transpose32(p5q2xq5p2_output.val[1]);
-    const uint8x8x2_t p4q3xq4p3_output =
-        Interleave32(p4q4_output, Transpose32(p3q3_output));
-    uint8x8_t p4q3_output = p4q3xq4p3_output.val[0];
-    uint8x8_t p3q4_output = Transpose32(p4q3xq4p3_output.val[1]);
-
-    Transpose8x4(&p5q2_output, &p4q3_output, &p3q4_output, &p2q5_output);
-
-    StoreLo4(dst, p5q2_output);
-    StoreHi4((dst + 8), p5q2_output);
-    StoreLo4(dst + stride, p4q3_output);
-    StoreHi4((dst + 8) + stride, p4q3_output);
-    StoreLo4(dst + 2 * stride, p3q4_output);
-    StoreHi4((dst + 8) + 2 * stride, p3q4_output);
-    StoreLo4(dst + 3 * stride, p2q5_output);
-    StoreHi4((dst + 8) + 3 * stride, p2q5_output);
 #if defined(__aarch64__)
   }
 #endif  // defined(__aarch64__)
@@ -1115,23 +1099,60 @@
   p0q0_output = vbsl_u8(is_flat4_mask, p0q0_output, f_p0q0);
   p0q0_output = vbsl_u8(needs_filter8_mask, p0q0_output, p0q0);
 
-  // Transpose the middle (p1-q1) 4x4 block and write it out.
-  const uint8x8x2_t p1p0xq1q0 = Interleave32(p1q1_output, p0q0_output);
-  uint8x8_t output_0 = p1p0xq1q0.val[0];
-  uint8x8_t output_1 = Transpose32(p1p0xq1q0.val[1]);
+  const uint8x16_t p0q0_p4q4 = vcombine_u8(p0q0_output, p4q4_output);
+  const uint8x16_t p2q2_p6q6 = vcombine_u8(p2q2_output, p6q6);
+  const uint8x16_t p1q1_p5q5 = vcombine_u8(p1q1_output, p5q5_output);
+  const uint8x16_t p3q3_p7q7 = vcombine_u8(p3q3_output, p7q7);
 
-  Transpose4x4(&output_0, &output_1);
+  const uint16x8x2_t out02 = vtrnq_u16(vreinterpretq_u16_u8(p0q0_p4q4),
+                                       vreinterpretq_u16_u8(p2q2_p6q6));
+  const uint16x8x2_t out13 = vtrnq_u16(vreinterpretq_u16_u8(p1q1_p5q5),
+                                       vreinterpretq_u16_u8(p3q3_p7q7));
+  const uint8x16x2_t out01 = vtrnq_u8(vreinterpretq_u8_u16(out02.val[0]),
+                                      vreinterpretq_u8_u16(out13.val[0]));
+  const uint8x16x2_t out23 = vtrnq_u8(vreinterpretq_u8_u16(out02.val[1]),
+                                      vreinterpretq_u8_u16(out13.val[1]));
 
-  // Move |dst| from p5 to p1.
-  dst += 4;
-  StoreLo4(dst, output_0);
-  StoreLo4(dst + stride, output_1);
-  StoreHi4(dst + 2 * stride, output_0);
-  StoreHi4(dst + 3 * stride, output_1);
+#if defined(__aarch64__)
+  const uint8x8_t index_p7top0 = vcreate_u8(0x0001020308090a0b);
+  const uint8x8_t index_q7toq0 = vcreate_u8(0x0f0e0d0c07060504);
+  const uint8x16_t index_p7toq7 = vcombine_u8(index_p7top0, index_q7toq0);
+
+  const uint8x16_t output_0 = vqtbl1q_u8(out01.val[0], index_p7toq7);
+  const uint8x16_t output_1 = vqtbl1q_u8(out01.val[1], index_p7toq7);
+  const uint8x16_t output_2 = vqtbl1q_u8(out23.val[0], index_p7toq7);
+  const uint8x16_t output_3 = vqtbl1q_u8(out23.val[1], index_p7toq7);
+#else
+  const uint8x8_t index_p7top0 = vcreate_u8(0x0001020308090a0b);
+  const uint8x8_t index_q7toq0 = vcreate_u8(0x0f0e0d0c07060504);
+
+  const uint8x8_t x0_p7p0 = VQTbl1U8(out01.val[0], index_p7top0);
+  const uint8x8_t x1_p7p0 = VQTbl1U8(out01.val[1], index_p7top0);
+  const uint8x8_t x2_p7p0 = VQTbl1U8(out23.val[0], index_p7top0);
+  const uint8x8_t x3_p7p0 = VQTbl1U8(out23.val[1], index_p7top0);
+
+  const uint8x8_t x0_q7q0 = VQTbl1U8(out01.val[0], index_q7toq0);
+  const uint8x8_t x1_q7q0 = VQTbl1U8(out01.val[1], index_q7toq0);
+  const uint8x8_t x2_q7q0 = VQTbl1U8(out23.val[0], index_q7toq0);
+  const uint8x8_t x3_q7q0 = VQTbl1U8(out23.val[1], index_q7toq0);
+
+  const uint8x16_t output_0 = vcombine_u8(x0_p7p0, x0_q7q0);
+  const uint8x16_t output_1 = vcombine_u8(x1_p7p0, x1_q7q0);
+  const uint8x16_t output_2 = vcombine_u8(x2_p7p0, x2_q7q0);
+  const uint8x16_t output_3 = vcombine_u8(x3_p7p0, x3_q7q0);
+#endif
+
+  vst1q_u8(dst, output_0);
+  dst += stride;
+  vst1q_u8(dst, output_1);
+  dst += stride;
+  vst1q_u8(dst, output_2);
+  dst += stride;
+  vst1q_u8(dst, output_3);
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeHorizontal] =
       Horizontal4_NEON;
@@ -1158,7 +1179,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/arm/loop_filter_neon.h b/libgav1/src/dsp/arm/loop_filter_neon.h
index 70289ac..5f79200 100644
--- a/libgav1/src/dsp/arm/loop_filter_neon.h
+++ b/libgav1/src/dsp/arm/loop_filter_neon.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_ARM_LOOP_FILTER_NEON_H_
 #define LIBGAV1_SRC_DSP_ARM_LOOP_FILTER_NEON_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -33,20 +33,20 @@
 #if LIBGAV1_ENABLE_NEON
 
 #define LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeHorizontal \
-  LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeVertical LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeVertical LIBGAV1_CPU_NEON
 
 #define LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeHorizontal \
-  LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeVertical LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeVertical LIBGAV1_CPU_NEON
 
 #define LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeHorizontal \
-  LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeVertical LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeVertical LIBGAV1_CPU_NEON
 
 #define LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeHorizontal \
-  LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeVertical LIBGAV1_DSP_NEON
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeVertical LIBGAV1_CPU_NEON
 
 #endif  // LIBGAV1_ENABLE_NEON
 
diff --git a/libgav1/src/dsp/arm/loop_restoration_neon.cc b/libgav1/src/dsp/arm/loop_restoration_neon.cc
index be45673..1e8dfb2 100644
--- a/libgav1/src/dsp/arm/loop_restoration_neon.cc
+++ b/libgav1/src/dsp/arm/loop_restoration_neon.cc
@@ -12,18 +12,21 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/loop_restoration.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
-
 #include <arm_neon.h>
 
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
+#include <cstring>
 
 #include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
 #include "src/utils/constants.h"
 
 namespace libgav1 {
@@ -31,352 +34,532 @@
 namespace low_bitdepth {
 namespace {
 
+template <int bytes>
+inline uint8x8_t VshrU128(const uint8x8x2_t src) {
+  return vext_u8(src.val[0], src.val[1], bytes);
+}
+
+template <int bytes>
+inline uint16x8_t VshrU128(const uint16x8x2_t src) {
+  return vextq_u16(src.val[0], src.val[1], bytes / 2);
+}
+
 // Wiener
+
+// Must make a local copy of coefficients to help compiler know that they have
+// no overlap with other buffers. Using 'const' keyword is not enough. Actually
+// compiler doesn't make a copy, since there is enough registers in this case.
 inline void PopulateWienerCoefficients(
     const RestorationUnitInfo& restoration_info, const int direction,
-    int16_t* const filter) {
+    int16_t filter[4]) {
   // In order to keep the horizontal pass intermediate values within 16 bits we
-  // initialize |filter[3]| to 0 instead of 128.
-  filter[3] = 0;
-  for (int i = 0; i < 3; ++i) {
-    const int16_t coeff = restoration_info.wiener_info.filter[direction][i];
-    filter[i] = coeff;
-    filter[6 - i] = coeff;
-    filter[3] -= coeff * 2;
+  // offset |filter[3]| by 128. The 128 offset will be added back in the loop.
+  for (int i = 0; i < 4; ++i) {
+    filter[i] = restoration_info.wiener_info.filter[direction][i];
+  }
+  if (direction == WienerInfo::kHorizontal) {
+    filter[3] -= 128;
   }
 }
 
-inline int16x8_t HorizontalSum(const uint8x8_t a[7], int16_t filter[7]) {
-  int16x8_t sum = vmulq_n_s16(vreinterpretq_s16_u16(vmovl_u8(a[0])), filter[0]);
-  sum = vaddq_s16(
-      sum, vmulq_n_s16(vreinterpretq_s16_u16(vmovl_u8(a[1])), filter[1]));
-  sum = vaddq_s16(
-      sum, vmulq_n_s16(vreinterpretq_s16_u16(vmovl_u8(a[2])), filter[2]));
-  sum = vaddq_s16(
-      sum, vmulq_n_s16(vreinterpretq_s16_u16(vmovl_u8(a[3])), filter[3]));
-  sum = vaddq_s16(
-      sum, vmulq_n_s16(vreinterpretq_s16_u16(vmovl_u8(a[4])), filter[4]));
-  sum = vaddq_s16(
-      sum, vmulq_n_s16(vreinterpretq_s16_u16(vmovl_u8(a[5])), filter[5]));
-  sum = vaddq_s16(
-      sum, vmulq_n_s16(vreinterpretq_s16_u16(vmovl_u8(a[6])), filter[6]));
-
-  sum = vrshrq_n_s16(sum, kInterRoundBitsHorizontal);
-
-  // Delaying |horizontal_rounding| until after dowshifting allows the sum to
-  // stay in 16 bits.
-  // |horizontal_rounding| = 1 << (bitdepth + kWienerFilterBits - 1)
-  //                         1 << (       8 +                 7 - 1)
-  // Plus |kInterRoundBitsHorizontal| and it works out to 1 << 11.
-  sum = vaddq_s16(sum, vdupq_n_s16(1 << 11));
-
-  // Just like |horizontal_rounding|, adding |filter[3]| at this point allows
-  // the sum to stay in 16 bits.
-  // But wait! We *did* calculate |filter[3]| and used it in the sum! But it was
-  // offset by 128. Fix that here:
-  // |src[3]| * 128 >> 3 == |src[3]| << 4
-  sum = vaddq_s16(sum, vreinterpretq_s16_u16(vshll_n_u8(a[3], 4)));
-
-  // Saturate to
-  // [0,
-  // (1 << (bitdepth + 1 + kWienerFilterBits - kInterRoundBitsHorizontal)) - 1)]
-  // (1 << (       8 + 1 +                 7 -                         3)) - 1)
-  sum = vminq_s16(sum, vdupq_n_s16((1 << 13) - 1));
-  sum = vmaxq_s16(sum, vdupq_n_s16(0));
-  return sum;
+inline int16x8_t WienerHorizontal2(const uint8x8_t s0, const uint8x8_t s1,
+                                   const int16_t filter, const int16x8_t sum) {
+  const int16x8_t ss = vreinterpretq_s16_u16(vaddl_u8(s0, s1));
+  return vmlaq_n_s16(sum, ss, filter);
 }
 
-template <int min_width>
-inline void VerticalSum(const int16_t* src_base, const ptrdiff_t src_stride,
-                        uint8_t* dst_base, const ptrdiff_t dst_stride,
-                        const int16x4_t filter[7], const int width,
-                        const int height) {
-  static_assert(min_width == 4 || min_width == 8, "");
-  // -(1 << (bitdepth + kInterRoundBitsVertical - 1))
-  // -(1 << (       8 +                      11 - 1))
-  constexpr int vertical_rounding = -(1 << 18);
-  if (min_width == 8) {
-    int x = 0;
+inline int16x8x2_t WienerHorizontal2(const uint8x16_t s0, const uint8x16_t s1,
+                                     const int16_t filter,
+                                     const int16x8x2_t sum) {
+  int16x8x2_t d;
+  d.val[0] =
+      WienerHorizontal2(vget_low_u8(s0), vget_low_u8(s1), filter, sum.val[0]);
+  d.val[1] =
+      WienerHorizontal2(vget_high_u8(s0), vget_high_u8(s1), filter, sum.val[1]);
+  return d;
+}
+
+inline void WienerHorizontalSum(const uint8x8_t s[3], const int16_t filter[4],
+                                int16x8_t sum, int16_t* const wiener_buffer) {
+  constexpr int offset =
+      1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
+  constexpr int limit = (offset << 2) - 1;
+  const int16x8_t s_0_2 = vreinterpretq_s16_u16(vaddl_u8(s[0], s[2]));
+  const int16x8_t s_1 = ZeroExtend(s[1]);
+  sum = vmlaq_n_s16(sum, s_0_2, filter[2]);
+  sum = vmlaq_n_s16(sum, s_1, filter[3]);
+  // Calculate scaled down offset correction, and add to sum here to prevent
+  // signed 16 bit outranging.
+  sum = vrsraq_n_s16(vshlq_n_s16(s_1, 7 - kInterRoundBitsHorizontal), sum,
+                     kInterRoundBitsHorizontal);
+  sum = vmaxq_s16(sum, vdupq_n_s16(-offset));
+  sum = vminq_s16(sum, vdupq_n_s16(limit - offset));
+  vst1q_s16(wiener_buffer, sum);
+}
+
+inline void WienerHorizontalSum(const uint8x16_t src[3],
+                                const int16_t filter[4], int16x8x2_t sum,
+                                int16_t* const wiener_buffer) {
+  uint8x8_t s[3];
+  s[0] = vget_low_u8(src[0]);
+  s[1] = vget_low_u8(src[1]);
+  s[2] = vget_low_u8(src[2]);
+  WienerHorizontalSum(s, filter, sum.val[0], wiener_buffer);
+  s[0] = vget_high_u8(src[0]);
+  s[1] = vget_high_u8(src[1]);
+  s[2] = vget_high_u8(src[2]);
+  WienerHorizontalSum(s, filter, sum.val[1], wiener_buffer + 8);
+}
+
+inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride,
+                                 const ptrdiff_t width, const int height,
+                                 const int16_t filter[4],
+                                 int16_t** const wiener_buffer) {
+  int y = height;
+  do {
+    const uint8_t* src_ptr = src;
+    uint8x16_t s[8];
+    s[0] = vld1q_u8(src_ptr);
+    ptrdiff_t x = width;
     do {
-      const int16_t* src = src_base + x;
-      uint8_t* dst = dst_base + x;
+      src_ptr += 16;
+      s[7] = vld1q_u8(src_ptr);
+      s[1] = vextq_u8(s[0], s[7], 1);
+      s[2] = vextq_u8(s[0], s[7], 2);
+      s[3] = vextq_u8(s[0], s[7], 3);
+      s[4] = vextq_u8(s[0], s[7], 4);
+      s[5] = vextq_u8(s[0], s[7], 5);
+      s[6] = vextq_u8(s[0], s[7], 6);
+      int16x8x2_t sum;
+      sum.val[0] = sum.val[1] = vdupq_n_s16(0);
+      sum = WienerHorizontal2(s[0], s[6], filter[0], sum);
+      sum = WienerHorizontal2(s[1], s[5], filter[1], sum);
+      WienerHorizontalSum(s + 2, filter, sum, *wiener_buffer);
+      s[0] = s[7];
+      *wiener_buffer += 16;
+      x -= 16;
+    } while (x != 0);
+    src += src_stride;
+  } while (--y != 0);
+}
+
+inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride,
+                                 const ptrdiff_t width, const int height,
+                                 const int16_t filter[4],
+                                 int16_t** const wiener_buffer) {
+  int y = height;
+  do {
+    const uint8_t* src_ptr = src;
+    uint8x16_t s[6];
+    s[0] = vld1q_u8(src_ptr);
+    ptrdiff_t x = width;
+    do {
+      src_ptr += 16;
+      s[5] = vld1q_u8(src_ptr);
+      s[1] = vextq_u8(s[0], s[5], 1);
+      s[2] = vextq_u8(s[0], s[5], 2);
+      s[3] = vextq_u8(s[0], s[5], 3);
+      s[4] = vextq_u8(s[0], s[5], 4);
+      int16x8x2_t sum;
+      sum.val[0] = sum.val[1] = vdupq_n_s16(0);
+      sum = WienerHorizontal2(s[0], s[4], filter[1], sum);
+      WienerHorizontalSum(s + 1, filter, sum, *wiener_buffer);
+      s[0] = s[5];
+      *wiener_buffer += 16;
+      x -= 16;
+    } while (x != 0);
+    src += src_stride;
+  } while (--y != 0);
+}
+
+inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride,
+                                 const ptrdiff_t width, const int height,
+                                 const int16_t filter[4],
+                                 int16_t** const wiener_buffer) {
+  int y = height;
+  do {
+    const uint8_t* src_ptr = src;
+    uint8x16_t s[4];
+    s[0] = vld1q_u8(src_ptr);
+    ptrdiff_t x = width;
+    do {
+      src_ptr += 16;
+      s[3] = vld1q_u8(src_ptr);
+      s[1] = vextq_u8(s[0], s[3], 1);
+      s[2] = vextq_u8(s[0], s[3], 2);
+      int16x8x2_t sum;
+      sum.val[0] = sum.val[1] = vdupq_n_s16(0);
+      WienerHorizontalSum(s, filter, sum, *wiener_buffer);
+      s[0] = s[3];
+      *wiener_buffer += 16;
+      x -= 16;
+    } while (x != 0);
+    src += src_stride;
+  } while (--y != 0);
+}
+
+inline void WienerHorizontalTap1(const uint8_t* src, const ptrdiff_t src_stride,
+                                 const ptrdiff_t width, const int height,
+                                 int16_t** const wiener_buffer) {
+  int y = height;
+  do {
+    const uint8_t* src_ptr = src;
+    ptrdiff_t x = width;
+    do {
+      const uint8x16_t s = vld1q_u8(src_ptr);
+      const uint8x8_t s0 = vget_low_u8(s);
+      const uint8x8_t s1 = vget_high_u8(s);
+      const int16x8_t d0 = vreinterpretq_s16_u16(vshll_n_u8(s0, 4));
+      const int16x8_t d1 = vreinterpretq_s16_u16(vshll_n_u8(s1, 4));
+      vst1q_s16(*wiener_buffer + 0, d0);
+      vst1q_s16(*wiener_buffer + 8, d1);
+      src_ptr += 16;
+      *wiener_buffer += 16;
+      x -= 16;
+    } while (x != 0);
+    src += src_stride;
+  } while (--y != 0);
+}
+
+inline int32x4x2_t WienerVertical2(const int16x8_t a0, const int16x8_t a1,
+                                   const int16_t filter,
+                                   const int32x4x2_t sum) {
+  const int16x8_t a = vaddq_s16(a0, a1);
+  int32x4x2_t d;
+  d.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(a), filter);
+  d.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(a), filter);
+  return d;
+}
+
+inline uint8x8_t WienerVertical(const int16x8_t a[3], const int16_t filter[4],
+                                const int32x4x2_t sum) {
+  int32x4x2_t d = WienerVertical2(a[0], a[2], filter[2], sum);
+  d.val[0] = vmlal_n_s16(d.val[0], vget_low_s16(a[1]), filter[3]);
+  d.val[1] = vmlal_n_s16(d.val[1], vget_high_s16(a[1]), filter[3]);
+  const uint16x4_t sum_lo_16 = vqrshrun_n_s32(d.val[0], 11);
+  const uint16x4_t sum_hi_16 = vqrshrun_n_s32(d.val[1], 11);
+  return vqmovn_u16(vcombine_u16(sum_lo_16, sum_hi_16));
+}
+
+inline uint8x8_t WienerVerticalTap7Kernel(const int16_t* const wiener_buffer,
+                                          const ptrdiff_t wiener_stride,
+                                          const int16_t filter[4],
+                                          int16x8_t a[7]) {
+  int32x4x2_t sum;
+  a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
+  a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
+  a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
+  a[6] = vld1q_s16(wiener_buffer + 6 * wiener_stride);
+  sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+  sum = WienerVertical2(a[0], a[6], filter[0], sum);
+  sum = WienerVertical2(a[1], a[5], filter[1], sum);
+  a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
+  a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
+  a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
+  return WienerVertical(a + 2, filter, sum);
+}
+
+inline uint8x8x2_t WienerVerticalTap7Kernel2(const int16_t* const wiener_buffer,
+                                             const ptrdiff_t wiener_stride,
+                                             const int16_t filter[4]) {
+  int16x8_t a[8];
+  int32x4x2_t sum;
+  uint8x8x2_t d;
+  d.val[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a);
+  a[7] = vld1q_s16(wiener_buffer + 7 * wiener_stride);
+  sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+  sum = WienerVertical2(a[1], a[7], filter[0], sum);
+  sum = WienerVertical2(a[2], a[6], filter[1], sum);
+  d.val[1] = WienerVertical(a + 3, filter, sum);
+  return d;
+}
+
+inline void WienerVerticalTap7(const int16_t* wiener_buffer,
+                               const ptrdiff_t width, const int height,
+                               const int16_t filter[4], uint8_t* dst,
+                               const ptrdiff_t dst_stride) {
+  for (int y = height >> 1; y != 0; --y) {
+    uint8_t* dst_ptr = dst;
+    ptrdiff_t x = width;
+    do {
+      uint8x8x2_t d[2];
+      d[0] = WienerVerticalTap7Kernel2(wiener_buffer + 0, width, filter);
+      d[1] = WienerVerticalTap7Kernel2(wiener_buffer + 8, width, filter);
+      vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
+      vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
+      wiener_buffer += 16;
+      dst_ptr += 16;
+      x -= 16;
+    } while (x != 0);
+    wiener_buffer += width;
+    dst += 2 * dst_stride;
+  }
+
+  if ((height & 1) != 0) {
+    ptrdiff_t x = width;
+    do {
       int16x8_t a[7];
-      a[0] = vld1q_s16(src);
-      src += src_stride;
-      a[1] = vld1q_s16(src);
-      src += src_stride;
-      a[2] = vld1q_s16(src);
-      src += src_stride;
-      a[3] = vld1q_s16(src);
-      src += src_stride;
-      a[4] = vld1q_s16(src);
-      src += src_stride;
-      a[5] = vld1q_s16(src);
-      src += src_stride;
-
-      int y = 0;
-      do {
-        a[6] = vld1q_s16(src);
-        src += src_stride;
-
-        int32x4_t sum_lo = vdupq_n_s32(vertical_rounding);
-        sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[0]), filter[0]);
-        sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[1]), filter[1]);
-        sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[2]), filter[2]);
-        sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[3]), filter[3]);
-        sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[4]), filter[4]);
-        sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[5]), filter[5]);
-        sum_lo = vmlal_s16(sum_lo, vget_low_s16(a[6]), filter[6]);
-        uint16x4_t sum_lo_16 = vqrshrun_n_s32(sum_lo, 11);
-
-        int32x4_t sum_hi = vdupq_n_s32(vertical_rounding);
-        sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[0]), filter[0]);
-        sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[1]), filter[1]);
-        sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[2]), filter[2]);
-        sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[3]), filter[3]);
-        sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[4]), filter[4]);
-        sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[5]), filter[5]);
-        sum_hi = vmlal_s16(sum_hi, vget_high_s16(a[6]), filter[6]);
-        uint16x4_t sum_hi_16 = vqrshrun_n_s32(sum_hi, 11);
-
-        vst1_u8(dst, vqmovn_u16(vcombine_u16(sum_lo_16, sum_hi_16)));
-        dst += dst_stride;
-
-        a[0] = a[1];
-        a[1] = a[2];
-        a[2] = a[3];
-        a[3] = a[4];
-        a[4] = a[5];
-        a[5] = a[6];
-      } while (++y < height);
-      x += 8;
-    } while (x < width);
-  } else if (min_width == 4) {
-    const int16_t* src = src_base;
-    uint8_t* dst = dst_base;
-    int16x4_t a[7];
-    a[0] = vld1_s16(src);
-    src += src_stride;
-    a[1] = vld1_s16(src);
-    src += src_stride;
-    a[2] = vld1_s16(src);
-    src += src_stride;
-    a[3] = vld1_s16(src);
-    src += src_stride;
-    a[4] = vld1_s16(src);
-    src += src_stride;
-    a[5] = vld1_s16(src);
-    src += src_stride;
-
-    int y = 0;
-    do {
-      a[6] = vld1_s16(src);
-      src += src_stride;
-
-      int32x4_t sum = vdupq_n_s32(vertical_rounding);
-      sum = vmlal_s16(sum, a[0], filter[0]);
-      sum = vmlal_s16(sum, a[1], filter[1]);
-      sum = vmlal_s16(sum, a[2], filter[2]);
-      sum = vmlal_s16(sum, a[3], filter[3]);
-      sum = vmlal_s16(sum, a[4], filter[4]);
-      sum = vmlal_s16(sum, a[5], filter[5]);
-      sum = vmlal_s16(sum, a[6], filter[6]);
-      uint16x4_t sum_16 = vqrshrun_n_s32(sum, 11);
-
-      StoreLo4(dst, vqmovn_u16(vcombine_u16(sum_16, sum_16)));
-      dst += dst_stride;
-
-      a[0] = a[1];
-      a[1] = a[2];
-      a[2] = a[3];
-      a[3] = a[4];
-      a[4] = a[5];
-      a[5] = a[6];
-    } while (++y < height);
+      const uint8x8_t d0 =
+          WienerVerticalTap7Kernel(wiener_buffer + 0, width, filter, a);
+      const uint8x8_t d1 =
+          WienerVerticalTap7Kernel(wiener_buffer + 8, width, filter, a);
+      vst1q_u8(dst, vcombine_u8(d0, d1));
+      wiener_buffer += 16;
+      dst += 16;
+      x -= 16;
+    } while (x != 0);
   }
 }
 
+inline uint8x8_t WienerVerticalTap5Kernel(const int16_t* const wiener_buffer,
+                                          const ptrdiff_t wiener_stride,
+                                          const int16_t filter[4],
+                                          int16x8_t a[5]) {
+  a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
+  a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
+  a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
+  a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
+  a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
+  int32x4x2_t sum;
+  sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+  sum = WienerVertical2(a[0], a[4], filter[1], sum);
+  return WienerVertical(a + 1, filter, sum);
+}
+
+inline uint8x8x2_t WienerVerticalTap5Kernel2(const int16_t* const wiener_buffer,
+                                             const ptrdiff_t wiener_stride,
+                                             const int16_t filter[4]) {
+  int16x8_t a[6];
+  int32x4x2_t sum;
+  uint8x8x2_t d;
+  d.val[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a);
+  a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
+  sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+  sum = WienerVertical2(a[1], a[5], filter[1], sum);
+  d.val[1] = WienerVertical(a + 2, filter, sum);
+  return d;
+}
+
+inline void WienerVerticalTap5(const int16_t* wiener_buffer,
+                               const ptrdiff_t width, const int height,
+                               const int16_t filter[4], uint8_t* dst,
+                               const ptrdiff_t dst_stride) {
+  for (int y = height >> 1; y != 0; --y) {
+    uint8_t* dst_ptr = dst;
+    ptrdiff_t x = width;
+    do {
+      uint8x8x2_t d[2];
+      d[0] = WienerVerticalTap5Kernel2(wiener_buffer + 0, width, filter);
+      d[1] = WienerVerticalTap5Kernel2(wiener_buffer + 8, width, filter);
+      vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
+      vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
+      wiener_buffer += 16;
+      dst_ptr += 16;
+      x -= 16;
+    } while (x != 0);
+    wiener_buffer += width;
+    dst += 2 * dst_stride;
+  }
+
+  if ((height & 1) != 0) {
+    ptrdiff_t x = width;
+    do {
+      int16x8_t a[5];
+      const uint8x8_t d0 =
+          WienerVerticalTap5Kernel(wiener_buffer + 0, width, filter, a);
+      const uint8x8_t d1 =
+          WienerVerticalTap5Kernel(wiener_buffer + 8, width, filter, a);
+      vst1q_u8(dst, vcombine_u8(d0, d1));
+      wiener_buffer += 16;
+      dst += 16;
+      x -= 16;
+    } while (x != 0);
+  }
+}
+
+inline uint8x8_t WienerVerticalTap3Kernel(const int16_t* const wiener_buffer,
+                                          const ptrdiff_t wiener_stride,
+                                          const int16_t filter[4],
+                                          int16x8_t a[3]) {
+  a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
+  a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
+  a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
+  int32x4x2_t sum;
+  sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+  return WienerVertical(a, filter, sum);
+}
+
+inline uint8x8x2_t WienerVerticalTap3Kernel2(const int16_t* const wiener_buffer,
+                                             const ptrdiff_t wiener_stride,
+                                             const int16_t filter[4]) {
+  int16x8_t a[4];
+  int32x4x2_t sum;
+  uint8x8x2_t d;
+  d.val[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a);
+  a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
+  sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+  d.val[1] = WienerVertical(a + 1, filter, sum);
+  return d;
+}
+
+inline void WienerVerticalTap3(const int16_t* wiener_buffer,
+                               const ptrdiff_t width, const int height,
+                               const int16_t filter[4], uint8_t* dst,
+                               const ptrdiff_t dst_stride) {
+  for (int y = height >> 1; y != 0; --y) {
+    uint8_t* dst_ptr = dst;
+    ptrdiff_t x = width;
+    do {
+      uint8x8x2_t d[2];
+      d[0] = WienerVerticalTap3Kernel2(wiener_buffer + 0, width, filter);
+      d[1] = WienerVerticalTap3Kernel2(wiener_buffer + 8, width, filter);
+      vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
+      vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
+      wiener_buffer += 16;
+      dst_ptr += 16;
+      x -= 16;
+    } while (x != 0);
+    wiener_buffer += width;
+    dst += 2 * dst_stride;
+  }
+
+  if ((height & 1) != 0) {
+    ptrdiff_t x = width;
+    do {
+      int16x8_t a[3];
+      const uint8x8_t d0 =
+          WienerVerticalTap3Kernel(wiener_buffer + 0, width, filter, a);
+      const uint8x8_t d1 =
+          WienerVerticalTap3Kernel(wiener_buffer + 8, width, filter, a);
+      vst1q_u8(dst, vcombine_u8(d0, d1));
+      wiener_buffer += 16;
+      dst += 16;
+      x -= 16;
+    } while (x != 0);
+  }
+}
+
+inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer,
+                                     uint8_t* const dst) {
+  const int16x8_t a0 = vld1q_s16(wiener_buffer + 0);
+  const int16x8_t a1 = vld1q_s16(wiener_buffer + 8);
+  const uint8x8_t d0 = vqrshrun_n_s16(a0, 4);
+  const uint8x8_t d1 = vqrshrun_n_s16(a1, 4);
+  vst1q_u8(dst, vcombine_u8(d0, d1));
+}
+
+inline void WienerVerticalTap1(const int16_t* wiener_buffer,
+                               const ptrdiff_t width, const int height,
+                               uint8_t* dst, const ptrdiff_t dst_stride) {
+  for (int y = height >> 1; y != 0; --y) {
+    uint8_t* dst_ptr = dst;
+    ptrdiff_t x = width;
+    do {
+      WienerVerticalTap1Kernel(wiener_buffer, dst_ptr);
+      WienerVerticalTap1Kernel(wiener_buffer + width, dst_ptr + dst_stride);
+      wiener_buffer += 16;
+      dst_ptr += 16;
+      x -= 16;
+    } while (x != 0);
+    wiener_buffer += width;
+    dst += 2 * dst_stride;
+  }
+
+  if ((height & 1) != 0) {
+    ptrdiff_t x = width;
+    do {
+      WienerVerticalTap1Kernel(wiener_buffer, dst);
+      wiener_buffer += 16;
+      dst += 16;
+      x -= 16;
+    } while (x != 0);
+  }
+}
+
+// For width 16 and up, store the horizontal results, and then do the vertical
+// filter row by row. This is faster than doing it column by column when
+// considering cache issues.
 void WienerFilter_NEON(const void* const source, void* const dest,
                        const RestorationUnitInfo& restoration_info,
                        const ptrdiff_t source_stride,
                        const ptrdiff_t dest_stride, const int width,
                        const int height, RestorationBuffer* const buffer) {
-  int16_t filter[kSubPixelTaps - 1];
-  const auto* src = static_cast<const uint8_t*>(source);
-  auto* dst = static_cast<uint8_t*>(dest);
-  // It should be possible to set this to |width|.
-  ptrdiff_t buffer_stride = buffer->wiener_buffer_stride;
-  // Casting once here saves a lot of vreinterpret() calls. The values are
-  // saturated to 13 bits before storing.
-  int16_t* wiener_buffer = reinterpret_cast<int16_t*>(buffer->wiener_buffer);
+  constexpr int kCenterTap = kWienerFilterTaps / 2;
+  const int16_t* const number_leading_zero_coefficients =
+      restoration_info.wiener_info.number_leading_zero_coefficients;
+  const int number_rows_to_skip = std::max(
+      static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]),
+      1);
+  const ptrdiff_t wiener_stride = Align(width, 16);
+  int16_t* const wiener_buffer_vertical = buffer->wiener_buffer;
+  // The values are saturated to 13 bits before storing.
+  int16_t* wiener_buffer_horizontal =
+      wiener_buffer_vertical + number_rows_to_skip * wiener_stride;
+  int16_t filter_horizontal[(kWienerFilterTaps + 1) / 2];
+  int16_t filter_vertical[(kWienerFilterTaps + 1) / 2];
+  PopulateWienerCoefficients(restoration_info, WienerInfo::kHorizontal,
+                             filter_horizontal);
+  PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical,
+                             filter_vertical);
 
-  // Horizontal filtering.
-  PopulateWienerCoefficients(restoration_info, WienerInfo::kHorizontal, filter);
-  // The taps have a radius of 3. Adjust |src| so we start reading with the top
-  // left value.
-  const int center_tap = 3;
-  src -= center_tap * source_stride + center_tap;
-  int y = 0;
-  do {
-    int x = 0;
-    do {
-      // This is just as fast as an 8x8 transpose but avoids over-reading extra
-      // rows. It always over-reads by at least 1 value. On small widths (4xH)
-      // it over-reads by 9 values.
-      const uint8x16_t src_v = vld1q_u8(src + x);
-      uint8x8_t b[7];
-      b[0] = vget_low_u8(src_v);
-      b[1] = vget_low_u8(vextq_u8(src_v, src_v, 1));
-      b[2] = vget_low_u8(vextq_u8(src_v, src_v, 2));
-      b[3] = vget_low_u8(vextq_u8(src_v, src_v, 3));
-      b[4] = vget_low_u8(vextq_u8(src_v, src_v, 4));
-      b[5] = vget_low_u8(vextq_u8(src_v, src_v, 5));
-      b[6] = vget_low_u8(vextq_u8(src_v, src_v, 6));
-
-      int16x8_t sum = HorizontalSum(b, filter);
-
-      vst1q_s16(wiener_buffer + x, sum);
-      x += 8;
-    } while (x < width);
-    src += source_stride;
-    wiener_buffer += buffer_stride;
-  } while (++y < height + kSubPixelTaps - 2);
-
-  // Vertical filtering.
-  wiener_buffer = reinterpret_cast<int16_t*>(buffer->wiener_buffer);
-  PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical, filter);
-  // Add 128 to |filter[3]| to fix the adjustment for the horizontal filtering.
-  // This pass starts with 13 bits so there is no chance of keeping it in 16.
-  const int16x4_t filter_v[7] = {
-      vdup_n_s16(filter[0]),       vdup_n_s16(filter[1]), vdup_n_s16(filter[2]),
-      vdup_n_s16(filter[3] + 128), vdup_n_s16(filter[4]), vdup_n_s16(filter[5]),
-      vdup_n_s16(filter[6])};
-
-  if (width == 4) {
-    VerticalSum<4>(wiener_buffer, buffer_stride, dst, dest_stride, filter_v,
-                   width, height);
+  // horizontal filtering.
+  // Over-reads up to 15 - |kRestorationHorizontalBorder| values.
+  const int height_horizontal =
+      height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip;
+  const auto* const src = static_cast<const uint8_t*>(source) -
+                          (kCenterTap - number_rows_to_skip) * source_stride;
+  if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) {
+    WienerHorizontalTap7(src - 3, source_stride, wiener_stride,
+                         height_horizontal, filter_horizontal,
+                         &wiener_buffer_horizontal);
+  } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) {
+    WienerHorizontalTap5(src - 2, source_stride, wiener_stride,
+                         height_horizontal, filter_horizontal,
+                         &wiener_buffer_horizontal);
+  } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) {
+    // The maximum over-reads happen here.
+    WienerHorizontalTap3(src - 1, source_stride, wiener_stride,
+                         height_horizontal, filter_horizontal,
+                         &wiener_buffer_horizontal);
   } else {
-    VerticalSum<8>(wiener_buffer, buffer_stride, dst, dest_stride, filter_v,
-                   width, height);
+    assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3);
+    WienerHorizontalTap1(src, source_stride, wiener_stride, height_horizontal,
+                         &wiener_buffer_horizontal);
+  }
+
+  // vertical filtering.
+  // Over-writes up to 15 values.
+  auto* dst = static_cast<uint8_t*>(dest);
+  if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) {
+    // Because the top row of |source| is a duplicate of the second row, and the
+    // bottom row of |source| is a duplicate of its above row, we can duplicate
+    // the top and bottom row of |wiener_buffer| accordingly.
+    memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride,
+           sizeof(*wiener_buffer_horizontal) * wiener_stride);
+    memcpy(buffer->wiener_buffer, buffer->wiener_buffer + wiener_stride,
+           sizeof(*buffer->wiener_buffer) * wiener_stride);
+    WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height,
+                       filter_vertical, dst, dest_stride);
+  } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) {
+    WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride,
+                       height, filter_vertical, dst, dest_stride);
+  } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) {
+    WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride,
+                       wiener_stride, height, filter_vertical, dst,
+                       dest_stride);
+  } else {
+    assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3);
+    WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride,
+                       wiener_stride, height, dst, dest_stride);
   }
 }
 
+//------------------------------------------------------------------------------
 // SGR
-inline uint16x4_t Sum3Horizontal(const uint16x8_t a) {
-  const uint16x4_t sum =
-      vadd_u16(vget_low_u16(a), vext_u16(vget_low_u16(a), vget_high_u16(a), 1));
-  return vadd_u16(sum, vext_u16(vget_low_u16(a), vget_high_u16(a), 2));
-}
-
-inline uint32x4_t Sum3Horizontal(const uint32x4x2_t a) {
-  const uint32x4_t sum = vaddq_u32(a.val[0], vextq_u32(a.val[0], a.val[1], 1));
-  return vaddq_u32(sum, vextq_u32(a.val[0], a.val[1], 2));
-}
-
-inline uint16x8_t Sum3Vertical(const uint8x8_t a[3]) {
-  const uint16x8_t sum = vaddl_u8(a[0], a[1]);
-  return vaddw_u8(sum, a[2]);
-}
-
-inline uint32x4x2_t Sum3Vertical(const uint16x8_t a[3]) {
-  uint32x4_t sum_a = vaddl_u16(vget_low_u16(a[0]), vget_low_u16(a[1]));
-  sum_a = vaddw_u16(sum_a, vget_low_u16(a[2]));
-  uint32x4_t sum_b = vaddl_u16(vget_high_u16(a[0]), vget_high_u16(a[1]));
-  sum_b = vaddw_u16(sum_b, vget_high_u16(a[2]));
-  return {sum_a, sum_b};
-}
-
-inline uint16x4_t Sum5Horizontal(const uint16x8_t a) {
-  uint16x4_t sum =
-      vadd_u16(vget_low_u16(a), vext_u16(vget_low_u16(a), vget_high_u16(a), 1));
-  sum = vadd_u16(sum, vext_u16(vget_low_u16(a), vget_high_u16(a), 2));
-  sum = vadd_u16(sum, vext_u16(vget_low_u16(a), vget_high_u16(a), 3));
-  return vadd_u16(sum, vget_high_u16(a));
-}
-
-inline uint32x4_t Sum5Horizontal(const uint32x4x2_t a) {
-  uint32x4_t sum = vaddq_u32(a.val[0], vextq_u32(a.val[0], a.val[1], 1));
-  sum = vaddq_u32(sum, vextq_u32(a.val[0], a.val[1], 2));
-  sum = vaddq_u32(sum, vextq_u32(a.val[0], a.val[1], 3));
-  return vaddq_u32(sum, a.val[1]);
-}
-
-inline uint16x8_t Sum5Vertical(const uint8x8_t a[5]) {
-  uint16x8_t sum = vaddl_u8(a[0], a[1]);
-  sum = vaddq_u16(sum, vaddl_u8(a[2], a[3]));
-  return vaddw_u8(sum, a[4]);
-}
-
-inline uint32x4x2_t Sum5Vertical(const uint16x8_t a[5]) {
-  uint32x4_t sum_a = vaddl_u16(vget_low_u16(a[0]), vget_low_u16(a[1]));
-  sum_a = vaddq_u32(sum_a, vaddl_u16(vget_low_u16(a[2]), vget_low_u16(a[3])));
-  sum_a = vaddw_u16(sum_a, vget_low_u16(a[4]));
-  uint32x4_t sum_b = vaddl_u16(vget_high_u16(a[0]), vget_high_u16(a[1]));
-  sum_b = vaddq_u32(sum_b, vaddl_u16(vget_high_u16(a[2]), vget_high_u16(a[3])));
-  sum_b = vaddw_u16(sum_b, vget_high_u16(a[4]));
-  return {sum_a, sum_b};
-}
-
-constexpr int kSgrProjScaleBits = 20;
-constexpr int kSgrProjRestoreBits = 4;
-constexpr int kSgrProjSgrBits = 8;
-constexpr int kSgrProjReciprocalBits = 12;
-
-constexpr int kIntermediateStride = 68;
-constexpr int kIntermediateHeight = 66;
-
-// a2 = ((z << kSgrProjSgrBits) + (z >> 1)) / (z + 1);
-constexpr uint16_t kA2Lookup[256] = {
-    1,   128, 171, 192, 205, 213, 219, 224, 228, 230, 233, 235, 236, 238, 239,
-    240, 241, 242, 243, 243, 244, 244, 245, 245, 246, 246, 247, 247, 247, 247,
-    248, 248, 248, 248, 249, 249, 249, 249, 249, 250, 250, 250, 250, 250, 250,
-    250, 251, 251, 251, 251, 251, 251, 251, 251, 251, 251, 252, 252, 252, 252,
-    252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 253, 253,
-    253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253,
-    253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 254, 254, 254,
-    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
-    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
-    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
-    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
-    254, 254, 254, 254, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    256};
-
-inline uint16x4_t Sum565(const uint16x8_t a) {
-  // Multiply everything by 4.
-  const uint16x8_t x4 = vshlq_n_u16(a, 2);
-  // Everything * 5
-  const uint16x8_t x5 = vaddq_u16(x4, a);
-  // The middle elements get added once more
-  const uint16x4_t middle = vext_u16(vget_low_u16(a), vget_high_u16(a), 1);
-  return vadd_u16(middle, Sum3Horizontal(x5));
-}
-
-inline uint32x4_t Sum3HorizontalW(const uint32x4_t a, const uint32x4_t b) {
-  const uint32x4_t sum = vaddq_u32(a, vextq_u32(a, b, 1));
-  return vaddq_u32(sum, vcombine_u32(vget_high_u32(a), vget_low_u32(b)));
-}
-
-inline uint32x4_t Sum565W(const uint16x8_t a) {
-  // Multiply everything by 4. |b2| values can be up to 65088 which means we
-  // have to step up to 32 bits immediately.
-  const uint32x4_t x4_lo = vshll_n_u16(vget_low_u16(a), 2);
-  const uint32x4_t x4_hi = vshll_n_u16(vget_high_u16(a), 2);
-  // Everything * 5
-  const uint32x4_t x5_lo = vaddw_u16(x4_lo, vget_low_u16(a));
-  const uint32x4_t x5_hi = vaddw_u16(x4_hi, vget_high_u16(a));
-  // The middle elements get added once more
-  const uint16x4_t middle = vext_u16(vget_low_u16(a), vget_high_u16(a), 1);
-  return vaddw_u16(Sum3HorizontalW(x5_lo, x5_hi), middle);
-}
 
 template <int n>
-inline uint16x4_t CalculateA2(const uint32x4_t sum_sq, const uint16x4_t sum,
-                              const uint32_t s, const uint16x4_t v_255) {
+inline uint16x4_t CalculateMa(const uint16x4_t sum, const uint32x4_t sum_sq,
+                              const uint32_t scale) {
   // a = |sum_sq|
   // d = |sum|
   // p = (a * n < d * d) ? 0 : a * n - d * d;
@@ -384,665 +567,1254 @@
   const uint32x4_t axn = vmulq_n_u32(sum_sq, n);
   // Ensure |p| does not underflow by using saturating subtraction.
   const uint32x4_t p = vqsubq_u32(axn, dxd);
-
-  // z = RightShiftWithRounding(p * s, kSgrProjScaleBits);
-  const uint32x4_t pxs = vmulq_n_u32(p, s);
-  // For some reason vrshrn_n_u32() (narrowing shift) can only shift by 16
-  // and kSgrProjScaleBits is 20.
+  // z = RightShiftWithRounding(p * scale, kSgrProjScaleBits);
+  const uint32x4_t pxs = vmulq_n_u32(p, scale);
+  // vrshrn_n_u32() (narrowing shift) can only shift by 16 and kSgrProjScaleBits
+  // is 20.
   const uint32x4_t shifted = vrshrq_n_u32(pxs, kSgrProjScaleBits);
-  // Narrow |shifted| so we can use a D register for v_255.
-  const uint16x4_t z = vmin_u16(v_255, vmovn_u32(shifted));
-
-  // a2 = ((z << kSgrProjSgrBits) + (z >> 1)) / (z + 1);
-  const uint16_t lookup[4] = {
-      kA2Lookup[vget_lane_u16(z, 0)], kA2Lookup[vget_lane_u16(z, 1)],
-      kA2Lookup[vget_lane_u16(z, 2)], kA2Lookup[vget_lane_u16(z, 3)]};
-  return vld1_u16(lookup);
+  return vmovn_u32(shifted);
 }
 
-inline uint16x4_t CalculateB2Shifted(const uint16x4_t a2, const uint16x4_t sum,
-                                     const uint32_t one_over_n) {
-  // b2 = ((1 << kSgrProjSgrBits) - a2) * b * one_over_n
-  // 1 << kSgrProjSgrBits = 256
-  // |a2| = [1, 256]
-  // |sgrMa2| max value = 255
-  const uint16x4_t sgrMa2 = vsub_u16(vdup_n_u16(1 << kSgrProjSgrBits), a2);
-  // |sum| is a box sum with radius 1 or 2.
-  // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
-  // For the second pass radius is 1. Maxmimum value is 3x3x255 = 2295.
-  // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
-  // When radius is 2 |n| is 25. |one_over_n| is 164.
-  // When radius is 1 |n| is 9. |one_over_n| is 455.
-  const uint32x4_t b2 = vmulq_n_u32(vmull_u16(sgrMa2, sum), one_over_n);
-  // static_cast<int>(RightShiftWithRounding(b2, kSgrProjReciprocalBits));
-  // |kSgrProjReciprocalBits| is 12.
-  // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
-  // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
-  return vrshrn_n_u32(b2, kSgrProjReciprocalBits);
+inline void Prepare3_8(const uint8x8x2_t src, uint8x8_t dst[3]) {
+  dst[0] = VshrU128<0>(src);
+  dst[1] = VshrU128<1>(src);
+  dst[2] = VshrU128<2>(src);
 }
 
-// RightShiftWithRounding(
-//   (a * src_ptr[x] + b), kSgrProjSgrBits + shift - kSgrProjRestoreBits);
+inline void Prepare3_16(const uint16x8x2_t src, uint16x4_t low[3],
+                        uint16x4_t high[3]) {
+  uint16x8_t s[3];
+  s[0] = VshrU128<0>(src);
+  s[1] = VshrU128<2>(src);
+  s[2] = VshrU128<4>(src);
+  low[0] = vget_low_u16(s[0]);
+  low[1] = vget_low_u16(s[1]);
+  low[2] = vget_low_u16(s[2]);
+  high[0] = vget_high_u16(s[0]);
+  high[1] = vget_high_u16(s[1]);
+  high[2] = vget_high_u16(s[2]);
+}
+
+inline void Prepare5_8(const uint8x8x2_t src, uint8x8_t dst[5]) {
+  dst[0] = VshrU128<0>(src);
+  dst[1] = VshrU128<1>(src);
+  dst[2] = VshrU128<2>(src);
+  dst[3] = VshrU128<3>(src);
+  dst[4] = VshrU128<4>(src);
+}
+
+inline void Prepare5_16(const uint16x8x2_t src, uint16x4_t low[5],
+                        uint16x4_t high[5]) {
+  Prepare3_16(src, low, high);
+  const uint16x8_t s3 = VshrU128<6>(src);
+  const uint16x8_t s4 = VshrU128<8>(src);
+  low[3] = vget_low_u16(s3);
+  low[4] = vget_low_u16(s4);
+  high[3] = vget_high_u16(s3);
+  high[4] = vget_high_u16(s4);
+}
+
+inline uint16x8_t Sum3_16(const uint16x8_t src0, const uint16x8_t src1,
+                          const uint16x8_t src2) {
+  const uint16x8_t sum = vaddq_u16(src0, src1);
+  return vaddq_u16(sum, src2);
+}
+
+inline uint16x8_t Sum3_16(const uint16x8_t src[3]) {
+  return Sum3_16(src[0], src[1], src[2]);
+}
+
+inline uint32x4_t Sum3_32(const uint32x4_t src0, const uint32x4_t src1,
+                          const uint32x4_t src2) {
+  const uint32x4_t sum = vaddq_u32(src0, src1);
+  return vaddq_u32(sum, src2);
+}
+
+inline uint32x4x2_t Sum3_32(const uint32x4x2_t src[3]) {
+  uint32x4x2_t d;
+  d.val[0] = Sum3_32(src[0].val[0], src[1].val[0], src[2].val[0]);
+  d.val[1] = Sum3_32(src[0].val[1], src[1].val[1], src[2].val[1]);
+  return d;
+}
+
+inline uint16x8_t Sum3W_16(const uint8x8_t src[3]) {
+  const uint16x8_t sum = vaddl_u8(src[0], src[1]);
+  return vaddw_u8(sum, src[2]);
+}
+
+inline uint32x4_t Sum3W_32(const uint16x4_t src[3]) {
+  const uint32x4_t sum = vaddl_u16(src[0], src[1]);
+  return vaddw_u16(sum, src[2]);
+}
+
+inline uint16x8_t Sum5_16(const uint16x8_t src[5]) {
+  const uint16x8_t sum01 = vaddq_u16(src[0], src[1]);
+  const uint16x8_t sum23 = vaddq_u16(src[2], src[3]);
+  const uint16x8_t sum = vaddq_u16(sum01, sum23);
+  return vaddq_u16(sum, src[4]);
+}
+
+inline uint32x4_t Sum5_32(const uint32x4_t src0, const uint32x4_t src1,
+                          const uint32x4_t src2, const uint32x4_t src3,
+                          const uint32x4_t src4) {
+  const uint32x4_t sum01 = vaddq_u32(src0, src1);
+  const uint32x4_t sum23 = vaddq_u32(src2, src3);
+  const uint32x4_t sum = vaddq_u32(sum01, sum23);
+  return vaddq_u32(sum, src4);
+}
+
+inline uint32x4x2_t Sum5_32(const uint32x4x2_t src[5]) {
+  uint32x4x2_t d;
+  d.val[0] = Sum5_32(src[0].val[0], src[1].val[0], src[2].val[0], src[3].val[0],
+                     src[4].val[0]);
+  d.val[1] = Sum5_32(src[0].val[1], src[1].val[1], src[2].val[1], src[3].val[1],
+                     src[4].val[1]);
+  return d;
+}
+
+inline uint32x4_t Sum5W_32(const uint16x4_t src[5]) {
+  const uint32x4_t sum01 = vaddl_u16(src[0], src[1]);
+  const uint32x4_t sum23 = vaddl_u16(src[2], src[3]);
+  const uint32x4_t sum0123 = vaddq_u32(sum01, sum23);
+  return vaddw_u16(sum0123, src[4]);
+}
+
+inline uint16x8_t Sum3Horizontal(const uint8x8x2_t src) {
+  uint8x8_t s[3];
+  Prepare3_8(src, s);
+  return Sum3W_16(s);
+}
+
+inline uint32x4x2_t Sum3WHorizontal(const uint16x8x2_t src) {
+  uint16x4_t low[3], high[3];
+  uint32x4x2_t sum;
+  Prepare3_16(src, low, high);
+  sum.val[0] = Sum3W_32(low);
+  sum.val[1] = Sum3W_32(high);
+  return sum;
+}
+
+inline uint16x8_t Sum5Horizontal(const uint8x8x2_t src) {
+  uint8x8_t s[5];
+  Prepare5_8(src, s);
+  const uint16x8_t sum01 = vaddl_u8(s[0], s[1]);
+  const uint16x8_t sum23 = vaddl_u8(s[2], s[3]);
+  const uint16x8_t sum0123 = vaddq_u16(sum01, sum23);
+  return vaddw_u8(sum0123, s[4]);
+}
+
+inline uint32x4x2_t Sum5WHorizontal(const uint16x8x2_t src) {
+  uint16x4_t low[5], high[5];
+  Prepare5_16(src, low, high);
+  uint32x4x2_t sum;
+  sum.val[0] = Sum5W_32(low);
+  sum.val[1] = Sum5W_32(high);
+  return sum;
+}
+
+void SumHorizontal(const uint16x4_t src[5], uint32x4_t* const row_sq3,
+                   uint32x4_t* const row_sq5) {
+  const uint32x4_t sum04 = vaddl_u16(src[0], src[4]);
+  const uint32x4_t sum12 = vaddl_u16(src[1], src[2]);
+  *row_sq3 = vaddw_u16(sum12, src[3]);
+  *row_sq5 = vaddq_u32(sum04, *row_sq3);
+}
+
+void SumHorizontal(const uint8x8x2_t src, const uint16x8x2_t sq,
+                   uint16x8_t* const row3, uint16x8_t* const row5,
+                   uint32x4x2_t* const row_sq3, uint32x4x2_t* const row_sq5) {
+  uint8x8_t s[5];
+  Prepare5_8(src, s);
+  const uint16x8_t sum04 = vaddl_u8(s[0], s[4]);
+  const uint16x8_t sum12 = vaddl_u8(s[1], s[2]);
+  *row3 = vaddw_u8(sum12, s[3]);
+  *row5 = vaddq_u16(sum04, *row3);
+  uint16x4_t low[5], high[5];
+  Prepare5_16(sq, low, high);
+  SumHorizontal(low, &row_sq3->val[0], &row_sq5->val[0]);
+  SumHorizontal(high, &row_sq3->val[1], &row_sq5->val[1]);
+}
+
+inline uint16x8_t Sum343(const uint8x8x2_t src) {
+  uint8x8_t s[3];
+  Prepare3_8(src, s);
+  const uint16x8_t sum = Sum3W_16(s);
+  const uint16x8_t sum3 = Sum3_16(sum, sum, sum);
+  return vaddw_u8(sum3, s[1]);
+}
+
+inline uint32x4_t Sum343W(const uint16x4_t src[3]) {
+  const uint32x4_t sum = Sum3W_32(src);
+  const uint32x4_t sum3 = Sum3_32(sum, sum, sum);
+  return vaddw_u16(sum3, src[1]);
+}
+
+inline uint32x4x2_t Sum343W(const uint16x8x2_t src) {
+  uint16x4_t low[3], high[3];
+  uint32x4x2_t d;
+  Prepare3_16(src, low, high);
+  d.val[0] = Sum343W(low);
+  d.val[1] = Sum343W(high);
+  return d;
+}
+
+inline uint16x8_t Sum565(const uint8x8x2_t src) {
+  uint8x8_t s[3];
+  Prepare3_8(src, s);
+  const uint16x8_t sum = Sum3W_16(s);
+  const uint16x8_t sum4 = vshlq_n_u16(sum, 2);
+  const uint16x8_t sum5 = vaddq_u16(sum4, sum);
+  return vaddw_u8(sum5, s[1]);
+}
+
+inline uint32x4_t Sum565W(const uint16x4_t src[3]) {
+  const uint32x4_t sum = Sum3W_32(src);
+  const uint32x4_t sum4 = vshlq_n_u32(sum, 2);
+  const uint32x4_t sum5 = vaddq_u32(sum4, sum);
+  return vaddw_u16(sum5, src[1]);
+}
+
+inline uint32x4x2_t Sum565W(const uint16x8x2_t src) {
+  uint16x4_t low[3], high[3];
+  uint32x4x2_t d;
+  Prepare3_16(src, low, high);
+  d.val[0] = Sum565W(low);
+  d.val[1] = Sum565W(high);
+  return d;
+}
+
+inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
+                         const ptrdiff_t x, uint16x8_t* const sum_ma343,
+                         uint16x8_t* const sum_ma444,
+                         uint32x4x2_t* const sum_b343,
+                         uint32x4x2_t* const sum_b444, uint16_t* const ma343,
+                         uint16_t* const ma444, uint32_t* const b343,
+                         uint32_t* const b444) {
+  uint8x8_t s[3];
+  Prepare3_8(ma3, s);
+  const uint16x8_t sum_ma111 = Sum3W_16(s);
+  *sum_ma444 = vshlq_n_u16(sum_ma111, 2);
+  const uint16x8_t sum333 = vsubq_u16(*sum_ma444, sum_ma111);
+  *sum_ma343 = vaddw_u8(sum333, s[1]);
+  uint16x4_t low[3], high[3];
+  uint32x4x2_t sum_b111;
+  Prepare3_16(b3, low, high);
+  sum_b111.val[0] = Sum3W_32(low);
+  sum_b111.val[1] = Sum3W_32(high);
+  sum_b444->val[0] = vshlq_n_u32(sum_b111.val[0], 2);
+  sum_b444->val[1] = vshlq_n_u32(sum_b111.val[1], 2);
+  sum_b343->val[0] = vsubq_u32(sum_b444->val[0], sum_b111.val[0]);
+  sum_b343->val[1] = vsubq_u32(sum_b444->val[1], sum_b111.val[1]);
+  sum_b343->val[0] = vaddw_u16(sum_b343->val[0], low[1]);
+  sum_b343->val[1] = vaddw_u16(sum_b343->val[1], high[1]);
+  vst1q_u16(ma343 + x, *sum_ma343);
+  vst1q_u16(ma444 + x, *sum_ma444);
+  vst1q_u32(b343 + x + 0, (*sum_b343).val[0]);
+  vst1q_u32(b343 + x + 4, (*sum_b343).val[1]);
+  vst1q_u32(b444 + x + 0, (*sum_b444).val[0]);
+  vst1q_u32(b444 + x + 4, (*sum_b444).val[1]);
+}
+
+inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
+                         const ptrdiff_t x, uint16x8_t* const sum_ma343,
+                         uint32x4x2_t* const sum_b343, uint16_t* const ma343,
+                         uint16_t* const ma444, uint32_t* const b343,
+                         uint32_t* const b444) {
+  uint16x8_t sum_ma444;
+  uint32x4x2_t sum_b444;
+  Store343_444(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, &sum_b444, ma343,
+               ma444, b343, b444);
+}
+
+inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
+                         const ptrdiff_t x, uint16_t* const ma343,
+                         uint16_t* const ma444, uint32_t* const b343,
+                         uint32_t* const b444) {
+  uint16x8_t sum_ma343;
+  uint32x4x2_t sum_b343;
+  Store343_444(ma3, b3, x, &sum_ma343, &sum_b343, ma343, ma444, b343, b444);
+}
+
 template <int shift>
-inline uint16x4_t CalculateFilteredOutput(const uint16x4_t a,
-                                          const uint32x4_t b,
-                                          const uint16x4_t src) {
-  // a: 256 * 32 = 8192 (14 bits)
+inline int16x4_t FilterOutput(const uint16x4_t src, const uint16x4_t ma,
+                              const uint32x4_t b) {
+  // ma: 255 * 32 = 8160 (13 bits)
   // b: 65088 * 32 = 2082816 (21 bits)
-  const uint32x4_t axsrc = vmull_u16(a, src);
-  // v: 8192 * 255 + 2082816 = 4171876 (22 bits)
-  const uint32x4_t v = vaddq_u32(axsrc, b);
-
+  // v: b - ma * 255 (22 bits)
+  const int32x4_t v = vreinterpretq_s32_u32(vmlsl_u16(b, ma, src));
   // kSgrProjSgrBits = 8
   // kSgrProjRestoreBits = 4
   // shift = 4 or 5
-  // v >> 8 or 9
-  // 22 bits >> 8 = 14 bits
-  return vrshrn_n_u32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
+  // v >> 8 or 9 (13 bits)
+  return vrshrn_n_s32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
 }
 
-inline void BoxFilterProcess_FirstPass(const uint8_t* const src,
-                                       const ptrdiff_t stride, const int width,
-                                       const int height, const uint32_t s,
-                                       uint16_t* const buf) {
-  // Number of elements in the box being summed.
-  const uint32_t n = 25;
-  const uint32_t one_over_n = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n;
+template <int shift>
+inline int16x8_t CalculateFilteredOutput(const uint8x8_t src,
+                                         const uint16x8_t ma,
+                                         const uint32x4x2_t b) {
+  const uint16x8_t src_u16 = vmovl_u8(src);
+  const int16x4_t dst_lo =
+      FilterOutput<shift>(vget_low_u16(src_u16), vget_low_u16(ma), b.val[0]);
+  const int16x4_t dst_hi =
+      FilterOutput<shift>(vget_high_u16(src_u16), vget_high_u16(ma), b.val[1]);
+  return vcombine_s16(dst_lo, dst_hi);  // 13 bits
+}
 
-  const uint16x4_t v_255 = vdup_n_u16(255);
+inline int16x8_t CalculateFilteredOutputPass1(const uint8x8_t s,
+                                              uint16x8_t ma[2],
+                                              uint32x4x2_t b[2]) {
+  const uint16x8_t ma_sum = vaddq_u16(ma[0], ma[1]);
+  uint32x4x2_t b_sum;
+  b_sum.val[0] = vaddq_u32(b[0].val[0], b[1].val[0]);
+  b_sum.val[1] = vaddq_u32(b[0].val[1], b[1].val[1]);
+  return CalculateFilteredOutput<5>(s, ma_sum, b_sum);
+}
 
-  // We have combined PreProcess and Process for the first pass by storing
-  // intermediate values in the |a2| region. The values stored are one vertical
-  // column of interleaved |a2| and |b2| values and consume 4 * |height| values.
-  // This is |height| and not |height| * 2 because PreProcess only generates
-  // output for every other row. When processing the next column we write the
-  // new scratch values right after reading the previously saved ones.
-  uint16_t* const temp = buf + kIntermediateStride * kIntermediateHeight;
+inline int16x8_t CalculateFilteredOutputPass2(const uint8x8_t s,
+                                              uint16x8_t ma[3],
+                                              uint32x4x2_t b[3]) {
+  const uint16x8_t ma_sum = Sum3_16(ma);
+  const uint32x4x2_t b_sum = Sum3_32(b);
+  return CalculateFilteredOutput<5>(s, ma_sum, b_sum);
+}
 
-  // The PreProcess phase calculates a 5x5 box sum for every other row
-  //
-  // PreProcess and Process have been combined into the same step. We need 8
-  // input values to generate 4 output values for PreProcess:
-  // 0 1 2 3 4 5 6 7
-  // 2 = 0 + 1 + 2 + 3 + 4
-  // 3 = 1 + 2 + 3 + 4 + 5
-  // 4 = 2 + 3 + 4 + 5 + 6
-  // 5 = 3 + 4 + 5 + 6 + 7
-  //
-  // and then we need 6 input values to generate 4 output values for Process:
-  // 0 1 2 3 4 5
-  // 1 = 0 + 1 + 2
-  // 2 = 1 + 2 + 3
-  // 3 = 2 + 3 + 4
-  // 4 = 3 + 4 + 5
-  //
-  // To avoid re-calculating PreProcess values over and over again we will do a
-  // single column of 4 output values and store them interleaved in |temp|. Next
-  // we will start the second column. When 2 rows have been calculated we can
-  // calculate Process and output those into the top of |buf|.
+inline void SelfGuidedFinal(const uint8x8_t src, const int32x4_t v[2],
+                            uint8_t* const dst) {
+  const int16x4_t v_lo =
+      vrshrn_n_s32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits);
+  const int16x4_t v_hi =
+      vrshrn_n_s32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits);
+  const int16x8_t vv = vcombine_s16(v_lo, v_hi);
+  const int16x8_t s = ZeroExtend(src);
+  const int16x8_t d = vaddq_s16(s, vv);
+  vst1_u8(dst, vqmovun_s16(d));
+}
 
-  // The first phase needs a radius of 2 context values. The second phase needs
-  // a context of radius 1 values. This means we start at (-3, -3).
-  const uint8_t* const src_pre_process = src - 3 - 3 * stride;
+inline void SelfGuidedDoubleMultiplier(const uint8x8_t src,
+                                       const int16x8_t filter[2], const int w0,
+                                       const int w2, uint8_t* const dst) {
+  int32x4_t v[2];
+  v[0] = vmull_n_s16(vget_low_s16(filter[0]), w0);
+  v[1] = vmull_n_s16(vget_high_s16(filter[0]), w0);
+  v[0] = vmlal_n_s16(v[0], vget_low_s16(filter[1]), w2);
+  v[1] = vmlal_n_s16(v[1], vget_high_s16(filter[1]), w2);
+  SelfGuidedFinal(src, v, dst);
+}
 
-  // Calculate and store a single column. Scope so we can re-use the variable
-  // names for the next step.
-  {
-    const uint8_t* column = src_pre_process;
-    uint16_t* temp_column = temp;
+inline void SelfGuidedSingleMultiplier(const uint8x8_t src,
+                                       const int16x8_t filter, const int w0,
+                                       uint8_t* const dst) {
+  // weight: -96 to 96 (Sgrproj_Xqd_Min/Max)
+  int32x4_t v[2];
+  v[0] = vmull_n_s16(vget_low_s16(filter), w0);
+  v[1] = vmull_n_s16(vget_high_s16(filter), w0);
+  SelfGuidedFinal(src, v, dst);
+}
 
-    uint8x8_t row[5];
-    row[0] = vld1_u8(column);
-    column += stride;
-    row[1] = vld1_u8(column);
-    column += stride;
-    row[2] = vld1_u8(column);
-    column += stride;
-
-    uint16x8_t row_sq[5];
-    row_sq[0] = vmull_u8(row[0], row[0]);
-    row_sq[1] = vmull_u8(row[1], row[1]);
-    row_sq[2] = vmull_u8(row[2], row[2]);
-
-    int y = -1;
+inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
+                   const int height, const ptrdiff_t width, uint16_t* sum3,
+                   uint16_t* sum5, uint32_t* square_sum3,
+                   uint32_t* square_sum5) {
+  int y = height;
+  do {
+    uint8x8x2_t s;
+    uint16x8x2_t sq;
+    s.val[0] = vld1_u8(src);
+    sq.val[0] = vmull_u8(s.val[0], s.val[0]);
+    ptrdiff_t x = 0;
     do {
-      row[3] = vld1_u8(column);
-      column += stride;
-      row[4] = vld1_u8(column);
-      column += stride;
+      uint16x8_t row3, row5;
+      uint32x4x2_t row_sq3, row_sq5;
+      s.val[1] = vld1_u8(src + x + 8);
+      sq.val[1] = vmull_u8(s.val[1], s.val[1]);
+      SumHorizontal(s, sq, &row3, &row5, &row_sq3, &row_sq5);
+      vst1q_u16(sum3, row3);
+      vst1q_u16(sum5, row5);
+      vst1q_u32(square_sum3 + 0, row_sq3.val[0]);
+      vst1q_u32(square_sum3 + 4, row_sq3.val[1]);
+      vst1q_u32(square_sum5 + 0, row_sq5.val[0]);
+      vst1q_u32(square_sum5 + 4, row_sq5.val[1]);
+      s.val[0] = s.val[1];
+      sq.val[0] = sq.val[1];
+      sum3 += 8;
+      sum5 += 8;
+      square_sum3 += 8;
+      square_sum5 += 8;
+      x += 8;
+    } while (x < width);
+    src += src_stride;
+  } while (--y != 0);
+}
 
-      row_sq[3] = vmull_u8(row[3], row[3]);
-      row_sq[4] = vmull_u8(row[4], row[4]);
+template <int size>
+inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
+                   const int height, const ptrdiff_t width, uint16_t* sums,
+                   uint32_t* square_sums) {
+  static_assert(size == 3 || size == 5, "");
+  int y = height;
+  do {
+    uint8x8x2_t s;
+    uint16x8x2_t sq;
+    s.val[0] = vld1_u8(src);
+    sq.val[0] = vmull_u8(s.val[0], s.val[0]);
+    ptrdiff_t x = 0;
+    do {
+      uint16x8_t row;
+      uint32x4x2_t row_sq;
+      s.val[1] = vld1_u8(src + x + 8);
+      sq.val[1] = vmull_u8(s.val[1], s.val[1]);
+      if (size == 3) {
+        row = Sum3Horizontal(s);
+        row_sq = Sum3WHorizontal(sq);
+      } else {
+        row = Sum5Horizontal(s);
+        row_sq = Sum5WHorizontal(sq);
+      }
+      vst1q_u16(sums, row);
+      vst1q_u32(square_sums + 0, row_sq.val[0]);
+      vst1q_u32(square_sums + 4, row_sq.val[1]);
+      s.val[0] = s.val[1];
+      sq.val[0] = sq.val[1];
+      sums += 8;
+      square_sums += 8;
+      x += 8;
+    } while (x < width);
+    src += src_stride;
+  } while (--y != 0);
+}
 
-      const uint16x4_t sum = Sum5Horizontal(Sum5Vertical(row));
-      const uint32x4_t sum_sq = Sum5Horizontal(Sum5Vertical(row_sq));
+template <int n>
+inline void CalculateIntermediate(const uint16x8_t sum,
+                                  const uint32x4x2_t sum_sq,
+                                  const uint32_t scale, uint8x8_t* const ma,
+                                  uint16x8_t* const b) {
+  constexpr uint32_t one_over_n =
+      ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n;
+  const uint16x4_t z0 = CalculateMa<n>(vget_low_u16(sum), sum_sq.val[0], scale);
+  const uint16x4_t z1 =
+      CalculateMa<n>(vget_high_u16(sum), sum_sq.val[1], scale);
+  const uint16x8_t z01 = vcombine_u16(z0, z1);
+  // Using vqmovn_u16() needs an extra sign extension instruction.
+  const uint16x8_t z = vminq_u16(z01, vdupq_n_u16(255));
+  // Using vgetq_lane_s16() can save the sign extension instruction.
+  const uint8_t lookup[8] = {
+      kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 0)],
+      kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 1)],
+      kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 2)],
+      kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 3)],
+      kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 4)],
+      kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 5)],
+      kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 6)],
+      kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 7)]};
+  *ma = vld1_u8(lookup);
+  // b = ma * b * one_over_n
+  // |ma| = [0, 255]
+  // |sum| is a box sum with radius 1 or 2.
+  // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
+  // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
+  // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
+  // When radius is 2 |n| is 25. |one_over_n| is 164.
+  // When radius is 1 |n| is 9. |one_over_n| is 455.
+  // |kSgrProjReciprocalBits| is 12.
+  // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
+  // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
+  const uint16x8_t maq = vmovl_u8(*ma);
+  const uint32x4_t m0 = vmull_u16(vget_low_u16(maq), vget_low_u16(sum));
+  const uint32x4_t m1 = vmull_u16(vget_high_u16(maq), vget_high_u16(sum));
+  const uint32x4_t m2 = vmulq_n_u32(m0, one_over_n);
+  const uint32x4_t m3 = vmulq_n_u32(m1, one_over_n);
+  const uint16x4_t b_lo = vrshrn_n_u32(m2, kSgrProjReciprocalBits);
+  const uint16x4_t b_hi = vrshrn_n_u32(m3, kSgrProjReciprocalBits);
+  *b = vcombine_u16(b_lo, b_hi);
+}
 
-      const uint16x4_t a2 = CalculateA2<n>(sum_sq, sum, s, v_255);
-      const uint16x4_t b2 = CalculateB2Shifted(a2, sum, one_over_n);
+inline void CalculateIntermediate5(const uint16x8_t s5[5],
+                                   const uint32x4x2_t sq5[5],
+                                   const uint32_t scale, uint8x8_t* const ma,
+                                   uint16x8_t* const b) {
+  const uint16x8_t sum = Sum5_16(s5);
+  const uint32x4x2_t sum_sq = Sum5_32(sq5);
+  CalculateIntermediate<25>(sum, sum_sq, scale, ma, b);
+}
 
-      vst1q_u16(temp_column, vcombine_u16(a2, b2));
-      temp_column += 8;
+inline void CalculateIntermediate3(const uint16x8_t s3[3],
+                                   const uint32x4x2_t sq3[3],
+                                   const uint32_t scale, uint8x8_t* const ma,
+                                   uint16x8_t* const b) {
+  const uint16x8_t sum = Sum3_16(s3);
+  const uint32x4x2_t sum_sq = Sum3_32(sq3);
+  CalculateIntermediate<9>(sum, sum_sq, scale, ma, b);
+}
 
-      row[0] = row[2];
-      row[1] = row[3];
-      row[2] = row[4];
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5(
+    const uint8_t* const src, const ptrdiff_t src_stride, const ptrdiff_t x,
+    const uint32_t scale, uint8x8x2_t s[2], uint16x8x2_t sq[2],
+    uint16_t* const sum5[5], uint32_t* const square_sum5[5],
+    uint8x8_t* const ma, uint16x8_t* const b) {
+  uint16x8_t s5[5];
+  uint32x4x2_t sq5[5];
+  s[0].val[1] = vld1_u8(src + x + 8);
+  s[1].val[1] = vld1_u8(src + src_stride + x + 8);
+  sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]);
+  sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]);
+  s5[3] = Sum5Horizontal(s[0]);
+  s5[4] = Sum5Horizontal(s[1]);
+  sq5[3] = Sum5WHorizontal(sq[0]);
+  sq5[4] = Sum5WHorizontal(sq[1]);
+  vst1q_u16(sum5[3] + x, s5[3]);
+  vst1q_u16(sum5[4] + x, s5[4]);
+  vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]);
+  vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]);
+  vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]);
+  vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]);
+  s5[0] = vld1q_u16(sum5[0] + x);
+  s5[1] = vld1q_u16(sum5[1] + x);
+  s5[2] = vld1q_u16(sum5[2] + x);
+  sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
+  sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
+  sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
+  sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
+  sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
+  sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
+  CalculateIntermediate5(s5, sq5, scale, ma, b);
+}
 
-      row_sq[0] = row_sq[2];
-      row_sq[1] = row_sq[3];
-      row_sq[2] = row_sq[4];
-      y += 2;
-    } while (y < height + 1);
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow(
+    const uint8_t* const src, const ptrdiff_t x, const uint32_t scale,
+    uint8x8x2_t* const s, uint16x8x2_t* const sq, uint16_t* const sum5[5],
+    uint32_t* const square_sum5[5], uint8x8_t* const ma, uint16x8_t* const b) {
+  uint16x8_t s5[5];
+  uint32x4x2_t sq5[5];
+  s->val[1] = vld1_u8(src + x + 8);
+  sq->val[1] = vmull_u8(s->val[1], s->val[1]);
+  s5[3] = s5[4] = Sum5Horizontal(*s);
+  sq5[3] = sq5[4] = Sum5WHorizontal(*sq);
+  s5[0] = vld1q_u16(sum5[0] + x);
+  s5[1] = vld1q_u16(sum5[1] + x);
+  s5[2] = vld1q_u16(sum5[2] + x);
+  sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
+  sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
+  sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
+  sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
+  sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
+  sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
+  CalculateIntermediate5(s5, sq5, scale, ma, b);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3(
+    const uint8_t* const src, const ptrdiff_t x, const uint32_t scale,
+    uint8x8x2_t* const s, uint16x8x2_t* const sq, uint16_t* const sum3[3],
+    uint32_t* const square_sum3[3], uint8x8_t* const ma, uint16x8_t* const b) {
+  uint16x8_t s3[3];
+  uint32x4x2_t sq3[3];
+  s->val[1] = vld1_u8(src + x + 8);
+  sq->val[1] = vmull_u8(s->val[1], s->val[1]);
+  s3[2] = Sum3Horizontal(*s);
+  sq3[2] = Sum3WHorizontal(*sq);
+  vst1q_u16(sum3[2] + x, s3[2]);
+  vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]);
+  vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]);
+  s3[0] = vld1q_u16(sum3[0] + x);
+  s3[1] = vld1q_u16(sum3[1] + x);
+  sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
+  sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
+  sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
+  sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
+  CalculateIntermediate3(s3, sq3, scale, ma, b);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess(
+    const uint8_t* const src, const ptrdiff_t src_stride, const ptrdiff_t x,
+    const uint16_t scales[2], uint8x8x2_t s[2], uint16x8x2_t sq[2],
+    uint16_t* const sum3[4], uint16_t* const sum5[5],
+    uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
+    uint8x8_t* const ma3_0, uint8x8_t* const ma3_1, uint16x8_t* const b3_0,
+    uint16x8_t* const b3_1, uint8x8_t* const ma5, uint16x8_t* const b5) {
+  uint16x8_t s3[4], s5[5];
+  uint32x4x2_t sq3[4], sq5[5];
+  s[0].val[1] = vld1_u8(src + x + 8);
+  s[1].val[1] = vld1_u8(src + src_stride + x + 8);
+  sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]);
+  sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]);
+  SumHorizontal(s[0], sq[0], &s3[2], &s5[3], &sq3[2], &sq5[3]);
+  SumHorizontal(s[1], sq[1], &s3[3], &s5[4], &sq3[3], &sq5[4]);
+  vst1q_u16(sum3[2] + x, s3[2]);
+  vst1q_u16(sum3[3] + x, s3[3]);
+  vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]);
+  vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]);
+  vst1q_u32(square_sum3[3] + x + 0, sq3[3].val[0]);
+  vst1q_u32(square_sum3[3] + x + 4, sq3[3].val[1]);
+  vst1q_u16(sum5[3] + x, s5[3]);
+  vst1q_u16(sum5[4] + x, s5[4]);
+  vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]);
+  vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]);
+  vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]);
+  vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]);
+  s3[0] = vld1q_u16(sum3[0] + x);
+  s3[1] = vld1q_u16(sum3[1] + x);
+  sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
+  sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
+  sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
+  sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
+  s5[0] = vld1q_u16(sum5[0] + x);
+  s5[1] = vld1q_u16(sum5[1] + x);
+  s5[2] = vld1q_u16(sum5[2] + x);
+  sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
+  sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
+  sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
+  sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
+  sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
+  sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
+  CalculateIntermediate3(s3, sq3, scales[1], ma3_0, b3_0);
+  CalculateIntermediate3(s3 + 1, sq3 + 1, scales[1], ma3_1, b3_1);
+  CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow(
+    const uint8_t* const src, const ptrdiff_t x, const uint16_t scales[2],
+    const uint16_t* const sum3[4], const uint16_t* const sum5[5],
+    const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5],
+    uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma3,
+    uint8x8_t* const ma5, uint16x8_t* const b3, uint16x8_t* const b5) {
+  uint16x8_t s3[3], s5[5];
+  uint32x4x2_t sq3[3], sq5[5];
+  s->val[1] = vld1_u8(src + x + 8);
+  sq->val[1] = vmull_u8(s->val[1], s->val[1]);
+  SumHorizontal(*s, *sq, &s3[2], &s5[3], &sq3[2], &sq5[3]);
+  s5[0] = vld1q_u16(sum5[0] + x);
+  s5[1] = vld1q_u16(sum5[1] + x);
+  s5[2] = vld1q_u16(sum5[2] + x);
+  s5[4] = s5[3];
+  sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
+  sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
+  sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
+  sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
+  sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
+  sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
+  sq5[4] = sq5[3];
+  CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
+  s3[0] = vld1q_u16(sum3[0] + x);
+  s3[1] = vld1q_u16(sum3[1] + x);
+  sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
+  sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
+  sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
+  sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
+  CalculateIntermediate3(s3, sq3, scales[1], ma3, b3);
+}
+
+inline void BoxSumFilterPreProcess5(const uint8_t* const src,
+                                    const ptrdiff_t src_stride, const int width,
+                                    const uint32_t scale,
+                                    uint16_t* const sum5[5],
+                                    uint32_t* const square_sum5[5],
+                                    uint16_t* ma565, uint32_t* b565) {
+  uint8x8x2_t s[2], mas;
+  uint16x8x2_t sq[2], bs;
+  s[0].val[0] = vld1_u8(src);
+  sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
+  s[1].val[0] = vld1_u8(src + src_stride);
+  sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
+  BoxFilterPreProcess5(src, src_stride, 0, scale, s, sq, sum5, square_sum5,
+                       &mas.val[0], &bs.val[0]);
+
+  int x = 0;
+  do {
+    s[0].val[0] = s[0].val[1];
+    s[1].val[0] = s[1].val[1];
+    sq[0].val[0] = sq[0].val[1];
+    sq[1].val[0] = sq[1].val[1];
+    BoxFilterPreProcess5(src, src_stride, x + 8, scale, s, sq, sum5,
+                         square_sum5, &mas.val[1], &bs.val[1]);
+    const uint16x8_t ma = Sum565(mas);
+    const uint32x4x2_t b = Sum565W(bs);
+    vst1q_u16(ma565, ma);
+    vst1q_u32(b565 + 0, b.val[0]);
+    vst1q_u32(b565 + 4, b.val[1]);
+    mas.val[0] = mas.val[1];
+    bs.val[0] = bs.val[1];
+    ma565 += 8;
+    b565 += 8;
+    x += 8;
+  } while (x < width);
+}
+
+template <bool calculate444>
+LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3(
+    const uint8_t* const src, const int width, const uint32_t scale,
+    uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint16_t* ma343,
+    uint16_t* ma444, uint32_t* b343, uint32_t* b444) {
+  uint8x8x2_t s, mas;
+  uint16x8x2_t sq, bs;
+  s.val[0] = vld1_u8(src);
+  sq.val[0] = vmull_u8(s.val[0], s.val[0]);
+  BoxFilterPreProcess3(src, 0, scale, &s, &sq, sum3, square_sum3, &mas.val[0],
+                       &bs.val[0]);
+
+  int x = 0;
+  do {
+    s.val[0] = s.val[1];
+    sq.val[0] = sq.val[1];
+    BoxFilterPreProcess3(src, x + 8, scale, &s, &sq, sum3, square_sum3,
+                         &mas.val[1], &bs.val[1]);
+    if (calculate444) {
+      Store343_444(mas, bs, 0, ma343, ma444, b343, b444);
+      ma444 += 8;
+      b444 += 8;
+    } else {
+      const uint16x8_t ma = Sum343(mas);
+      const uint32x4x2_t b = Sum343W(bs);
+      vst1q_u16(ma343, ma);
+      vst1q_u32(b343 + 0, b.val[0]);
+      vst1q_u32(b343 + 4, b.val[1]);
+    }
+    mas.val[0] = mas.val[1];
+    bs.val[0] = bs.val[1];
+    ma343 += 8;
+    b343 += 8;
+    x += 8;
+  } while (x < width);
+}
+
+inline void BoxSumFilterPreProcess(
+    const uint8_t* const src, const ptrdiff_t src_stride, const int width,
+    const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
+    uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
+    uint16_t* const ma343[4], uint16_t* const ma444[2], uint16_t* ma565,
+    uint32_t* const b343[4], uint32_t* const b444[2], uint32_t* b565) {
+  uint8x8x2_t s[2];
+  uint8x8x2_t ma3[2], ma5;
+  uint16x8x2_t sq[2], b3[2], b5;
+  s[0].val[0] = vld1_u8(src + 0);
+  s[1].val[0] = vld1_u8(src + src_stride + 0);
+  sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
+  sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
+  BoxFilterPreProcess(src, src_stride, 0, scales, s, sq, sum3, sum5,
+                      square_sum3, square_sum5, &ma3[0].val[0], &ma3[1].val[0],
+                      &b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]);
+
+  int x = 0;
+  do {
+    s[0].val[0] = s[0].val[1];
+    s[1].val[0] = s[1].val[1];
+    sq[0].val[0] = sq[0].val[1];
+    sq[1].val[0] = sq[1].val[1];
+    BoxFilterPreProcess(src, src_stride, x + 8, scales, s, sq, sum3, sum5,
+                        square_sum3, square_sum5, &ma3[0].val[1],
+                        &ma3[1].val[1], &b3[0].val[1], &b3[1].val[1],
+                        &ma5.val[1], &b5.val[1]);
+    uint16x8_t ma = Sum343(ma3[0]);
+    uint32x4x2_t b = Sum343W(b3[0]);
+    vst1q_u16(ma343[0] + x, ma);
+    vst1q_u32(b343[0] + x, b.val[0]);
+    vst1q_u32(b343[0] + x + 4, b.val[1]);
+    Store343_444(ma3[1], b3[1], x, ma343[1], ma444[0], b343[1], b444[0]);
+    ma = Sum565(ma5);
+    b = Sum565W(b5);
+    vst1q_u16(ma565, ma);
+    vst1q_u32(b565 + 0, b.val[0]);
+    vst1q_u32(b565 + 4, b.val[1]);
+    ma3[0].val[0] = ma3[0].val[1];
+    ma3[1].val[0] = ma3[1].val[1];
+    b3[0].val[0] = b3[0].val[1];
+    b3[1].val[0] = b3[1].val[1];
+    ma5.val[0] = ma5.val[1];
+    b5.val[0] = b5.val[1];
+    ma565 += 8;
+    b565 += 8;
+    x += 8;
+  } while (x < width);
+}
+
+inline void BoxFilterPass1(const uint8_t* const src0, const uint8_t* const src,
+                           const ptrdiff_t src_stride, uint16_t* const sum5[5],
+                           uint32_t* const square_sum5[5], const int width,
+                           const uint32_t scale, const int16_t w0,
+                           uint16_t* const ma565[2], uint32_t* const b565[2],
+                           uint8_t* const dst, const ptrdiff_t dst_stride) {
+  uint8x8x2_t s[2], mas;
+  uint16x8x2_t sq[2], bs;
+  s[0].val[0] = vld1_u8(src);
+  s[1].val[0] = vld1_u8(src + src_stride);
+  sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
+  sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
+  BoxFilterPreProcess5(src, src_stride, 0, scale, s, sq, sum5, square_sum5,
+                       &mas.val[0], &bs.val[0]);
+
+  int x = 0;
+  do {
+    s[0].val[0] = s[0].val[1];
+    s[1].val[0] = s[1].val[1];
+    sq[0].val[0] = sq[0].val[1];
+    sq[1].val[0] = sq[1].val[1];
+    BoxFilterPreProcess5(src, src_stride, x + 8, scale, s, sq, sum5,
+                         square_sum5, &mas.val[1], &bs.val[1]);
+    uint16x8_t ma[2];
+    uint32x4x2_t b[2];
+    ma[1] = Sum565(mas);
+    b[1] = Sum565W(bs);
+    vst1q_u16(ma565[1] + x, ma[1]);
+    vst1q_u32(b565[1] + x + 0, b[1].val[0]);
+    vst1q_u32(b565[1] + x + 4, b[1].val[1]);
+    const uint8x8_t s0 = vld1_u8(src0 + x);
+    const uint8x8_t s1 = vld1_u8(src0 + src_stride + x);
+    int16x8_t p0, p1;
+    ma[0] = vld1q_u16(ma565[0] + x);
+    b[0].val[0] = vld1q_u32(b565[0] + x + 0);
+    b[0].val[1] = vld1q_u32(b565[0] + x + 4);
+    p0 = CalculateFilteredOutputPass1(s0, ma, b);
+    p1 = CalculateFilteredOutput<4>(s1, ma[1], b[1]);
+    SelfGuidedSingleMultiplier(s0, p0, w0, dst + x);
+    SelfGuidedSingleMultiplier(s1, p1, w0, dst + dst_stride + x);
+    mas.val[0] = mas.val[1];
+    bs.val[0] = bs.val[1];
+    x += 8;
+  } while (x < width);
+}
+
+inline void BoxFilterPass1LastRow(const uint8_t* const src0,
+                                  const uint8_t* const src, const int width,
+                                  const uint32_t scale, const int16_t w0,
+                                  uint16_t* const sum5[5],
+                                  uint32_t* const square_sum5[5],
+                                  uint16_t* ma565, uint32_t* b565,
+                                  uint8_t* const dst) {
+  uint8x8x2_t s, mas;
+  uint16x8x2_t sq, bs;
+  s.val[0] = vld1_u8(src);
+  sq.val[0] = vmull_u8(s.val[0], s.val[0]);
+  BoxFilterPreProcess5LastRow(src, 0, scale, &s, &sq, sum5, square_sum5,
+                              &mas.val[0], &bs.val[0]);
+
+  int x = 0;
+  do {
+    s.val[0] = s.val[1];
+    sq.val[0] = sq.val[1];
+    BoxFilterPreProcess5LastRow(src, x + 8, scale, &s, &sq, sum5, square_sum5,
+                                &mas.val[1], &bs.val[1]);
+    uint16x8_t ma[2];
+    uint32x4x2_t b[2];
+    ma[1] = Sum565(mas);
+    b[1] = Sum565W(bs);
+    mas.val[0] = mas.val[1];
+    bs.val[0] = bs.val[1];
+    ma[0] = vld1q_u16(ma565);
+    b[0].val[0] = vld1q_u32(b565 + 0);
+    b[0].val[1] = vld1q_u32(b565 + 4);
+    const uint8x8_t s = vld1_u8(src0 + x);
+    const int16x8_t p = CalculateFilteredOutputPass1(s, ma, b);
+    SelfGuidedSingleMultiplier(s, p, w0, dst + x);
+    ma565 += 8;
+    b565 += 8;
+    x += 8;
+  } while (x < width);
+}
+
+inline void BoxFilterPass2(const uint8_t* const src0, const uint8_t* const src,
+                           const int width, const uint32_t scale,
+                           const int16_t w0, uint16_t* const sum3[3],
+                           uint32_t* const square_sum3[3],
+                           uint16_t* const ma343[3], uint16_t* const ma444[2],
+                           uint32_t* const b343[3], uint32_t* const b444[2],
+                           uint8_t* const dst) {
+  uint8x8x2_t s, mas;
+  uint16x8x2_t sq, bs;
+  s.val[0] = vld1_u8(src);
+  sq.val[0] = vmull_u8(s.val[0], s.val[0]);
+  BoxFilterPreProcess3(src, 0, scale, &s, &sq, sum3, square_sum3, &mas.val[0],
+                       &bs.val[0]);
+
+  int x = 0;
+  do {
+    s.val[0] = s.val[1];
+    sq.val[0] = sq.val[1];
+    BoxFilterPreProcess3(src, x + 8, scale, &s, &sq, sum3, square_sum3,
+                         &mas.val[1], &bs.val[1]);
+    uint16x8_t ma[3];
+    uint32x4x2_t b[3];
+    Store343_444(mas, bs, x, &ma[2], &b[2], ma343[2], ma444[1], b343[2],
+                 b444[1]);
+    const uint8x8_t s0 = vld1_u8(src0 + x);
+    ma[0] = vld1q_u16(ma343[0] + x);
+    ma[1] = vld1q_u16(ma444[0] + x);
+    b[0].val[0] = vld1q_u32(b343[0] + x + 0);
+    b[0].val[1] = vld1q_u32(b343[0] + x + 4);
+    b[1].val[0] = vld1q_u32(b444[0] + x + 0);
+    b[1].val[1] = vld1q_u32(b444[0] + x + 4);
+    const int16x8_t p = CalculateFilteredOutputPass2(s0, ma, b);
+    SelfGuidedSingleMultiplier(s0, p, w0, dst + x);
+    mas.val[0] = mas.val[1];
+    bs.val[0] = bs.val[1];
+    x += 8;
+  } while (x < width);
+}
+
+inline void BoxFilter(const uint8_t* const src0, const uint8_t* const src,
+                      const ptrdiff_t src_stride, const int width,
+                      const uint16_t scales[2], const int16_t w0,
+                      const int16_t w2, uint16_t* const sum3[4],
+                      uint16_t* const sum5[5], uint32_t* const square_sum3[4],
+                      uint32_t* const square_sum5[5], uint16_t* const ma343[4],
+                      uint16_t* const ma444[3], uint16_t* const ma565[2],
+                      uint32_t* const b343[4], uint32_t* const b444[3],
+                      uint32_t* const b565[2], uint8_t* const dst,
+                      const ptrdiff_t dst_stride) {
+  uint8x8x2_t s[2], ma3[2], ma5;
+  uint16x8x2_t sq[2], b3[2], b5;
+  s[0].val[0] = vld1_u8(src);
+  s[1].val[0] = vld1_u8(src + src_stride);
+  sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
+  sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
+  BoxFilterPreProcess(src, src_stride, 0, scales, s, sq, sum3, sum5,
+                      square_sum3, square_sum5, &ma3[0].val[0], &ma3[1].val[0],
+                      &b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]);
+
+  int x = 0;
+  do {
+    s[0].val[0] = s[0].val[1];
+    s[1].val[0] = s[1].val[1];
+    sq[0].val[0] = sq[0].val[1];
+    sq[1].val[0] = sq[1].val[1];
+    BoxFilterPreProcess(src, src_stride, x + 8, scales, s, sq, sum3, sum5,
+                        square_sum3, square_sum5, &ma3[0].val[1],
+                        &ma3[1].val[1], &b3[0].val[1], &b3[1].val[1],
+                        &ma5.val[1], &b5.val[1]);
+    uint16x8_t ma[3][3];
+    uint32x4x2_t b[3][3];
+    Store343_444(ma3[0], b3[0], x, &ma[1][2], &ma[2][1], &b[1][2], &b[2][1],
+                 ma343[2], ma444[1], b343[2], b444[1]);
+    Store343_444(ma3[1], b3[1], x, &ma[2][2], &b[2][2], ma343[3], ma444[2],
+                 b343[3], b444[2]);
+    ma[0][1] = Sum565(ma5);
+    b[0][1] = Sum565W(b5);
+    vst1q_u16(ma565[1] + x, ma[0][1]);
+    vst1q_u32(b565[1] + x, b[0][1].val[0]);
+    vst1q_u32(b565[1] + x + 4, b[0][1].val[1]);
+    s[0].val[0] = s[0].val[1];
+    s[1].val[0] = s[1].val[1];
+    sq[0].val[0] = sq[0].val[1];
+    sq[1].val[0] = sq[1].val[1];
+    ma3[0].val[0] = ma3[0].val[1];
+    ma3[1].val[0] = ma3[1].val[1];
+    b3[0].val[0] = b3[0].val[1];
+    b3[1].val[0] = b3[1].val[1];
+    ma5.val[0] = ma5.val[1];
+    b5.val[0] = b5.val[1];
+    int16x8_t p[2][2];
+    const uint8x8_t s0 = vld1_u8(src0 + x);
+    const uint8x8_t s1 = vld1_u8(src0 + src_stride + x);
+    ma[0][0] = vld1q_u16(ma565[0] + x);
+    b[0][0].val[0] = vld1q_u32(b565[0] + x);
+    b[0][0].val[1] = vld1q_u32(b565[0] + x + 4);
+    p[0][0] = CalculateFilteredOutputPass1(s0, ma[0], b[0]);
+    p[1][0] = CalculateFilteredOutput<4>(s1, ma[0][1], b[0][1]);
+    ma[1][0] = vld1q_u16(ma343[0] + x);
+    ma[1][1] = vld1q_u16(ma444[0] + x);
+    b[1][0].val[0] = vld1q_u32(b343[0] + x);
+    b[1][0].val[1] = vld1q_u32(b343[0] + x + 4);
+    b[1][1].val[0] = vld1q_u32(b444[0] + x);
+    b[1][1].val[1] = vld1q_u32(b444[0] + x + 4);
+    p[0][1] = CalculateFilteredOutputPass2(s0, ma[1], b[1]);
+    ma[2][0] = vld1q_u16(ma343[1] + x);
+    b[2][0].val[0] = vld1q_u32(b343[1] + x);
+    b[2][0].val[1] = vld1q_u32(b343[1] + x + 4);
+    p[1][1] = CalculateFilteredOutputPass2(s1, ma[2], b[2]);
+    SelfGuidedDoubleMultiplier(s0, p[0], w0, w2, dst + x);
+    SelfGuidedDoubleMultiplier(s1, p[1], w0, w2, dst + dst_stride + x);
+    x += 8;
+  } while (x < width);
+}
+
+inline void BoxFilterLastRow(
+    const uint8_t* const src0, const uint8_t* const src, const int width,
+    const uint16_t scales[2], const int16_t w0, const int16_t w2,
+    uint16_t* const sum3[4], uint16_t* const sum5[5],
+    uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
+    uint16_t* const ma343[4], uint16_t* const ma444[3],
+    uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3],
+    uint32_t* const b565[2], uint8_t* const dst) {
+  uint8x8x2_t s, ma3, ma5;
+  uint16x8x2_t sq, b3, b5;
+  uint16x8_t ma[3];
+  uint32x4x2_t b[3];
+  s.val[0] = vld1_u8(src);
+  sq.val[0] = vmull_u8(s.val[0], s.val[0]);
+  BoxFilterPreProcessLastRow(src, 0, scales, sum3, sum5, square_sum3,
+                             square_sum5, &s, &sq, &ma3.val[0], &ma5.val[0],
+                             &b3.val[0], &b5.val[0]);
+
+  int x = 0;
+  do {
+    s.val[0] = s.val[1];
+    sq.val[0] = sq.val[1];
+    BoxFilterPreProcessLastRow(src, x + 8, scales, sum3, sum5, square_sum3,
+                               square_sum5, &s, &sq, &ma3.val[1], &ma5.val[1],
+                               &b3.val[1], &b5.val[1]);
+    ma[1] = Sum565(ma5);
+    b[1] = Sum565W(b5);
+    ma5.val[0] = ma5.val[1];
+    b5.val[0] = b5.val[1];
+    ma[2] = Sum343(ma3);
+    b[2] = Sum343W(b3);
+    ma3.val[0] = ma3.val[1];
+    b3.val[0] = b3.val[1];
+    const uint8x8_t s0 = vld1_u8(src0 + x);
+    int16x8_t p[2];
+    ma[0] = vld1q_u16(ma565[0] + x);
+    b[0].val[0] = vld1q_u32(b565[0] + x + 0);
+    b[0].val[1] = vld1q_u32(b565[0] + x + 4);
+    p[0] = CalculateFilteredOutputPass1(s0, ma, b);
+    ma[0] = vld1q_u16(ma343[0] + x);
+    ma[1] = vld1q_u16(ma444[0] + x);
+    b[0].val[0] = vld1q_u32(b343[0] + x + 0);
+    b[0].val[1] = vld1q_u32(b343[0] + x + 4);
+    b[1].val[0] = vld1q_u32(b444[0] + x + 0);
+    b[1].val[1] = vld1q_u32(b444[0] + x + 4);
+    p[1] = CalculateFilteredOutputPass2(s0, ma, b);
+    SelfGuidedDoubleMultiplier(s0, p, w0, w2, dst + x);
+    x += 8;
+  } while (x < width);
+}
+
+template <typename T>
+void Circulate3PointersBy1(T* p[3]) {
+  T* const p0 = p[0];
+  p[0] = p[1];
+  p[1] = p[2];
+  p[2] = p0;
+}
+
+template <typename T>
+void Circulate4PointersBy2(T* p[4]) {
+  std::swap(p[0], p[2]);
+  std::swap(p[1], p[3]);
+}
+
+template <typename T>
+void Circulate5PointersBy2(T* p[5]) {
+  T* const p0 = p[0];
+  T* const p1 = p[1];
+  p[0] = p[2];
+  p[1] = p[3];
+  p[2] = p[4];
+  p[3] = p0;
+  p[4] = p1;
+}
+
+inline void BoxFilterProcess(const RestorationUnitInfo& restoration_info,
+                             const uint8_t* src, const ptrdiff_t src_stride,
+                             const int width, const int height,
+                             SgrBuffer* const sgr_buffer, uint8_t* dst,
+                             const ptrdiff_t dst_stride) {
+  const auto temp_stride = Align<ptrdiff_t>(width, 8);
+  const ptrdiff_t sum_stride = temp_stride + 8;
+  const int sgr_proj_index = restoration_info.sgr_proj_info.index;
+  const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index];  // < 2^12.
+  const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
+  const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
+  const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
+  uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2];
+  uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2];
+  sum3[0] = sgr_buffer->sum3;
+  square_sum3[0] = sgr_buffer->square_sum3;
+  ma343[0] = sgr_buffer->ma343;
+  b343[0] = sgr_buffer->b343;
+  for (int i = 1; i <= 3; ++i) {
+    sum3[i] = sum3[i - 1] + sum_stride;
+    square_sum3[i] = square_sum3[i - 1] + sum_stride;
+    ma343[i] = ma343[i - 1] + temp_stride;
+    b343[i] = b343[i - 1] + temp_stride;
   }
+  sum5[0] = sgr_buffer->sum5;
+  square_sum5[0] = sgr_buffer->square_sum5;
+  for (int i = 1; i <= 4; ++i) {
+    sum5[i] = sum5[i - 1] + sum_stride;
+    square_sum5[i] = square_sum5[i - 1] + sum_stride;
+  }
+  ma444[0] = sgr_buffer->ma444;
+  b444[0] = sgr_buffer->b444;
+  for (int i = 1; i <= 2; ++i) {
+    ma444[i] = ma444[i - 1] + temp_stride;
+    b444[i] = b444[i - 1] + temp_stride;
+  }
+  ma565[0] = sgr_buffer->ma565;
+  ma565[1] = ma565[0] + temp_stride;
+  b565[0] = sgr_buffer->b565;
+  b565[1] = b565[0] + temp_stride;
+  assert(scales[0] != 0);
+  assert(scales[1] != 0);
+  BoxSum(src - 2 * src_stride - 3, src_stride, 2, sum_stride, sum3[0], sum5[1],
+         square_sum3[0], square_sum5[1]);
+  sum5[0] = sum5[1];
+  square_sum5[0] = square_sum5[1];
+  BoxSumFilterPreProcess(src - 3, src_stride, width, scales, sum3, sum5,
+                         square_sum3, square_sum5, ma343, ma444, ma565[0], b343,
+                         b444, b565[0]);
+  sum5[0] = sgr_buffer->sum5;
+  square_sum5[0] = sgr_buffer->square_sum5;
+  for (int y = height >> 1; y != 0; --y) {
+    Circulate4PointersBy2<uint16_t>(sum3);
+    Circulate4PointersBy2<uint32_t>(square_sum3);
+    Circulate5PointersBy2<uint16_t>(sum5);
+    Circulate5PointersBy2<uint32_t>(square_sum5);
+    BoxFilter(src, src + 2 * src_stride - 3, src_stride, width, scales, w0, w2,
+              sum3, sum5, square_sum3, square_sum5, ma343, ma444, ma565, b343,
+              b444, b565, dst, dst_stride);
+    src += 2 * src_stride;
+    dst += 2 * dst_stride;
+    Circulate4PointersBy2<uint16_t>(ma343);
+    Circulate4PointersBy2<uint32_t>(b343);
+    std::swap(ma444[0], ma444[2]);
+    std::swap(b444[0], b444[2]);
+    std::swap(ma565[0], ma565[1]);
+    std::swap(b565[0], b565[1]);
+  }
+  if ((height & 1) != 0) {
+    Circulate4PointersBy2<uint16_t>(sum3);
+    Circulate4PointersBy2<uint32_t>(square_sum3);
+    Circulate5PointersBy2<uint16_t>(sum5);
+    Circulate5PointersBy2<uint32_t>(square_sum5);
+    BoxFilterLastRow(src, src + 2 * src_stride - 3, width, scales, w0, w2, sum3,
+                     sum5, square_sum3, square_sum5, ma343, ma444, ma565, b343,
+                     b444, b565, dst);
+  }
+}
 
-  int x = 0;
+inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info,
+                                  const uint8_t* src,
+                                  const ptrdiff_t src_stride, const int width,
+                                  const int height, SgrBuffer* const sgr_buffer,
+                                  uint8_t* dst, const ptrdiff_t dst_stride) {
+  const auto temp_stride = Align<ptrdiff_t>(width, 8);
+  const ptrdiff_t sum_stride = temp_stride + 8;
+  const int sgr_proj_index = restoration_info.sgr_proj_info.index;
+  const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0];  // < 2^12.
+  const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
+  uint16_t *sum5[5], *ma565[2];
+  uint32_t *square_sum5[5], *b565[2];
+  sum5[0] = sgr_buffer->sum5;
+  square_sum5[0] = sgr_buffer->square_sum5;
+  for (int i = 1; i <= 4; ++i) {
+    sum5[i] = sum5[i - 1] + sum_stride;
+    square_sum5[i] = square_sum5[i - 1] + sum_stride;
+  }
+  ma565[0] = sgr_buffer->ma565;
+  ma565[1] = ma565[0] + temp_stride;
+  b565[0] = sgr_buffer->b565;
+  b565[1] = b565[0] + temp_stride;
+  assert(scale != 0);
+  BoxSum<5>(src - 2 * src_stride - 3, src_stride, 2, sum_stride, sum5[1],
+            square_sum5[1]);
+  sum5[0] = sum5[1];
+  square_sum5[0] = square_sum5[1];
+  BoxSumFilterPreProcess5(src - 3, src_stride, width, scale, sum5, square_sum5,
+                          ma565[0], b565[0]);
+  sum5[0] = sgr_buffer->sum5;
+  square_sum5[0] = sgr_buffer->square_sum5;
+  for (int y = height >> 1; y != 0; --y) {
+    Circulate5PointersBy2<uint16_t>(sum5);
+    Circulate5PointersBy2<uint32_t>(square_sum5);
+    BoxFilterPass1(src, src + 2 * src_stride - 3, src_stride, sum5, square_sum5,
+                   width, scale, w0, ma565, b565, dst, dst_stride);
+    src += 2 * src_stride;
+    dst += 2 * dst_stride;
+    std::swap(ma565[0], ma565[1]);
+    std::swap(b565[0], b565[1]);
+  }
+  if ((height & 1) != 0) {
+    Circulate5PointersBy2<uint16_t>(sum5);
+    Circulate5PointersBy2<uint32_t>(square_sum5);
+    BoxFilterPass1LastRow(src, src + 2 * src_stride - 3, width, scale, w0, sum5,
+                          square_sum5, ma565[0], b565[0], dst);
+  }
+}
+
+inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info,
+                                  const uint8_t* src,
+                                  const ptrdiff_t src_stride, const int width,
+                                  const int height, SgrBuffer* const sgr_buffer,
+                                  uint8_t* dst, const ptrdiff_t dst_stride) {
+  assert(restoration_info.sgr_proj_info.multiplier[0] == 0);
+  const auto temp_stride = Align<ptrdiff_t>(width, 8);
+  const ptrdiff_t sum_stride = temp_stride + 8;
+  const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
+  const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1;
+  const int sgr_proj_index = restoration_info.sgr_proj_info.index;
+  const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1];  // < 2^12.
+  uint16_t *sum3[3], *ma343[3], *ma444[2];
+  uint32_t *square_sum3[3], *b343[3], *b444[2];
+  sum3[0] = sgr_buffer->sum3;
+  square_sum3[0] = sgr_buffer->square_sum3;
+  ma343[0] = sgr_buffer->ma343;
+  b343[0] = sgr_buffer->b343;
+  for (int i = 1; i <= 2; ++i) {
+    sum3[i] = sum3[i - 1] + sum_stride;
+    square_sum3[i] = square_sum3[i - 1] + sum_stride;
+    ma343[i] = ma343[i - 1] + temp_stride;
+    b343[i] = b343[i - 1] + temp_stride;
+  }
+  ma444[0] = sgr_buffer->ma444;
+  ma444[1] = ma444[0] + temp_stride;
+  b444[0] = sgr_buffer->b444;
+  b444[1] = b444[0] + temp_stride;
+  assert(scale != 0);
+  BoxSum<3>(src - 2 * src_stride - 2, src_stride, 2, sum_stride, sum3[0],
+            square_sum3[0]);
+  BoxSumFilterPreProcess3<false>(src - 2, width, scale, sum3, square_sum3,
+                                 ma343[0], nullptr, b343[0], nullptr);
+  Circulate3PointersBy1<uint16_t>(sum3);
+  Circulate3PointersBy1<uint32_t>(square_sum3);
+  BoxSumFilterPreProcess3<true>(src + src_stride - 2, width, scale, sum3,
+                                square_sum3, ma343[1], ma444[0], b343[1],
+                                b444[0]);
+
+  int y = height;
   do {
-    // |src_pre_process| is X but we already processed the first column of 4
-    // values so we want to start at Y and increment from there.
-    // X s s s Y s s
-    // s s s s s s s
-    // s s i i i i i
-    // s s i o o o o
-    // s s i o o o o
-    const uint8_t* column = src_pre_process + x + 4;
-
-    uint8x8_t row[5];
-    row[0] = vld1_u8(column);
-    column += stride;
-    row[1] = vld1_u8(column);
-    column += stride;
-    row[2] = vld1_u8(column);
-    column += stride;
-    row[3] = vld1_u8(column);
-    column += stride;
-    row[4] = vld1_u8(column);
-    column += stride;
-
-    uint16x8_t row_sq[5];
-    row_sq[0] = vmull_u8(row[0], row[0]);
-    row_sq[1] = vmull_u8(row[1], row[1]);
-    row_sq[2] = vmull_u8(row[2], row[2]);
-    row_sq[3] = vmull_u8(row[3], row[3]);
-    row_sq[4] = vmull_u8(row[4], row[4]);
-
-    // Seed the loop with one line of output. Then, inside the loop, for each
-    // iteration we can output one even row and one odd row and carry the new
-    // line to the next iteration. In the diagram below 'i' values are
-    // intermediary values from the first step and '-' values are empty.
-    // iiii
-    // ---- > even row
-    // iiii - odd row
-    // ---- > even row
-    // iiii
-    uint16_t* temp_column = temp;
-    uint16x4_t sum565_a0;
-    uint32x4_t sum565_b0;
-    {
-      const uint16x4_t sum = Sum5Horizontal(Sum5Vertical(row));
-      const uint32x4_t sum_sq = Sum5Horizontal(Sum5Vertical(row_sq));
-
-      const uint16x4_t a2 = CalculateA2<n>(sum_sq, sum, s, v_255);
-      const uint16x4_t b2 = CalculateB2Shifted(a2, sum, one_over_n);
-
-      // Exchange the previously calculated |a2| and |b2| values.
-      const uint16x8_t a2_b2 = vld1q_u16(temp_column);
-      vst1q_u16(temp_column, vcombine_u16(a2, b2));
-      temp_column += 8;
-
-      // Pass 1 Process. These are the only values we need to propagate between
-      // rows.
-      sum565_a0 = Sum565(vcombine_u16(vget_low_u16(a2_b2), a2));
-      sum565_b0 = Sum565W(vcombine_u16(vget_high_u16(a2_b2), b2));
-    }
-
-    row[0] = row[2];
-    row[1] = row[3];
-    row[2] = row[4];
-
-    row_sq[0] = row_sq[2];
-    row_sq[1] = row_sq[3];
-    row_sq[2] = row_sq[4];
-
-    const uint8_t* src_ptr = src + x;
-    uint16_t* out_buf = buf + x;
-
-    // Calculate one output line. Add in the line from the previous pass and
-    // output one even row. Sum the new line and output the odd row. Carry the
-    // new row into the next pass.
-    int y = 0;
-    do {
-      row[3] = vld1_u8(column);
-      column += stride;
-      row[4] = vld1_u8(column);
-      column += stride;
-
-      row_sq[3] = vmull_u8(row[3], row[3]);
-      row_sq[4] = vmull_u8(row[4], row[4]);
-
-      const uint16x4_t sum = Sum5Horizontal(Sum5Vertical(row));
-      const uint32x4_t sum_sq = Sum5Horizontal(Sum5Vertical(row_sq));
-
-      const uint16x4_t a2 = CalculateA2<n>(sum_sq, sum, s, v_255);
-      const uint16x4_t b2 = CalculateB2Shifted(a2, sum, one_over_n);
-
-      const uint16x8_t a2_b2 = vld1q_u16(temp_column);
-      vst1q_u16(temp_column, vcombine_u16(a2, b2));
-      temp_column += 8;
-
-      uint16x4_t sum565_a1 = Sum565(vcombine_u16(vget_low_u16(a2_b2), a2));
-      uint32x4_t sum565_b1 = Sum565W(vcombine_u16(vget_high_u16(a2_b2), b2));
-
-      const uint8x8_t src_u8 = vld1_u8(src_ptr);
-      src_ptr += stride;
-      const uint16x4_t src_u16 = vget_low_u16(vmovl_u8(src_u8));
-
-      const uint16x4_t output =
-          CalculateFilteredOutput<5>(vadd_u16(sum565_a0, sum565_a1),
-                                     vaddq_u32(sum565_b0, sum565_b1), src_u16);
-
-      vst1_u16(out_buf, output);
-      out_buf += kIntermediateStride;
-
-      const uint8x8_t src0_u8 = vld1_u8(src_ptr);
-      src_ptr += stride;
-      const uint16x4_t src0_u16 = vget_low_u16(vmovl_u8(src0_u8));
-
-      const uint16x4_t output1 =
-          CalculateFilteredOutput<4>(sum565_a1, sum565_b1, src0_u16);
-      vst1_u16(out_buf, output1);
-      out_buf += kIntermediateStride;
-
-      row[0] = row[2];
-      row[1] = row[3];
-      row[2] = row[4];
-
-      row_sq[0] = row_sq[2];
-      row_sq[1] = row_sq[3];
-      row_sq[2] = row_sq[4];
-
-      sum565_a0 = sum565_a1;
-      sum565_b0 = sum565_b1;
-      y += 2;
-    } while (y < height);
-    x += 4;
-  } while (x < width);
-}
-
-inline void BoxFilterPreProcess_SecondPass(const uint8_t* const src,
-                                           const ptrdiff_t stride,
-                                           const int width, const int height,
-                                           const uint32_t s,
-                                           uint16_t* const a2) {
-  // Number of elements in the box being summed.
-  const uint32_t n = 9;
-  const uint32_t one_over_n = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n;
-
-  const uint16x4_t v_255 = vdup_n_u16(255);
-
-  // Calculate intermediate results, including one-pixel border, for example,
-  // if unit size is 64x64, we calculate 66x66 pixels.
-  // Because of the vectors this calculates in blocks of 4 so we actually
-  // get 68 values. This doesn't appear to be causing problems yet but it
-  // might.
-  const uint8_t* const src_top_left_corner = src - 1 - 2 * stride;
-  int x = -1;
-  do {
-    const uint8_t* column = src_top_left_corner + x;
-    uint16_t* a2_column = a2 + (x + 1);
-    uint8x8_t row[3];
-    row[0] = vld1_u8(column);
-    column += stride;
-    row[1] = vld1_u8(column);
-    column += stride;
-
-    uint16x8_t row_sq[3];
-    row_sq[0] = vmull_u8(row[0], row[0]);
-    row_sq[1] = vmull_u8(row[1], row[1]);
-
-    int y = -1;
-    do {
-      row[2] = vld1_u8(column);
-      column += stride;
-
-      row_sq[2] = vmull_u8(row[2], row[2]);
-
-      const uint16x4_t sum = Sum3Horizontal(Sum3Vertical(row));
-      const uint32x4_t sum_sq = Sum3Horizontal(Sum3Vertical(row_sq));
-
-      const uint16x4_t a2_v = CalculateA2<n>(sum_sq, sum, s, v_255);
-
-      vst1_u16(a2_column, a2_v);
-      a2_column += kIntermediateStride;
-
-      vst1_u16(a2_column, CalculateB2Shifted(a2_v, sum, one_over_n));
-      a2_column += kIntermediateStride;
-
-      row[0] = row[1];
-      row[1] = row[2];
-
-      row_sq[0] = row_sq[1];
-      row_sq[1] = row_sq[2];
-    } while (++y < height + 1);
-    x += 4;
-  } while (x < width + 1);
-}
-
-inline uint16x4_t Sum444(const uint16x8_t a) {
-  return Sum3Horizontal(vshlq_n_u16(a, 2));
-}
-
-inline uint32x4_t Sum444W(const uint16x8_t a) {
-  return Sum3HorizontalW(vshll_n_u16(vget_low_u16(a), 2),
-                         vshll_n_u16(vget_high_u16(a), 2));
-}
-
-inline uint16x4_t Sum343(const uint16x8_t a) {
-  const uint16x4_t middle = vext_u16(vget_low_u16(a), vget_high_u16(a), 1);
-  const uint16x4_t sum = Sum3Horizontal(a);
-  return vadd_u16(vadd_u16(vadd_u16(sum, sum), sum), middle);
-}
-
-inline uint32x4_t Sum343W(const uint16x8_t a) {
-  const uint16x4_t middle = vext_u16(vget_low_u16(a), vget_high_u16(a), 1);
-  const uint32x4_t sum =
-      Sum3HorizontalW(vmovl_u16(vget_low_u16(a)), vmovl_u16(vget_high_u16(a)));
-  return vaddw_u16(vaddq_u32(vaddq_u32(sum, sum), sum), middle);
-}
-
-inline void BoxFilterProcess_SecondPass(const uint8_t* src,
-                                        const ptrdiff_t stride, const int width,
-                                        const int height, const uint32_t s,
-                                        uint16_t* const intermediate_buffer) {
-  uint16_t* const a2 =
-      intermediate_buffer + kIntermediateStride * kIntermediateHeight;
-
-  BoxFilterPreProcess_SecondPass(src, stride, width, height, s, a2);
-
-  int x = 0;
-  do {
-    uint16_t* a2_ptr = a2 + x;
-    const uint8_t* src_ptr = src + x;
-    // |filtered_output| must match how |a2| values are read since they are
-    // written out over the |a2| values which have already been used.
-    uint16_t* filtered_output = a2_ptr;
-
-    uint16x4_t sum343_a[2], sum444_a;
-    uint32x4_t sum343_b[2], sum444_b;
-
-    sum343_a[0] = Sum343(vld1q_u16(a2_ptr));
-    a2_ptr += kIntermediateStride;
-
-    sum343_b[0] = Sum343W(vld1q_u16(a2_ptr));
-    a2_ptr += kIntermediateStride;
-
-    const uint16x8_t a_1 = vld1q_u16(a2_ptr);
-    a2_ptr += kIntermediateStride;
-    sum343_a[1] = Sum343(a_1);
-    sum444_a = Sum444(a_1);
-
-    const uint16x8_t b_1 = vld1q_u16(a2_ptr);
-    a2_ptr += kIntermediateStride;
-    sum343_b[1] = Sum343W(b_1);
-    sum444_b = Sum444W(b_1);
-
-    int y = 0;
-    do {
-      const uint16x8_t a_2 = vld1q_u16(a2_ptr);
-      a2_ptr += kIntermediateStride;
-
-      const uint16x4_t sum343_a2 = Sum343(a_2);
-      const uint16x4_t a_v =
-          vadd_u16(vadd_u16(sum343_a[0], sum444_a), sum343_a2);
-      sum444_a = Sum444(a_2);
-      sum343_a[0] = sum343_a[1];
-      sum343_a[1] = sum343_a2;
-
-      const uint16x8_t b_2 = vld1q_u16(a2_ptr);
-      a2_ptr += kIntermediateStride;
-
-      const uint32x4_t sum343_b2 = Sum343W(b_2);
-      const uint32x4_t b_v =
-          vaddq_u32(vaddq_u32(sum343_b[0], sum444_b), sum343_b2);
-      sum444_b = Sum444W(b_2);
-      sum343_b[0] = sum343_b[1];
-      sum343_b[1] = sum343_b2;
-
-      // Load 8 values and discard 4.
-      const uint8x8_t src_u8 = vld1_u8(src_ptr);
-      const uint16x4_t src_u16 = vget_low_u16(vmovl_u8(src_u8));
-
-      vst1_u16(filtered_output, CalculateFilteredOutput<5>(a_v, b_v, src_u16));
-
-      src_ptr += stride;
-      filtered_output += kIntermediateStride;
-    } while (++y < height);
-    x += 4;
-  } while (x < width);
-}
-
-template <int min_width>
-inline void SelfGuidedSingleMultiplier(const uint8_t* src,
-                                       const ptrdiff_t src_stride,
-                                       uint16_t* box_filter_process_output,
-                                       uint8_t* dst, const ptrdiff_t dst_stride,
-                                       const int width, const int height,
-                                       const int16_t w_combo,
-                                       const int16x4_t w_single) {
-  static_assert(min_width == 4 || min_width == 8, "");
-
-  int y = 0;
-  do {
-    if (min_width == 8) {
-      int x = 0;
-      do {
-        const int16x8_t u = vreinterpretq_s16_u16(
-            vshll_n_u8(vld1_u8(src + x), kSgrProjRestoreBits));
-        const int16x8_t p =
-            vreinterpretq_s16_u16(vld1q_u16(box_filter_process_output + x));
-
-        // u * w1 + u * wN == u * (w1 + wN)
-        int32x4_t v_lo = vmull_n_s16(vget_low_s16(u), w_combo);
-        v_lo = vmlal_s16(v_lo, vget_low_s16(p), w_single);
-
-        int32x4_t v_hi = vmull_n_s16(vget_high_s16(u), w_combo);
-        v_hi = vmlal_s16(v_hi, vget_high_s16(p), w_single);
-
-        const int16x4_t s_lo =
-            vrshrn_n_s32(v_lo, kSgrProjRestoreBits + kSgrProjPrecisionBits);
-        const int16x4_t s_hi =
-            vrshrn_n_s32(v_hi, kSgrProjRestoreBits + kSgrProjPrecisionBits);
-        vst1_u8(dst + x, vqmovun_s16(vcombine_s16(s_lo, s_hi)));
-        x += 8;
-      } while (x < width);
-    } else if (min_width == 4) {
-      const int16x8_t u =
-          vreinterpretq_s16_u16(vshll_n_u8(vld1_u8(src), kSgrProjRestoreBits));
-      const int16x8_t p =
-          vreinterpretq_s16_u16(vld1q_u16(box_filter_process_output));
-
-      // u * w1 + u * wN == u * (w1 + wN)
-      int32x4_t v_lo = vmull_n_s16(vget_low_s16(u), w_combo);
-      v_lo = vmlal_s16(v_lo, vget_low_s16(p), w_single);
-
-      int32x4_t v_hi = vmull_n_s16(vget_high_s16(u), w_combo);
-      v_hi = vmlal_s16(v_hi, vget_high_s16(p), w_single);
-
-      const int16x4_t s_lo =
-          vrshrn_n_s32(v_lo, kSgrProjRestoreBits + kSgrProjPrecisionBits);
-      const int16x4_t s_hi =
-          vrshrn_n_s32(v_hi, kSgrProjRestoreBits + kSgrProjPrecisionBits);
-      StoreLo4(dst, vqmovun_s16(vcombine_s16(s_lo, s_hi)));
-    }
+    Circulate3PointersBy1<uint16_t>(sum3);
+    Circulate3PointersBy1<uint32_t>(square_sum3);
+    BoxFilterPass2(src, src + 2 * src_stride - 2, width, scale, w0, sum3,
+                   square_sum3, ma343, ma444, b343, b444, dst);
     src += src_stride;
     dst += dst_stride;
-    box_filter_process_output += kIntermediateStride;
-  } while (++y < height);
+    Circulate3PointersBy1<uint16_t>(ma343);
+    Circulate3PointersBy1<uint32_t>(b343);
+    std::swap(ma444[0], ma444[1]);
+    std::swap(b444[0], b444[1]);
+  } while (--y != 0);
 }
 
-template <int min_width>
-inline void SelfGuidedDoubleMultiplier(const uint8_t* src,
-                                       const ptrdiff_t src_stride,
-                                       uint16_t* box_filter_process_output[2],
-                                       uint8_t* dst, const ptrdiff_t dst_stride,
-                                       const int width, const int height,
-                                       const int16x4_t w0, const int w1,
-                                       const int16x4_t w2) {
-  static_assert(min_width == 4 || min_width == 8, "");
-
-  int y = 0;
-  do {
-    if (min_width == 8) {
-      int x = 0;
-      do {
-        // |wN| values are signed. |src| values can be treated as int16_t.
-        const int16x8_t u = vreinterpretq_s16_u16(
-            vshll_n_u8(vld1_u8(src + x), kSgrProjRestoreBits));
-        // |box_filter_process_output| is 14 bits, also safe to treat as
-        // int16_t.
-        const int16x8_t p0 =
-            vreinterpretq_s16_u16(vld1q_u16(box_filter_process_output[0] + x));
-        const int16x8_t p1 =
-            vreinterpretq_s16_u16(vld1q_u16(box_filter_process_output[1] + x));
-
-        int32x4_t v_lo = vmull_n_s16(vget_low_s16(u), w1);
-        v_lo = vmlal_s16(v_lo, vget_low_s16(p0), w0);
-        v_lo = vmlal_s16(v_lo, vget_low_s16(p1), w2);
-
-        int32x4_t v_hi = vmull_n_s16(vget_high_s16(u), w1);
-        v_hi = vmlal_s16(v_hi, vget_high_s16(p0), w0);
-        v_hi = vmlal_s16(v_hi, vget_high_s16(p1), w2);
-
-        // |s| is saturated to uint8_t.
-        const int16x4_t s_lo =
-            vrshrn_n_s32(v_lo, kSgrProjRestoreBits + kSgrProjPrecisionBits);
-        const int16x4_t s_hi =
-            vrshrn_n_s32(v_hi, kSgrProjRestoreBits + kSgrProjPrecisionBits);
-        vst1_u8(dst + x, vqmovun_s16(vcombine_s16(s_lo, s_hi)));
-        x += 8;
-      } while (x < width);
-    } else if (min_width == 4) {
-      // |wN| values are signed. |src| values can be treated as int16_t.
-      // Load 8 values but ignore 4.
-      const int16x4_t u = vget_low_s16(
-          vreinterpretq_s16_u16(vshll_n_u8(vld1_u8(src), kSgrProjRestoreBits)));
-      // |box_filter_process_output| is 14 bits, also safe to treat as
-      // int16_t.
-      const int16x4_t p0 =
-          vreinterpret_s16_u16(vld1_u16(box_filter_process_output[0]));
-      const int16x4_t p1 =
-          vreinterpret_s16_u16(vld1_u16(box_filter_process_output[1]));
-
-      int32x4_t v = vmull_n_s16(u, w1);
-      v = vmlal_s16(v, p0, w0);
-      v = vmlal_s16(v, p1, w2);
-
-      // |s| is saturated to uint8_t.
-      const int16x4_t s =
-          vrshrn_n_s32(v, kSgrProjRestoreBits + kSgrProjPrecisionBits);
-      StoreLo4(dst, vqmovun_s16(vcombine_s16(s, s)));
-    }
-    src += src_stride;
-    dst += dst_stride;
-    box_filter_process_output[0] += kIntermediateStride;
-    box_filter_process_output[1] += kIntermediateStride;
-  } while (++y < height);
-}
-
-// Assume box_filter_process_output[2] are allocated before calling
-// this function. Their sizes are width * height, stride equals width.
+// If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in
+// the end of each row. It is safe to overwrite the output as it will not be
+// part of the visible frame.
 void SelfGuidedFilter_NEON(const void* const source, void* const dest,
                            const RestorationUnitInfo& restoration_info,
-                           ptrdiff_t source_stride, ptrdiff_t dest_stride,
-                           const int width, const int height,
-                           RestorationBuffer* const /*buffer*/) {
-  // The output frame is broken into blocks of 64x64 (32x32 if U/V are
-  // subsampled). If either dimension is less than 32/64 it indicates it is at
-  // the right or bottom edge of the frame. It is safe to overwrite the output
-  // as it will not be part of the visible frame. This saves us from having to
-  // handle non-multiple-of-8 widths.
-  // We could round here, but the for loop with += 8 does the same thing.
-
-  // width = (width + 7) & ~0x7;
-
-  // -96 to 96 (Sgrproj_Xqd_Min/Max)
-  const int8_t w0 = restoration_info.sgr_proj_info.multiplier[0];
-  const int8_t w1 = restoration_info.sgr_proj_info.multiplier[1];
-  const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
+                           const ptrdiff_t source_stride,
+                           const ptrdiff_t dest_stride, const int width,
+                           const int height,
+                           RestorationBuffer* const restoration_buffer) {
   const int index = restoration_info.sgr_proj_info.index;
-  const uint8_t radius_pass_0 = kSgrProjParams[index][0];
-  const uint8_t radius_pass_1 = kSgrProjParams[index][2];
-  const auto* src = static_cast<const uint8_t*>(source);
-  auto* dst = static_cast<uint8_t*>(dest);
-
-  // |intermediate_buffer| is broken down into three distinct regions, each with
-  // size |kIntermediateStride| * |kIntermediateHeight|.
-  // The first is for final output of the first pass of PreProcess/Process. It
-  // can be stored in |width| * |height| (at most 64x64).
-  // The second and third are scratch space for |a2| and |b2| values from
-  // PreProcess.
-  //
-  // At the end of BoxFilterProcess_SecondPass() the output is stored over |a2|.
-
-  uint16_t intermediate_buffer[3 * kIntermediateStride * kIntermediateHeight];
-  uint16_t* box_filter_process_output[2] = {
-      intermediate_buffer,
-      intermediate_buffer + kIntermediateStride * kIntermediateHeight};
-
-  // If |radius| is 0 then there is nothing to do. If |radius| is not 0, it is
-  // always 2 for the first pass and 1 for the second pass.
-  if (radius_pass_0 != 0) {
-    BoxFilterProcess_FirstPass(src, source_stride, width, height,
-                               kSgrScaleParameter[index][0],
-                               intermediate_buffer);
-  }
-
-  if (radius_pass_1 != 0) {
-    BoxFilterProcess_SecondPass(src, source_stride, width, height,
-                                kSgrScaleParameter[index][1],
-                                intermediate_buffer);
-  }
-
-  // Put |w[02]| in vectors because we can use vmull_n_s16() for |w1| but there
-  // is no vmlal_n_s16().
-  const int16x4_t w0_v = vdup_n_s16(w0);
-  const int16x4_t w2_v = vdup_n_s16(w2);
-  if (radius_pass_0 != 0 && radius_pass_1 != 0) {
-    if (width > 4) {
-      SelfGuidedDoubleMultiplier<8>(src, source_stride,
-                                    box_filter_process_output, dst, dest_stride,
-                                    width, height, w0_v, w1, w2_v);
-    } else /* if (width == 4) */ {
-      SelfGuidedDoubleMultiplier<4>(src, source_stride,
-                                    box_filter_process_output, dst, dest_stride,
-                                    width, height, w0_v, w1, w2_v);
-    }
+  const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
+  const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
+  const auto* const src = static_cast<const uint8_t*>(source);
+  auto* const dst = static_cast<uint8_t*>(dest);
+  SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer;
+  if (radius_pass_1 == 0) {
+    // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the
+    // following assertion.
+    assert(radius_pass_0 != 0);
+    BoxFilterProcessPass1(restoration_info, src, source_stride, width, height,
+                          sgr_buffer, dst, dest_stride);
+  } else if (radius_pass_0 == 0) {
+    BoxFilterProcessPass2(restoration_info, src, source_stride, width, height,
+                          sgr_buffer, dst, dest_stride);
   } else {
-    int16_t w_combo;
-    int16x4_t w_single;
-    uint16_t* box_filter_process_output_n;
-    if (radius_pass_0 != 0) {
-      w_combo = w1 + w2;
-      w_single = w0_v;
-      box_filter_process_output_n = box_filter_process_output[0];
-    } else /* if (radius_pass_1 != 0) */ {
-      w_combo = w1 + w0;
-      w_single = w2_v;
-      box_filter_process_output_n = box_filter_process_output[1];
-    }
-
-    if (width > 4) {
-      SelfGuidedSingleMultiplier<8>(
-          src, source_stride, box_filter_process_output_n, dst, dest_stride,
-          width, height, w_combo, w_single);
-    } else /* if (width == 4) */ {
-      SelfGuidedSingleMultiplier<4>(
-          src, source_stride, box_filter_process_output_n, dst, dest_stride,
-          width, height, w_combo, w_single);
-    }
+    BoxFilterProcess(restoration_info, src, source_stride, width, height,
+                     sgr_buffer, dst, dest_stride);
   }
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   dsp->loop_restorations[0] = WienerFilter_NEON;
   dsp->loop_restorations[1] = SelfGuidedFilter_NEON;
@@ -1056,7 +1828,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/arm/loop_restoration_neon.h b/libgav1/src/dsp/arm/loop_restoration_neon.h
index a6ea74b..b551610 100644
--- a/libgav1/src/dsp/arm/loop_restoration_neon.h
+++ b/libgav1/src/dsp/arm/loop_restoration_neon.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_ARM_LOOP_RESTORATION_NEON_H_
 #define LIBGAV1_SRC_DSP_ARM_LOOP_RESTORATION_NEON_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -32,8 +32,8 @@
 
 #if LIBGAV1_ENABLE_NEON
 
-#define LIBGAV1_Dsp8bpp_WienerFilter LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_SelfGuidedFilter LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_WienerFilter LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_SelfGuidedFilter LIBGAV1_CPU_NEON
 
 #endif  // LIBGAV1_ENABLE_NEON
 
diff --git a/libgav1/src/dsp/arm/mask_blend_neon.cc b/libgav1/src/dsp/arm/mask_blend_neon.cc
index 03ac791..21f3fb1 100644
--- a/libgav1/src/dsp/arm/mask_blend_neon.cc
+++ b/libgav1/src/dsp/arm/mask_blend_neon.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/mask_blend.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -24,6 +24,8 @@
 #include <cstdint>
 
 #include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/utils/common.h"
 
 namespace libgav1 {
@@ -31,286 +33,396 @@
 namespace low_bitdepth {
 namespace {
 
-constexpr int kBitdepth8 = 8;
-
+// TODO(b/150461164): Consider combining with GetInterIntraMask4x2().
+// Compound predictors use int16_t values and need to multiply long because the
+// Convolve range * 64 is 20 bits. Unfortunately there is no multiply int16_t by
+// int8_t and accumulate into int32_t instruction.
 template <int subsampling_x, int subsampling_y>
-inline uint16x8_t GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) {
+inline int16x8_t GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) {
   if (subsampling_x == 1) {
-    const uint16x4_t mask_val0 = vpaddl_u8(vld1_u8(mask));
-    const uint16x4_t mask_val1 =
-        vpaddl_u8(vld1_u8(mask + (mask_stride << subsampling_y)));
-    uint16x8_t final_val;
+    const int16x4_t mask_val0 = vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask)));
+    const int16x4_t mask_val1 = vreinterpret_s16_u16(
+        vpaddl_u8(vld1_u8(mask + (mask_stride << subsampling_y))));
+    int16x8_t final_val;
     if (subsampling_y == 1) {
-      const uint16x4_t next_mask_val0 = vpaddl_u8(vld1_u8(mask + mask_stride));
-      const uint16x4_t next_mask_val1 =
-          vpaddl_u8(vld1_u8(mask + mask_stride * 3));
-      final_val = vaddq_u16(vcombine_u16(mask_val0, mask_val1),
-                            vcombine_u16(next_mask_val0, next_mask_val1));
+      const int16x4_t next_mask_val0 =
+          vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask + mask_stride)));
+      const int16x4_t next_mask_val1 =
+          vreinterpret_s16_u16(vpaddl_u8(vld1_u8(mask + mask_stride * 3)));
+      final_val = vaddq_s16(vcombine_s16(mask_val0, mask_val1),
+                            vcombine_s16(next_mask_val0, next_mask_val1));
     } else {
-      final_val = vpaddlq_u8(vcombine_u8(mask_val0, mask_val1));
+      final_val = vreinterpretq_s16_u16(
+          vpaddlq_u8(vreinterpretq_u8_s16(vcombine_s16(mask_val0, mask_val1))));
     }
-    return vrshrq_n_u16(final_val, subsampling_y + 1);
+    return vrshrq_n_s16(final_val, subsampling_y + 1);
   }
   assert(subsampling_y == 0 && subsampling_x == 0);
-  const uint8x8_t mask_val0 = LoadLo4(mask, vdup_n_u8(0));
-  const uint8x8_t mask_val =
-      LoadHi4(mask + (mask_stride << subsampling_y), mask_val0);
-  return vmovl_u8(mask_val);
+  const uint8x8_t mask_val0 = Load4(mask);
+  const uint8x8_t mask_val = Load4<1>(mask + mask_stride, mask_val0);
+  return vreinterpretq_s16_u16(vmovl_u8(mask_val));
 }
 
 template <int subsampling_x, int subsampling_y>
-inline uint16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) {
+inline int16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) {
   if (subsampling_x == 1) {
-    uint16x8_t mask_val = vpaddlq_u8(vld1q_u8(mask));
+    int16x8_t mask_val = vreinterpretq_s16_u16(vpaddlq_u8(vld1q_u8(mask)));
     if (subsampling_y == 1) {
-      const uint16x8_t next_mask_val = vpaddlq_u8(vld1q_u8(mask + mask_stride));
-      mask_val = vaddq_u16(mask_val, next_mask_val);
+      const int16x8_t next_mask_val =
+          vreinterpretq_s16_u16(vpaddlq_u8(vld1q_u8(mask + mask_stride)));
+      mask_val = vaddq_s16(mask_val, next_mask_val);
     }
-    return vrshrq_n_u16(mask_val, 1 + subsampling_y);
+    return vrshrq_n_s16(mask_val, 1 + subsampling_y);
   }
   assert(subsampling_y == 0 && subsampling_x == 0);
   const uint8x8_t mask_val = vld1_u8(mask);
-  return vmovl_u8(mask_val);
+  return vreinterpretq_s16_u16(vmovl_u8(mask_val));
 }
 
-template <bool is_inter_intra>
-inline void WriteMaskBlendLine4x2(const uint16_t* const pred_0,
-                                  const ptrdiff_t pred_stride_0,
-                                  const uint16_t* const pred_1,
-                                  const ptrdiff_t pred_stride_1,
-                                  const uint16x8_t pred_mask_0,
-                                  const uint16x8_t pred_mask_1, uint8_t* dst,
+inline void WriteMaskBlendLine4x2(const int16_t* const pred_0,
+                                  const int16_t* const pred_1,
+                                  const int16x8_t pred_mask_0,
+                                  const int16x8_t pred_mask_1, uint8_t* dst,
                                   const ptrdiff_t dst_stride) {
-  const uint16x4_t pred_val_0_lo = vld1_u16(pred_0);
-  const uint16x4_t pred_val_0_hi = vld1_u16(pred_0 + pred_stride_0);
-  uint16x4_t pred_val_1_lo = vld1_u16(pred_1);
-  uint16x4_t pred_val_1_hi = vld1_u16(pred_1 + pred_stride_1);
-  uint8x8_t result;
-  if (is_inter_intra) {
-    // An offset to cancel offsets used in compound predictor generation
-    // that make intermediate computations non negative.
-    const uint16x8_t single_round_offset =
-        vdupq_n_u16((1 << kBitdepth8) + (1 << (kBitdepth8 - 1)));
-    // pred_0 and pred_1 are switched at the beginning with is_inter_intra.
-    // Clip3(prediction_0[x] - single_round_offset, 0, (1 << kBitdepth8) - 1)
-    const uint16x8_t pred_val_1 = vmovl_u8(vqmovn_u16(vqsubq_u16(
-        vcombine_u16(pred_val_1_lo, pred_val_1_hi), single_round_offset)));
-
-    const uint16x8_t pred_val_0 = vcombine_u16(pred_val_0_lo, pred_val_0_hi);
-    const uint16x8_t weighted_pred_0 = vmulq_u16(pred_val_0, pred_mask_0);
-    const uint16x8_t weighted_combo =
-        vmlaq_u16(weighted_pred_0, pred_mask_1, pred_val_1);
-    result = vrshrn_n_u16(weighted_combo, 6);
-  } else {
-    // int res = (mask_value * prediction_0[x] +
-    //      (64 - mask_value) * prediction_1[x]) >> 6;
-    const uint32x4_t weighted_pred_0_lo =
-        vmull_u16(vget_low_u16(pred_mask_0), pred_val_0_lo);
-    const uint32x4_t weighted_pred_0_hi =
-        vmull_u16(vget_high_u16(pred_mask_0), pred_val_0_hi);
-    const uint32x4_t weighted_combo_lo =
-        vmlal_u16(weighted_pred_0_lo, vget_low_u16(pred_mask_1), pred_val_1_lo);
-    const uint32x4_t weighted_combo_hi = vmlal_u16(
-        weighted_pred_0_hi, vget_high_u16(pred_mask_1), pred_val_1_hi);
-    // res -= compound_round_offset;
-    // dst[x] = static_cast<Pixel>(
-    //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
-    //         (1 << kBitdepth8) - 1));
-    const int16x8_t compound_round_offset =
-        vdupq_n_s16((1 << (kBitdepth8 + 4)) + (1 << (kBitdepth8 + 3)));
-    result = vqrshrun_n_s16(vsubq_s16(vreinterpretq_s16_u16(vcombine_u16(
-                                          vshrn_n_u32(weighted_combo_lo, 6),
-                                          vshrn_n_u32(weighted_combo_hi, 6))),
-                                      compound_round_offset),
-                            4);
-  }
+  const int16x4_t pred_val_0_lo = vld1_s16(pred_0);
+  const int16x4_t pred_val_0_hi = vld1_s16(pred_0 + 4);
+  const int16x4_t pred_val_1_lo = vld1_s16(pred_1);
+  const int16x4_t pred_val_1_hi = vld1_s16(pred_1 + 4);
+  // int res = (mask_value * prediction_0[x] +
+  //      (64 - mask_value) * prediction_1[x]) >> 6;
+  const int32x4_t weighted_pred_0_lo =
+      vmull_s16(vget_low_s16(pred_mask_0), pred_val_0_lo);
+  const int32x4_t weighted_pred_0_hi =
+      vmull_s16(vget_high_s16(pred_mask_0), pred_val_0_hi);
+  const int32x4_t weighted_combo_lo =
+      vmlal_s16(weighted_pred_0_lo, vget_low_s16(pred_mask_1), pred_val_1_lo);
+  const int32x4_t weighted_combo_hi =
+      vmlal_s16(weighted_pred_0_hi, vget_high_s16(pred_mask_1), pred_val_1_hi);
+  // dst[x] = static_cast<Pixel>(
+  //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
+  //         (1 << kBitdepth8) - 1));
+  const uint8x8_t result =
+      vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6),
+                                  vshrn_n_s32(weighted_combo_hi, 6)),
+                     4);
   StoreLo4(dst, result);
   StoreHi4(dst + dst_stride, result);
 }
 
-template <bool is_inter_intra, int subsampling_x, int subsampling_y>
-inline void MaskBlending4x4_NEON(const uint16_t* pred_0,
-                                 const ptrdiff_t prediction_stride_0,
-                                 const uint16_t* pred_1,
-                                 const ptrdiff_t prediction_stride_1,
+template <int subsampling_x, int subsampling_y>
+inline void MaskBlending4x4_NEON(const int16_t* pred_0, const int16_t* pred_1,
                                  const uint8_t* mask,
                                  const ptrdiff_t mask_stride, uint8_t* dst,
                                  const ptrdiff_t dst_stride) {
-  const uint16x8_t mask_inverter = vdupq_n_u16(64);
-  uint16x8_t pred_mask_0 =
+  const int16x8_t mask_inverter = vdupq_n_s16(64);
+  int16x8_t pred_mask_0 =
       GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
-  uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
-  WriteMaskBlendLine4x2<is_inter_intra>(pred_0, prediction_stride_0, pred_1,
-                                        prediction_stride_1, pred_mask_0,
-                                        pred_mask_1, dst, dst_stride);
-  pred_0 += prediction_stride_0 << 1;
-  pred_1 += prediction_stride_1 << 1;
+  int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
+  WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+                        dst_stride);
+  // TODO(b/150461164): Arm tends to do better with load(val); val += stride
+  // It may be possible to turn this into a loop with a templated height.
+  pred_0 += 4 << 1;
+  pred_1 += 4 << 1;
   mask += mask_stride << (1 + subsampling_y);
   dst += dst_stride << 1;
 
   pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
-  pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
-  WriteMaskBlendLine4x2<is_inter_intra>(pred_0, prediction_stride_0, pred_1,
-                                        prediction_stride_1, pred_mask_0,
-                                        pred_mask_1, dst, dst_stride);
+  pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
+  WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+                        dst_stride);
 }
 
-template <bool is_inter_intra, int subsampling_x, int subsampling_y>
-inline void MaskBlending4xH_NEON(const uint16_t* pred_0,
-                                 const ptrdiff_t pred_stride_0,
-                                 const int height, const uint16_t* pred_1,
-                                 const ptrdiff_t pred_stride_1,
+template <int subsampling_x, int subsampling_y>
+inline void MaskBlending4xH_NEON(const int16_t* pred_0, const int16_t* pred_1,
                                  const uint8_t* const mask_ptr,
-                                 const ptrdiff_t mask_stride, uint8_t* dst,
-                                 const ptrdiff_t dst_stride) {
+                                 const ptrdiff_t mask_stride, const int height,
+                                 uint8_t* dst, const ptrdiff_t dst_stride) {
   const uint8_t* mask = mask_ptr;
   if (height == 4) {
-    MaskBlending4x4_NEON<is_inter_intra, subsampling_x, subsampling_y>(
-        pred_0, pred_stride_0, pred_1, pred_stride_1, mask, mask_stride, dst,
-        dst_stride);
+    MaskBlending4x4_NEON<subsampling_x, subsampling_y>(
+        pred_0, pred_1, mask, mask_stride, dst, dst_stride);
     return;
   }
-  const uint16x8_t mask_inverter = vdupq_n_u16(64);
+  const int16x8_t mask_inverter = vdupq_n_s16(64);
   int y = 0;
   do {
-    uint16x8_t pred_mask_0 =
+    int16x8_t pred_mask_0 =
         GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
-    uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
+    int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
 
-    WriteMaskBlendLine4x2<is_inter_intra>(pred_0, pred_stride_0, pred_1,
-                                          pred_stride_1, pred_mask_0,
-                                          pred_mask_1, dst, dst_stride);
-    pred_0 += pred_stride_0 << 1;
-    pred_1 += pred_stride_1 << 1;
+    WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+                          dst_stride);
+    pred_0 += 4 << 1;
+    pred_1 += 4 << 1;
     mask += mask_stride << (1 + subsampling_y);
     dst += dst_stride << 1;
 
     pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
-    pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
-    WriteMaskBlendLine4x2<is_inter_intra>(pred_0, pred_stride_0, pred_1,
-                                          pred_stride_1, pred_mask_0,
-                                          pred_mask_1, dst, dst_stride);
-    pred_0 += pred_stride_0 << 1;
-    pred_1 += pred_stride_1 << 1;
+    pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
+    WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+                          dst_stride);
+    pred_0 += 4 << 1;
+    pred_1 += 4 << 1;
     mask += mask_stride << (1 + subsampling_y);
     dst += dst_stride << 1;
 
     pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
-    pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
-    WriteMaskBlendLine4x2<is_inter_intra>(pred_0, pred_stride_0, pred_1,
-                                          pred_stride_1, pred_mask_0,
-                                          pred_mask_1, dst, dst_stride);
-    pred_0 += pred_stride_0 << 1;
-    pred_1 += pred_stride_1 << 1;
+    pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
+    WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+                          dst_stride);
+    pred_0 += 4 << 1;
+    pred_1 += 4 << 1;
     mask += mask_stride << (1 + subsampling_y);
     dst += dst_stride << 1;
 
     pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
-    pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
-    WriteMaskBlendLine4x2<is_inter_intra>(pred_0, pred_stride_0, pred_1,
-                                          pred_stride_1, pred_mask_0,
-                                          pred_mask_1, dst, dst_stride);
-    pred_0 += pred_stride_0 << 1;
-    pred_1 += pred_stride_1 << 1;
+    pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
+    WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+                          dst_stride);
+    pred_0 += 4 << 1;
+    pred_1 += 4 << 1;
     mask += mask_stride << (1 + subsampling_y);
     dst += dst_stride << 1;
     y += 8;
   } while (y < height);
 }
 
-template <bool is_inter_intra, int subsampling_x, int subsampling_y>
-inline void MaskBlend_NEON(
-    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
-    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
-    const uint8_t* const mask_ptr, const ptrdiff_t mask_stride, const int width,
-    const int height, void* dest, const ptrdiff_t dst_stride) {
-  uint8_t* dst = reinterpret_cast<uint8_t*>(dest);
-  const uint16_t* pred_0 = is_inter_intra ? prediction_1 : prediction_0;
-  const uint16_t* pred_1 = is_inter_intra ? prediction_0 : prediction_1;
-  const ptrdiff_t pred_stride_0 =
-      is_inter_intra ? prediction_stride_1 : prediction_stride_0;
-  const ptrdiff_t pred_stride_1 =
-      is_inter_intra ? prediction_stride_0 : prediction_stride_1;
+template <int subsampling_x, int subsampling_y>
+inline void MaskBlend_NEON(const void* prediction_0, const void* prediction_1,
+                           const ptrdiff_t /*prediction_stride_1*/,
+                           const uint8_t* const mask_ptr,
+                           const ptrdiff_t mask_stride, const int width,
+                           const int height, void* dest,
+                           const ptrdiff_t dst_stride) {
+  auto* dst = static_cast<uint8_t*>(dest);
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   if (width == 4) {
-    MaskBlending4xH_NEON<is_inter_intra, subsampling_x, subsampling_y>(
-        pred_0, pred_stride_0, height, pred_1, pred_stride_1, mask_ptr,
-        mask_stride, dst, dst_stride);
+    MaskBlending4xH_NEON<subsampling_x, subsampling_y>(
+        pred_0, pred_1, mask_ptr, mask_stride, height, dst, dst_stride);
     return;
   }
   const uint8_t* mask = mask_ptr;
-  const uint16x8_t mask_inverter = vdupq_n_u16(64);
+  const int16x8_t mask_inverter = vdupq_n_s16(64);
   int y = 0;
   do {
     int x = 0;
     do {
-      const uint16x8_t pred_mask_0 = GetMask8<subsampling_x, subsampling_y>(
+      const int16x8_t pred_mask_0 = GetMask8<subsampling_x, subsampling_y>(
           mask + (x << subsampling_x), mask_stride);
       // 64 - mask
-      const uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
-      const uint16x8_t pred_val_0 = vld1q_u16(pred_0 + x);
-      uint16x8_t pred_val_1 = vld1q_u16(pred_1 + x);
-      if (is_inter_intra) {
-        // An offset to cancel offsets used in compound predictor generation
-        // that make intermediate computations non negative.
-        const uint16x8_t single_round_offset =
-            vdupq_n_u16((1 << kBitdepth8) + (1 << (kBitdepth8 - 1)));
-        pred_val_1 =
-            vmovl_u8(vqmovn_u16(vqsubq_u16(pred_val_1, single_round_offset)));
-      }
+      const int16x8_t pred_mask_1 = vsubq_s16(mask_inverter, pred_mask_0);
+      const int16x8_t pred_val_0 = vld1q_s16(pred_0 + x);
+      const int16x8_t pred_val_1 = vld1q_s16(pred_1 + x);
       uint8x8_t result;
-      if (is_inter_intra) {
-        const uint16x8_t weighted_pred_0 = vmulq_u16(pred_mask_0, pred_val_0);
-        // weighted_pred0 + weighted_pred1
-        const uint16x8_t weighted_combo =
-            vmlaq_u16(weighted_pred_0, pred_mask_1, pred_val_1);
-        result = vrshrn_n_u16(weighted_combo, 6);
-      } else {
-        // int res = (mask_value * prediction_0[x] +
-        //      (64 - mask_value) * prediction_1[x]) >> 6;
-        const uint32x4_t weighted_pred_0_lo =
-            vmull_u16(vget_low_u16(pred_mask_0), vget_low_u16(pred_val_0));
-        const uint32x4_t weighted_pred_0_hi =
-            vmull_u16(vget_high_u16(pred_mask_0), vget_high_u16(pred_val_0));
-        const uint32x4_t weighted_combo_lo =
-            vmlal_u16(weighted_pred_0_lo, vget_low_u16(pred_mask_1),
-                      vget_low_u16(pred_val_1));
-        const uint32x4_t weighted_combo_hi =
-            vmlal_u16(weighted_pred_0_hi, vget_high_u16(pred_mask_1),
-                      vget_high_u16(pred_val_1));
+      // int res = (mask_value * prediction_0[x] +
+      //      (64 - mask_value) * prediction_1[x]) >> 6;
+      const int32x4_t weighted_pred_0_lo =
+          vmull_s16(vget_low_s16(pred_mask_0), vget_low_s16(pred_val_0));
+      const int32x4_t weighted_pred_0_hi =
+          vmull_s16(vget_high_s16(pred_mask_0), vget_high_s16(pred_val_0));
+      const int32x4_t weighted_combo_lo =
+          vmlal_s16(weighted_pred_0_lo, vget_low_s16(pred_mask_1),
+                    vget_low_s16(pred_val_1));
+      const int32x4_t weighted_combo_hi =
+          vmlal_s16(weighted_pred_0_hi, vget_high_s16(pred_mask_1),
+                    vget_high_s16(pred_val_1));
 
-        const int16x8_t compound_round_offset =
-            vdupq_n_s16((1 << (kBitdepth8 + 4)) + (1 << (kBitdepth8 + 3)));
-        // res -= compound_round_offset;
-        // dst[x] = static_cast<Pixel>(
-        //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
-        //           (1 << kBitdepth8) - 1));
-        result =
-            vqrshrun_n_s16(vsubq_s16(vreinterpretq_s16_u16(vcombine_u16(
-                                         vshrn_n_u32(weighted_combo_lo, 6),
-                                         vshrn_n_u32(weighted_combo_hi, 6))),
-                                     compound_round_offset),
-                           4);
-      }
+      // dst[x] = static_cast<Pixel>(
+      //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
+      //           (1 << kBitdepth8) - 1));
+      result = vqrshrun_n_s16(vcombine_s16(vshrn_n_s32(weighted_combo_lo, 6),
+                                           vshrn_n_s32(weighted_combo_hi, 6)),
+                              4);
       vst1_u8(dst + x, result);
 
       x += 8;
     } while (x < width);
     dst += dst_stride;
-    pred_0 += pred_stride_0;
-    pred_1 += pred_stride_1;
+    pred_0 += width;
+    pred_1 += width;
+    mask += mask_stride << subsampling_y;
+  } while (++y < height);
+}
+
+// TODO(b/150461164): This is much faster for inter_intra (input is Pixel
+// values) but regresses compound versions (input is int16_t). Try to
+// consolidate these.
+template <int subsampling_x, int subsampling_y>
+inline uint8x8_t GetInterIntraMask4x2(const uint8_t* mask,
+                                      ptrdiff_t mask_stride) {
+  if (subsampling_x == 1) {
+    const uint8x8_t mask_val =
+        vpadd_u8(vld1_u8(mask), vld1_u8(mask + (mask_stride << subsampling_y)));
+    if (subsampling_y == 1) {
+      const uint8x8_t next_mask_val = vpadd_u8(vld1_u8(mask + mask_stride),
+                                               vld1_u8(mask + mask_stride * 3));
+
+      // Use a saturating add to work around the case where all |mask| values
+      // are 64. Together with the rounding shift this ensures the correct
+      // result.
+      const uint8x8_t sum = vqadd_u8(mask_val, next_mask_val);
+      return vrshr_n_u8(sum, /*subsampling_x=*/1 + subsampling_y);
+    }
+
+    return vrshr_n_u8(mask_val, /*subsampling_x=*/1);
+  }
+
+  assert(subsampling_y == 0 && subsampling_x == 0);
+  const uint8x8_t mask_val0 = Load4(mask);
+  // TODO(b/150461164): Investigate the source of |mask| and see if the stride
+  // can be removed.
+  // TODO(b/150461164): The unit tests start at 8x8. Does this get run?
+  return Load4<1>(mask + mask_stride, mask_val0);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline uint8x8_t GetInterIntraMask8(const uint8_t* mask,
+                                    ptrdiff_t mask_stride) {
+  if (subsampling_x == 1) {
+    const uint8x16_t mask_val = vld1q_u8(mask);
+    const uint8x8_t mask_paired =
+        vpadd_u8(vget_low_u8(mask_val), vget_high_u8(mask_val));
+    if (subsampling_y == 1) {
+      const uint8x16_t next_mask_val = vld1q_u8(mask + mask_stride);
+      const uint8x8_t next_mask_paired =
+          vpadd_u8(vget_low_u8(next_mask_val), vget_high_u8(next_mask_val));
+
+      // Use a saturating add to work around the case where all |mask| values
+      // are 64. Together with the rounding shift this ensures the correct
+      // result.
+      const uint8x8_t sum = vqadd_u8(mask_paired, next_mask_paired);
+      return vrshr_n_u8(sum, /*subsampling_x=*/1 + subsampling_y);
+    }
+
+    return vrshr_n_u8(mask_paired, /*subsampling_x=*/1);
+  }
+
+  assert(subsampling_y == 0 && subsampling_x == 0);
+  return vld1_u8(mask);
+}
+
+inline void InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t* const pred_0,
+                                                uint8_t* const pred_1,
+                                                const ptrdiff_t pred_stride_1,
+                                                const uint8x8_t pred_mask_0,
+                                                const uint8x8_t pred_mask_1) {
+  const uint8x8_t pred_val_0 = vld1_u8(pred_0);
+  uint8x8_t pred_val_1 = Load4(pred_1);
+  pred_val_1 = Load4<1>(pred_1 + pred_stride_1, pred_val_1);
+
+  const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0);
+  const uint16x8_t weighted_combo =
+      vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1);
+  const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6);
+  StoreLo4(pred_1, result);
+  StoreHi4(pred_1 + pred_stride_1, result);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void InterIntraMaskBlending8bpp4x4_NEON(const uint8_t* pred_0,
+                                               uint8_t* pred_1,
+                                               const ptrdiff_t pred_stride_1,
+                                               const uint8_t* mask,
+                                               const ptrdiff_t mask_stride) {
+  const uint8x8_t mask_inverter = vdup_n_u8(64);
+  uint8x8_t pred_mask_1 =
+      GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+  uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1);
+  InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1,
+                                      pred_mask_0, pred_mask_1);
+  pred_0 += 4 << 1;
+  pred_1 += pred_stride_1 << 1;
+  mask += mask_stride << (1 + subsampling_y);
+
+  pred_mask_1 =
+      GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+  pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1);
+  InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1,
+                                      pred_mask_0, pred_mask_1);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void InterIntraMaskBlending8bpp4xH_NEON(
+    const uint8_t* pred_0, uint8_t* pred_1, const ptrdiff_t pred_stride_1,
+    const uint8_t* mask, const ptrdiff_t mask_stride, const int height) {
+  if (height == 4) {
+    InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>(
+        pred_0, pred_1, pred_stride_1, mask, mask_stride);
+    return;
+  }
+  int y = 0;
+  do {
+    InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>(
+        pred_0, pred_1, pred_stride_1, mask, mask_stride);
+    pred_0 += 4 << 2;
+    pred_1 += pred_stride_1 << 2;
+    mask += mask_stride << (2 + subsampling_y);
+
+    InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>(
+        pred_0, pred_1, pred_stride_1, mask, mask_stride);
+    pred_0 += 4 << 2;
+    pred_1 += pred_stride_1 << 2;
+    mask += mask_stride << (2 + subsampling_y);
+    y += 8;
+  } while (y < height);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void InterIntraMaskBlend8bpp_NEON(const uint8_t* prediction_0,
+                                         uint8_t* prediction_1,
+                                         const ptrdiff_t prediction_stride_1,
+                                         const uint8_t* const mask_ptr,
+                                         const ptrdiff_t mask_stride,
+                                         const int width, const int height) {
+  if (width == 4) {
+    InterIntraMaskBlending8bpp4xH_NEON<subsampling_x, subsampling_y>(
+        prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride,
+        height);
+    return;
+  }
+  const uint8_t* mask = mask_ptr;
+  const uint8x8_t mask_inverter = vdup_n_u8(64);
+  int y = 0;
+  do {
+    int x = 0;
+    do {
+      // TODO(b/150461164): Consider a 16 wide specialization (at least for the
+      // unsampled version) to take advantage of vld1q_u8().
+      const uint8x8_t pred_mask_1 =
+          GetInterIntraMask8<subsampling_x, subsampling_y>(
+              mask + (x << subsampling_x), mask_stride);
+      // 64 - mask
+      const uint8x8_t pred_mask_0 = vsub_u8(mask_inverter, pred_mask_1);
+      const uint8x8_t pred_val_0 = vld1_u8(prediction_0);
+      prediction_0 += 8;
+      const uint8x8_t pred_val_1 = vld1_u8(prediction_1 + x);
+      const uint16x8_t weighted_pred_0 = vmull_u8(pred_mask_0, pred_val_0);
+      // weighted_pred0 + weighted_pred1
+      const uint16x8_t weighted_combo =
+          vmlal_u8(weighted_pred_0, pred_mask_1, pred_val_1);
+      const uint8x8_t result = vrshrn_n_u16(weighted_combo, 6);
+      vst1_u8(prediction_1 + x, result);
+
+      x += 8;
+    } while (x < width);
+    prediction_1 += prediction_stride_1;
     mask += mask_stride << subsampling_y;
   } while (++y < height);
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
-  dsp->mask_blend[0][0] = MaskBlend_NEON<false, 0, 0>;
-  dsp->mask_blend[1][0] = MaskBlend_NEON<false, 1, 0>;
-  dsp->mask_blend[2][0] = MaskBlend_NEON<false, 1, 1>;
-  dsp->mask_blend[0][1] = MaskBlend_NEON<true, 0, 0>;
-  dsp->mask_blend[1][1] = MaskBlend_NEON<true, 1, 0>;
-  dsp->mask_blend[2][1] = MaskBlend_NEON<true, 1, 1>;
+  dsp->mask_blend[0][0] = MaskBlend_NEON<0, 0>;
+  dsp->mask_blend[1][0] = MaskBlend_NEON<1, 0>;
+  dsp->mask_blend[2][0] = MaskBlend_NEON<1, 1>;
+  // The is_inter_intra index of mask_blend[][] is replaced by
+  // inter_intra_mask_blend_8bpp[] in 8-bit.
+  dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_NEON<0, 0>;
+  dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_NEON<1, 0>;
+  dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_NEON<1, 1>;
 }
 
 }  // namespace
@@ -321,7 +433,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 
 namespace libgav1 {
 namespace dsp {
diff --git a/libgav1/src/dsp/arm/mask_blend_neon.h b/libgav1/src/dsp/arm/mask_blend_neon.h
index 3ac7cae..3829274 100644
--- a/libgav1/src/dsp/arm/mask_blend_neon.h
+++ b/libgav1/src/dsp/arm/mask_blend_neon.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_
 #define LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -30,12 +30,12 @@
 }  // namespace libgav1
 
 #if LIBGAV1_ENABLE_NEON
-#define LIBGAV1_Dsp8bpp_MaskBlend444 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_MaskBlend422 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_MaskBlend420 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_MaskBlendInterIntra444 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_MaskBlendInterIntra422 LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_MaskBlendInterIntra420 LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_MaskBlend444 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_MaskBlend422 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_MaskBlend420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp444 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp422 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp420 LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_
diff --git a/libgav1/src/dsp/arm/motion_field_projection_neon.cc b/libgav1/src/dsp/arm/motion_field_projection_neon.cc
new file mode 100644
index 0000000..8caba7d
--- /dev/null
+++ b/libgav1/src/dsp/arm/motion_field_projection_neon.cc
@@ -0,0 +1,393 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/motion_field_projection.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+#include "src/utils/types.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+inline int16x8_t LoadDivision(const int8x8x2_t division_table,
+                              const int8x8_t reference_offset) {
+  const int8x8_t kOne = vcreate_s8(0x0100010001000100);
+  const int8x16_t kOneQ = vcombine_s8(kOne, kOne);
+  const int8x8_t t = vadd_s8(reference_offset, reference_offset);
+  const int8x8x2_t tt = vzip_s8(t, t);
+  const int8x16_t t1 = vcombine_s8(tt.val[0], tt.val[1]);
+  const int8x16_t idx = vaddq_s8(t1, kOneQ);
+  const int8x8_t idx_low = vget_low_s8(idx);
+  const int8x8_t idx_high = vget_high_s8(idx);
+  const int16x4_t d0 = vreinterpret_s16_s8(vtbl2_s8(division_table, idx_low));
+  const int16x4_t d1 = vreinterpret_s16_s8(vtbl2_s8(division_table, idx_high));
+  return vcombine_s16(d0, d1);
+}
+
+inline int16x4_t MvProjection(const int16x4_t mv, const int16x4_t denominator,
+                              const int numerator) {
+  const int32x4_t m0 = vmull_s16(mv, denominator);
+  const int32x4_t m = vmulq_n_s32(m0, numerator);
+  // Add the sign (0 or -1) to round towards zero.
+  const int32x4_t add_sign = vsraq_n_s32(m, m, 31);
+  return vqrshrn_n_s32(add_sign, 14);
+}
+
+inline int16x8_t MvProjectionClip(const int16x8_t mv,
+                                  const int16x8_t denominator,
+                                  const int numerator) {
+  const int16x4_t mv0 = vget_low_s16(mv);
+  const int16x4_t mv1 = vget_high_s16(mv);
+  const int16x4_t s0 = MvProjection(mv0, vget_low_s16(denominator), numerator);
+  const int16x4_t s1 = MvProjection(mv1, vget_high_s16(denominator), numerator);
+  const int16x8_t projection = vcombine_s16(s0, s1);
+  const int16x8_t projection_mv_clamp = vdupq_n_s16(kProjectionMvClamp);
+  const int16x8_t clamp = vminq_s16(projection, projection_mv_clamp);
+  return vmaxq_s16(clamp, vnegq_s16(projection_mv_clamp));
+}
+
+inline int8x8_t Project_NEON(const int16x8_t delta, const int16x8_t dst_sign) {
+  // Add 63 to negative delta so that it shifts towards zero.
+  const int16x8_t delta_sign = vshrq_n_s16(delta, 15);
+  const uint16x8_t delta_u = vreinterpretq_u16_s16(delta);
+  const uint16x8_t delta_sign_u = vreinterpretq_u16_s16(delta_sign);
+  const uint16x8_t delta_adjust_u = vsraq_n_u16(delta_u, delta_sign_u, 10);
+  const int16x8_t delta_adjust = vreinterpretq_s16_u16(delta_adjust_u);
+  const int16x8_t offset0 = vshrq_n_s16(delta_adjust, 6);
+  const int16x8_t offset1 = veorq_s16(offset0, dst_sign);
+  const int16x8_t offset2 = vsubq_s16(offset1, dst_sign);
+  return vqmovn_s16(offset2);
+}
+
+inline void GetPosition(
+    const int8x8x2_t division_table, const MotionVector* const mv,
+    const int numerator, const int x8_start, const int x8_end, const int x8,
+    const int8x8_t r_offsets, const int8x8_t source_reference_type8,
+    const int8x8_t skip_r, const int8x8_t y8_floor8, const int8x8_t y8_ceiling8,
+    const int16x8_t d_sign, const int delta, int8x8_t* const r,
+    int8x8_t* const position_y8, int8x8_t* const position_x8,
+    int64_t* const skip_64, int32x4_t mvs[2]) {
+  const auto* const mv_int = reinterpret_cast<const int32_t*>(mv + x8);
+  *r = vtbl1_s8(r_offsets, source_reference_type8);
+  const int16x8_t denorm = LoadDivision(division_table, source_reference_type8);
+  int16x8_t projection_mv[2];
+  mvs[0] = vld1q_s32(mv_int + 0);
+  mvs[1] = vld1q_s32(mv_int + 4);
+  // Deinterlace x and y components
+  const int16x8_t mv0 = vreinterpretq_s16_s32(mvs[0]);
+  const int16x8_t mv1 = vreinterpretq_s16_s32(mvs[1]);
+  const int16x8x2_t mv_yx = vuzpq_s16(mv0, mv1);
+  // numerator could be 0.
+  projection_mv[0] = MvProjectionClip(mv_yx.val[0], denorm, numerator);
+  projection_mv[1] = MvProjectionClip(mv_yx.val[1], denorm, numerator);
+  // Do not update the motion vector if the block position is not valid or
+  // if position_x8 is outside the current range of x8_start and x8_end.
+  // Note that position_y8 will always be within the range of y8_start and
+  // y8_end.
+  // After subtracting the base, valid projections are within 8-bit.
+  *position_y8 = Project_NEON(projection_mv[0], d_sign);
+  const int8x8_t position_x = Project_NEON(projection_mv[1], d_sign);
+  const int8x8_t k01234567 = vcreate_s8(uint64_t{0x0706050403020100});
+  *position_x8 = vqadd_s8(position_x, k01234567);
+  const int8x16_t position_xy = vcombine_s8(*position_x8, *position_y8);
+  const int x8_floor = std::max(
+      x8_start - x8, delta - kProjectionMvMaxHorizontalOffset);  // [-8, 8]
+  const int x8_ceiling = std::min(
+      x8_end - x8, delta + 8 + kProjectionMvMaxHorizontalOffset);  // [0, 16]
+  const int8x8_t x8_floor8 = vdup_n_s8(x8_floor);
+  const int8x8_t x8_ceiling8 = vdup_n_s8(x8_ceiling);
+  const int8x16_t floor_xy = vcombine_s8(x8_floor8, y8_floor8);
+  const int8x16_t ceiling_xy = vcombine_s8(x8_ceiling8, y8_ceiling8);
+  const uint8x16_t underflow = vcltq_s8(position_xy, floor_xy);
+  const uint8x16_t overflow = vcgeq_s8(position_xy, ceiling_xy);
+  const int8x16_t out = vreinterpretq_s8_u8(vorrq_u8(underflow, overflow));
+  const int8x8_t skip_low = vorr_s8(skip_r, vget_low_s8(out));
+  const int8x8_t skip = vorr_s8(skip_low, vget_high_s8(out));
+  *skip_64 = vget_lane_s64(vreinterpret_s64_s8(skip), 0);
+}
+
+template <int idx>
+inline void Store(const int16x8_t position, const int8x8_t reference_offset,
+                  const int32x4_t mv, int8_t* dst_reference_offset,
+                  MotionVector* dst_mv) {
+  const ptrdiff_t offset = vgetq_lane_s16(position, idx);
+  auto* const d_mv = reinterpret_cast<int32_t*>(&dst_mv[offset]);
+  vst1q_lane_s32(d_mv, mv, idx & 3);
+  vst1_lane_s8(&dst_reference_offset[offset], reference_offset, idx);
+}
+
+template <int idx>
+inline void CheckStore(const int8_t* skips, const int16x8_t position,
+                       const int8x8_t reference_offset, const int32x4_t mv,
+                       int8_t* dst_reference_offset, MotionVector* dst_mv) {
+  if (skips[idx] == 0) {
+    Store<idx>(position, reference_offset, mv, dst_reference_offset, dst_mv);
+  }
+}
+
+// 7.9.2.
+void MotionFieldProjectionKernel_NEON(const ReferenceInfo& reference_info,
+                                      const int reference_to_current_with_sign,
+                                      const int dst_sign, const int y8_start,
+                                      const int y8_end, const int x8_start,
+                                      const int x8_end,
+                                      TemporalMotionField* const motion_field) {
+  const ptrdiff_t stride = motion_field->mv.columns();
+  // The column range has to be offset by kProjectionMvMaxHorizontalOffset since
+  // coordinates in that range could end up being position_x8 because of
+  // projection.
+  const int adjusted_x8_start =
+      std::max(x8_start - kProjectionMvMaxHorizontalOffset, 0);
+  const int adjusted_x8_end = std::min(
+      x8_end + kProjectionMvMaxHorizontalOffset, static_cast<int>(stride));
+  const int adjusted_x8_end8 = adjusted_x8_end & ~7;
+  const int leftover = adjusted_x8_end - adjusted_x8_end8;
+  const int8_t* const reference_offsets =
+      reference_info.relative_distance_to.data();
+  const bool* const skip_references = reference_info.skip_references.data();
+  const int16_t* const projection_divisions =
+      reference_info.projection_divisions.data();
+  const ReferenceFrameType* source_reference_types =
+      &reference_info.motion_field_reference_frame[y8_start][0];
+  const MotionVector* mv = &reference_info.motion_field_mv[y8_start][0];
+  int8_t* dst_reference_offset = motion_field->reference_offset[y8_start];
+  MotionVector* dst_mv = motion_field->mv[y8_start];
+  const int16x8_t d_sign = vdupq_n_s16(dst_sign);
+
+  static_assert(sizeof(int8_t) == sizeof(bool), "");
+  static_assert(sizeof(int8_t) == sizeof(ReferenceFrameType), "");
+  static_assert(sizeof(int32_t) == sizeof(MotionVector), "");
+  assert(dst_sign == 0 || dst_sign == -1);
+  assert(stride == motion_field->reference_offset.columns());
+  assert((y8_start & 7) == 0);
+  assert((adjusted_x8_start & 7) == 0);
+  // The final position calculation is represented with int16_t. Valid
+  // position_y8 from its base is at most 7. After considering the horizontal
+  // offset which is at most |stride - 1|, we have the following assertion,
+  // which means this optimization works for frame width up to 32K (each
+  // position is a 8x8 block).
+  assert(8 * stride <= 32768);
+  const int8x8_t skip_reference =
+      vld1_s8(reinterpret_cast<const int8_t*>(skip_references));
+  const int8x8_t r_offsets = vld1_s8(reference_offsets);
+  const int8x16_t table = vreinterpretq_s8_s16(vld1q_s16(projection_divisions));
+  int8x8x2_t division_table;
+  division_table.val[0] = vget_low_s8(table);
+  division_table.val[1] = vget_high_s8(table);
+
+  int y8 = y8_start;
+  do {
+    const int y8_floor = (y8 & ~7) - y8;                         // [-7, 0]
+    const int y8_ceiling = std::min(y8_end - y8, y8_floor + 8);  // [1, 8]
+    const int8x8_t y8_floor8 = vdup_n_s8(y8_floor);
+    const int8x8_t y8_ceiling8 = vdup_n_s8(y8_ceiling);
+    int x8;
+
+    for (x8 = adjusted_x8_start; x8 < adjusted_x8_end8; x8 += 8) {
+      const int8x8_t source_reference_type8 =
+          vld1_s8(reinterpret_cast<const int8_t*>(source_reference_types + x8));
+      const int8x8_t skip_r = vtbl1_s8(skip_reference, source_reference_type8);
+      const int64_t early_skip = vget_lane_s64(vreinterpret_s64_s8(skip_r), 0);
+      // Early termination #1 if all are skips. Chance is typically ~30-40%.
+      if (early_skip == -1) continue;
+      int64_t skip_64;
+      int8x8_t r, position_x8, position_y8;
+      int32x4_t mvs[2];
+      GetPosition(division_table, mv, reference_to_current_with_sign, x8_start,
+                  x8_end, x8, r_offsets, source_reference_type8, skip_r,
+                  y8_floor8, y8_ceiling8, d_sign, 0, &r, &position_y8,
+                  &position_x8, &skip_64, mvs);
+      // Early termination #2 if all are skips.
+      // Chance is typically ~15-25% after Early termination #1.
+      if (skip_64 == -1) continue;
+      const int16x8_t p_y = vmovl_s8(position_y8);
+      const int16x8_t p_x = vmovl_s8(position_x8);
+      const int16x8_t pos = vmlaq_n_s16(p_x, p_y, stride);
+      const int16x8_t position = vaddq_s16(pos, vdupq_n_s16(x8));
+      if (skip_64 == 0) {
+        // Store all. Chance is typically ~70-85% after Early termination #2.
+        Store<0>(position, r, mvs[0], dst_reference_offset, dst_mv);
+        Store<1>(position, r, mvs[0], dst_reference_offset, dst_mv);
+        Store<2>(position, r, mvs[0], dst_reference_offset, dst_mv);
+        Store<3>(position, r, mvs[0], dst_reference_offset, dst_mv);
+        Store<4>(position, r, mvs[1], dst_reference_offset, dst_mv);
+        Store<5>(position, r, mvs[1], dst_reference_offset, dst_mv);
+        Store<6>(position, r, mvs[1], dst_reference_offset, dst_mv);
+        Store<7>(position, r, mvs[1], dst_reference_offset, dst_mv);
+      } else {
+        // Check and store each.
+        // Chance is typically ~15-30% after Early termination #2.
+        // The compiler is smart enough to not create the local buffer skips[].
+        int8_t skips[8];
+        memcpy(skips, &skip_64, sizeof(skips));
+        CheckStore<0>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
+        CheckStore<1>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
+        CheckStore<2>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
+        CheckStore<3>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
+        CheckStore<4>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
+        CheckStore<5>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
+        CheckStore<6>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
+        CheckStore<7>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
+      }
+    }
+
+    // The following leftover processing cannot be moved out of the do...while
+    // loop. Doing so may change the result storing orders of the same position.
+    if (leftover > 0) {
+      // Use SIMD only when leftover is at least 4, and there are at least 8
+      // elements in a row.
+      if (leftover >= 4 && adjusted_x8_start < adjusted_x8_end8) {
+        // Process the last 8 elements to avoid loading invalid memory. Some
+        // elements may have been processed in the above loop, which is OK.
+        const int delta = 8 - leftover;
+        x8 = adjusted_x8_end - 8;
+        const int8x8_t source_reference_type8 = vld1_s8(
+            reinterpret_cast<const int8_t*>(source_reference_types + x8));
+        const int8x8_t skip_r =
+            vtbl1_s8(skip_reference, source_reference_type8);
+        const int64_t early_skip =
+            vget_lane_s64(vreinterpret_s64_s8(skip_r), 0);
+        // Early termination #1 if all are skips.
+        if (early_skip != -1) {
+          int64_t skip_64;
+          int8x8_t r, position_x8, position_y8;
+          int32x4_t mvs[2];
+          GetPosition(division_table, mv, reference_to_current_with_sign,
+                      x8_start, x8_end, x8, r_offsets, source_reference_type8,
+                      skip_r, y8_floor8, y8_ceiling8, d_sign, delta, &r,
+                      &position_y8, &position_x8, &skip_64, mvs);
+          // Early termination #2 if all are skips.
+          if (skip_64 != -1) {
+            const int16x8_t p_y = vmovl_s8(position_y8);
+            const int16x8_t p_x = vmovl_s8(position_x8);
+            const int16x8_t pos = vmlaq_n_s16(p_x, p_y, stride);
+            const int16x8_t position = vaddq_s16(pos, vdupq_n_s16(x8));
+            // Store up to 7 elements since leftover is at most 7.
+            if (skip_64 == 0) {
+              // Store all.
+              Store<1>(position, r, mvs[0], dst_reference_offset, dst_mv);
+              Store<2>(position, r, mvs[0], dst_reference_offset, dst_mv);
+              Store<3>(position, r, mvs[0], dst_reference_offset, dst_mv);
+              Store<4>(position, r, mvs[1], dst_reference_offset, dst_mv);
+              Store<5>(position, r, mvs[1], dst_reference_offset, dst_mv);
+              Store<6>(position, r, mvs[1], dst_reference_offset, dst_mv);
+              Store<7>(position, r, mvs[1], dst_reference_offset, dst_mv);
+            } else {
+              // Check and store each.
+              // The compiler is smart enough to not create the local buffer
+              // skips[].
+              int8_t skips[8];
+              memcpy(skips, &skip_64, sizeof(skips));
+              CheckStore<1>(skips, position, r, mvs[0], dst_reference_offset,
+                            dst_mv);
+              CheckStore<2>(skips, position, r, mvs[0], dst_reference_offset,
+                            dst_mv);
+              CheckStore<3>(skips, position, r, mvs[0], dst_reference_offset,
+                            dst_mv);
+              CheckStore<4>(skips, position, r, mvs[1], dst_reference_offset,
+                            dst_mv);
+              CheckStore<5>(skips, position, r, mvs[1], dst_reference_offset,
+                            dst_mv);
+              CheckStore<6>(skips, position, r, mvs[1], dst_reference_offset,
+                            dst_mv);
+              CheckStore<7>(skips, position, r, mvs[1], dst_reference_offset,
+                            dst_mv);
+            }
+          }
+        }
+      } else {
+        for (; x8 < adjusted_x8_end; ++x8) {
+          const int source_reference_type = source_reference_types[x8];
+          if (skip_references[source_reference_type]) continue;
+          MotionVector projection_mv;
+          // reference_to_current_with_sign could be 0.
+          GetMvProjection(mv[x8], reference_to_current_with_sign,
+                          projection_divisions[source_reference_type],
+                          &projection_mv);
+          // Do not update the motion vector if the block position is not valid
+          // or if position_x8 is outside the current range of x8_start and
+          // x8_end. Note that position_y8 will always be within the range of
+          // y8_start and y8_end.
+          const int position_y8 = Project(0, projection_mv.mv[0], dst_sign);
+          if (position_y8 < y8_floor || position_y8 >= y8_ceiling) continue;
+          const int x8_base = x8 & ~7;
+          const int x8_floor =
+              std::max(x8_start, x8_base - kProjectionMvMaxHorizontalOffset);
+          const int x8_ceiling =
+              std::min(x8_end, x8_base + 8 + kProjectionMvMaxHorizontalOffset);
+          const int position_x8 = Project(x8, projection_mv.mv[1], dst_sign);
+          if (position_x8 < x8_floor || position_x8 >= x8_ceiling) continue;
+          dst_mv[position_y8 * stride + position_x8] = mv[x8];
+          dst_reference_offset[position_y8 * stride + position_x8] =
+              reference_offsets[source_reference_type];
+        }
+      }
+    }
+
+    source_reference_types += stride;
+    mv += stride;
+    dst_reference_offset += stride;
+    dst_mv += stride;
+  } while (++y8 < y8_end);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_NEON;
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_NEON;
+}
+#endif
+
+}  // namespace
+
+void MotionFieldProjectionInit_NEON() {
+  Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  Init10bpp();
+#endif
+}
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else  // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void MotionFieldProjectionInit_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/motion_field_projection_neon.h b/libgav1/src/dsp/arm/motion_field_projection_neon.h
new file mode 100644
index 0000000..41ab6a6
--- /dev/null
+++ b/libgav1/src/dsp/arm/motion_field_projection_neon.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_MOTION_FIELD_PROJECTION_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_MOTION_FIELD_PROJECTION_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::motion_field_projection_kernel. This function is not
+// thread-safe.
+void MotionFieldProjectionInit_NEON();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+
+#define LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel LIBGAV1_CPU_NEON
+
+#endif  // LIBGAV1_ENABLE_NEON
+
+#endif  // LIBGAV1_SRC_DSP_ARM_MOTION_FIELD_PROJECTION_NEON_H_
diff --git a/libgav1/src/dsp/arm/motion_vector_search_neon.cc b/libgav1/src/dsp/arm/motion_vector_search_neon.cc
new file mode 100644
index 0000000..8a403a6
--- /dev/null
+++ b/libgav1/src/dsp/arm/motion_vector_search_neon.cc
@@ -0,0 +1,267 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/motion_vector_search.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+#include "src/utils/types.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+inline int16x4_t MvProjection(const int16x4_t mv, const int16x4_t denominator,
+                              const int32x4_t numerator) {
+  const int32x4_t m0 = vmull_s16(mv, denominator);
+  const int32x4_t m = vmulq_s32(m0, numerator);
+  // Add the sign (0 or -1) to round towards zero.
+  const int32x4_t add_sign = vsraq_n_s32(m, m, 31);
+  return vqrshrn_n_s32(add_sign, 14);
+}
+
+inline int16x4_t MvProjectionCompound(const int16x4_t mv,
+                                      const int temporal_reference_offsets,
+                                      const int reference_offsets[2]) {
+  const int16x4_t denominator =
+      vdup_n_s16(kProjectionMvDivisionLookup[temporal_reference_offsets]);
+  const int32x2_t offset = vld1_s32(reference_offsets);
+  const int32x2x2_t offsets = vzip_s32(offset, offset);
+  const int32x4_t numerator = vcombine_s32(offsets.val[0], offsets.val[1]);
+  return MvProjection(mv, denominator, numerator);
+}
+
+inline int16x8_t ProjectionClip(const int16x4_t mv0, const int16x4_t mv1) {
+  const int16x8_t projection_mv_clamp = vdupq_n_s16(kProjectionMvClamp);
+  const int16x8_t mv = vcombine_s16(mv0, mv1);
+  const int16x8_t clamp = vminq_s16(mv, projection_mv_clamp);
+  return vmaxq_s16(clamp, vnegq_s16(projection_mv_clamp));
+}
+
+inline int16x8_t MvProjectionCompoundClip(
+    const MotionVector* const temporal_mvs,
+    const int8_t* const temporal_reference_offsets,
+    const int reference_offsets[2]) {
+  const auto* const tmvs = reinterpret_cast<const int32_t*>(temporal_mvs);
+  const int32x2_t temporal_mv = vld1_s32(tmvs);
+  const int16x4_t tmv0 = vreinterpret_s16_s32(vdup_lane_s32(temporal_mv, 0));
+  const int16x4_t tmv1 = vreinterpret_s16_s32(vdup_lane_s32(temporal_mv, 1));
+  const int16x4_t mv0 = MvProjectionCompound(
+      tmv0, temporal_reference_offsets[0], reference_offsets);
+  const int16x4_t mv1 = MvProjectionCompound(
+      tmv1, temporal_reference_offsets[1], reference_offsets);
+  return ProjectionClip(mv0, mv1);
+}
+
+inline int16x8_t MvProjectionSingleClip(
+    const MotionVector* const temporal_mvs,
+    const int8_t* const temporal_reference_offsets, const int reference_offset,
+    int16x4_t* const lookup) {
+  const auto* const tmvs = reinterpret_cast<const int16_t*>(temporal_mvs);
+  const int16x8_t temporal_mv = vld1q_s16(tmvs);
+  *lookup = vld1_lane_s16(
+      &kProjectionMvDivisionLookup[temporal_reference_offsets[0]], *lookup, 0);
+  *lookup = vld1_lane_s16(
+      &kProjectionMvDivisionLookup[temporal_reference_offsets[1]], *lookup, 1);
+  *lookup = vld1_lane_s16(
+      &kProjectionMvDivisionLookup[temporal_reference_offsets[2]], *lookup, 2);
+  *lookup = vld1_lane_s16(
+      &kProjectionMvDivisionLookup[temporal_reference_offsets[3]], *lookup, 3);
+  const int16x4x2_t denominator = vzip_s16(*lookup, *lookup);
+  const int16x4_t tmv0 = vget_low_s16(temporal_mv);
+  const int16x4_t tmv1 = vget_high_s16(temporal_mv);
+  const int32x4_t numerator = vdupq_n_s32(reference_offset);
+  const int16x4_t mv0 = MvProjection(tmv0, denominator.val[0], numerator);
+  const int16x4_t mv1 = MvProjection(tmv1, denominator.val[1], numerator);
+  return ProjectionClip(mv0, mv1);
+}
+
+inline void LowPrecision(const int16x8_t mv, void* const candidate_mvs) {
+  const int16x8_t kRoundDownMask = vdupq_n_s16(1);
+  const uint16x8_t mvu = vreinterpretq_u16_s16(mv);
+  const int16x8_t mv0 = vreinterpretq_s16_u16(vsraq_n_u16(mvu, mvu, 15));
+  const int16x8_t mv1 = vbicq_s16(mv0, kRoundDownMask);
+  vst1q_s16(static_cast<int16_t*>(candidate_mvs), mv1);
+}
+
+inline void ForceInteger(const int16x8_t mv, void* const candidate_mvs) {
+  const int16x8_t kRoundDownMask = vdupq_n_s16(7);
+  const uint16x8_t mvu = vreinterpretq_u16_s16(mv);
+  const int16x8_t mv0 = vreinterpretq_s16_u16(vsraq_n_u16(mvu, mvu, 15));
+  const int16x8_t mv1 = vaddq_s16(mv0, vdupq_n_s16(3));
+  const int16x8_t mv2 = vbicq_s16(mv1, kRoundDownMask);
+  vst1q_s16(static_cast<int16_t*>(candidate_mvs), mv2);
+}
+
+void MvProjectionCompoundLowPrecision_NEON(
+    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const int reference_offsets[2], const int count,
+    CompoundMotionVector* candidate_mvs) {
+  // |reference_offsets| non-zero check usually equals true and is ignored.
+  // To facilitate the compilers, make a local copy of |reference_offsets|.
+  const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
+  // One more element could be calculated.
+  int loop_count = (count + 1) >> 1;
+  do {
+    const int16x8_t mv = MvProjectionCompoundClip(
+        temporal_mvs, temporal_reference_offsets, offsets);
+    LowPrecision(mv, candidate_mvs);
+    temporal_mvs += 2;
+    temporal_reference_offsets += 2;
+    candidate_mvs += 2;
+  } while (--loop_count);
+}
+
+void MvProjectionCompoundForceInteger_NEON(
+    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const int reference_offsets[2], const int count,
+    CompoundMotionVector* candidate_mvs) {
+  // |reference_offsets| non-zero check usually equals true and is ignored.
+  // To facilitate the compilers, make a local copy of |reference_offsets|.
+  const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
+  // One more element could be calculated.
+  int loop_count = (count + 1) >> 1;
+  do {
+    const int16x8_t mv = MvProjectionCompoundClip(
+        temporal_mvs, temporal_reference_offsets, offsets);
+    ForceInteger(mv, candidate_mvs);
+    temporal_mvs += 2;
+    temporal_reference_offsets += 2;
+    candidate_mvs += 2;
+  } while (--loop_count);
+}
+
+void MvProjectionCompoundHighPrecision_NEON(
+    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const int reference_offsets[2], const int count,
+    CompoundMotionVector* candidate_mvs) {
+  // |reference_offsets| non-zero check usually equals true and is ignored.
+  // To facilitate the compilers, make a local copy of |reference_offsets|.
+  const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
+  // One more element could be calculated.
+  int loop_count = (count + 1) >> 1;
+  do {
+    const int16x8_t mv = MvProjectionCompoundClip(
+        temporal_mvs, temporal_reference_offsets, offsets);
+    vst1q_s16(reinterpret_cast<int16_t*>(candidate_mvs), mv);
+    temporal_mvs += 2;
+    temporal_reference_offsets += 2;
+    candidate_mvs += 2;
+  } while (--loop_count);
+}
+
+void MvProjectionSingleLowPrecision_NEON(
+    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const int reference_offset, const int count, MotionVector* candidate_mvs) {
+  // Up to three more elements could be calculated.
+  int loop_count = (count + 3) >> 2;
+  int16x4_t lookup = vdup_n_s16(0);
+  do {
+    const int16x8_t mv = MvProjectionSingleClip(
+        temporal_mvs, temporal_reference_offsets, reference_offset, &lookup);
+    LowPrecision(mv, candidate_mvs);
+    temporal_mvs += 4;
+    temporal_reference_offsets += 4;
+    candidate_mvs += 4;
+  } while (--loop_count);
+}
+
+void MvProjectionSingleForceInteger_NEON(
+    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const int reference_offset, const int count, MotionVector* candidate_mvs) {
+  // Up to three more elements could be calculated.
+  int loop_count = (count + 3) >> 2;
+  int16x4_t lookup = vdup_n_s16(0);
+  do {
+    const int16x8_t mv = MvProjectionSingleClip(
+        temporal_mvs, temporal_reference_offsets, reference_offset, &lookup);
+    ForceInteger(mv, candidate_mvs);
+    temporal_mvs += 4;
+    temporal_reference_offsets += 4;
+    candidate_mvs += 4;
+  } while (--loop_count);
+}
+
+void MvProjectionSingleHighPrecision_NEON(
+    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const int reference_offset, const int count, MotionVector* candidate_mvs) {
+  // Up to three more elements could be calculated.
+  int loop_count = (count + 3) >> 2;
+  int16x4_t lookup = vdup_n_s16(0);
+  do {
+    const int16x8_t mv = MvProjectionSingleClip(
+        temporal_mvs, temporal_reference_offsets, reference_offset, &lookup);
+    vst1q_s16(reinterpret_cast<int16_t*>(candidate_mvs), mv);
+    temporal_mvs += 4;
+    temporal_reference_offsets += 4;
+    candidate_mvs += 4;
+  } while (--loop_count);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+  dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_NEON;
+  dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_NEON;
+  dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_NEON;
+  dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_NEON;
+  dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_NEON;
+  dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_NEON;
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+  dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_NEON;
+  dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_NEON;
+  dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_NEON;
+  dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_NEON;
+  dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_NEON;
+  dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_NEON;
+}
+#endif
+
+}  // namespace
+
+void MotionVectorSearchInit_NEON() {
+  Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  Init10bpp();
+#endif
+}
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else  // !LIBGAV1_ENABLE_NEON
+namespace libgav1 {
+namespace dsp {
+
+void MotionVectorSearchInit_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/motion_vector_search_neon.h b/libgav1/src/dsp/arm/motion_vector_search_neon.h
new file mode 100644
index 0000000..19b4519
--- /dev/null
+++ b/libgav1/src/dsp/arm/motion_vector_search_neon.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_MOTION_VECTOR_SEARCH_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_MOTION_VECTOR_SEARCH_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::mv_projection_compound and Dsp::mv_projection_single. This
+// function is not thread-safe.
+void MotionVectorSearchInit_NEON();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+
+#define LIBGAV1_Dsp8bpp_MotionVectorSearch LIBGAV1_CPU_NEON
+
+#endif  // LIBGAV1_ENABLE_NEON
+
+#endif  // LIBGAV1_SRC_DSP_ARM_MOTION_VECTOR_SEARCH_NEON_H_
diff --git a/libgav1/src/dsp/arm/obmc_neon.cc b/libgav1/src/dsp/arm/obmc_neon.cc
index 0bd994d..66ad663 100644
--- a/libgav1/src/dsp/arm/obmc_neon.cc
+++ b/libgav1/src/dsp/arm/obmc_neon.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/obmc.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -26,6 +26,8 @@
 #include <cstring>
 
 #include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/utils/common.h"
 
 namespace libgav1 {
@@ -34,24 +36,11 @@
 
 #include "src/dsp/obmc.inc"
 
-inline uint8x8_t Load2(const uint8_t* src) {
-  uint16_t tmp;
-  memcpy(&tmp, src, 2);
-  uint16x4_t result = vcreate_u16(tmp);
-  return vreinterpret_u8_u16(result);
-}
-
-template <int lane>
-inline void StoreLane2(uint8_t* dst, uint8x8_t src) {
-  const uint16_t out_val = vget_lane_u16(vreinterpret_u16_u8(src), lane);
-  memcpy(dst, &out_val, 2);
-}
-
 inline void WriteObmcLine4(uint8_t* const pred, const uint8_t* const obmc_pred,
                            const uint8x8_t pred_mask,
                            const uint8x8_t obmc_pred_mask) {
-  const uint8x8_t pred_val = LoadLo4(pred, vdup_n_u8(0));
-  const uint8x8_t obmc_pred_val = LoadLo4(obmc_pred, vdup_n_u8(0));
+  const uint8x8_t pred_val = Load4(pred);
+  const uint8x8_t obmc_pred_val = Load4(obmc_pred);
   const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val);
   const uint8x8_t result =
       vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
@@ -79,18 +68,20 @@
     // Weights for the last line are all 64, which is a no-op.
     compute_height = height - 1;
   }
+  uint8x8_t pred_val = vdup_n_u8(0);
+  uint8x8_t obmc_pred_val = vdup_n_u8(0);
   int y = 0;
   do {
     if (!from_left) {
       pred_mask = vdup_n_u8(kObmcMask[mask_offset + y]);
       obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
     }
-    const uint8x8_t pred_val = Load2(pred);
+    pred_val = Load2<0>(pred, pred_val);
     const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val);
-    const uint8x8_t obmc_pred_val = Load2(obmc_pred);
+    obmc_pred_val = Load2<0>(obmc_pred, obmc_pred_val);
     const uint8x8_t result =
         vrshrn_n_u16(vmlal_u8(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
-    StoreLane2<0>(pred, result);
+    Store2<0>(pred, result);
 
     pred += prediction_stride;
     obmc_pred += obmc_prediction_stride;
@@ -105,7 +96,7 @@
   const uint8_t* obmc_pred = obmc_prediction;
 
   const uint8x8_t mask_inverter = vdup_n_u8(64);
-  const uint8x8_t pred_mask = LoadLo4(kObmcMask + 2, vdup_n_u8(0));
+  const uint8x8_t pred_mask = Load4(kObmcMask + 2);
   // 64 - mask
   const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
   int y = 0;
@@ -376,7 +367,7 @@
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendFromTop_NEON;
   dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendFromLeft_NEON;
@@ -389,7 +380,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 
 namespace libgav1 {
 namespace dsp {
diff --git a/libgav1/src/dsp/arm/obmc_neon.h b/libgav1/src/dsp/arm/obmc_neon.h
index 3155cf0..d5c9d9c 100644
--- a/libgav1/src/dsp/arm/obmc_neon.h
+++ b/libgav1/src/dsp/arm/obmc_neon.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_
 #define LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -31,8 +31,8 @@
 
 // If NEON is enabled, signal the NEON implementation should be used.
 #if LIBGAV1_ENABLE_NEON
-#define LIBGAV1_Dsp8bpp_ObmcVertical LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp8bpp_ObmcHorizontal LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_ObmcVertical LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_ObmcHorizontal LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_
diff --git a/libgav1/src/dsp/arm/super_res_neon.cc b/libgav1/src/dsp/arm/super_res_neon.cc
new file mode 100644
index 0000000..d77b9c7
--- /dev/null
+++ b/libgav1/src/dsp/arm/super_res_neon.cc
@@ -0,0 +1,92 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/super_res.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+
+namespace low_bitdepth {
+namespace {
+
+void ComputeSuperRes_NEON(const void* source, const int upscaled_width,
+                          const int initial_subpixel_x, const int step,
+                          void* const dest) {
+  const auto* src = static_cast<const uint8_t*>(source);
+  auto* dst = static_cast<uint8_t*>(dest);
+  src -= kSuperResFilterTaps >> 1;
+
+  int p = initial_subpixel_x;
+  uint16x8_t weighted_src[8];
+  for (int x = 0; x < upscaled_width; x += 8) {
+    for (int i = 0; i < kSuperResFilterTaps; ++i, p += step) {
+      const uint8x8_t src_x = vld1_u8(&src[p >> kSuperResScaleBits]);
+      const int remainder = p & kSuperResScaleMask;
+      const uint8x8_t filter =
+          vld1_u8(kUpscaleFilterUnsigned[remainder >> kSuperResExtraBits]);
+      weighted_src[i] = vmull_u8(src_x, filter);
+    }
+    Transpose8x8(weighted_src);
+
+    // Maximum sum of positive taps: 171 = 7 + 86 + 71 + 7
+    // Maximum sum: 255*171 == 0xAA55
+    // The sum is clipped to [0, 255], so adding all positive and then
+    // subtracting all negative with saturation is sufficient.
+    //           0 1 2 3 4 5 6 7
+    // tap sign: - + - + + - + -
+    uint16x8_t res = weighted_src[1];
+    res = vaddq_u16(res, weighted_src[3]);
+    res = vaddq_u16(res, weighted_src[4]);
+    res = vaddq_u16(res, weighted_src[6]);
+    res = vqsubq_u16(res, weighted_src[0]);
+    res = vqsubq_u16(res, weighted_src[2]);
+    res = vqsubq_u16(res, weighted_src[5]);
+    res = vqsubq_u16(res, weighted_src[7]);
+    vst1_u8(&dst[x], vqrshrn_n_u16(res, kFilterBits));
+  }
+}
+
+void Init8bpp() {
+  Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  dsp->super_res_row = ComputeSuperRes_NEON;
+}
+
+}  // namespace
+}  // namespace low_bitdepth
+
+void SuperResInit_NEON() { low_bitdepth::Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else  // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void SuperResInit_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/super_res_neon.h b/libgav1/src/dsp/arm/super_res_neon.h
new file mode 100644
index 0000000..f51785d
--- /dev/null
+++ b/libgav1/src/dsp/arm/super_res_neon.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_SUPER_RES_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_SUPER_RES_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::super_res. This function is not thread-safe.
+void SuperResInit_NEON();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_SuperRes LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_SuperResClip LIBGAV1_CPU_NEON
+#endif  // LIBGAV1_ENABLE_NEON
+
+#endif  // LIBGAV1_SRC_DSP_ARM_SUPER_RES_NEON_H_
diff --git a/libgav1/src/dsp/arm/warp_neon.cc b/libgav1/src/dsp/arm/warp_neon.cc
index b25f183..7a41998 100644
--- a/libgav1/src/dsp/arm/warp_neon.cc
+++ b/libgav1/src/dsp/arm/warp_neon.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/warp.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_NEON
 
@@ -24,9 +24,11 @@
 #include <cstddef>
 #include <cstdint>
 #include <cstdlib>
+#include <type_traits>
 
 #include "src/dsp/arm/common_neon.h"
 #include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/utils/common.h"
 #include "src/utils/constants.h"
 
@@ -37,32 +39,101 @@
 
 // Number of extra bits of precision in warped filtering.
 constexpr int kWarpedDiffPrecisionBits = 10;
+constexpr int kFirstPassOffset = 1 << 14;
+constexpr int kOffsetRemoval =
+    (kFirstPassOffset >> kInterRoundBitsHorizontal) * 128;
 
+// Applies the horizontal filter to one source row and stores the result in
+// |intermediate_result_row|. |intermediate_result_row| is a row in the 15x8
+// |intermediate_result| two-dimensional array.
+//
+// src_row_centered contains 16 "centered" samples of a source row. (We center
+// the samples by subtracting 128 from the samples.)
+void HorizontalFilter(const int sx4, const int16_t alpha,
+                      const int8x16_t src_row_centered,
+                      int16_t intermediate_result_row[8]) {
+  int sx = sx4 - MultiplyBy4(alpha);
+  int8x8_t filter[8];
+  for (int x = 0; x < 8; ++x) {
+    const int offset = RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) +
+                       kWarpedPixelPrecisionShifts;
+    filter[x] = vld1_s8(kWarpedFilters8[offset]);
+    sx += alpha;
+  }
+  Transpose8x8(filter);
+  // Add kFirstPassOffset to ensure |sum| stays within uint16_t.
+  // Add 128 (offset) * 128 (filter sum) (also 1 << 14) to account for the
+  // centering of the source samples. These combined are 1 << 15 or -32768.
+  int16x8_t sum =
+      vdupq_n_s16(static_cast<int16_t>(kFirstPassOffset + 128 * 128));
+  // Unrolled k = 0..7 loop. We need to manually unroll the loop because the
+  // third argument (an index value) to vextq_s8() must be a constant
+  // (immediate). src_row_window is a sliding window of length 8 into
+  // src_row_centered.
+  // k = 0.
+  int8x8_t src_row_window = vget_low_s8(src_row_centered);
+  sum = vmlal_s8(sum, filter[0], src_row_window);
+  // k = 1.
+  src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 1));
+  sum = vmlal_s8(sum, filter[1], src_row_window);
+  // k = 2.
+  src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 2));
+  sum = vmlal_s8(sum, filter[2], src_row_window);
+  // k = 3.
+  src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 3));
+  sum = vmlal_s8(sum, filter[3], src_row_window);
+  // k = 4.
+  src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 4));
+  sum = vmlal_s8(sum, filter[4], src_row_window);
+  // k = 5.
+  src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 5));
+  sum = vmlal_s8(sum, filter[5], src_row_window);
+  // k = 6.
+  src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 6));
+  sum = vmlal_s8(sum, filter[6], src_row_window);
+  // k = 7.
+  src_row_window = vget_low_s8(vextq_s8(src_row_centered, src_row_centered, 7));
+  sum = vmlal_s8(sum, filter[7], src_row_window);
+  // End of unrolled k = 0..7 loop.
+  // Due to the offset |sum| is guaranteed to be unsigned.
+  uint16x8_t sum_unsigned = vreinterpretq_u16_s16(sum);
+  sum_unsigned = vrshrq_n_u16(sum_unsigned, kInterRoundBitsHorizontal);
+  // After the shift |sum_unsigned| will fit into int16_t.
+  vst1q_s16(intermediate_result_row, vreinterpretq_s16_u16(sum_unsigned));
+}
+
+template <bool is_compound>
 void Warp_NEON(const void* const source, const ptrdiff_t source_stride,
                const int source_width, const int source_height,
                const int* const warp_params, const int subsampling_x,
-               const int subsampling_y, const int inter_round_bits_vertical,
-               const int block_start_x, const int block_start_y,
-               const int block_width, const int block_height,
-               const int16_t alpha, const int16_t beta, const int16_t gamma,
-               const int16_t delta, uint16_t* dest,
+               const int subsampling_y, const int block_start_x,
+               const int block_start_y, const int block_width,
+               const int block_height, const int16_t alpha, const int16_t beta,
+               const int16_t gamma, const int16_t delta, void* dest,
                const ptrdiff_t dest_stride) {
-  constexpr int bitdepth = 8;
-  // Intermediate_result is the output of the horizontal filtering and rounding.
-  // The range is within 13 (= bitdepth + kFilterBits + 1 -
-  // kInterRoundBitsHorizontal) bits (unsigned). We use the signed int16_t type
-  // so that we can multiply it by kWarpedFilters (which has signed values)
-  // using vmlal_s16().
-  int16_t intermediate_result[15][8];  // 15 rows, 8 columns.
-  const int horizontal_offset = 1 << (bitdepth + kFilterBits - 1);
-  const int vertical_offset =
-      1 << (bitdepth + 2 * kFilterBits - kInterRoundBitsHorizontal);
+  constexpr int kRoundBitsVertical =
+      is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical;
+  union {
+    // Intermediate_result is the output of the horizontal filtering and
+    // rounding. The range is within 13 (= bitdepth + kFilterBits + 1 -
+    // kInterRoundBitsHorizontal) bits (unsigned). We use the signed int16_t
+    // type so that we can multiply it by kWarpedFilters (which has signed
+    // values) using vmlal_s16().
+    int16_t intermediate_result[15][8];  // 15 rows, 8 columns.
+    // In the simple special cases where the samples in each row are all the
+    // same, store one sample per row in a column vector.
+    int16_t intermediate_result_column[15];
+  };
+
   const auto* const src = static_cast<const uint8_t*>(source);
+  using DestType =
+      typename std::conditional<is_compound, int16_t, uint8_t>::type;
+  auto* dst = static_cast<DestType*>(dest);
 
   assert(block_width >= 8);
   assert(block_height >= 8);
 
-  // Warp process applies for each 8x8 block (or smaller).
+  // Warp process applies for each 8x8 block.
   int start_y = block_start_y;
   do {
     int start_x = block_start_x;
@@ -77,42 +148,197 @@
       const int y4 = dst_y >> subsampling_y;
       const int ix4 = x4 >> kWarpedModelPrecisionBits;
       const int iy4 = y4 >> kWarpedModelPrecisionBits;
+      // A prediction block may fall outside the frame's boundaries. If a
+      // prediction block is calculated using only samples outside the frame's
+      // boundary, the filtering can be simplified. We can divide the plane
+      // into several regions and handle them differently.
+      //
+      //                |           |
+      //            1   |     3     |   1
+      //                |           |
+      //         -------+-----------+-------
+      //                |***********|
+      //            2   |*****4*****|   2
+      //                |***********|
+      //         -------+-----------+-------
+      //                |           |
+      //            1   |     3     |   1
+      //                |           |
+      //
+      // At the center, region 4 represents the frame and is the general case.
+      //
+      // In regions 1 and 2, the prediction block is outside the frame's
+      // boundary horizontally. Therefore the horizontal filtering can be
+      // simplified. Furthermore, in the region 1 (at the four corners), the
+      // prediction is outside the frame's boundary both horizontally and
+      // vertically, so we get a constant prediction block.
+      //
+      // In region 3, the prediction block is outside the frame's boundary
+      // vertically. Unfortunately because we apply the horizontal filters
+      // first, by the time we apply the vertical filters, they no longer see
+      // simple inputs. So the only simplification is that all the rows are
+      // the same, but we still need to apply all the horizontal and vertical
+      // filters.
 
-      // Horizontal filter.
-      int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7;
-      for (int y = -7; y < 8; ++y) {
-        // TODO(chengchen):
-        // Because of warping, the index could be out of frame boundary. Thus
-        // clip is needed. However, can we remove or reduce usage of clip?
-        // Besides, special cases exist, for example,
-        // if iy4 - 7 >= source_height or iy4 + 7 < 0, there's no need to do the
-        // filtering.
-        const int row = Clip3(iy4 + y, 0, source_height - 1);
+      // Check for two simple special cases, where the horizontal filter can
+      // be significantly simplified.
+      //
+      // In general, for each row, the horizontal filter is calculated as
+      // follows:
+      //   for (int x = -4; x < 4; ++x) {
+      //     const int offset = ...;
+      //     int sum = first_pass_offset;
+      //     for (int k = 0; k < 8; ++k) {
+      //       const int column = Clip3(ix4 + x + k - 3, 0, source_width - 1);
+      //       sum += kWarpedFilters[offset][k] * src_row[column];
+      //     }
+      //     ...
+      //   }
+      // The column index before clipping, ix4 + x + k - 3, varies in the range
+      // ix4 - 7 <= ix4 + x + k - 3 <= ix4 + 7. If ix4 - 7 >= source_width - 1
+      // or ix4 + 7 <= 0, then all the column indexes are clipped to the same
+      // border index (source_width - 1 or 0, respectively). Then for each x,
+      // the inner for loop of the horizontal filter is reduced to multiplying
+      // the border pixel by the sum of the filter coefficients.
+      if (ix4 - 7 >= source_width - 1 || ix4 + 7 <= 0) {
+        // Regions 1 and 2.
+        // Points to the left or right border of the first row of |src|.
+        const uint8_t* first_row_border =
+            (ix4 + 7 <= 0) ? src : src + source_width - 1;
+        // In general, for y in [-7, 8), the row number iy4 + y is clipped:
+        //   const int row = Clip3(iy4 + y, 0, source_height - 1);
+        // In two special cases, iy4 + y is clipped to either 0 or
+        // source_height - 1 for all y. In the rest of the cases, iy4 + y is
+        // bounded and we can avoid clipping iy4 + y by relying on a reference
+        // frame's boundary extension on the top and bottom.
+        if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) {
+          // Region 1.
+          // Every sample used to calculate the prediction block has the same
+          // value. So the whole prediction block has the same value.
+          const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1;
+          const uint8_t row_border_pixel =
+              first_row_border[row * source_stride];
+
+          DestType* dst_row = dst + start_x - block_start_x;
+          for (int y = 0; y < 8; ++y) {
+            if (is_compound) {
+              const int16x8_t sum =
+                  vdupq_n_s16(row_border_pixel << (kInterRoundBitsVertical -
+                                                   kRoundBitsVertical));
+              vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum);
+            } else {
+              memset(dst_row, row_border_pixel, 8);
+            }
+            dst_row += dest_stride;
+          }
+          // End of region 1. Continue the |start_x| do-while loop.
+          start_x += 8;
+          continue;
+        }
+
+        // Region 2.
+        // Horizontal filter.
+        // The input values in this region are generated by extending the border
+        // which makes them identical in the horizontal direction. This
+        // computation could be inlined in the vertical pass but most
+        // implementations will need a transpose of some sort.
+        // It is not necessary to use the offset values here because the
+        // horizontal pass is a simple shift and the vertical pass will always
+        // require using 32 bits.
+        for (int y = -7; y < 8; ++y) {
+          // We may over-read up to 13 pixels above the top source row, or up
+          // to 13 pixels below the bottom source row. This is proved in
+          // warp.cc.
+          const int row = iy4 + y;
+          int sum = first_row_border[row * source_stride];
+          sum <<= (kFilterBits - kInterRoundBitsHorizontal);
+          intermediate_result_column[y + 7] = sum;
+        }
+        // Vertical filter.
+        DestType* dst_row = dst + start_x - block_start_x;
+        int sy4 =
+            (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta);
+        for (int y = 0; y < 8; ++y) {
+          int sy = sy4 - MultiplyBy4(gamma);
+#if defined(__aarch64__)
+          const int16x8_t intermediate =
+              vld1q_s16(&intermediate_result_column[y]);
+          int16_t tmp[8];
+          for (int x = 0; x < 8; ++x) {
+            const int offset =
+                RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
+                kWarpedPixelPrecisionShifts;
+            const int16x8_t filter = vld1q_s16(kWarpedFilters[offset]);
+            const int32x4_t product_low =
+                vmull_s16(vget_low_s16(filter), vget_low_s16(intermediate));
+            const int32x4_t product_high =
+                vmull_s16(vget_high_s16(filter), vget_high_s16(intermediate));
+            // vaddvq_s32 is only available on __aarch64__.
+            const int32_t sum =
+                vaddvq_s32(product_low) + vaddvq_s32(product_high);
+            const int16_t sum_descale =
+                RightShiftWithRounding(sum, kRoundBitsVertical);
+            if (is_compound) {
+              dst_row[x] = sum_descale;
+            } else {
+              tmp[x] = sum_descale;
+            }
+            sy += gamma;
+          }
+          if (!is_compound) {
+            const int16x8_t sum = vld1q_s16(tmp);
+            vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum));
+          }
+#else  // !defined(__aarch64__)
+          int16x8_t filter[8];
+          for (int x = 0; x < 8; ++x) {
+            const int offset =
+                RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
+                kWarpedPixelPrecisionShifts;
+            filter[x] = vld1q_s16(kWarpedFilters[offset]);
+            sy += gamma;
+          }
+          Transpose8x8(filter);
+          int32x4_t sum_low = vdupq_n_s32(0);
+          int32x4_t sum_high = sum_low;
+          for (int k = 0; k < 8; ++k) {
+            const int16_t intermediate = intermediate_result_column[y + k];
+            sum_low =
+                vmlal_n_s16(sum_low, vget_low_s16(filter[k]), intermediate);
+            sum_high =
+                vmlal_n_s16(sum_high, vget_high_s16(filter[k]), intermediate);
+          }
+          const int16x8_t sum =
+              vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical),
+                           vrshrn_n_s32(sum_high, kRoundBitsVertical));
+          if (is_compound) {
+            vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum);
+          } else {
+            vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum));
+          }
+#endif  // defined(__aarch64__)
+          dst_row += dest_stride;
+          sy4 += delta;
+        }
+        // End of region 2. Continue the |start_x| do-while loop.
+        start_x += 8;
+        continue;
+      }
+
+      // Regions 3 and 4.
+      // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0.
+
+      // In general, for y in [-7, 8), the row number iy4 + y is clipped:
+      //   const int row = Clip3(iy4 + y, 0, source_height - 1);
+      // In two special cases, iy4 + y is clipped to either 0 or
+      // source_height - 1 for all y. In the rest of the cases, iy4 + y is
+      // bounded and we can avoid clipping iy4 + y by relying on a reference
+      // frame's boundary extension on the top and bottom.
+      if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) {
+        // Region 3.
+        // Horizontal filter.
+        const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1;
         const uint8_t* const src_row = src + row * source_stride;
-        // Check for two simple special cases.
-        if (ix4 - 7 >= source_width - 1) {
-          // Every sample is equal to src_row[source_width - 1]. Since the sum
-          // of the warped filter coefficients is 128 (= 2^7), the filtering is
-          // equivalent to multiplying src_row[source_width - 1] by 128.
-          const int16_t s =
-              (horizontal_offset >> kInterRoundBitsHorizontal) +
-              (src_row[source_width - 1] << (7 - kInterRoundBitsHorizontal));
-          const int16x8_t sum = vdupq_n_s16(s);
-          vst1q_s16(intermediate_result[y + 7], sum);
-          sx4 += beta;
-          continue;
-        }
-        if (ix4 + 7 <= 0) {
-          // Every sample is equal to src_row[0]. Since the sum of the warped
-          // filter coefficients is 128 (= 2^7), the filtering is equivalent to
-          // multiplying src_row[0] by 128.
-          const int16_t s = (horizontal_offset >> kInterRoundBitsHorizontal) +
-                            (src_row[0] << (7 - kInterRoundBitsHorizontal));
-          const int16x8_t sum = vdupq_n_s16(s);
-          vst1q_s16(intermediate_result[y + 7], sum);
-          sx4 += beta;
-          continue;
-        }
         // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also
         // read but is ignored.
         //
@@ -121,86 +347,50 @@
         // has left and right borders of at least 13 bytes that extend the
         // frame boundary pixels. We also assume there is at least one extra
         // padding byte after the right border of the last source row.
-        const uint8x16_t src_row_u8 = vld1q_u8(&src_row[ix4 - 7]);
-        const int16x8_t src_row_low_s16 =
-            vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(src_row_u8)));
-        const int16x8_t src_row_high_s16 =
-            vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(src_row_u8)));
-        int sx = sx4 - MultiplyBy4(alpha);
-        int16x8_t filter[8];
-        for (int x = 0; x < 8; ++x) {
-          const int offset =
-              RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) +
-              kWarpedPixelPrecisionShifts;
-          filter[x] = vld1q_s16(kWarpedFilters[offset]);
-          sx += alpha;
+        const uint8x16_t src_row_v = vld1q_u8(&src_row[ix4 - 7]);
+        // Convert src_row_v to int8 (subtract 128).
+        const int8x16_t src_row_centered =
+            vreinterpretq_s8_u8(vsubq_u8(src_row_v, vdupq_n_u8(128)));
+        int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7;
+        for (int y = -7; y < 8; ++y) {
+          HorizontalFilter(sx4, alpha, src_row_centered,
+                           intermediate_result[y + 7]);
+          sx4 += beta;
         }
-        Transpose8x8(&filter[0], &filter[1], &filter[2], &filter[3], &filter[4],
-                     &filter[5], &filter[6], &filter[7]);
-        // For 8 bit, the range of sum is within uint16_t, if we add a
-        // horizontal offset. horizontal_offset guarantees sum is nonnegative.
-        //
-        // Proof:
-        // Given that the minimum (most negative) sum of the negative filter
-        // coefficients is -47 and the maximum sum of the positive filter
-        // coefficients is 175, the range of the horizontal filter output is
-        //   -47 * 255 <= output <= 175 * 255
-        // Since -2^14 < -47 * 255, adding -2^14 (= horizontal_offset) to the
-        // horizontal filter output produces a positive value:
-        //   0 < output + 2^14 <= 175 * 255 + 2^14
-        // The final rounding right shift by 3 (= kInterRoundBitsHorizontal)
-        // bits adds 2^2 to the sum:
-        //   0 < output + 2^14 + 2^2 <= 175 * 255 + 2^14 + 2^2 = 61013
-        // Since 61013 < 2^16, the final sum (right before the right shift by 3
-        // bits) will not overflow uint16_t. In addition, the value after the
-        // right shift by 3 bits is in the following range:
-        //   0 <= intermediate_result[y][x] < 2^13
-        // This property is used in determining the range of the vertical
-        // filtering output. [End of proof.]
-        //
-        // We can do signed int16_t arithmetic and just treat the final result
-        // as uint16_t when we shift it right.
-        int16x8_t sum = vdupq_n_s16(horizontal_offset);
-        // Unrolled k = 0..7 loop. We need to manually unroll the loop because
-        // the third argument (an index value) to vextq_s16() must be a
-        // constant (immediate).
-        // k = 0.
-        int16x8_t src_row_v_s16 = src_row_low_s16;
-        sum = vmlaq_s16(sum, filter[0], src_row_v_s16);
-        // k = 1.
-        src_row_v_s16 = vextq_s16(src_row_low_s16, src_row_high_s16, 1);
-        sum = vmlaq_s16(sum, filter[1], src_row_v_s16);
-        // k = 2.
-        src_row_v_s16 = vextq_s16(src_row_low_s16, src_row_high_s16, 2);
-        sum = vmlaq_s16(sum, filter[2], src_row_v_s16);
-        // k = 3.
-        src_row_v_s16 = vextq_s16(src_row_low_s16, src_row_high_s16, 3);
-        sum = vmlaq_s16(sum, filter[3], src_row_v_s16);
-        // k = 4.
-        src_row_v_s16 = vextq_s16(src_row_low_s16, src_row_high_s16, 4);
-        sum = vmlaq_s16(sum, filter[4], src_row_v_s16);
-        // k = 5.
-        src_row_v_s16 = vextq_s16(src_row_low_s16, src_row_high_s16, 5);
-        sum = vmlaq_s16(sum, filter[5], src_row_v_s16);
-        // k = 6.
-        src_row_v_s16 = vextq_s16(src_row_low_s16, src_row_high_s16, 6);
-        sum = vmlaq_s16(sum, filter[6], src_row_v_s16);
-        // k = 7.
-        src_row_v_s16 = vextq_s16(src_row_low_s16, src_row_high_s16, 7);
-        sum = vmlaq_s16(sum, filter[7], src_row_v_s16);
-        // End of unrolled k = 0..7 loop.
-        // Treat sum as unsigned for the right shift.
-        sum = vreinterpretq_s16_u16(vrshrq_n_u16(vreinterpretq_u16_s16(sum),
-                                                 kInterRoundBitsHorizontal));
-        vst1q_s16(intermediate_result[y + 7], sum);
-        sx4 += beta;
+      } else {
+        // Region 4.
+        // Horizontal filter.
+        int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7;
+        for (int y = -7; y < 8; ++y) {
+          // We may over-read up to 13 pixels above the top source row, or up
+          // to 13 pixels below the bottom source row. This is proved in
+          // warp.cc.
+          const int row = iy4 + y;
+          const uint8_t* const src_row = src + row * source_stride;
+          // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also
+          // read but is ignored.
+          //
+          // NOTE: This may read up to 13 bytes before src_row[0] or up to 14
+          // bytes after src_row[source_width - 1]. We assume the source frame
+          // has left and right borders of at least 13 bytes that extend the
+          // frame boundary pixels. We also assume there is at least one extra
+          // padding byte after the right border of the last source row.
+          const uint8x16_t src_row_v = vld1q_u8(&src_row[ix4 - 7]);
+          // Convert src_row_v to int8 (subtract 128).
+          const int8x16_t src_row_centered =
+              vreinterpretq_s8_u8(vsubq_u8(src_row_v, vdupq_n_u8(128)));
+          HorizontalFilter(sx4, alpha, src_row_centered,
+                           intermediate_result[y + 7]);
+          sx4 += beta;
+        }
       }
 
+      // Regions 3 and 4.
       // Vertical filter.
-      uint16_t* dst_row = dest + start_x - block_start_x;
+      DestType* dst_row = dst + start_x - block_start_x;
       int sy4 =
           (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta);
-      for (int y = -4; y < 4; ++y) {
+      for (int y = 0; y < 8; ++y) {
         int sy = sy4 - MultiplyBy4(gamma);
         int16x8_t filter[8];
         for (int x = 0; x < 8; ++x) {
@@ -210,88 +400,39 @@
           filter[x] = vld1q_s16(kWarpedFilters[offset]);
           sy += gamma;
         }
-        Transpose8x8(&filter[0], &filter[1], &filter[2], &filter[3], &filter[4],
-                     &filter[5], &filter[6], &filter[7]);
-        // Similar to horizontal_offset, vertical_offset guarantees sum before
-        // shifting is nonnegative.
-        //
-        // Proof:
-        // The range of an entry in intermediate_result is
-        //   0 <= intermediate_result[y][x] < 2^13
-        // The range of the vertical filter output is
-        //   -47 * 2^13 < output < 175 * 2^13
-        // Since -2^19 < -47 * 2^13, adding -2^19 (= vertical_offset) to the
-        // vertical filter output produces a positive value:
-        //   0 < output + 2^19 < 175 * 2^13 + 2^19
-        // The final rounding right shift by either 7 or 11 bits adds at most
-        // 2^10 to the sum:
-        //   0 < output + 2^19 + rounding < 175 * 2^13 + 2^19 + 2^10 = 1958912
-        // Since 1958912 = 0x1DE400 < 2^22, shifting it right by 7 or 11 bits
-        // brings the value under 2^15, which fits in uint16_t.
-        int32x4_t sum_low = vdupq_n_s32(vertical_offset);
+        Transpose8x8(filter);
+        int32x4_t sum_low = vdupq_n_s32(-kOffsetRemoval);
         int32x4_t sum_high = sum_low;
         for (int k = 0; k < 8; ++k) {
-          const int16x8_t intermediate =
-              vld1q_s16(intermediate_result[y + 4 + k]);
+          const int16x8_t intermediate = vld1q_s16(intermediate_result[y + k]);
           sum_low = vmlal_s16(sum_low, vget_low_s16(filter[k]),
                               vget_low_s16(intermediate));
           sum_high = vmlal_s16(sum_high, vget_high_s16(filter[k]),
                                vget_high_s16(intermediate));
         }
-        assert(inter_round_bits_vertical == 7 ||
-               inter_round_bits_vertical == 11);
-        // Since inter_round_bits_vertical can be 7 or 11, and all the narrowing
-        // shift intrinsics require the shift argument to be a constant
-        // (immediate), we have two options:
-        // 1. Call a non-narrowing shift, followed by a narrowing extract.
-        // 2. Call a narrowing shift (with a constant shift of 7 or 11) in an
-        //    if-else statement.
-#if defined(__aarch64__)
-        // This version is slightly faster for arm64 (1106 ms vs 1112 ms).
-        // This version is slower for 32-bit arm (1235 ms vs 1149 ms).
-        const int32x4_t shift = vdupq_n_s32(-inter_round_bits_vertical);
-        const uint32x4_t sum_low_shifted =
-            vrshlq_u32(vreinterpretq_u32_s32(sum_low), shift);
-        const uint32x4_t sum_high_shifted =
-            vrshlq_u32(vreinterpretq_u32_s32(sum_high), shift);
-        const uint16x4_t sum_low_16 = vmovn_u32(sum_low_shifted);
-        const uint16x4_t sum_high_16 = vmovn_u32(sum_high_shifted);
-#else   // !defined(__aarch64__)
-        // This version is faster for 32-bit arm.
-        // This version is slightly slower for arm64.
-        uint16x4_t sum_low_16;
-        uint16x4_t sum_high_16;
-        if (inter_round_bits_vertical == 7) {
-          sum_low_16 = vrshrn_n_u32(vreinterpretq_u32_s32(sum_low), 7);
-          sum_high_16 = vrshrn_n_u32(vreinterpretq_u32_s32(sum_high), 7);
+        const int16x8_t sum =
+            vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical),
+                         vrshrn_n_s32(sum_high, kRoundBitsVertical));
+        if (is_compound) {
+          vst1q_s16(reinterpret_cast<int16_t*>(dst_row), sum);
         } else {
-          sum_low_16 = vrshrn_n_u32(vreinterpretq_u32_s32(sum_low), 11);
-          sum_high_16 = vrshrn_n_u32(vreinterpretq_u32_s32(sum_high), 11);
+          vst1_u8(reinterpret_cast<uint8_t*>(dst_row), vqmovun_s16(sum));
         }
-#endif  // defined(__aarch64__)
-        // vst1q_u16 can also be used:
-        //   vst1q_u16(dst_row, vcombine_u16(sum_low_16, sum_high_16));
-        // But it is slightly slower for arm64 (the same speed for 32-bit arm).
-        //
-        // vst1_u16_x2 could be used, but it is also slightly slower for arm64
-        // and causes a bus error for 32-bit arm. Also, it is not supported by
-        // gcc 7.2.0.
-        vst1_u16(dst_row, sum_low_16);
-        vst1_u16(dst_row + 4, sum_high_16);
         dst_row += dest_stride;
         sy4 += delta;
       }
       start_x += 8;
     } while (start_x < block_start_x + block_width);
-    dest += 8 * dest_stride;
+    dst += 8 * dest_stride;
     start_y += 8;
   } while (start_y < block_start_y + block_height);
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
-  dsp->warp = Warp_NEON;
+  dsp->warp = Warp_NEON</*is_compound=*/false>;
+  dsp->warp_compound = Warp_NEON</*is_compound=*/true>;
 }
 
 }  // namespace
@@ -301,7 +442,7 @@
 
 }  // namespace dsp
 }  // namespace libgav1
-#else   // !LIBGAV1_ENABLE_NEON
+#else  // !LIBGAV1_ENABLE_NEON
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/arm/warp_neon.h b/libgav1/src/dsp/arm/warp_neon.h
index 5722249..dbcaa23 100644
--- a/libgav1/src/dsp/arm/warp_neon.h
+++ b/libgav1/src/dsp/arm/warp_neon.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_
 #define LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -30,7 +30,8 @@
 }  // namespace libgav1
 
 #if LIBGAV1_ENABLE_NEON
-#define LIBGAV1_Dsp8bpp_Warp LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp8bpp_Warp LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WarpCompound LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_
diff --git a/libgav1/src/dsp/arm/weight_mask_neon.cc b/libgav1/src/dsp/arm/weight_mask_neon.cc
new file mode 100644
index 0000000..49d3be0
--- /dev/null
+++ b/libgav1/src/dsp/arm/weight_mask_neon.cc
@@ -0,0 +1,463 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/arm/weight_mask_neon.h"
+
+#include "src/dsp/weight_mask.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON
+
+#include <arm_neon.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+constexpr int kRoundingBits8bpp = 4;
+
+template <bool mask_is_inverse>
+inline void WeightMask8_NEON(const int16_t* prediction_0,
+                             const int16_t* prediction_1, uint8_t* mask) {
+  const int16x8_t pred_0 = vld1q_s16(prediction_0);
+  const int16x8_t pred_1 = vld1q_s16(prediction_1);
+  const uint8x8_t difference_offset = vdup_n_u8(38);
+  const uint8x8_t mask_ceiling = vdup_n_u8(64);
+  const uint16x8_t difference = vrshrq_n_u16(
+      vreinterpretq_u16_s16(vabdq_s16(pred_0, pred_1)), kRoundingBits8bpp);
+  const uint8x8_t adjusted_difference =
+      vqadd_u8(vqshrn_n_u16(difference, 4), difference_offset);
+  const uint8x8_t mask_value = vmin_u8(adjusted_difference, mask_ceiling);
+  if (mask_is_inverse) {
+    const uint8x8_t inverted_mask_value = vsub_u8(mask_ceiling, mask_value);
+    vst1_u8(mask, inverted_mask_value);
+  } else {
+    vst1_u8(mask, mask_value);
+  }
+}
+
+#define WEIGHT8_WITHOUT_STRIDE \
+  WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask)
+
+#define WEIGHT8_AND_STRIDE \
+  WEIGHT8_WITHOUT_STRIDE;  \
+  pred_0 += 8;             \
+  pred_1 += 8;             \
+  mask += mask_stride
+
+template <bool mask_is_inverse>
+void WeightMask8x8_NEON(const void* prediction_0, const void* prediction_1,
+                        uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y = 0;
+  do {
+    WEIGHT8_AND_STRIDE;
+  } while (++y < 7);
+  WEIGHT8_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask8x16_NEON(const void* prediction_0, const void* prediction_1,
+                         uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT8_AND_STRIDE;
+    WEIGHT8_AND_STRIDE;
+    WEIGHT8_AND_STRIDE;
+  } while (++y3 < 5);
+  WEIGHT8_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask8x32_NEON(const void* prediction_0, const void* prediction_1,
+                         uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y5 = 0;
+  do {
+    WEIGHT8_AND_STRIDE;
+    WEIGHT8_AND_STRIDE;
+    WEIGHT8_AND_STRIDE;
+    WEIGHT8_AND_STRIDE;
+    WEIGHT8_AND_STRIDE;
+  } while (++y5 < 6);
+  WEIGHT8_AND_STRIDE;
+  WEIGHT8_WITHOUT_STRIDE;
+}
+
+#define WEIGHT16_WITHOUT_STRIDE                            \
+  WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask); \
+  WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8)
+
+#define WEIGHT16_AND_STRIDE \
+  WEIGHT16_WITHOUT_STRIDE;  \
+  pred_0 += 16;             \
+  pred_1 += 16;             \
+  mask += mask_stride
+
+template <bool mask_is_inverse>
+void WeightMask16x8_NEON(const void* prediction_0, const void* prediction_1,
+                         uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y = 0;
+  do {
+    WEIGHT16_AND_STRIDE;
+  } while (++y < 7);
+  WEIGHT16_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask16x16_NEON(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+  } while (++y3 < 5);
+  WEIGHT16_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask16x32_NEON(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y5 = 0;
+  do {
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+  } while (++y5 < 6);
+  WEIGHT16_AND_STRIDE;
+  WEIGHT16_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask16x64_NEON(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+  } while (++y3 < 21);
+  WEIGHT16_WITHOUT_STRIDE;
+}
+
+#define WEIGHT32_WITHOUT_STRIDE                                           \
+  WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask);                \
+  WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8);    \
+  WeightMask8_NEON<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \
+  WeightMask8_NEON<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24)
+
+#define WEIGHT32_AND_STRIDE \
+  WEIGHT32_WITHOUT_STRIDE;  \
+  pred_0 += 32;             \
+  pred_1 += 32;             \
+  mask += mask_stride
+
+template <bool mask_is_inverse>
+void WeightMask32x8_NEON(const void* prediction_0, const void* prediction_1,
+                         uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask32x16_NEON(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+  } while (++y3 < 5);
+  WEIGHT32_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask32x32_NEON(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y5 = 0;
+  do {
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+  } while (++y5 < 6);
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask32x64_NEON(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+  } while (++y3 < 21);
+  WEIGHT32_WITHOUT_STRIDE;
+}
+
+#define WEIGHT64_WITHOUT_STRIDE                                           \
+  WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask);                \
+  WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8);    \
+  WeightMask8_NEON<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \
+  WeightMask8_NEON<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24); \
+  WeightMask8_NEON<mask_is_inverse>(pred_0 + 32, pred_1 + 32, mask + 32); \
+  WeightMask8_NEON<mask_is_inverse>(pred_0 + 40, pred_1 + 40, mask + 40); \
+  WeightMask8_NEON<mask_is_inverse>(pred_0 + 48, pred_1 + 48, mask + 48); \
+  WeightMask8_NEON<mask_is_inverse>(pred_0 + 56, pred_1 + 56, mask + 56)
+
+#define WEIGHT64_AND_STRIDE \
+  WEIGHT64_WITHOUT_STRIDE;  \
+  pred_0 += 64;             \
+  pred_1 += 64;             \
+  mask += mask_stride
+
+template <bool mask_is_inverse>
+void WeightMask64x16_NEON(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+  } while (++y3 < 5);
+  WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask64x32_NEON(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y5 = 0;
+  do {
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+  } while (++y5 < 6);
+  WEIGHT64_AND_STRIDE;
+  WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask64x64_NEON(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+  } while (++y3 < 21);
+  WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask64x128_NEON(const void* prediction_0, const void* prediction_1,
+                           uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+  } while (++y3 < 42);
+  WEIGHT64_AND_STRIDE;
+  WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask128x64_NEON(const void* prediction_0, const void* prediction_1,
+                           uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
+  do {
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += 64;
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += adjusted_mask_stride;
+
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += 64;
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += adjusted_mask_stride;
+
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += 64;
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += adjusted_mask_stride;
+  } while (++y3 < 21);
+  WEIGHT64_WITHOUT_STRIDE;
+  pred_0 += 64;
+  pred_1 += 64;
+  mask += 64;
+  WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask128x128_NEON(const void* prediction_0, const void* prediction_1,
+                            uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
+  do {
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += 64;
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += adjusted_mask_stride;
+
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += 64;
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += adjusted_mask_stride;
+
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += 64;
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += adjusted_mask_stride;
+  } while (++y3 < 42);
+  WEIGHT64_WITHOUT_STRIDE;
+  pred_0 += 64;
+  pred_1 += 64;
+  mask += 64;
+  WEIGHT64_WITHOUT_STRIDE;
+  pred_0 += 64;
+  pred_1 += 64;
+  mask += adjusted_mask_stride;
+
+  WEIGHT64_WITHOUT_STRIDE;
+  pred_0 += 64;
+  pred_1 += 64;
+  mask += 64;
+  WEIGHT64_WITHOUT_STRIDE;
+}
+
+#define INIT_WEIGHT_MASK_8BPP(width, height, w_index, h_index) \
+  dsp->weight_mask[w_index][h_index][0] =                      \
+      WeightMask##width##x##height##_NEON<0>;                  \
+  dsp->weight_mask[w_index][h_index][1] = WeightMask##width##x##height##_NEON<1>
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+  INIT_WEIGHT_MASK_8BPP(8, 8, 0, 0);
+  INIT_WEIGHT_MASK_8BPP(8, 16, 0, 1);
+  INIT_WEIGHT_MASK_8BPP(8, 32, 0, 2);
+  INIT_WEIGHT_MASK_8BPP(16, 8, 1, 0);
+  INIT_WEIGHT_MASK_8BPP(16, 16, 1, 1);
+  INIT_WEIGHT_MASK_8BPP(16, 32, 1, 2);
+  INIT_WEIGHT_MASK_8BPP(16, 64, 1, 3);
+  INIT_WEIGHT_MASK_8BPP(32, 8, 2, 0);
+  INIT_WEIGHT_MASK_8BPP(32, 16, 2, 1);
+  INIT_WEIGHT_MASK_8BPP(32, 32, 2, 2);
+  INIT_WEIGHT_MASK_8BPP(32, 64, 2, 3);
+  INIT_WEIGHT_MASK_8BPP(64, 16, 3, 1);
+  INIT_WEIGHT_MASK_8BPP(64, 32, 3, 2);
+  INIT_WEIGHT_MASK_8BPP(64, 64, 3, 3);
+  INIT_WEIGHT_MASK_8BPP(64, 128, 3, 4);
+  INIT_WEIGHT_MASK_8BPP(128, 64, 4, 3);
+  INIT_WEIGHT_MASK_8BPP(128, 128, 4, 4);
+}
+
+}  // namespace
+}  // namespace low_bitdepth
+
+void WeightMaskInit_NEON() { low_bitdepth::Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else  // !LIBGAV1_ENABLE_NEON
+
+namespace libgav1 {
+namespace dsp {
+
+void WeightMaskInit_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/weight_mask_neon.h b/libgav1/src/dsp/arm/weight_mask_neon.h
new file mode 100644
index 0000000..b4749ec
--- /dev/null
+++ b/libgav1/src/dsp/arm/weight_mask_neon.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_NEON_H_
+#define LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_NEON_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::weight_mask. This function is not thread-safe.
+void WeightMaskInit_NEON();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_8x8 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_8x16 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_8x32 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_16x8 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_16x16 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_16x32 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_16x64 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_32x8 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_32x16 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_32x32 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_32x64 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_64x16 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_64x32 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_64x64 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_64x128 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_128x64 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_WeightMask_128x128 LIBGAV1_CPU_NEON
+#endif  // LIBGAV1_ENABLE_NEON
+
+#endif  // LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_NEON_H_
diff --git a/libgav1/src/dsp/average_blend.cc b/libgav1/src/dsp/average_blend.cc
index 98f4059..a59abb0 100644
--- a/libgav1/src/dsp/average_blend.cc
+++ b/libgav1/src/dsp/average_blend.cc
@@ -17,6 +17,7 @@
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
+#include <type_traits>
 
 #include "src/dsp/dsp.h"
 #include "src/utils/common.h"
@@ -26,19 +27,16 @@
 namespace {
 
 template <int bitdepth, typename Pixel>
-void AverageBlend_C(const uint16_t* prediction_0,
-                    const ptrdiff_t prediction_stride_0,
-                    const uint16_t* prediction_1,
-                    const ptrdiff_t prediction_stride_1, const int width,
-                    const int height, void* const dest,
+void AverageBlend_C(const void* prediction_0, const void* prediction_1,
+                    const int width, const int height, void* const dest,
                     const ptrdiff_t dest_stride) {
-  // An offset to cancel offsets used in compound predictor generation that
-  // make intermediate computations non negative.
-  constexpr int compound_round_offset =
-      (2 << (bitdepth + 4)) + (2 << (bitdepth + 3));
   // 7.11.3.2 Rounding variables derivation process
   //   2 * FILTER_BITS(7) - (InterRound0(3|5) + InterRound1(7))
   constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
+  using PredType =
+      typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type;
+  const auto* pred_0 = static_cast<const PredType*>(prediction_0);
+  const auto* pred_1 = static_cast<const PredType*>(prediction_1);
   auto* dst = static_cast<Pixel*>(dest);
   const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel);
 
@@ -46,17 +44,17 @@
   do {
     int x = 0;
     do {
-      // prediction range: 8bpp: [0, 15471] 10bpp: [0, 61983] 12bpp: [0, 62007].
-      int res = prediction_0[x] + prediction_1[x];
-      res -= compound_round_offset;
+      // See warp.cc and convolve.cc for detailed prediction ranges.
+      int res = pred_0[x] + pred_1[x];
+      res -= (bitdepth == 8) ? 0 : kCompoundOffset + kCompoundOffset;
       dst[x] = static_cast<Pixel>(
           Clip3(RightShiftWithRounding(res, inter_post_round_bits + 1), 0,
                 (1 << bitdepth) - 1));
     } while (++x < width);
 
     dst += dst_stride;
-    prediction_0 += prediction_stride_0;
-    prediction_1 += prediction_stride_1;
+    pred_0 += width;
+    pred_1 += width;
   } while (++y < height);
 }
 
diff --git a/libgav1/src/dsp/cdef.cc b/libgav1/src/dsp/cdef.cc
index cd4f125..95e5a4a 100644
--- a/libgav1/src/dsp/cdef.cc
+++ b/libgav1/src/dsp/cdef.cc
@@ -20,33 +20,22 @@
 #include <cstdint>
 #include <cstring>
 
+#include "src/dsp/constants.h"
 #include "src/dsp/dsp.h"
 #include "src/utils/common.h"
+#include "src/utils/constants.h"
 
 namespace libgav1 {
 namespace dsp {
 namespace {
 
-constexpr uint16_t kCdefLargeValue = 30000;
+#include "src/dsp/cdef.inc"
 
-constexpr int16_t kDivisionTable[] = {0,   840, 420, 280, 210,
-                                      168, 140, 120, 105};
-
-constexpr uint8_t kPrimaryTaps[2][2] = {{4, 2}, {3, 3}};
-
-constexpr uint8_t kSecondaryTaps[2][2] = {{2, 1}, {2, 1}};
-
-constexpr int8_t kCdefDirections[8][2][2] = {
-    {{-1, 1}, {-2, 2}}, {{0, 1}, {-1, 2}}, {{0, 1}, {0, 2}}, {{0, 1}, {1, 2}},
-    {{1, 1}, {2, 2}},   {{1, 0}, {2, 1}},  {{1, 0}, {2, 0}}, {{1, 0}, {2, -1}}};
-
-int Constrain(int diff, int threshold, int damping) {
-  if (threshold == 0) return 0;
-  damping = std::max(0, damping - FloorLog2(threshold));
-  const int sign = (diff < 0) ? -1 : 1;
-  return sign *
-         Clip3(threshold - (std::abs(diff) >> damping), 0, std::abs(diff));
-}
+// Silence unused function warnings when CdefDirection_C is obviated.
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS ||        \
+    !defined(LIBGAV1_Dsp8bpp_CdefDirection) || \
+    (LIBGAV1_MAX_BITDEPTH >= 10 && !defined(LIBGAV1_Dsp10bpp_CdefDirection))
+constexpr int16_t kDivisionTable[] = {840, 420, 280, 210, 168, 140, 120, 105};
 
 int32_t Square(int32_t x) { return x * x; }
 
@@ -79,24 +68,24 @@
     cost[2] += Square(partial[2][i]);
     cost[6] += Square(partial[6][i]);
   }
-  cost[2] *= kDivisionTable[8];
-  cost[6] *= kDivisionTable[8];
+  cost[2] *= kDivisionTable[7];
+  cost[6] *= kDivisionTable[7];
   for (int i = 0; i < 7; ++i) {
     cost[0] += (Square(partial[0][i]) + Square(partial[0][14 - i])) *
-               kDivisionTable[i + 1];
+               kDivisionTable[i];
     cost[4] += (Square(partial[4][i]) + Square(partial[4][14 - i])) *
-               kDivisionTable[i + 1];
+               kDivisionTable[i];
   }
-  cost[0] += Square(partial[0][7]) * kDivisionTable[8];
-  cost[4] += Square(partial[4][7]) * kDivisionTable[8];
+  cost[0] += Square(partial[0][7]) * kDivisionTable[7];
+  cost[4] += Square(partial[4][7]) * kDivisionTable[7];
   for (int i = 1; i < 8; i += 2) {
     for (int j = 0; j < 5; ++j) {
       cost[i] += Square(partial[i][3 + j]);
     }
-    cost[i] *= kDivisionTable[8];
+    cost[i] *= kDivisionTable[7];
     for (int j = 0; j < 3; ++j) {
       cost[i] += (Square(partial[i][j]) + Square(partial[i][10 - j])) *
-                 kDivisionTable[2 * j + 2];
+                 kDivisionTable[2 * j + 1];
     }
   }
   int32_t best_cost = 0;
@@ -109,27 +98,58 @@
   }
   *variance = (best_cost - cost[(*direction + 4) & 7]) >> 10;
 }
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS ||
+        // !defined(LIBGAV1_Dsp8bpp_CdefDirection) ||
+        // (LIBGAV1_MAX_BITDEPTH >= 10 &&
+        // !defined(LIBGAV1_Dsp10bpp_CdefDirection))
+
+// Silence unused function warnings when CdefFilter_C is obviated.
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS ||      \
+    !defined(LIBGAV1_Dsp8bpp_CdefFilters) || \
+    (LIBGAV1_MAX_BITDEPTH >= 10 && !defined(LIBGAV1_Dsp10bpp_CdefFilters))
+
+int Constrain(int diff, int threshold, int damping) {
+  assert(threshold != 0);
+  damping = std::max(0, damping - FloorLog2(threshold));
+  const int sign = (diff < 0) ? -1 : 1;
+  return sign *
+         Clip3(threshold - (std::abs(diff) >> damping), 0, std::abs(diff));
+}
 
 // Filters the source block. It doesn't check whether the candidate pixel is
 // inside the frame. However it requires the source input to be padded with a
-// constant large value if at the boundary. And the input should be uint16_t.
-template <int bitdepth, typename Pixel>
-void CdefFilter_C(const void* const source, const ptrdiff_t source_stride,
-                  const int rows4x4, const int columns4x4, const int curr_x,
-                  const int curr_y, const int subsampling_x,
-                  const int subsampling_y, const int primary_strength,
+// constant large value (kCdefLargeValue) if at the boundary.
+template <int block_width, int bitdepth, typename Pixel,
+          bool enable_primary = true, bool enable_secondary = true>
+void CdefFilter_C(const uint16_t* src, const ptrdiff_t src_stride,
+                  const int block_height, const int primary_strength,
                   const int secondary_strength, const int damping,
                   const int direction, void* const dest,
                   const ptrdiff_t dest_stride) {
-  const int coeff_shift = bitdepth - 8;
-  const int plane_width = MultiplyBy4(columns4x4) >> subsampling_x;
-  const int plane_height = MultiplyBy4(rows4x4) >> subsampling_y;
-  const int block_width = std::min(8 >> subsampling_x, plane_width - curr_x);
-  const int block_height = std::min(8 >> subsampling_y, plane_height - curr_y);
-  const auto* src = static_cast<const uint16_t*>(source);
+  static_assert(block_width == 4 || block_width == 8, "Invalid CDEF width.");
+  static_assert(enable_primary || enable_secondary, "");
+  assert(block_height == 4 || block_height == 8);
+  assert(direction >= 0 && direction <= 7);
+  constexpr int coeff_shift = bitdepth - 8;
+  // Section 5.9.19. CDEF params syntax.
+  assert(primary_strength >= 0 && primary_strength <= 15 << coeff_shift);
+  assert(secondary_strength >= 0 && secondary_strength <= 4 << coeff_shift &&
+         secondary_strength != 3 << coeff_shift);
+  assert(primary_strength != 0 || secondary_strength != 0);
+  // damping is decreased by 1 for chroma.
+  assert((damping >= 3 && damping <= 6 + coeff_shift) ||
+         (damping >= 2 && damping <= 5 + coeff_shift));
+  // When only primary_strength or secondary_strength are non-zero the number
+  // of pixels inspected (4 for primary_strength, 8 for secondary_strength) and
+  // the taps used don't exceed the amount the sum is
+  // descaled by (16) so we can skip tracking and clipping to the minimum and
+  // maximum value observed.
+  constexpr bool clipping_required = enable_primary && enable_secondary;
+  static constexpr int kCdefSecondaryTaps[2] = {kCdefSecondaryTap0,
+                                                kCdefSecondaryTap1};
   auto* dst = static_cast<Pixel*>(dest);
   const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel);
-  int y = 0;
+  int y = block_height;
   do {
     int x = 0;
     do {
@@ -138,60 +158,96 @@
       uint16_t max_value = pixel_value;
       uint16_t min_value = pixel_value;
       for (int k = 0; k < 2; ++k) {
-        const int signs[] = {-1, 1};
+        static constexpr int signs[] = {-1, 1};
         for (const int& sign : signs) {
-          int dy = sign * kCdefDirections[direction][k][0];
-          int dx = sign * kCdefDirections[direction][k][1];
-          uint16_t value = src[dy * source_stride + dx + x];
-          // Note: the summation can ignore the condition check in SIMD
-          // implementation, because Constrain() will return 0 when
-          // value == kCdefLargeValue.
-          if (value != kCdefLargeValue) {
-            sum += Constrain(value - pixel_value, primary_strength, damping) *
-                   kPrimaryTaps[(primary_strength >> coeff_shift) & 1][k];
-            max_value = std::max(value, max_value);
-            min_value = std::min(value, min_value);
-          }
-          const int offsets[] = {-2, 2};
-          for (const int& offset : offsets) {
-            dy = sign * kCdefDirections[(direction + offset) & 7][k][0];
-            dx = sign * kCdefDirections[(direction + offset) & 7][k][1];
-            value = src[dy * source_stride + dx + x];
+          if (enable_primary) {
+            const int dy = sign * kCdefDirections[direction][k][0];
+            const int dx = sign * kCdefDirections[direction][k][1];
+            const uint16_t value = src[dy * src_stride + dx + x];
             // Note: the summation can ignore the condition check in SIMD
-            // implementation.
+            // implementation, because Constrain() will return 0 when
+            // value == kCdefLargeValue.
             if (value != kCdefLargeValue) {
-              sum +=
-                  Constrain(value - pixel_value, secondary_strength, damping) *
-                  kSecondaryTaps[(primary_strength >> coeff_shift) & 1][k];
-              max_value = std::max(value, max_value);
-              min_value = std::min(value, min_value);
+              sum += Constrain(value - pixel_value, primary_strength, damping) *
+                     kCdefPrimaryTaps[(primary_strength >> coeff_shift) & 1][k];
+              if (clipping_required) {
+                max_value = std::max(value, max_value);
+                min_value = std::min(value, min_value);
+              }
+            }
+          }
+
+          if (enable_secondary) {
+            static constexpr int offsets[] = {-2, 2};
+            for (const int& offset : offsets) {
+              const int dy = sign * kCdefDirections[direction + offset][k][0];
+              const int dx = sign * kCdefDirections[direction + offset][k][1];
+              const uint16_t value = src[dy * src_stride + dx + x];
+              // Note: the summation can ignore the condition check in SIMD
+              // implementation.
+              if (value != kCdefLargeValue) {
+                sum += Constrain(value - pixel_value, secondary_strength,
+                                 damping) *
+                       kCdefSecondaryTaps[k];
+                if (clipping_required) {
+                  max_value = std::max(value, max_value);
+                  min_value = std::min(value, min_value);
+                }
+              }
             }
           }
         }
       }
 
-      dst[x] = static_cast<Pixel>(Clip3(
-          pixel_value + ((8 + sum - (sum < 0)) >> 4), min_value, max_value));
+      const int offset = (8 + sum - (sum < 0)) >> 4;
+      if (clipping_required) {
+        dst[x] = static_cast<Pixel>(
+            Clip3(pixel_value + offset, min_value, max_value));
+      } else {
+        dst[x] = static_cast<Pixel>(pixel_value + offset);
+      }
     } while (++x < block_width);
 
-    src += source_stride;
+    src += src_stride;
     dst += dst_stride;
-  } while (++y < block_height);
+  } while (--y != 0);
 }
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS ||
+        // !defined(LIBGAV1_Dsp8bpp_CdefFilters) ||
+        // (LIBGAV1_MAX_BITDEPTH >= 10 &&
+        // !defined(LIBGAV1_Dsp10bpp_CdefFilters))
 
 void Init8bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
   assert(dsp != nullptr);
 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   dsp->cdef_direction = CdefDirection_C<8, uint8_t>;
-  dsp->cdef_filter = CdefFilter_C<8, uint8_t>;
+  dsp->cdef_filters[0][0] = CdefFilter_C<4, 8, uint8_t>;
+  dsp->cdef_filters[0][1] = CdefFilter_C<4, 8, uint8_t, /*enable_primary=*/true,
+                                         /*enable_secondary=*/false>;
+  dsp->cdef_filters[0][2] =
+      CdefFilter_C<4, 8, uint8_t, /*enable_primary=*/false>;
+  dsp->cdef_filters[1][0] = CdefFilter_C<8, 8, uint8_t>;
+  dsp->cdef_filters[1][1] = CdefFilter_C<8, 8, uint8_t, /*enable_primary=*/true,
+                                         /*enable_secondary=*/false>;
+  dsp->cdef_filters[1][2] =
+      CdefFilter_C<8, 8, uint8_t, /*enable_primary=*/false>;
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   static_cast<void>(dsp);
 #ifndef LIBGAV1_Dsp8bpp_CdefDirection
   dsp->cdef_direction = CdefDirection_C<8, uint8_t>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_CdefFilter
-  dsp->cdef_filter = CdefFilter_C<8, uint8_t>;
+#ifndef LIBGAV1_Dsp8bpp_CdefFilters
+  dsp->cdef_filters[0][0] = CdefFilter_C<4, 8, uint8_t>;
+  dsp->cdef_filters[0][1] = CdefFilter_C<4, 8, uint8_t, /*enable_primary=*/true,
+                                         /*enable_secondary=*/false>;
+  dsp->cdef_filters[0][2] =
+      CdefFilter_C<4, 8, uint8_t, /*enable_primary=*/false>;
+  dsp->cdef_filters[1][0] = CdefFilter_C<8, 8, uint8_t>;
+  dsp->cdef_filters[1][1] = CdefFilter_C<8, 8, uint8_t, /*enable_primary=*/true,
+                                         /*enable_secondary=*/false>;
+  dsp->cdef_filters[1][2] =
+      CdefFilter_C<8, 8, uint8_t, /*enable_primary=*/false>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
@@ -202,14 +258,36 @@
   assert(dsp != nullptr);
 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   dsp->cdef_direction = CdefDirection_C<10, uint16_t>;
-  dsp->cdef_filter = CdefFilter_C<10, uint16_t>;
+  dsp->cdef_filters[0][0] = CdefFilter_C<4, 10, uint16_t>;
+  dsp->cdef_filters[0][1] =
+      CdefFilter_C<4, 10, uint16_t, /*enable_primary=*/true,
+                   /*enable_secondary=*/false>;
+  dsp->cdef_filters[0][2] =
+      CdefFilter_C<4, 10, uint16_t, /*enable_primary=*/false>;
+  dsp->cdef_filters[1][0] = CdefFilter_C<8, 10, uint16_t>;
+  dsp->cdef_filters[1][1] =
+      CdefFilter_C<8, 10, uint16_t, /*enable_primary=*/true,
+                   /*enable_secondary=*/false>;
+  dsp->cdef_filters[1][2] =
+      CdefFilter_C<8, 10, uint16_t, /*enable_primary=*/false>;
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   static_cast<void>(dsp);
 #ifndef LIBGAV1_Dsp10bpp_CdefDirection
   dsp->cdef_direction = CdefDirection_C<10, uint16_t>;
 #endif
-#ifndef LIBGAV1_Dsp10bpp_CdefFilter
-  dsp->cdef_filter = CdefFilter_C<10, uint16_t>;
+#ifndef LIBGAV1_Dsp10bpp_CdefFilters
+  dsp->cdef_filters[0][0] = CdefFilter_C<4, 10, uint16_t>;
+  dsp->cdef_filters[0][1] =
+      CdefFilter_C<4, 10, uint16_t, /*enable_primary=*/true,
+                   /*enable_secondary=*/false>;
+  dsp->cdef_filters[0][2] =
+      CdefFilter_C<4, 10, uint16_t, /*enable_primary=*/false>;
+  dsp->cdef_filters[1][0] = CdefFilter_C<8, 10, uint16_t>;
+  dsp->cdef_filters[1][1] =
+      CdefFilter_C<8, 10, uint16_t, /*enable_primary=*/true,
+                   /*enable_secondary=*/false>;
+  dsp->cdef_filters[1][2] =
+      CdefFilter_C<8, 10, uint16_t, /*enable_primary=*/false>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
diff --git a/libgav1/src/dsp/cdef.h b/libgav1/src/dsp/cdef.h
index d97f4d2..2d70d2c 100644
--- a/libgav1/src/dsp/cdef.h
+++ b/libgav1/src/dsp/cdef.h
@@ -17,10 +17,27 @@
 #ifndef LIBGAV1_SRC_DSP_CDEF_H_
 #define LIBGAV1_SRC_DSP_CDEF_H_
 
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
+
+// ARM:
+#include "src/dsp/arm/cdef_neon.h"
+
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
+#include "src/dsp/x86/cdef_sse4.h"
+// clang-format on
+// IWYU pragma: end_exports
+
 namespace libgav1 {
 namespace dsp {
 
-// Initializes Dsp::cdef_direction and cdef::filter. This function is not
+// Initializes Dsp::cdef_direction and Dsp::cdef_filters. This function is not
 // thread-safe.
 void CdefInit_C();
 
diff --git a/libgav1/src/dsp/cdef.inc b/libgav1/src/dsp/cdef.inc
new file mode 100644
index 0000000..c1a3136
--- /dev/null
+++ b/libgav1/src/dsp/cdef.inc
@@ -0,0 +1,29 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Constants used for cdef implementations.
+// This will be included inside an anonymous namespace on files where these are
+// necessary.
+
+const int8_t (*const kCdefDirections)[2][2] = kCdefDirectionsPadded + 2;
+
+// Mirror values and pad to 16 elements.
+alignas(16) constexpr uint32_t kCdefDivisionTable[] = {
+    840, 420, 280, 210, 168, 140, 120, 105,
+    120, 140, 168, 210, 280, 420, 840, 0};
+
+// Used when calculating odd |cost[x]| values to mask off unwanted elements.
+// Holds elements 1 3 5 X 5 3 1 X
+alignas(16) constexpr uint32_t kCdefDivisionTableOdd[] = {420, 210, 140, 0,
+                                                          140, 210, 420, 0};
diff --git a/libgav1/src/dsp/common.h b/libgav1/src/dsp/common.h
index 62072c2..8ce3211 100644
--- a/libgav1/src/dsp/common.h
+++ b/libgav1/src/dsp/common.h
@@ -17,7 +17,6 @@
 #ifndef LIBGAV1_SRC_DSP_COMMON_H_
 #define LIBGAV1_SRC_DSP_COMMON_H_
 
-#include <cstddef>  // ptrdiff_t
 #include <cstdint>
 
 #include "src/dsp/constants.h"
@@ -26,10 +25,7 @@
 
 namespace libgav1 {
 
-struct LoopRestoration {
-  LoopRestorationType type[kMaxPlanes];
-  int unit_size[kMaxPlanes];
-};
+enum { kSgrStride = kRestorationUnitWidth + 8 };  // anonymous enum
 
 // Self guided projection filter.
 struct SgrProjInfo {
@@ -40,70 +36,44 @@
 struct WienerInfo {
   static const int kVertical = 0;
   static const int kHorizontal = 1;
-
-  alignas(kMaxAlignment) int16_t filter[2][kSubPixelTaps];
+  int16_t number_leading_zero_coefficients[2];
+  alignas(kMaxAlignment) int16_t filter[2][(kWienerFilterTaps + 1) / 2];
 };
 
-struct RestorationUnitInfo : public Allocable {
+struct RestorationUnitInfo : public MaxAlignedAllocable {
   LoopRestorationType type;
   SgrProjInfo sgr_proj_info;
   WienerInfo wiener_info;
 };
 
-struct RestorationBuffer {
-  // For self-guided filter.
-  int* box_filter_process_output[2];
-  ptrdiff_t box_filter_process_output_stride;
-  uint32_t* box_filter_process_intermediate[2];
-  ptrdiff_t box_filter_process_intermediate_stride;
-  // For wiener filter.
-  uint16_t* wiener_buffer;
-  ptrdiff_t wiener_buffer_stride;
+struct SgrBuffer {
+  alignas(kMaxAlignment) uint16_t sum3[4 * kSgrStride];
+  alignas(kMaxAlignment) uint16_t sum5[5 * kSgrStride];
+  alignas(kMaxAlignment) uint32_t square_sum3[4 * kSgrStride];
+  alignas(kMaxAlignment) uint32_t square_sum5[5 * kSgrStride];
+  alignas(kMaxAlignment) uint16_t ma343[4 * kRestorationUnitWidth];
+  alignas(kMaxAlignment) uint16_t ma444[3 * kRestorationUnitWidth];
+  alignas(kMaxAlignment) uint16_t ma565[2 * kRestorationUnitWidth];
+  alignas(kMaxAlignment) uint32_t b343[4 * kRestorationUnitWidth];
+  alignas(kMaxAlignment) uint32_t b444[3 * kRestorationUnitWidth];
+  alignas(kMaxAlignment) uint32_t b565[2 * kRestorationUnitWidth];
+  alignas(kMaxAlignment) uint16_t
+      temp_buffer[12 * (kRestorationUnitHeight + 2)];
+  alignas(kMaxAlignment) uint8_t ma[kSgrStride];  // [0, 255]
+  // b is less than 2^16 for 8-bit. However, making it a template slows down the
+  // C function by 5%. So b is fixed to 32-bit.
+  alignas(kMaxAlignment) uint32_t b[kSgrStride];
 };
 
-// Section 6.8.20.
-// Note: In spec, film grain section uses YCbCr to denote variable names,
-// such as num_cb_points, num_cr_points. To keep it consistent with other
-// parts of code, we use YUV, i.e., num_u_points, num_v_points, etc.
-struct FilmGrainParams {
-  bool apply_grain;
-  bool update_grain;
-  bool chroma_scaling_from_luma;
-  bool overlap_flag;
-  bool clip_to_restricted_range;
-
-  uint8_t num_y_points;  // [0, 14].
-  uint8_t num_u_points;  // [0, 10].
-  uint8_t num_v_points;  // [0, 10].
-  // Must be [0, 255]. 10/12 bit /= 4 or 16. Must be in increasing order.
-  uint8_t point_y_value[14];
-  uint8_t point_y_scaling[14];
-  uint8_t point_u_value[10];
-  uint8_t point_u_scaling[10];
-  uint8_t point_v_value[10];
-  uint8_t point_v_scaling[10];
-
-  uint8_t chroma_scaling;             // [8, 11].
-  uint8_t auto_regression_coeff_lag;  // [0, 3].
-  int auto_regression_coeff_y[24];
-  int auto_regression_coeff_u[25];
-  int auto_regression_coeff_v[25];
-  // Shift value: auto regression coeffs range
-  // 6: [-2, 2)
-  // 7: [-1, 1)
-  // 8: [-0.5, 0.5)
-  // 9: [-0.25, 0.25)
-  uint8_t auto_regression_shift;
-
-  uint16_t grain_seed;
-  int reference_index;
-  int grain_scale_shift;
-  int u_multiplier;
-  int u_luma_multiplier;
-  int u_offset;
-  int v_multiplier;
-  int v_luma_multiplier;
-  int v_offset;
+union RestorationBuffer {
+  // For self-guided filter.
+  SgrBuffer sgr_buffer;
+  // For wiener filter.
+  // The array |intermediate| in Section 7.17.4, the intermediate results
+  // between the horizontal and vertical filters.
+  alignas(kMaxAlignment) int16_t
+      wiener_buffer[(kRestorationUnitHeight + kWienerFilterTaps - 1) *
+                    kRestorationUnitWidth];
 };
 
 }  // namespace libgav1
diff --git a/libgav1/src/dsp/constants.cc b/libgav1/src/dsp/constants.cc
index 7c83e24..0099ca3 100644
--- a/libgav1/src/dsp/constants.cc
+++ b/libgav1/src/dsp/constants.cc
@@ -79,4 +79,25 @@
     {0, 1177},   {0, 925},    {56, 0},    {22, 0},
 };
 
+const uint8_t kCdefPrimaryTaps[2][2] = {{4, 2}, {3, 3}};
+
+// This is Cdef_Directions (section 7.15.3) with 2 padding entries at the
+// beginning and end of the table. The cdef direction range is [0, 7] and the
+// first index is offset +/-2. This removes the need to constrain the first
+// index to the same range using e.g., & 7.
+const int8_t kCdefDirectionsPadded[12][2][2] = {
+    {{1, 0}, {2, 0}},    // Padding: Cdef_Directions[6]
+    {{1, 0}, {2, -1}},   // Padding: Cdef_Directions[7]
+    {{-1, 1}, {-2, 2}},  // Begin Cdef_Directions
+    {{0, 1}, {-1, 2}},   //
+    {{0, 1}, {0, 2}},    //
+    {{0, 1}, {1, 2}},    //
+    {{1, 1}, {2, 2}},    //
+    {{1, 0}, {2, 1}},    //
+    {{1, 0}, {2, 0}},    //
+    {{1, 0}, {2, -1}},   // End Cdef_Directions
+    {{-1, 1}, {-2, 2}},  // Padding: Cdef_Directions[0]
+    {{0, 1}, {-1, 2}},   // Padding: Cdef_Directions[1]
+};
+
 }  // namespace libgav1
diff --git a/libgav1/src/dsp/constants.h b/libgav1/src/dsp/constants.h
index 00ebf12..7c1b62c 100644
--- a/libgav1/src/dsp/constants.h
+++ b/libgav1/src/dsp/constants.h
@@ -27,6 +27,10 @@
 namespace libgav1 {
 
 enum {
+  // Documentation variables.
+  kBitdepth8 = 8,
+  kBitdepth10 = 10,
+  kBitdepth12 = 12,
   // Weights are quadratic from '1' to '1 / block_size', scaled by
   // 2^kSmoothWeightScale.
   kSmoothWeightScale = 8,
@@ -34,8 +38,14 @@
   // InterRound0, Section 7.11.3.2.
   kInterRoundBitsHorizontal = 3,  // 8 & 10-bit.
   kInterRoundBitsHorizontal12bpp = 5,
-  kInterRoundBitsVertical = 11,  // 8 & 10-bit, single prediction.
+  kInterRoundBitsCompoundVertical = 7,  // 8, 10 & 12-bit compound prediction.
+  kInterRoundBitsVertical = 11,         // 8 & 10-bit, single prediction.
   kInterRoundBitsVertical12bpp = 9,
+  // Offset applied to 10bpp and 12bpp predictors to allow storing them in
+  // uint16_t. Removed before blending.
+  kCompoundOffset = (1 << 14) + (1 << 13),
+  kCdefSecondaryTap0 = 2,
+  kCdefSecondaryTap1 = 1,
 };  // anonymous enum
 
 extern const int8_t kFilterIntraTaps[kNumFilterIntraPredictors][8][8];
@@ -52,6 +62,10 @@
 
 extern const uint16_t kSgrScaleParameter[16][2];
 
+extern const uint8_t kCdefPrimaryTaps[2][2];
+
+extern const int8_t kCdefDirectionsPadded[12][2][2];
+
 }  // namespace libgav1
 
 #endif  // LIBGAV1_SRC_DSP_CONSTANTS_H_
diff --git a/libgav1/src/dsp/convolve.cc b/libgav1/src/dsp/convolve.cc
index 5358473..c8df357 100644
--- a/libgav1/src/dsp/convolve.cc
+++ b/libgav1/src/dsp/convolve.cc
@@ -29,33 +29,44 @@
 namespace dsp {
 namespace {
 
-constexpr int kSubPixelMask = (1 << kSubPixelBits) - 1;
 constexpr int kHorizontalOffset = 3;
 constexpr int kVerticalOffset = 3;
 
-int GetFilterIndex(const int filter_index, const int length) {
-  if (length <= 4) {
-    if (filter_index == kInterpolationFilterEightTap ||
-        filter_index == kInterpolationFilterEightTapSharp) {
-      return 4;
-    }
-    if (filter_index == kInterpolationFilterEightTapSmooth) {
-      return 5;
-    }
-  }
-  return filter_index;
-}
+// Compound prediction output ranges from ConvolveTest.ShowRange.
+// Bitdepth:  8 Input range:            [       0,      255]
+//   intermediate range:                [   -7140,    23460]
+//   first pass output range:           [   -1785,     5865]
+//   intermediate range:                [ -328440,   589560]
+//   second pass output range:          [       0,      255]
+//   compound second pass output range: [   -5132,     9212]
+//
+// Bitdepth: 10 Input range:            [       0,     1023]
+//   intermediate range:                [  -28644,    94116]
+//   first pass output range:           [   -7161,    23529]
+//   intermediate range:                [-1317624,  2365176]
+//   second pass output range:          [       0,     1023]
+//   compound second pass output range: [    3988,    61532]
+//
+// Bitdepth: 12 Input range:            [       0,     4095]
+//   intermediate range:                [ -114660,   376740]
+//   first pass output range:           [   -7166,    23546]
+//   intermediate range:                [-1318560,  2366880]
+//   second pass output range:          [       0,     4095]
+//   compound second pass output range: [    3974,    61559]
 
 template <int bitdepth, typename Pixel>
-void ConvolveScale2D_C(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int horizontal_filter_index, const int vertical_filter_index,
-    const int inter_round_bits_vertical, const int subpixel_x,
-    const int subpixel_y, const int step_x, const int step_y, const int width,
-    const int height, void* prediction, const ptrdiff_t pred_stride) {
+void ConvolveScale2D_C(const void* const reference,
+                       const ptrdiff_t reference_stride,
+                       const int horizontal_filter_index,
+                       const int vertical_filter_index, const int subpixel_x,
+                       const int subpixel_y, const int step_x, const int step_y,
+                       const int width, const int height, void* prediction,
+                       const ptrdiff_t pred_stride) {
   constexpr int kRoundBitsHorizontal = (bitdepth == 12)
                                            ? kInterRoundBitsHorizontal12bpp
                                            : kInterRoundBitsHorizontal;
+  constexpr int kRoundBitsVertical =
+      (bitdepth == 12) ? kInterRoundBitsVertical12bpp : kInterRoundBitsVertical;
   const int intermediate_height =
       (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
        kScaleSubPixelBits) +
@@ -65,7 +76,6 @@
   int16_t intermediate_result[kMaxSuperBlockSizeInPixels *
                               (2 * kMaxSuperBlockSizeInPixels + 8)];
   const int intermediate_stride = kMaxSuperBlockSizeInPixels;
-  const int single_round_offset = (1 << bitdepth) + (1 << (bitdepth - 1));
   const int max_pixel_value = (1 << bitdepth) - 1;
 
   // Horizontal filter.
@@ -87,16 +97,13 @@
     int p = subpixel_x;
     int x = 0;
     do {
-      // An offset to guarantee the sum is non negative.
-      int sum = 1 << (bitdepth + kFilterBits - 1);
+      int sum = 0;
       const Pixel* src_x = &src[(p >> kScaleSubPixelBits) - ref_x];
       const int filter_id = (p >> 6) & kSubPixelMask;
       for (int k = 0; k < kSubPixelTaps; ++k) {
-        sum += kSubPixelFilters[filter_index][filter_id][k] * src_x[k];
+        sum += kHalfSubPixelFilters[filter_index][filter_id][k] * src_x[k];
       }
-      assert(sum >= 0 && sum < (1 << (bitdepth + kFilterBits + 1)));
-      intermediate[x] = static_cast<int16_t>(
-          RightShiftWithRounding(sum, kRoundBitsHorizontal));
+      intermediate[x] = RightShiftWithRounding(sum, kRoundBitsHorizontal - 1);
       p += step_x;
     } while (++x < width);
 
@@ -107,26 +114,21 @@
   // Vertical filter.
   filter_index = GetFilterIndex(vertical_filter_index, height);
   intermediate = intermediate_result;
-  const int offset_bits = bitdepth + 2 * kFilterBits - kRoundBitsHorizontal;
   int p = subpixel_y & 1023;
   y = 0;
   do {
     const int filter_id = (p >> 6) & kSubPixelMask;
     int x = 0;
     do {
-      // An offset to guarantee the sum is non negative.
-      int sum = 1 << offset_bits;
+      int sum = 0;
       for (int k = 0; k < kSubPixelTaps; ++k) {
         sum +=
-            kSubPixelFilters[filter_index][filter_id][k] *
+            kHalfSubPixelFilters[filter_index][filter_id][k] *
             intermediate[((p >> kScaleSubPixelBits) + k) * intermediate_stride +
                          x];
       }
-      assert(sum >= 0 && sum < (1 << (offset_bits + 2)));
-      dest[x] = static_cast<Pixel>(
-          Clip3(RightShiftWithRounding(sum, inter_round_bits_vertical) -
-                    single_round_offset,
-                0, max_pixel_value));
+      dest[x] = Clip3(RightShiftWithRounding(sum, kRoundBitsVertical - 1), 0,
+                      max_pixel_value);
     } while (++x < width);
 
     dest += dest_stride;
@@ -135,15 +137,23 @@
 }
 
 template <int bitdepth, typename Pixel>
-void ConvolveCompoundScale2D_C(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int horizontal_filter_index, const int vertical_filter_index,
-    const int inter_round_bits_vertical, const int subpixel_x,
-    const int subpixel_y, const int step_x, const int step_y, const int width,
-    const int height, void* prediction, const ptrdiff_t pred_stride) {
+void ConvolveCompoundScale2D_C(const void* const reference,
+                               const ptrdiff_t reference_stride,
+                               const int horizontal_filter_index,
+                               const int vertical_filter_index,
+                               const int subpixel_x, const int subpixel_y,
+                               const int step_x, const int step_y,
+                               const int width, const int height,
+                               void* prediction, const ptrdiff_t pred_stride) {
+  // All compound functions output to the predictor buffer with |pred_stride|
+  // equal to |width|.
+  assert(pred_stride == width);
+  // Compound functions start at 4x4.
+  assert(width >= 4 && height >= 4);
   constexpr int kRoundBitsHorizontal = (bitdepth == 12)
                                            ? kInterRoundBitsHorizontal12bpp
                                            : kInterRoundBitsHorizontal;
+  constexpr int kRoundBitsVertical = kInterRoundBitsCompoundVertical;
   const int intermediate_height =
       (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
        kScaleSubPixelBits) +
@@ -172,16 +182,13 @@
     int p = subpixel_x;
     int x = 0;
     do {
-      // An offset to guarantee the sum is non negative.
-      int sum = 1 << (bitdepth + kFilterBits - 1);
+      int sum = 0;
       const Pixel* src_x = &src[(p >> kScaleSubPixelBits) - ref_x];
       const int filter_id = (p >> 6) & kSubPixelMask;
       for (int k = 0; k < kSubPixelTaps; ++k) {
-        sum += kSubPixelFilters[filter_index][filter_id][k] * src_x[k];
+        sum += kHalfSubPixelFilters[filter_index][filter_id][k] * src_x[k];
       }
-      assert(sum >= 0 && sum < (1 << (bitdepth + kFilterBits + 1)));
-      intermediate[x] = static_cast<int16_t>(
-          RightShiftWithRounding(sum, kRoundBitsHorizontal));
+      intermediate[x] = RightShiftWithRounding(sum, kRoundBitsHorizontal - 1);
       p += step_x;
     } while (++x < width);
 
@@ -192,24 +199,22 @@
   // Vertical filter.
   filter_index = GetFilterIndex(vertical_filter_index, height);
   intermediate = intermediate_result;
-  const int offset_bits = bitdepth + 2 * kFilterBits - kRoundBitsHorizontal;
   int p = subpixel_y & 1023;
   y = 0;
   do {
     const int filter_id = (p >> 6) & kSubPixelMask;
     int x = 0;
     do {
-      // An offset to guarantee the sum is non negative.
-      int sum = 1 << offset_bits;
+      int sum = 0;
       for (int k = 0; k < kSubPixelTaps; ++k) {
         sum +=
-            kSubPixelFilters[filter_index][filter_id][k] *
+            kHalfSubPixelFilters[filter_index][filter_id][k] *
             intermediate[((p >> kScaleSubPixelBits) + k) * intermediate_stride +
                          x];
       }
-      assert(sum >= 0 && sum < (1 << (offset_bits + 2)));
-      dest[x] = static_cast<uint16_t>(
-          RightShiftWithRounding(sum, inter_round_bits_vertical));
+      sum = RightShiftWithRounding(sum, kRoundBitsVertical - 1);
+      sum += (bitdepth == 8) ? 0 : kCompoundOffset;
+      dest[x] = sum;
     } while (++x < width);
 
     dest += pred_stride;
@@ -221,15 +226,19 @@
 void ConvolveCompound2D_C(const void* const reference,
                           const ptrdiff_t reference_stride,
                           const int horizontal_filter_index,
-                          const int vertical_filter_index,
-                          const int inter_round_bits_vertical,
-                          const int subpixel_x, const int subpixel_y,
-                          const int /*step_x*/, const int /*step_y*/,
-                          const int width, const int height, void* prediction,
+                          const int vertical_filter_index, const int subpixel_x,
+                          const int subpixel_y, const int width,
+                          const int height, void* prediction,
                           const ptrdiff_t pred_stride) {
+  // All compound functions output to the predictor buffer with |pred_stride|
+  // equal to |width|.
+  assert(pred_stride == width);
+  // Compound functions start at 4x4.
+  assert(width >= 4 && height >= 4);
   constexpr int kRoundBitsHorizontal = (bitdepth == 12)
                                            ? kInterRoundBitsHorizontal12bpp
                                            : kInterRoundBitsHorizontal;
+  constexpr int kRoundBitsVertical = kInterRoundBitsCompoundVertical;
   const int intermediate_height = height + kSubPixelTaps - 1;
   // The output of the horizontal filter, i.e. the intermediate_result, is
   // guaranteed to fit in int16_t.
@@ -249,18 +258,17 @@
                     kVerticalOffset * src_stride - kHorizontalOffset;
   auto* dest = static_cast<uint16_t*>(prediction);
   int filter_id = (subpixel_x >> 6) & kSubPixelMask;
+  // If |filter_id| == 0 then ConvolveVertical() should be called.
+  assert(filter_id != 0);
   int y = 0;
   do {
     int x = 0;
     do {
-      // An offset to guarantee the sum is non negative.
-      int sum = 1 << (bitdepth + kFilterBits - 1);
+      int sum = 0;
       for (int k = 0; k < kSubPixelTaps; ++k) {
-        sum += kSubPixelFilters[filter_index][filter_id][k] * src[x + k];
+        sum += kHalfSubPixelFilters[filter_index][filter_id][k] * src[x + k];
       }
-      assert(sum >= 0 && sum < (1 << (bitdepth + kFilterBits + 1)));
-      intermediate[x] = static_cast<int16_t>(
-          RightShiftWithRounding(sum, kRoundBitsHorizontal));
+      intermediate[x] = RightShiftWithRounding(sum, kRoundBitsHorizontal - 1);
     } while (++x < width);
 
     src += src_stride;
@@ -271,20 +279,20 @@
   filter_index = GetFilterIndex(vertical_filter_index, height);
   intermediate = intermediate_result;
   filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
-  const int offset_bits = bitdepth + 2 * kFilterBits - kRoundBitsHorizontal;
+  // If |filter_id| == 0 then ConvolveHorizontal() should be called.
+  assert(filter_id != 0);
   y = 0;
   do {
     int x = 0;
     do {
-      // An offset to guarantee the sum is non negative.
-      int sum = 1 << offset_bits;
+      int sum = 0;
       for (int k = 0; k < kSubPixelTaps; ++k) {
-        sum += kSubPixelFilters[filter_index][filter_id][k] *
+        sum += kHalfSubPixelFilters[filter_index][filter_id][k] *
                intermediate[k * intermediate_stride + x];
       }
-      assert(sum >= 0 && sum < (1 << (offset_bits + 2)));
-      dest[x] = static_cast<uint16_t>(
-          RightShiftWithRounding(sum, inter_round_bits_vertical));
+      sum = RightShiftWithRounding(sum, kRoundBitsVertical - 1);
+      sum += (bitdepth == 8) ? 0 : kCompoundOffset;
+      dest[x] = sum;
     } while (++x < width);
 
     dest += pred_stride;
@@ -300,21 +308,20 @@
 template <int bitdepth, typename Pixel>
 void Convolve2D_C(const void* const reference, const ptrdiff_t reference_stride,
                   const int horizontal_filter_index,
-                  const int vertical_filter_index,
-                  const int inter_round_bits_vertical, const int subpixel_x,
-                  const int subpixel_y, const int /*step_x*/,
-                  const int /*step_y*/, const int width, const int height,
+                  const int vertical_filter_index, const int subpixel_x,
+                  const int subpixel_y, const int width, const int height,
                   void* prediction, const ptrdiff_t pred_stride) {
   constexpr int kRoundBitsHorizontal = (bitdepth == 12)
                                            ? kInterRoundBitsHorizontal12bpp
                                            : kInterRoundBitsHorizontal;
+  constexpr int kRoundBitsVertical =
+      (bitdepth == 12) ? kInterRoundBitsVertical12bpp : kInterRoundBitsVertical;
   const int intermediate_height = height + kSubPixelTaps - 1;
   // The output of the horizontal filter, i.e. the intermediate_result, is
   // guaranteed to fit in int16_t.
   int16_t intermediate_result[kMaxSuperBlockSizeInPixels *
                               (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
   const int intermediate_stride = kMaxSuperBlockSizeInPixels;
-  const int single_round_offset = (1 << bitdepth) + (1 << (bitdepth - 1));
   const int max_pixel_value = (1 << bitdepth) - 1;
 
   // Horizontal filter.
@@ -330,18 +337,17 @@
   auto* dest = static_cast<Pixel*>(prediction);
   const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel);
   int filter_id = (subpixel_x >> 6) & kSubPixelMask;
+  // If |filter_id| == 0 then ConvolveVertical() should be called.
+  assert(filter_id != 0);
   int y = 0;
   do {
     int x = 0;
     do {
-      // An offset to guarantee the sum is non negative.
-      int sum = 1 << (bitdepth + kFilterBits - 1);
+      int sum = 0;
       for (int k = 0; k < kSubPixelTaps; ++k) {
-        sum += kSubPixelFilters[filter_index][filter_id][k] * src[x + k];
+        sum += kHalfSubPixelFilters[filter_index][filter_id][k] * src[x + k];
       }
-      assert(sum >= 0 && sum < (1 << (bitdepth + kFilterBits + 1)));
-      intermediate[x] = static_cast<int16_t>(
-          RightShiftWithRounding(sum, kRoundBitsHorizontal));
+      intermediate[x] = RightShiftWithRounding(sum, kRoundBitsHorizontal - 1);
     } while (++x < width);
 
     src += src_stride;
@@ -352,22 +358,19 @@
   filter_index = GetFilterIndex(vertical_filter_index, height);
   intermediate = intermediate_result;
   filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
-  const int offset_bits = bitdepth + 2 * kFilterBits - kRoundBitsHorizontal;
+  // If |filter_id| == 0 then ConvolveHorizontal() should be called.
+  assert(filter_id != 0);
   y = 0;
   do {
     int x = 0;
     do {
-      // An offset to guarantee the sum is non negative.
-      int sum = 1 << offset_bits;
+      int sum = 0;
       for (int k = 0; k < kSubPixelTaps; ++k) {
-        sum += kSubPixelFilters[filter_index][filter_id][k] *
+        sum += kHalfSubPixelFilters[filter_index][filter_id][k] *
                intermediate[k * intermediate_stride + x];
       }
-      assert(sum >= 0 && sum < (1 << (offset_bits + 2)));
-      dest[x] = static_cast<Pixel>(
-          Clip3(RightShiftWithRounding(sum, inter_round_bits_vertical) -
-                    single_round_offset,
-                0, max_pixel_value));
+      dest[x] = Clip3(RightShiftWithRounding(sum, kRoundBitsVertical - 1), 0,
+                      max_pixel_value);
     } while (++x < width);
 
     dest += dest_stride;
@@ -385,9 +388,7 @@
                           const ptrdiff_t reference_stride,
                           const int horizontal_filter_index,
                           const int /*vertical_filter_index*/,
-                          const int /*inter_round_bits_vertical*/,
                           const int subpixel_x, const int /*subpixel_y*/,
-                          const int /*step_x*/, const int /*step_y*/,
                           const int width, const int height, void* prediction,
                           const ptrdiff_t pred_stride) {
   constexpr int kRoundBitsHorizontal = (bitdepth == 12)
@@ -407,11 +408,10 @@
     do {
       int sum = 0;
       for (int k = 0; k < kSubPixelTaps; ++k) {
-        sum += kSubPixelFilters[filter_index][filter_id][k] * src[x + k];
+        sum += kHalfSubPixelFilters[filter_index][filter_id][k] * src[x + k];
       }
-      sum = RightShiftWithRounding(sum, kRoundBitsHorizontal);
-      dest[x] = static_cast<Pixel>(
-          Clip3(RightShiftWithRounding(sum, bits), 0, max_pixel_value));
+      sum = RightShiftWithRounding(sum, kRoundBitsHorizontal - 1);
+      dest[x] = Clip3(RightShiftWithRounding(sum, bits), 0, max_pixel_value);
     } while (++x < width);
 
     src += src_stride;
@@ -429,9 +429,7 @@
                         const ptrdiff_t reference_stride,
                         const int /*horizontal_filter_index*/,
                         const int vertical_filter_index,
-                        const int /*inter_round_bits_vertical*/,
                         const int /*subpixel_x*/, const int subpixel_y,
-                        const int /*step_x*/, const int /*step_y*/,
                         const int width, const int height, void* prediction,
                         const ptrdiff_t pred_stride) {
   const int filter_index = GetFilterIndex(vertical_filter_index, height);
@@ -441,18 +439,9 @@
   auto* dest = static_cast<Pixel*>(prediction);
   const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel);
   const int filter_id = (subpixel_y >> 6) & kSubPixelMask;
-  // First filter is always a copy.
-  if (filter_id == 0) {
-    // Move |src| down the actual values and not the start of the context.
-    src = static_cast<const Pixel*>(reference);
-    int y = 0;
-    do {
-      memcpy(dest, src, width * sizeof(src[0]));
-      src += src_stride;
-      dest += dest_stride;
-    } while (++y < height);
-    return;
-  }
+  // Copy filters must call ConvolveCopy().
+  assert(filter_id != 0);
+
   const int max_pixel_value = (1 << bitdepth) - 1;
   int y = 0;
   do {
@@ -460,11 +449,11 @@
     do {
       int sum = 0;
       for (int k = 0; k < kSubPixelTaps; ++k) {
-        sum += kSubPixelFilters[filter_index][filter_id][k] *
+        sum += kHalfSubPixelFilters[filter_index][filter_id][k] *
                src[k * src_stride + x];
       }
-      dest[x] = static_cast<Pixel>(
-          Clip3(RightShiftWithRounding(sum, kFilterBits), 0, max_pixel_value));
+      dest[x] = Clip3(RightShiftWithRounding(sum, kFilterBits - 1), 0,
+                      max_pixel_value);
     } while (++x < width);
 
     src += src_stride;
@@ -477,10 +466,8 @@
                     const ptrdiff_t reference_stride,
                     const int /*horizontal_filter_index*/,
                     const int /*vertical_filter_index*/,
-                    const int /*inter_round_bits_vertical*/,
                     const int /*subpixel_x*/, const int /*subpixel_y*/,
-                    const int /*step_x*/, const int /*step_y*/, const int width,
-                    const int height, void* prediction,
+                    const int width, const int height, void* prediction,
                     const ptrdiff_t pred_stride) {
   const auto* src = static_cast<const uint8_t*>(reference);
   auto* dest = static_cast<uint8_t*>(prediction);
@@ -497,23 +484,29 @@
                             const ptrdiff_t reference_stride,
                             const int /*horizontal_filter_index*/,
                             const int /*vertical_filter_index*/,
-                            const int /*inter_round_bits_vertical*/,
                             const int /*subpixel_x*/, const int /*subpixel_y*/,
-                            const int /*step_x*/, const int /*step_y*/,
                             const int width, const int height, void* prediction,
                             const ptrdiff_t pred_stride) {
+  // All compound functions output to the predictor buffer with |pred_stride|
+  // equal to |width|.
+  assert(pred_stride == width);
+  // Compound functions start at 4x4.
+  assert(width >= 4 && height >= 4);
+  constexpr int kRoundBitsVertical =
+      ((bitdepth == 12) ? kInterRoundBitsVertical12bpp
+                        : kInterRoundBitsVertical) -
+      kInterRoundBitsCompoundVertical;
   const auto* src = static_cast<const Pixel*>(reference);
   const ptrdiff_t src_stride = reference_stride / sizeof(Pixel);
   auto* dest = static_cast<uint16_t*>(prediction);
-  const int compound_round_offset =
-      (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
   int y = 0;
   do {
     int x = 0;
     do {
-      dest[x] = (src[x] << 4) + compound_round_offset;
+      int sum = (bitdepth == 8) ? 0 : ((1 << bitdepth) + (1 << (bitdepth - 1)));
+      sum += src[x];
+      dest[x] = sum << kRoundBitsVertical;
     } while (++x < width);
-
     src += src_stride;
     dest += pred_stride;
   } while (++y < height);
@@ -528,10 +521,13 @@
 void ConvolveCompoundHorizontal_C(
     const void* const reference, const ptrdiff_t reference_stride,
     const int horizontal_filter_index, const int /*vertical_filter_index*/,
-    const int inter_round_bits_vertical, const int subpixel_x,
-    const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
-    const int width, const int height, void* prediction,
-    const ptrdiff_t pred_stride) {
+    const int subpixel_x, const int /*subpixel_y*/, const int width,
+    const int height, void* prediction, const ptrdiff_t pred_stride) {
+  // All compound functions output to the predictor buffer with |pred_stride|
+  // equal to |width|.
+  assert(pred_stride == width);
+  // Compound functions start at 4x4.
+  assert(width >= 4 && height >= 4);
   constexpr int kRoundBitsHorizontal = (bitdepth == 12)
                                            ? kInterRoundBitsHorizontal12bpp
                                            : kInterRoundBitsHorizontal;
@@ -540,19 +536,19 @@
   const ptrdiff_t src_stride = reference_stride / sizeof(Pixel);
   auto* dest = static_cast<uint16_t*>(prediction);
   const int filter_id = (subpixel_x >> 6) & kSubPixelMask;
-  const int bits_shift = kFilterBits - inter_round_bits_vertical;
-  const int compound_round_offset =
-      (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
+  // Copy filters must call ConvolveCopy().
+  assert(filter_id != 0);
   int y = 0;
   do {
     int x = 0;
     do {
       int sum = 0;
       for (int k = 0; k < kSubPixelTaps; ++k) {
-        sum += kSubPixelFilters[filter_index][filter_id][k] * src[x + k];
+        sum += kHalfSubPixelFilters[filter_index][filter_id][k] * src[x + k];
       }
-      sum = RightShiftWithRounding(sum, kRoundBitsHorizontal) << bits_shift;
-      dest[x] = sum + compound_round_offset;
+      sum = RightShiftWithRounding(sum, kRoundBitsHorizontal - 1);
+      sum += (bitdepth == 8) ? 0 : kCompoundOffset;
+      dest[x] = sum;
     } while (++x < width);
 
     src += src_stride;
@@ -570,11 +566,14 @@
                                 const ptrdiff_t reference_stride,
                                 const int /*horizontal_filter_index*/,
                                 const int vertical_filter_index,
-                                const int inter_round_bits_vertical,
                                 const int /*subpixel_x*/, const int subpixel_y,
-                                const int /*step_x*/, const int /*step_y*/,
                                 const int width, const int height,
                                 void* prediction, const ptrdiff_t pred_stride) {
+  // All compound functions output to the predictor buffer with |pred_stride|
+  // equal to |width|.
+  assert(pred_stride == width);
+  // Compound functions start at 4x4.
+  assert(width >= 4 && height >= 4);
   constexpr int kRoundBitsHorizontal = (bitdepth == 12)
                                            ? kInterRoundBitsHorizontal12bpp
                                            : kInterRoundBitsHorizontal;
@@ -584,23 +583,21 @@
       static_cast<const Pixel*>(reference) - kVerticalOffset * src_stride;
   auto* dest = static_cast<uint16_t*>(prediction);
   const int filter_id = (subpixel_y >> 6) & kSubPixelMask;
-  const int bits_shift = kFilterBits - kRoundBitsHorizontal;
-  const int compound_round_offset =
-      (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
+  // Copy filters must call ConvolveCopy().
+  assert(filter_id != 0);
   int y = 0;
   do {
     int x = 0;
     do {
       int sum = 0;
       for (int k = 0; k < kSubPixelTaps; ++k) {
-        sum += kSubPixelFilters[filter_index][filter_id][k] *
+        sum += kHalfSubPixelFilters[filter_index][filter_id][k] *
                src[k * src_stride + x];
       }
-      dest[x] = RightShiftWithRounding(LeftShift(sum, bits_shift),
-                                       inter_round_bits_vertical) +
-                compound_round_offset;
+      sum = RightShiftWithRounding(sum, kRoundBitsHorizontal - 1);
+      sum += (bitdepth == 8) ? 0 : kCompoundOffset;
+      dest[x] = sum;
     } while (++x < width);
-
     src += src_stride;
     dest += pred_stride;
   } while (++y < height);
@@ -616,13 +613,11 @@
 void ConvolveIntraBlockCopy2D_C(
     const void* const reference, const ptrdiff_t reference_stride,
     const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const int /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
-    const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
-    const int width, const int height, void* prediction,
-    const ptrdiff_t pred_stride) {
-  const auto* src = reinterpret_cast<const Pixel*>(reference);
+    const int /*subpixel_x*/, const int /*subpixel_y*/, const int width,
+    const int height, void* prediction, const ptrdiff_t pred_stride) {
+  const auto* src = static_cast<const Pixel*>(reference);
   const ptrdiff_t src_stride = reference_stride / sizeof(Pixel);
-  auto* dest = reinterpret_cast<Pixel*>(prediction);
+  auto* dest = static_cast<Pixel*>(prediction);
   const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel);
   const int intermediate_height = height + 1;
   uint16_t intermediate_result[kMaxSuperBlockSizeInPixels *
@@ -647,8 +642,8 @@
   do {
     int x = 0;
     do {
-      dest[x] = static_cast<Pixel>(
-          RightShiftWithRounding(intermediate[x] + intermediate[x + width], 2));
+      dest[x] =
+          RightShiftWithRounding(intermediate[x] + intermediate[x + width], 2);
     } while (++x < width);
 
     intermediate += width;
@@ -668,21 +663,18 @@
 void ConvolveIntraBlockCopy1D_C(
     const void* const reference, const ptrdiff_t reference_stride,
     const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const int /*inter_round_bits_vertical*/, const int /*subpixel_x*/,
-    const int /*subpixel_y*/, const int /*step_x*/, const int /*step_y*/,
-    const int width, const int height, void* prediction,
-    const ptrdiff_t pred_stride) {
-  const auto* src = reinterpret_cast<const Pixel*>(reference);
+    const int /*subpixel_x*/, const int /*subpixel_y*/, const int width,
+    const int height, void* prediction, const ptrdiff_t pred_stride) {
+  const auto* src = static_cast<const Pixel*>(reference);
   const ptrdiff_t src_stride = reference_stride / sizeof(Pixel);
-  auto* dest = reinterpret_cast<Pixel*>(prediction);
+  auto* dest = static_cast<Pixel*>(prediction);
   const ptrdiff_t dest_stride = pred_stride / sizeof(Pixel);
   const ptrdiff_t offset = is_horizontal ? 1 : src_stride;
   int y = 0;
   do {
     int x = 0;
     do {
-      dest[x] = static_cast<Pixel>(
-          RightShiftWithRounding(src[x] + src[x + offset], 1));
+      dest[x] = RightShiftWithRounding(src[x] + src[x + offset], 1);
     } while (++x < width);
 
     src += src_stride;
diff --git a/libgav1/src/dsp/cpu.h b/libgav1/src/dsp/cpu.h
deleted file mode 100644
index 70816f5..0000000
--- a/libgav1/src/dsp/cpu.h
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Copyright 2019 The libgav1 Authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBGAV1_SRC_DSP_CPU_H_
-#define LIBGAV1_SRC_DSP_CPU_H_
-
-#include <cstdint>
-
-namespace libgav1 {
-namespace dsp {
-
-enum CpuFeatures : uint8_t {
-  kSSE2 = 1 << 0,
-#define LIBGAV1_DSP_SSE2 (1 << 0)
-  kSSSE3 = 1 << 1,
-#define LIBGAV1_DSP_SSSE3 (1 << 1)
-  kSSE4_1 = 1 << 2,
-#define LIBGAV1_DSP_SSE4_1 (1 << 2)
-  kAVX = 1 << 3,
-#define LIBGAV1_DSP_AVX (1 << 3)
-  kAVX2 = 1 << 4,
-#define LIBGAV1_DSP_AVX2 (1 << 4)
-  kNEON = 1 << 5,
-#define LIBGAV1_DSP_NEON (1 << 5)
-};
-
-// Returns a bit-wise OR of CpuFeatures supported by this platform.
-uint32_t GetCpuInfo();
-
-}  // namespace dsp
-}  // namespace libgav1
-
-#endif  // LIBGAV1_SRC_DSP_CPU_H_
diff --git a/libgav1/src/dsp/distance_weighted_blend.cc b/libgav1/src/dsp/distance_weighted_blend.cc
index 49326a7..a035fbe 100644
--- a/libgav1/src/dsp/distance_weighted_blend.cc
+++ b/libgav1/src/dsp/distance_weighted_blend.cc
@@ -17,6 +17,7 @@
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
+#include <type_traits>
 
 #include "src/dsp/dsp.h"
 #include "src/utils/common.h"
@@ -26,20 +27,17 @@
 namespace {
 
 template <int bitdepth, typename Pixel>
-void DistanceWeightedBlend_C(const uint16_t* prediction_0,
-                             const ptrdiff_t prediction_stride_0,
-                             const uint16_t* prediction_1,
-                             const ptrdiff_t prediction_stride_1,
+void DistanceWeightedBlend_C(const void* prediction_0, const void* prediction_1,
                              const uint8_t weight_0, const uint8_t weight_1,
                              const int width, const int height,
                              void* const dest, const ptrdiff_t dest_stride) {
-  // An offset to cancel offsets used in compound predictor generation that
-  // make intermediate computations non negative.
-  constexpr int compound_round_offset =
-      (16 << (bitdepth + 4)) + (16 << (bitdepth + 3));
   // 7.11.3.2 Rounding variables derivation process
   //   2 * FILTER_BITS(7) - (InterRound0(3|5) + InterRound1(7))
   constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
+  using PredType =
+      typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type;
+  const auto* pred_0 = static_cast<const PredType*>(prediction_0);
+  const auto* pred_1 = static_cast<const PredType*>(prediction_1);
   auto* dst = static_cast<Pixel*>(dest);
   const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel);
 
@@ -47,18 +45,18 @@
   do {
     int x = 0;
     do {
-      // prediction range: 8bpp: [0, 15471] 10bpp: [0, 61983] 12bpp: [0, 62007].
+      // See warp.cc and convolve.cc for detailed prediction ranges.
       // weight_0 + weight_1 = 16.
-      int res = prediction_0[x] * weight_0 + prediction_1[x] * weight_1;
-      res -= compound_round_offset;
+      int res = pred_0[x] * weight_0 + pred_1[x] * weight_1;
+      res -= (bitdepth == 8) ? 0 : kCompoundOffset * 16;
       dst[x] = static_cast<Pixel>(
           Clip3(RightShiftWithRounding(res, inter_post_round_bits + 4), 0,
                 (1 << bitdepth) - 1));
     } while (++x < width);
 
     dst += dst_stride;
-    prediction_0 += prediction_stride_0;
-    prediction_1 += prediction_stride_1;
+    pred_0 += width;
+    pred_1 += width;
   } while (++y < height);
 }
 
diff --git a/libgav1/src/dsp/dsp.cc b/libgav1/src/dsp/dsp.cc
index 381c047..c1df276 100644
--- a/libgav1/src/dsp/dsp.cc
+++ b/libgav1/src/dsp/dsp.cc
@@ -16,10 +16,10 @@
 
 #include <mutex>  // NOLINT (unapproved c++11 header)
 
+#include "src/dsp/arm/weight_mask_neon.h"
 #include "src/dsp/average_blend.h"
 #include "src/dsp/cdef.h"
 #include "src/dsp/convolve.h"
-#include "src/dsp/cpu.h"
 #include "src/dsp/distance_weighted_blend.h"
 #include "src/dsp/film_grain.h"
 #include "src/dsp/intra_edge.h"
@@ -28,8 +28,13 @@
 #include "src/dsp/loop_filter.h"
 #include "src/dsp/loop_restoration.h"
 #include "src/dsp/mask_blend.h"
+#include "src/dsp/motion_field_projection.h"
+#include "src/dsp/motion_vector_search.h"
 #include "src/dsp/obmc.h"
+#include "src/dsp/super_res.h"
 #include "src/dsp/warp.h"
+#include "src/dsp/weight_mask.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp_internal {
@@ -68,12 +73,17 @@
     LoopFilterInit_C();
     LoopRestorationInit_C();
     MaskBlendInit_C();
+    MotionFieldProjectionInit_C();
+    MotionVectorSearchInit_C();
     ObmcInit_C();
+    SuperResInit_C();
     WarpInit_C();
+    WeightMaskInit_C();
 #if LIBGAV1_ENABLE_SSE4_1
     const uint32_t cpu_features = GetCpuInfo();
     if ((cpu_features & kSSE4_1) != 0) {
       AverageBlendInit_SSE4_1();
+      CdefInit_SSE4_1();
       ConvolveInit_SSE4_1();
       DistanceWeightedBlendInit_SSE4_1();
       IntraEdgeInit_SSE4_1();
@@ -83,13 +93,21 @@
       InverseTransformInit_SSE4_1();
       LoopFilterInit_SSE4_1();
       LoopRestorationInit_SSE4_1();
+      MaskBlendInit_SSE4_1();
+      MotionFieldProjectionInit_SSE4_1();
+      MotionVectorSearchInit_SSE4_1();
       ObmcInit_SSE4_1();
+      SuperResInit_SSE4_1();
+      WarpInit_SSE4_1();
+      WeightMaskInit_SSE4_1();
     }
 #endif  // LIBGAV1_ENABLE_SSE4_1
 #if LIBGAV1_ENABLE_NEON
     AverageBlendInit_NEON();
+    CdefInit_NEON();
     ConvolveInit_NEON();
     DistanceWeightedBlendInit_NEON();
+    FilmGrainInit_NEON();
     IntraEdgeInit_NEON();
     IntraPredCflInit_NEON();
     IntraPredDirectionalInit_NEON();
@@ -100,8 +118,12 @@
     LoopFilterInit_NEON();
     LoopRestorationInit_NEON();
     MaskBlendInit_NEON();
+    MotionFieldProjectionInit_NEON();
+    MotionVectorSearchInit_NEON();
     ObmcInit_NEON();
+    SuperResInit_NEON();
     WarpInit_NEON();
+    WeightMaskInit_NEON();
 #endif  // LIBGAV1_ENABLE_NEON
   });
 }
diff --git a/libgav1/src/dsp/dsp.h b/libgav1/src/dsp/dsp.h
index b9f20de..1fa1560 100644
--- a/libgav1/src/dsp/dsp.h
+++ b/libgav1/src/dsp/dsp.h
@@ -23,7 +23,10 @@
 
 #include "src/dsp/common.h"
 #include "src/dsp/constants.h"
-#include "src/dsp/cpu.h"
+#include "src/dsp/film_grain_common.h"
+#include "src/utils/cpu.h"
+#include "src/utils/reference_info.h"
+#include "src/utils/types.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -32,29 +35,6 @@
 #define LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS 0
 #endif
 
-#if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))
-#define LIBGAV1_X86_MSVC
-#endif
-
-#if !defined(LIBGAV1_ENABLE_SSE4_1)
-#if defined(__SSE4_1__) || defined(LIBGAV1_X86_MSVC)
-#define LIBGAV1_ENABLE_SSE4_1 1
-#else
-#define LIBGAV1_ENABLE_SSE4_1 0
-#endif
-#endif  // !defined(LIBGAV1_ENABLE_SSE4_1)
-
-#undef LIBGAV1_X86_MSVC
-
-#if !defined(LIBGAV1_ENABLE_NEON)
-#if defined(__ARM_NEON__) || defined(__aarch64__) || \
-    (defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64)))
-#define LIBGAV1_ENABLE_NEON 1
-#else
-#define LIBGAV1_ENABLE_NEON 0
-#endif
-#endif  // !defined(LIBGAV1_ENABLE_NEON)
-
 enum IntraPredictor : uint8_t {
   kIntraPredictorDcFill,
   kIntraPredictorDcTop,
@@ -318,7 +298,7 @@
 // 7.13.3).
 // Apply the inverse transforms and add the residual to the destination frame
 // for the transform type and block size |tx_size| starting at position
-// |start_x| and |start_y|.  |dst_frame| is a pointer to an Array2d. |is_row|
+// |start_x| and |start_y|.  |dst_frame| is a pointer to an Array2D. |is_row|
 // signals the direction of the transform loop. |non_zero_coeff_count| is the
 // number of non zero coefficients in the block.
 using InverseTransformAddFunc = void (*)(TransformType tx_type,
@@ -348,28 +328,42 @@
                                    int* direction, int* variance);
 
 // Cdef filtering function signature. Section 7.15.3.
-// |source| is a pointer to the input block. |source_stride| is given in bytes.
-// |rows4x4| and |columns4x4| are frame sizes in units of 4x4 pixels.
-// |curr_x| and |curr_y| are current position in units of pixels.
-// |subsampling_x|, |subsampling_y| are the subsampling factors of current
-// plane.
+// |source| is a pointer to the input block padded with kCdefLargeValue if at a
+// frame border. |source_stride| is given in units of uint16_t.
+// |block_width|, |block_height| are the width/height of the input block.
 // |primary_strength|, |secondary_strength|, and |damping| are Cdef filtering
 // parameters.
 // |direction| is the filtering direction.
 // |dest| is the output buffer. |dest_stride| is given in bytes.
-using CdefFilteringFunc = void (*)(const void* source, ptrdiff_t source_stride,
-                                   int rows4x4, int columns4x4, int curr_x,
-                                   int curr_y, int subsampling_x,
-                                   int subsampling_y, int primary_strength,
-                                   int secondary_strength, int damping,
-                                   int direction, void* dest,
+using CdefFilteringFunc = void (*)(const uint16_t* source,
+                                   ptrdiff_t source_stride, int block_height,
+                                   int primary_strength, int secondary_strength,
+                                   int damping, int direction, void* dest,
                                    ptrdiff_t dest_stride);
 
+// The first index is block width: [0]: 4, [1]: 8. The second is based on
+// non-zero strengths: [0]: |primary_strength| and |secondary_strength|, [1]:
+// |primary_strength| only, [2]: |secondary_strength| only.
+using CdefFilteringFuncs = CdefFilteringFunc[2][3];
+
+// Upscaling process function signature. Section 7.16.
+// Operates on a single row.
+// |source| is the input frame buffer at the given row.
+// |dest| is the output row.
+// |upscaled_width| is the width of the output frame.
+// |step| is the number of subpixels to move the kernel for the next destination
+// pixel.
+// |initial_subpixel_x| is a base offset from which |step| increments.
+using SuperResRowFunc = void (*)(const void* source, const int upscaled_width,
+                                 const int initial_subpixel_x, const int step,
+                                 void* const dest);
+
 // Loop restoration function signature. Sections 7.16, 7.17.
 // |source| is the input frame buffer, which is deblocked and cdef filtered.
 // |dest| is the output.
 // |restoration_info| contains loop restoration information, such as filter
-// type, strength. |source| and |dest| share the same stride given in bytes.
+// type, strength.
+// |source_stride| and |dest_stride| are given in pixels.
 // |buffer| contains buffers required for self guided filter and wiener filter.
 // They must be initialized before calling.
 using LoopRestorationFunc = void (*)(
@@ -389,39 +383,81 @@
 // |vertical_filter_index|/|horizontal_filter_index| is the index to
 // retrieve the type of filter to be applied for vertical/horizontal direction
 // from the filter lookup table 'kSubPixelFilters'.
-// |inter_round_bits_vertical| is the rounding precision used after vertical
-// filtering (7 or 11). kInterRoundBitsHorizontal &
-// kInterRoundBitsHorizontal12bpp can be used after the horizontal pass.
+// |subpixel_x| and |subpixel_y| are starting positions in units of 1/1024.
+// |width| and |height| are width and height of the block to be filtered.
+// |ref_last_x| and |ref_last_y| are the last pixel of the reference frame in
+// x/y direction.
+// |prediction| is the output block (output frame buffer).
+// Rounding precision is derived from the function being called. For horizontal
+// filtering kInterRoundBitsHorizontal & kInterRoundBitsHorizontal12bpp will be
+// used. For compound vertical filtering kInterRoundBitsCompoundVertical will be
+// used. Otherwise kInterRoundBitsVertical & kInterRoundBitsVertical12bpp will
+// be used.
+using ConvolveFunc = void (*)(const void* reference, ptrdiff_t reference_stride,
+                              int horizontal_filter_index,
+                              int vertical_filter_index, int subpixel_x,
+                              int subpixel_y, int width, int height,
+                              void* prediction, ptrdiff_t pred_stride);
+
+// Convolve functions signature. Each points to one convolve function with
+// a specific setting:
+// ConvolveFunc[is_intra_block_copy][is_compound][has_vertical_filter]
+// [has_horizontal_filter].
+// If is_compound is false, the prediction is clipped to Pixel.
+// If is_compound is true, the range of prediction is:
+//   8bpp:  [-5132,  9212] (int16_t)
+//   10bpp: [ 3988, 61532] (uint16_t)
+//   12bpp: [ 3974, 61559] (uint16_t)
+// See src/dsp/convolve.cc
+using ConvolveFuncs = ConvolveFunc[2][2][2][2];
+
+// Convolve + scale function signature. Section 7.11.3.4.
+// This function applies a horizontal filter followed by a vertical filter.
+// |reference| is the input block (reference frame buffer). |reference_stride|
+// is the corresponding frame stride.
+// |vertical_filter_index|/|horizontal_filter_index| is the index to
+// retrieve the type of filter to be applied for vertical/horizontal direction
+// from the filter lookup table 'kSubPixelFilters'.
 // |subpixel_x| and |subpixel_y| are starting positions in units of 1/1024.
 // |step_x| and |step_y| are step sizes in units of 1/1024 of a pixel.
 // |width| and |height| are width and height of the block to be filtered.
 // |ref_last_x| and |ref_last_y| are the last pixel of the reference frame in
 // x/y direction.
 // |prediction| is the output block (output frame buffer).
-using ConvolveFunc = void (*)(const void* reference, ptrdiff_t reference_stride,
-                              int vertical_filter_index,
-                              int horizontal_filter_index,
-                              int inter_round_bits_vertical, int subpixel_x,
-                              int subpixel_y, int step_x, int step_y, int width,
-                              int height, void* prediction,
-                              ptrdiff_t pred_stride);
-
-// Convolve functions signature. Each points to one convolve function with
-// a specific setting:
-// ConvolveFunc[is_intra_block_copy][is_compound][has_vertical_filter]
-// [has_horizontal_filter].
-// If is_compound is false, the prediction is clipped to pixel.
-// If is_compound is true, the range of prediction is:
-//   8bpp: [0, 15471]
-//   10bpp: [0, 61983]
-//   12bpp: [0, 62007]
-// See:
-// https://docs.google.com/document/d/1f5YlLk02ETNxpilvsmjBtWgDXjtZYO33hjl6bAdvmxc
-using ConvolveFuncs = ConvolveFunc[2][2][2][2];
+// Rounding precision is derived from the function being called. For horizontal
+// filtering kInterRoundBitsHorizontal & kInterRoundBitsHorizontal12bpp will be
+// used. For compound vertical filtering kInterRoundBitsCompoundVertical will be
+// used. Otherwise kInterRoundBitsVertical & kInterRoundBitsVertical12bpp will
+// be used.
+using ConvolveScaleFunc = void (*)(const void* reference,
+                                   ptrdiff_t reference_stride,
+                                   int horizontal_filter_index,
+                                   int vertical_filter_index, int subpixel_x,
+                                   int subpixel_y, int step_x, int step_y,
+                                   int width, int height, void* prediction,
+                                   ptrdiff_t pred_stride);
 
 // Convolve functions signature for scaling version.
 // 0: single predictor. 1: compound predictor.
-using ConvolveScaleFuncs = ConvolveFunc[2];
+using ConvolveScaleFuncs = ConvolveScaleFunc[2];
+
+// Weight mask function signature. Section 7.11.3.12.
+// |prediction_0| is the first input block.
+// |prediction_1| is the second input block. Both blocks are int16_t* when
+// bitdepth == 8 and uint16_t* otherwise.
+// |width| and |height| are the prediction width and height.
+// The stride for the input buffers is equal to |width|.
+// The valid range of block size is [8x8, 128x128] for the luma plane.
+// |mask| is the output buffer. |mask_stride| is the output buffer stride.
+using WeightMaskFunc = void (*)(const void* prediction_0,
+                                const void* prediction_1, uint8_t* mask,
+                                ptrdiff_t mask_stride);
+
+// Weight mask functions signature. The dimensions (in order) are:
+//   * Width index (4 => 0, 8 => 1, 16 => 2 and so on).
+//   * Height index (4 => 0, 8 => 1, 16 => 2 and so on).
+//   * mask_is_inverse.
+using WeightMaskFuncs = WeightMaskFunc[6][6][2];
 
 // Average blending function signature.
 // Two predictors are averaged to generate the output.
@@ -429,15 +465,14 @@
 // range of Pixel value.
 // Average blending is in the bottom of Section 7.11.3.1 (COMPOUND_AVERAGE).
 // |prediction_0| is the first input block.
-// |prediction_1| is the second input block.
-// |prediction_stride_0| and |prediction_stride_1| are corresponding strides.
+// |prediction_1| is the second input block. Both blocks are int16_t* when
+// bitdepth == 8 and uint16_t* otherwise.
 // |width| and |height| are the same for the first and second input blocks.
+// The stride for the input buffers is equal to |width|.
 // The valid range of block size is [8x8, 128x128] for the luma plane.
 // |dest| is the output buffer. |dest_stride| is the output buffer stride.
-using AverageBlendFunc = void (*)(const uint16_t* prediction_0,
-                                  ptrdiff_t prediction_stride_0,
-                                  const uint16_t* prediction_1,
-                                  ptrdiff_t prediction_stride_1, int width,
+using AverageBlendFunc = void (*)(const void* prediction_0,
+                                  const void* prediction_1, int width,
                                   int height, void* dest,
                                   ptrdiff_t dest_stride);
 
@@ -447,19 +482,18 @@
 // This function takes two blocks (inter frame prediction) and produces a
 // weighted output.
 // |prediction_0| is the first input block.
-// |prediction_1| is the second input block.
-// |prediction_stride_0| and |prediction_stride_1| are corresponding strides.
+// |prediction_1| is the second input block. Both blocks are int16_t* when
+// bitdepth == 8 and uint16_t* otherwise.
 // |weight_0| is the weight for the first block. It is derived from the relative
 // distance of the first reference frame and the current frame.
 // |weight_1| is the weight for the second block. It is derived from the
 // relative distance of the second reference frame and the current frame.
 // |width| and |height| are the same for the first and second input blocks.
+// The stride for the input buffers is equal to |width|.
 // The valid range of block size is [8x8, 128x128] for the luma plane.
 // |dest| is the output buffer. |dest_stride| is the output buffer stride.
-using DistanceWeightedBlendFunc = void (*)(const uint16_t* prediction_0,
-                                           ptrdiff_t prediction_stride_0,
-                                           const uint16_t* prediction_1,
-                                           ptrdiff_t prediction_stride_1,
+using DistanceWeightedBlendFunc = void (*)(const void* prediction_0,
+                                           const void* prediction_1,
                                            uint8_t weight_0, uint8_t weight_1,
                                            int width, int height, void* dest,
                                            ptrdiff_t dest_stride);
@@ -469,11 +503,16 @@
 // output block |dest|. The blending is a weighted average process, controlled
 // by values of the mask.
 // |prediction_0| is the first input block. When prediction mode is inter_intra
-// (or wedge_inter_intra), this refers to the inter frame prediction.
-// |prediction_stride_0| is the stride, given in units of uint16_t.
+// (or wedge_inter_intra), this refers to the inter frame prediction. It is
+// int16_t* when bitdepth == 8 and uint16_t* otherwise.
+// The stride for |prediction_0| is equal to |width|.
 // |prediction_1| is the second input block. When prediction mode is inter_intra
-// (or wedge_inter_intra), this refers to the intra frame prediction.
-// |prediction_stride_1| is the stride, given in units of uint16_t.
+// (or wedge_inter_intra), this refers to the intra frame prediction and uses
+// Pixel values. It is only used for intra frame prediction when bitdepth >= 10.
+// It is int16_t* when bitdepth == 8 and uint16_t* otherwise.
+// |prediction_stride_1| is the stride, given in units of [u]int16_t. When
+// |is_inter_intra| is false (compound prediction) then |prediction_stride_1| is
+// equal to |width|.
 // |mask| is an integer array, whose value indicates the weight of the blending.
 // |mask_stride| is corresponding stride.
 // |width|, |height| are the same for both input blocks.
@@ -489,9 +528,8 @@
 // |is_wedge_inter_intra| indicates if the mask is for the wedge prediction.
 // |dest| is the output block.
 // |dest_stride| is the corresponding stride for dest.
-using MaskBlendFunc = void (*)(const uint16_t* prediction_0,
-                               ptrdiff_t prediction_stride_0,
-                               const uint16_t* prediction_1,
+using MaskBlendFunc = void (*)(const void* prediction_0,
+                               const void* prediction_1,
                                ptrdiff_t prediction_stride_1,
                                const uint8_t* mask, ptrdiff_t mask_stride,
                                int width, int height, void* dest,
@@ -502,6 +540,22 @@
 // MaskBlendFunc[subsampling_x + subsampling_y][is_inter_intra].
 using MaskBlendFuncs = MaskBlendFunc[3][2];
 
+// This function is similar to the MaskBlendFunc. It is only used when
+// |is_inter_intra| is true and |bitdepth| == 8.
+// |prediction_[01]| are Pixel values (uint8_t).
+// |prediction_1| is also the output buffer.
+using InterIntraMaskBlendFunc8bpp = void (*)(const uint8_t* prediction_0,
+                                             uint8_t* prediction_1,
+                                             ptrdiff_t prediction_stride_1,
+                                             const uint8_t* mask,
+                                             ptrdiff_t mask_stride, int width,
+                                             int height);
+
+// InterIntra8bpp mask blending functions signature. When is_wedge_inter_intra
+// is false, the function at index 0 must be used. Otherwise, the function at
+// index subsampling_x + subsampling_y must be used.
+using InterIntraMaskBlendFuncs8bpp = InterIntraMaskBlendFunc8bpp[3];
+
 // Obmc (overlapped block motion compensation) blending function signature.
 // Section 7.11.3.10.
 // This function takes two blocks and produces a blended output stored into the
@@ -536,9 +590,6 @@
 //     z .  y'  =   m4 m5 m1 *  y
 //          1]      m6 m7 1)    1]
 // |subsampling_x/y| is the current frame's plane subsampling factor.
-// |inter_round_bits_vertical| is the rounding precision used after vertical
-// filtering (7 or 11). kInterRoundBitsHorizontal &
-// kInterRoundBitsHorizontal12bpp can be used for the horizontal pass.
 // |block_start_x| and |block_start_y| are the starting position the current
 // coding block.
 // |block_width| and |block_height| are width and height of the current coding
@@ -546,79 +597,233 @@
 // |alpha|, |beta|, |gamma|, |delta| are valid warp parameters. See the
 // comments in the definition of struct GlobalMotion for the range of their
 // values.
-// |dest| is the output buffer. It is a predictor, whose type is int16_t.
-// |dest_stride| is the stride, in units of int16_t.
+// |dest| is the output buffer of type Pixel. The output values are clipped to
+// Pixel values.
+// |dest_stride| is the stride, in units of bytes.
+// Rounding precision is derived from the function being called. For horizontal
+// filtering kInterRoundBitsHorizontal & kInterRoundBitsHorizontal12bpp will be
+// used. For vertical filtering kInterRoundBitsVertical &
+// kInterRoundBitsVertical12bpp will be used.
 //
-// NOTE: WarpFunc assumes the source frame has left and right borders that
-// extend the frame boundary pixels. The left and right borders must be at
-// least 13 pixels wide. In addition, Warp_NEON() may read up to 14 bytes after
-// a row in the |source| buffer. Therefore, there must be at least one extra
-// padding byte after the right border of the last row in the source buffer.
+// NOTE: WarpFunc assumes the source frame has left, right, top, and bottom
+// borders that extend the frame boundary pixels.
+// * The left and right borders must be at least 13 pixels wide. In addition,
+//   Warp_NEON() may read up to 14 bytes after a row in the |source| buffer.
+//   Therefore, there must be at least one extra padding byte after the right
+//   border of the last row in the source buffer.
+// * The top and bottom borders must be at least 13 pixels high.
 using WarpFunc = void (*)(const void* source, ptrdiff_t source_stride,
                           int source_width, int source_height,
                           const int* warp_params, int subsampling_x,
-                          int subsampling_y, int inter_round_bits_vertical,
-                          int block_start_x, int block_start_y, int block_width,
-                          int block_height, int16_t alpha, int16_t beta,
-                          int16_t gamma, int16_t delta, uint16_t* dest,
-                          ptrdiff_t dest_stride);
+                          int subsampling_y, int block_start_x,
+                          int block_start_y, int block_width, int block_height,
+                          int16_t alpha, int16_t beta, int16_t gamma,
+                          int16_t delta, void* dest, ptrdiff_t dest_stride);
 
-// Film grain synthesis function signature. Section 7.18.3.
-// This function generates film grain noise and blends the noise with the
-// decoded frame.
-// |source_plane_y|, |source_plane_u|, and |source_plane_v| are the plane
-// buffers of the decoded frame. They are blended with the film grain noise and
-// written to |dest_plane_y|, |dest_plane_u|, and |dest_plane_v| as final
-// output for display. |source_plane_p| and |dest_plane_p| (where p is y, u, or
-// v) may point to the same buffer, in which case the film grain noise is added
-// in place.
-// |film_grain_params| are parameters read from frame header.
-// |is_monochrome| is true indicates only Y plane needs to be processed.
-// |color_matrix_is_identity| is true if the matrix_coefficients field in the
-// sequence header's color config is is MC_IDENTITY.
-// |width| is the upscaled width of the frame.
-// |height| is the frame height.
-// |subsampling_x| and |subsampling_y| are subsamplings for UV planes, not used
-// if |is_monochrome| is true.
-// Returns true on success, or false on failure (e.g., out of memory).
-using FilmGrainSynthesisFunc = bool (*)(
+// Warp for compound predictions. Section 7.11.3.5.
+// Similar to WarpFunc, but |dest| is a uint16_t predictor buffer,
+// |dest_stride| is given in units of uint16_t and |inter_round_bits_vertical|
+// is always 7 (kCompoundInterRoundBitsVertical).
+// Rounding precision is derived from the function being called. For horizontal
+// filtering kInterRoundBitsHorizontal & kInterRoundBitsHorizontal12bpp will be
+// used. For vertical filtering kInterRoundBitsCompondVertical will be used.
+using WarpCompoundFunc = WarpFunc;
+
+constexpr int kNumAutoRegressionLags = 4;
+// Applies an auto-regressive filter to the white noise in |luma_grain_buffer|.
+// Section 7.18.3.3, second code block
+// |params| are parameters read from frame header, mainly providing
+// auto_regression_coeff_y for the filter and auto_regression_shift to right
+// shift the filter sum by. Note: This method assumes
+// params.auto_regression_coeff_lag is not 0. Do not call this method if
+// params.auto_regression_coeff_lag is 0.
+using LumaAutoRegressionFunc = void (*)(const FilmGrainParams& params,
+                                        void* luma_grain_buffer);
+// Function index is auto_regression_coeff_lag - 1.
+using LumaAutoRegressionFuncs =
+    LumaAutoRegressionFunc[kNumAutoRegressionLags - 1];
+
+// Applies an auto-regressive filter to the white noise in u_grain and v_grain.
+// Section 7.18.3.3, third code block
+// The |luma_grain_buffer| provides samples that are added to the autoregressive
+// sum when num_y_points > 0.
+// |u_grain_buffer| and |v_grain_buffer| point to the buffers of chroma noise
+// that were generated from the stored Gaussian sequence, and are overwritten
+// with the results of the autoregressive filter. |params| are parameters read
+// from frame header, mainly providing auto_regression_coeff_u and
+// auto_regression_coeff_v for each chroma plane's filter, and
+// auto_regression_shift to right shift the filter sums by.
+using ChromaAutoRegressionFunc = void (*)(const FilmGrainParams& params,
+                                          const void* luma_grain_buffer,
+                                          int subsampling_x, int subsampling_y,
+                                          void* u_grain_buffer,
+                                          void* v_grain_buffer);
+using ChromaAutoRegressionFuncs =
+    ChromaAutoRegressionFunc[/*use_luma*/ 2][kNumAutoRegressionLags];
+
+// Build an image-wide "stripe" of grain noise for every 32 rows in the image.
+// Section 7.18.3.5, first code block.
+// Each 32x32 luma block is copied at a random offset specified via
+// |grain_seed| from the grain template produced by autoregression, and the same
+// is done for chroma grains, subject to subsampling.
+// |width| and |height| are the dimensions of the overall image.
+// |noise_stripes_buffer| points to an Array2DView with one row for each stripe.
+// Because this function treats all planes identically and independently, it is
+// simplified to take one grain buffer at a time. This means duplicating some
+// random number generations, but that work can be reduced in other ways.
+using ConstructNoiseStripesFunc = void (*)(const void* grain_buffer,
+                                           int grain_seed, int width,
+                                           int height, int subsampling_x,
+                                           int subsampling_y,
+                                           void* noise_stripes_buffer);
+using ConstructNoiseStripesFuncs =
+    ConstructNoiseStripesFunc[/*overlap_flag*/ 2];
+
+// Compute the one or two overlap rows for each stripe copied to the noise
+// image.
+// Section 7.18.3.5, second code block. |width| and |height| are the
+// dimensions of the overall image. |noise_stripes_buffer| points to an
+// Array2DView with one row for each stripe. |noise_image_buffer| points to an
+// Array2D containing the allocated plane for this frame. Because this function
+// treats all planes identically and independently, it is simplified to take one
+// grain buffer at a time.
+using ConstructNoiseImageOverlapFunc =
+    void (*)(const void* noise_stripes_buffer, int width, int height,
+             int subsampling_x, int subsampling_y, void* noise_image_buffer);
+
+// Populate a scaling lookup table with interpolated values of a piecewise
+// linear function where values in |point_value| are mapped to the values in
+// |point_scaling|.
+// |num_points| can be between 0 and 15. When 0, the lookup table is set to
+// zero.
+// |point_value| and |point_scaling| have |num_points| valid elements.
+using InitializeScalingLutFunc = void (*)(
+    int num_points, const uint8_t point_value[], const uint8_t point_scaling[],
+    uint8_t scaling_lut[kScalingLookupTableSize]);
+
+// Blend noise with image. Section 7.18.3.5, third code block.
+// |width| is the width of each row, while |height| is how many rows to compute.
+// |start_height| is an offset for the noise image, to support multithreading.
+// |min_value|, |max_luma|, and |max_chroma| are computed by the caller of these
+// functions, according to the code in the spec.
+// |source_plane_y| and |source_plane_uv| are the plane buffers of the decoded
+// frame. They are blended with the film grain noise and written to
+// |dest_plane_y| and |dest_plane_uv| as final output for display.
+// source_plane_* and dest_plane_* may point to the same buffer, in which case
+// the film grain noise is added in place.
+// |scaling_lut_y|  and |scaling_lut| represent a piecewise linear mapping from
+// the frame's raw pixel value, to a scaling factor for the noise sample.
+// |scaling_shift| is applied as a right shift after scaling, so that scaling
+// down is possible. It is found in FilmGrainParams, but supplied directly to
+// BlendNoiseWithImageLumaFunc because it's the only member used.
+using BlendNoiseWithImageLumaFunc =
+    void (*)(const void* noise_image_ptr, int min_value, int max_value,
+             int scaling_shift, int width, int height, int start_height,
+             const uint8_t scaling_lut_y[kScalingLookupTableSize],
+             const void* source_plane_y, ptrdiff_t source_stride_y,
+             void* dest_plane_y, ptrdiff_t dest_stride_y);
+
+using BlendNoiseWithImageChromaFunc = void (*)(
+    Plane plane, const FilmGrainParams& params, const void* noise_image_ptr,
+    int min_value, int max_value, int width, int height, int start_height,
+    int subsampling_x, int subsampling_y,
+    const uint8_t scaling_lut[kScalingLookupTableSize],
     const void* source_plane_y, ptrdiff_t source_stride_y,
-    const void* source_plane_u, ptrdiff_t source_stride_u,
-    const void* source_plane_v, ptrdiff_t source_stride_v,
-    const FilmGrainParams& film_grain_params, bool is_monochrome,
-    bool color_matrix_is_identity, int width, int height, int subsampling_x,
-    int subsampling_y, void* dest_plane_y, ptrdiff_t dest_stride_y,
-    void* dest_plane_u, ptrdiff_t dest_stride_u, void* dest_plane_v,
-    ptrdiff_t dest_stride_v);
+    const void* source_plane_uv, ptrdiff_t source_stride_uv,
+    void* dest_plane_uv, ptrdiff_t dest_stride_uv);
+
+using BlendNoiseWithImageChromaFuncs =
+    BlendNoiseWithImageChromaFunc[/*chroma_scaling_from_luma*/ 2];
+
 //------------------------------------------------------------------------------
 
+struct FilmGrainFuncs {
+  LumaAutoRegressionFuncs luma_auto_regression;
+  ChromaAutoRegressionFuncs chroma_auto_regression;
+  ConstructNoiseStripesFuncs construct_noise_stripes;
+  ConstructNoiseImageOverlapFunc construct_noise_image_overlap;
+  InitializeScalingLutFunc initialize_scaling_lut;
+  BlendNoiseWithImageLumaFunc blend_noise_luma;
+  BlendNoiseWithImageChromaFuncs blend_noise_chroma;
+};
+
+// Motion field projection function signature. Section 7.9.
+// |reference_info| provides reference information for motion field projection.
+// |reference_to_current_with_sign| is the precalculated reference frame id
+// distance from current frame.
+// |dst_sign| is -1 for LAST_FRAME and LAST2_FRAME, or 0 (1 in spec) for others.
+// |y8_start| and |y8_end| are the start and end 8x8 rows of the current tile.
+// |x8_start| and |x8_end| are the start and end 8x8 columns of the current
+// tile.
+// |motion_field| is the output which saves the projected motion field
+// information.
+using MotionFieldProjectionKernelFunc = void (*)(
+    const ReferenceInfo& reference_info, int reference_to_current_with_sign,
+    int dst_sign, int y8_start, int y8_end, int x8_start, int x8_end,
+    TemporalMotionField* motion_field);
+
+// Compound temporal motion vector projection function signature.
+// Section 7.9.3 and 7.10.2.10.
+// |temporal_mvs| is the set of temporal reference motion vectors.
+// |temporal_reference_offsets| specifies the number of frames covered by the
+// original motion vector.
+// |reference_offsets| specifies the number of frames to be covered by the
+// projected motion vector.
+// |count| is the number of the temporal motion vectors.
+// |candidate_mvs| is the set of projected motion vectors.
+using MvProjectionCompoundFunc = void (*)(
+    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const int reference_offsets[2], int count,
+    CompoundMotionVector* candidate_mvs);
+
+// Single temporal motion vector projection function signature.
+// Section 7.9.3 and 7.10.2.10.
+// |temporal_mvs| is the set of temporal reference motion vectors.
+// |temporal_reference_offsets| specifies the number of frames covered by the
+// original motion vector.
+// |reference_offset| specifies the number of frames to be covered by the
+// projected motion vector.
+// |count| is the number of the temporal motion vectors.
+// |candidate_mvs| is the set of projected motion vectors.
+using MvProjectionSingleFunc = void (*)(
+    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    int reference_offset, int count, MotionVector* candidate_mvs);
+
 struct Dsp {
-  IntraPredictorFuncs intra_predictors;
+  AverageBlendFunc average_blend;
+  CdefDirectionFunc cdef_direction;
+  CdefFilteringFuncs cdef_filters;
+  CflIntraPredictorFuncs cfl_intra_predictors;
+  CflSubsamplerFuncs cfl_subsamplers;
+  ConvolveFuncs convolve;
+  ConvolveScaleFuncs convolve_scale;
   DirectionalIntraPredictorZone1Func directional_intra_predictor_zone1;
   DirectionalIntraPredictorZone2Func directional_intra_predictor_zone2;
   DirectionalIntraPredictorZone3Func directional_intra_predictor_zone3;
+  DistanceWeightedBlendFunc distance_weighted_blend;
+  FilmGrainFuncs film_grain;
   FilterIntraPredictorFunc filter_intra_predictor;
-  CflIntraPredictorFuncs cfl_intra_predictors;
-  CflSubsamplerFuncs cfl_subsamplers;
+  InterIntraMaskBlendFuncs8bpp inter_intra_mask_blend_8bpp;
   IntraEdgeFilterFunc intra_edge_filter;
   IntraEdgeUpsamplerFunc intra_edge_upsampler;
+  IntraPredictorFuncs intra_predictors;
   InverseTransformAddFuncs inverse_transforms;
   LoopFilterFuncs loop_filters;
-  CdefDirectionFunc cdef_direction;
-  CdefFilteringFunc cdef_filter;
   LoopRestorationFuncs loop_restorations;
-  ConvolveFuncs convolve;
-  ConvolveScaleFuncs convolve_scale;
-  AverageBlendFunc average_blend;
-  DistanceWeightedBlendFunc distance_weighted_blend;
   MaskBlendFuncs mask_blend;
+  MotionFieldProjectionKernelFunc motion_field_projection_kernel;
+  MvProjectionCompoundFunc mv_projection_compound[3];
+  MvProjectionSingleFunc mv_projection_single[3];
   ObmcBlendFuncs obmc_blend;
+  SuperResRowFunc super_res_row;
+  WarpCompoundFunc warp_compound;
   WarpFunc warp;
-  FilmGrainSynthesisFunc film_grain_synthesis;
+  WeightMaskFuncs weight_mask;
 };
 
-// Initializes function pointers based on build config and runtime environment.
-// Must be called once before first use. This function is thread-safe.
+// Initializes function pointers based on build config and runtime
+// environment. Must be called once before first use. This function is
+// thread-safe.
 void DspInit();
 
 // Returns the appropriate Dsp table for |bitdepth| or nullptr if one doesn't
@@ -645,10 +850,10 @@
 //  true and can be omitted.
 #define DSP_ENABLED_8BPP_SSE4_1(func)  \
   (LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \
-   LIBGAV1_Dsp8bpp_##func == LIBGAV1_DSP_SSE4_1)
+   LIBGAV1_Dsp8bpp_##func == LIBGAV1_CPU_SSE4_1)
 #define DSP_ENABLED_10BPP_SSE4_1(func) \
   (LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \
-   LIBGAV1_Dsp10bpp_##func == LIBGAV1_DSP_SSE4_1)
+   LIBGAV1_Dsp10bpp_##func == LIBGAV1_CPU_SSE4_1)
 
 // Returns the appropriate Dsp table for |bitdepth| or nullptr if one doesn't
 // exist. This version is meant for use by test or dsp/*Init() functions only.
diff --git a/libgav1/src/dsp/film_grain.cc b/libgav1/src/dsp/film_grain.cc
index 924ff40..2ee290b 100644
--- a/libgav1/src/dsp/film_grain.cc
+++ b/libgav1/src/dsp/film_grain.cc
@@ -22,513 +22,32 @@
 #include <new>
 
 #include "src/dsp/common.h"
+#include "src/dsp/constants.h"
 #include "src/dsp/dsp.h"
+#include "src/dsp/film_grain_common.h"
+#include "src/utils/array_2d.h"
 #include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
 #include "src/utils/logging.h"
 
 namespace libgav1 {
 namespace dsp {
+namespace film_grain {
 namespace {
 
-// The kGaussianSequence array contains random samples from a Gaussian
-// distribution with zero mean and standard deviation of about 512 clipped to
-// the range of [-2048, 2047] (representable by a signed integer using 12 bits
-// of precision) and rounded to the nearest multiple of 4.
-//
-// Note: It is important that every element in the kGaussianSequence array be
-// less than 2040, so that RightShiftWithRounding(kGaussianSequence[i], 4) is
-// less than 128 for bitdepth=8 (GrainType=int8_t).
-constexpr int16_t kGaussianSequence[/*2048*/] = {
-    56,    568,   -180,  172,   124,   -84,   172,   -64,   -900,  24,   820,
-    224,   1248,  996,   272,   -8,    -916,  -388,  -732,  -104,  -188, 800,
-    112,   -652,  -320,  -376,  140,   -252,  492,   -168,  44,    -788, 588,
-    -584,  500,   -228,  12,    680,   272,   -476,  972,   -100,  652,  368,
-    432,   -196,  -720,  -192,  1000,  -332,  652,   -136,  -552,  -604, -4,
-    192,   -220,  -136,  1000,  -52,   372,   -96,   -624,  124,   -24,  396,
-    540,   -12,   -104,  640,   464,   244,   -208,  -84,   368,   -528, -740,
-    248,   -968,  -848,  608,   376,   -60,   -292,  -40,   -156,  252,  -292,
-    248,   224,   -280,  400,   -244,  244,   -60,   76,    -80,   212,  532,
-    340,   128,   -36,   824,   -352,  -60,   -264,  -96,   -612,  416,  -704,
-    220,   -204,  640,   -160,  1220,  -408,  900,   336,   20,    -336, -96,
-    -792,  304,   48,    -28,   -1232, -1172, -448,  104,   -292,  -520, 244,
-    60,    -948,  0,     -708,  268,   108,   356,   -548,  488,   -344, -136,
-    488,   -196,  -224,  656,   -236,  -1128, 60,    4,     140,   276,  -676,
-    -376,  168,   -108,  464,   8,     564,   64,    240,   308,   -300, -400,
-    -456,  -136,  56,    120,   -408,  -116,  436,   504,   -232,  328,  844,
-    -164,  -84,   784,   -168,  232,   -224,  348,   -376,  128,   568,  96,
-    -1244, -288,  276,   848,   832,   -360,  656,   464,   -384,  -332, -356,
-    728,   -388,  160,   -192,  468,   296,   224,   140,   -776,  -100, 280,
-    4,     196,   44,    -36,   -648,  932,   16,    1428,  28,    528,  808,
-    772,   20,    268,   88,    -332,  -284,  124,   -384,  -448,  208,  -228,
-    -1044, -328,  660,   380,   -148,  -300,  588,   240,   540,   28,   136,
-    -88,   -436,  256,   296,   -1000, 1400,  0,     -48,   1056,  -136, 264,
-    -528,  -1108, 632,   -484,  -592,  -344,  796,   124,   -668,  -768, 388,
-    1296,  -232,  -188,  -200,  -288,  -4,    308,   100,   -168,  256,  -500,
-    204,   -508,  648,   -136,  372,   -272,  -120,  -1004, -552,  -548, -384,
-    548,   -296,  428,   -108,  -8,    -912,  -324,  -224,  -88,   -112, -220,
-    -100,  996,   -796,  548,   360,   -216,  180,   428,   -200,  -212, 148,
-    96,    148,   284,   216,   -412,  -320,  120,   -300,  -384,  -604, -572,
-    -332,  -8,    -180,  -176,  696,   116,   -88,   628,   76,    44,   -516,
-    240,   -208,  -40,   100,   -592,  344,   -308,  -452,  -228,  20,   916,
-    -1752, -136,  -340,  -804,  140,   40,    512,   340,   248,   184,  -492,
-    896,   -156,  932,   -628,  328,   -688,  -448,  -616,  -752,  -100, 560,
-    -1020, 180,   -800,  -64,   76,    576,   1068,  396,   660,   552,  -108,
-    -28,   320,   -628,  312,   -92,   -92,   -472,  268,   16,    560,  516,
-    -672,  -52,   492,   -100,  260,   384,   284,   292,   304,   -148, 88,
-    -152,  1012,  1064,  -228,  164,   -376,  -684,  592,   -392,  156,  196,
-    -524,  -64,   -884,  160,   -176,  636,   648,   404,   -396,  -436, 864,
-    424,   -728,  988,   -604,  904,   -592,  296,   -224,  536,   -176, -920,
-    436,   -48,   1176,  -884,  416,   -776,  -824,  -884,  524,   -548, -564,
-    -68,   -164,  -96,   692,   364,   -692,  -1012, -68,   260,   -480, 876,
-    -1116, 452,   -332,  -352,  892,   -1088, 1220,  -676,  12,    -292, 244,
-    496,   372,   -32,   280,   200,   112,   -440,  -96,   24,    -644, -184,
-    56,    -432,  224,   -980,  272,   -260,  144,   -436,  420,   356,  364,
-    -528,  76,    172,   -744,  -368,  404,   -752,  -416,  684,   -688, 72,
-    540,   416,   92,    444,   480,   -72,   -1416, 164,   -1172, -68,  24,
-    424,   264,   1040,  128,   -912,  -524,  -356,  64,    876,   -12,  4,
-    -88,   532,   272,   -524,  320,   276,   -508,  940,   24,    -400, -120,
-    756,   60,    236,   -412,  100,   376,   -484,  400,   -100,  -740, -108,
-    -260,  328,   -268,  224,   -200,  -416,  184,   -604,  -564,  -20,  296,
-    60,    892,   -888,  60,    164,   68,    -760,  216,   -296,  904,  -336,
-    -28,   404,   -356,  -568,  -208,  -1480, -512,  296,   328,   -360, -164,
-    -1560, -776,  1156,  -428,  164,   -504,  -112,  120,   -216,  -148, -264,
-    308,   32,    64,    -72,   72,    116,   176,   -64,   -272,  460,  -536,
-    -784,  -280,  348,   108,   -752,  -132,  524,   -540,  -776,  116,  -296,
-    -1196, -288,  -560,  1040,  -472,  116,   -848,  -1116, 116,   636,  696,
-    284,   -176,  1016,  204,   -864,  -648,  -248,  356,   972,   -584, -204,
-    264,   880,   528,   -24,   -184,  116,   448,   -144,  828,   524,  212,
-    -212,  52,    12,    200,   268,   -488,  -404,  -880,  824,   -672, -40,
-    908,   -248,  500,   716,   -576,  492,   -576,  16,    720,   -108, 384,
-    124,   344,   280,   576,   -500,  252,   104,   -308,  196,   -188, -8,
-    1268,  296,   1032,  -1196, 436,   316,   372,   -432,  -200,  -660, 704,
-    -224,  596,   -132,  268,   32,    -452,  884,   104,   -1008, 424,  -1348,
-    -280,  4,     -1168, 368,   476,   696,   300,   -8,    24,    180,  -592,
-    -196,  388,   304,   500,   724,   -160,  244,   -84,   272,   -256, -420,
-    320,   208,   -144,  -156,  156,   364,   452,   28,    540,   316,  220,
-    -644,  -248,  464,   72,    360,   32,    -388,  496,   -680,  -48,  208,
-    -116,  -408,  60,    -604,  -392,  548,   -840,  784,   -460,  656,  -544,
-    -388,  -264,  908,   -800,  -628,  -612,  -568,  572,   -220,  164,  288,
-    -16,   -308,  308,   -112,  -636,  -760,  280,   -668,  432,   364,  240,
-    -196,  604,   340,   384,   196,   592,   -44,   -500,  432,   -580, -132,
-    636,   -76,   392,   4,     -412,  540,   508,   328,   -356,  -36,  16,
-    -220,  -64,   -248,  -60,   24,    -192,  368,   1040,  92,    -24,  -1044,
-    -32,   40,    104,   148,   192,   -136,  -520,  56,    -816,  -224, 732,
-    392,   356,   212,   -80,   -424,  -1008, -324,  588,   -1496, 576,  460,
-    -816,  -848,  56,    -580,  -92,   -1372, -112,  -496,  200,   364,  52,
-    -140,  48,    -48,   -60,   84,    72,    40,    132,   -356,  -268, -104,
-    -284,  -404,  732,   -520,  164,   -304,  -540,  120,   328,   -76,  -460,
-    756,   388,   588,   236,   -436,  -72,   -176,  -404,  -316,  -148, 716,
-    -604,  404,   -72,   -88,   -888,  -68,   944,   88,    -220,  -344, 960,
-    472,   460,   -232,  704,   120,   832,   -228,  692,   -508,  132,  -476,
-    844,   -748,  -364,  -44,   1116,  -1104, -1056, 76,    428,   552,  -692,
-    60,    356,   96,    -384,  -188,  -612,  -576,  736,   508,   892,  352,
-    -1132, 504,   -24,   -352,  324,   332,   -600,  -312,  292,   508,  -144,
-    -8,    484,   48,    284,   -260,  -240,  256,   -100,  -292,  -204, -44,
-    472,   -204,  908,   -188,  -1000, -256,  92,    1164,  -392,  564,  356,
-    652,   -28,   -884,  256,   484,   -192,  760,   -176,  376,   -524, -452,
-    -436,  860,   -736,  212,   124,   504,   -476,  468,   76,    -472, 552,
-    -692,  -944,  -620,  740,   -240,  400,   132,   20,    192,   -196, 264,
-    -668,  -1012, -60,   296,   -316,  -828,  76,    -156,  284,   -768, -448,
-    -832,  148,   248,   652,   616,   1236,  288,   -328,  -400,  -124, 588,
-    220,   520,   -696,  1032,  768,   -740,  -92,   -272,  296,   448,  -464,
-    412,   -200,  392,   440,   -200,  264,   -152,  -260,  320,   1032, 216,
-    320,   -8,    -64,   156,   -1016, 1084,  1172,  536,   484,   -432, 132,
-    372,   -52,   -256,  84,    116,   -352,  48,    116,   304,   -384, 412,
-    924,   -300,  528,   628,   180,   648,   44,    -980,  -220,  1320, 48,
-    332,   748,   524,   -268,  -720,  540,   -276,  564,   -344,  -208, -196,
-    436,   896,   88,    -392,  132,   80,    -964,  -288,  568,   56,   -48,
-    -456,  888,   8,     552,   -156,  -292,  948,   288,   128,   -716, -292,
-    1192,  -152,  876,   352,   -600,  -260,  -812,  -468,  -28,   -120, -32,
-    -44,   1284,  496,   192,   464,   312,   -76,   -516,  -380,  -456, -1012,
-    -48,   308,   -156,  36,    492,   -156,  -808,  188,   1652,  68,   -120,
-    -116,  316,   160,   -140,  352,   808,   -416,  592,   316,   -480, 56,
-    528,   -204,  -568,  372,   -232,  752,   -344,  744,   -4,    324,  -416,
-    -600,  768,   268,   -248,  -88,   -132,  -420,  -432,  80,    -288, 404,
-    -316,  -1216, -588,  520,   -108,  92,    -320,  368,   -480,  -216, -92,
-    1688,  -300,  180,   1020,  -176,  820,   -68,   -228,  -260,  436,  -904,
-    20,    40,    -508,  440,   -736,  312,   332,   204,   760,   -372, 728,
-    96,    -20,   -632,  -520,  -560,  336,   1076,  -64,   -532,  776,  584,
-    192,   396,   -728,  -520,  276,   -188,  80,    -52,   -612,  -252, -48,
-    648,   212,   -688,  228,   -52,   -260,  428,   -412,  -272,  -404, 180,
-    816,   -796,  48,    152,   484,   -88,   -216,  988,   696,   188,  -528,
-    648,   -116,  -180,  316,   476,   12,    -564,  96,    476,   -252, -364,
-    -376,  -392,  556,   -256,  -576,  260,   -352,  120,   -16,   -136, -260,
-    -492,  72,    556,   660,   580,   616,   772,   436,   424,   -32,  -324,
-    -1268, 416,   -324,  -80,   920,   160,   228,   724,   32,    -516, 64,
-    384,   68,    -128,  136,   240,   248,   -204,  -68,   252,   -932, -120,
-    -480,  -628,  -84,   192,   852,   -404,  -288,  -132,  204,   100,  168,
-    -68,   -196,  -868,  460,   1080,  380,   -80,   244,   0,     484,  -888,
-    64,    184,   352,   600,   460,   164,   604,   -196,  320,   -64,  588,
-    -184,  228,   12,    372,   48,    -848,  -344,  224,   208,   -200, 484,
-    128,   -20,   272,   -468,  -840,  384,   256,   -720,  -520,  -464, -580,
-    112,   -120,  644,   -356,  -208,  -608,  -528,  704,   560,   -424, 392,
-    828,   40,    84,    200,   -152,  0,     -144,  584,   280,   -120, 80,
-    -556,  -972,  -196,  -472,  724,   80,    168,   -32,   88,    160,  -688,
-    0,     160,   356,   372,   -776,  740,   -128,  676,   -248,  -480, 4,
-    -364,  96,    544,   232,   -1032, 956,   236,   356,   20,    -40,  300,
-    24,    -676,  -596,  132,   1120,  -104,  532,   -1096, 568,   648,  444,
-    508,   380,   188,   -376,  -604,  1488,  424,   24,    756,   -220, -192,
-    716,   120,   920,   688,   168,   44,    -460,  568,   284,   1144, 1160,
-    600,   424,   888,   656,   -356,  -320,  220,   316,   -176,  -724, -188,
-    -816,  -628,  -348,  -228,  -380,  1012,  -452,  -660,  736,   928,  404,
-    -696,  -72,   -268,  -892,  128,   184,   -344,  -780,  360,   336,  400,
-    344,   428,   548,   -112,  136,   -228,  -216,  -820,  -516,  340,  92,
-    -136,  116,   -300,  376,   -244,  100,   -316,  -520,  -284,  -12,  824,
-    164,   -548,  -180,  -128,  116,   -924,  -828,  268,   -368,  -580, 620,
-    192,   160,   0,     -1676, 1068,  424,   -56,   -360,  468,   -156, 720,
-    288,   -528,  556,   -364,  548,   -148,  504,   316,   152,   -648, -620,
-    -684,  -24,   -376,  -384,  -108,  -920,  -1032, 768,   180,   -264, -508,
-    -1268, -260,  -60,   300,   -240,  988,   724,   -376,  -576,  -212, -736,
-    556,   192,   1092,  -620,  -880,  376,   -56,   -4,    -216,  -32,  836,
-    268,   396,   1332,  864,   -600,  100,   56,    -412,  -92,   356,  180,
-    884,   -468,  -436,  292,   -388,  -804,  -704,  -840,  368,   -348, 140,
-    -724,  1536,  940,   372,   112,   -372,  436,   -480,  1136,  296,  -32,
-    -228,  132,   -48,   -220,  868,   -1016, -60,   -1044, -464,  328,  916,
-    244,   12,    -736,  -296,  360,   468,   -376,  -108,  -92,   788,  368,
-    -56,   544,   400,   -672,  -420,  728,   16,    320,   44,    -284, -380,
-    -796,  488,   132,   204,   -596,  -372,  88,    -152,  -908,  -636, -572,
-    -624,  -116,  -692,  -200,  -56,   276,   -88,   484,   -324,  948,  864,
-    1000,  -456,  -184,  -276,  292,   -296,  156,   676,   320,   160,  908,
-    -84,   -1236, -288,  -116,  260,   -372,  -644,  732,   -756,  -96,  84,
-    344,   -520,  348,   -688,  240,   -84,   216,   -1044, -136,  -676, -396,
-    -1500, 960,   -40,   176,   168,   1516,  420,   -504,  -344,  -364, -360,
-    1216,  -940,  -380,  -212,  252,   -660,  -708,  484,   -444,  -152, 928,
-    -120,  1112,  476,   -260,  560,   -148,  -344,  108,   -196,  228,  -288,
-    504,   560,   -328,  -88,   288,   -1008, 460,   -228,  468,   -836, -196,
-    76,    388,   232,   412,   -1168, -716,  -644,  756,   -172,  -356, -504,
-    116,   432,   528,   48,    476,   -168,  -608,  448,   160,   -532, -272,
-    28,    -676,  -12,   828,   980,   456,   520,   104,   -104,  256,  -344,
-    -4,    -28,   -368,  -52,   -524,  -572,  -556,  -200,  768,   1124, -208,
-    -512,  176,   232,   248,   -148,  -888,  604,   -600,  -304,  804,  -156,
-    -212,  488,   -192,  -804,  -256,  368,   -360,  -916,  -328,  228,  -240,
-    -448,  -472,  856,   -556,  -364,  572,   -12,   -156,  -368,  -340, 432,
-    252,   -752,  -152,  288,   268,   -580,  -848,  -592,  108,   -76,  244,
-    312,   -716,  592,   -80,   436,   360,   4,     -248,  160,   516,  584,
-    732,   44,    -468,  -280,  -292,  -156,  -588,  28,    308,   912,  24,
-    124,   156,   180,   -252,  944,   -924,  -772,  -520,  -428,  -624, 300,
-    -212,  -1144, 32,    -724,  800,   -1128, -212,  -1288, -848,  180,  -416,
-    440,   192,   -576,  -792,  -76,   -1080, 80,    -532,  -352,  -132, 380,
-    -820,  148,   1112,  128,   164,   456,   700,   -924,  144,   -668, -384,
-    648,   -832,  508,   552,   -52,   -100,  -656,  208,   -568,  748,  -88,
-    680,   232,   300,   192,   -408,  -1012, -152,  -252,  -268,  272,  -876,
-    -664,  -648,  -332,  -136,  16,    12,    1152,  -28,   332,   -536, 320,
-    -672,  -460,  -316,  532,   -260,  228,   -40,   1052,  -816,  180,  88,
-    -496,  -556,  -672,  -368,  428,   92,    356,   404,   -408,  252,  196,
-    -176,  -556,  792,   268,   32,    372,   40,    96,    -332,  328,  120,
-    372,   -900,  -40,   472,   -264,  -592,  952,   128,   656,   112,  664,
-    -232,  420,   4,     -344,  -464,  556,   244,   -416,  -32,   252,  0,
-    -412,  188,   -696,  508,   -476,  324,   -1096, 656,   -312,  560,  264,
-    -136,  304,   160,   -64,   -580,  248,   336,   -720,  560,   -348, -288,
-    -276,  -196,  -500,  852,   -544,  -236,  -1128, -992,  -776,  116,  56,
-    52,    860,   884,   212,   -12,   168,   1020,  512,   -552,  924,  -148,
-    716,   188,   164,   -340,  -520,  -184,  880,   -152,  -680,  -208, -1156,
-    -300,  -528,  -472,  364,   100,   -744,  -1056, -32,   540,   280,  144,
-    -676,  -32,   -232,  -280,  -224,  96,    568,   -76,   172,   148,  148,
-    104,   32,    -296,  -32,   788,   -80,   32,    -16,   280,   288,  944,
-    428,   -484};
-static_assert(sizeof(kGaussianSequence) / sizeof(kGaussianSequence[0]) == 2048,
-              "");
-
-// Section 7.18.3.1.
-template <int bitdepth>
-bool FilmGrainSynthesis_C(const void* source_plane_y, ptrdiff_t source_stride_y,
-                          const void* source_plane_u, ptrdiff_t source_stride_u,
-                          const void* source_plane_v, ptrdiff_t source_stride_v,
-                          const FilmGrainParams& film_grain_params,
-                          const bool is_monochrome,
-                          const bool color_matrix_is_identity, const int width,
-                          const int height, const int subsampling_x,
-                          const int subsampling_y, void* dest_plane_y,
-                          ptrdiff_t dest_stride_y, void* dest_plane_u,
-                          ptrdiff_t dest_stride_u, void* dest_plane_v,
-                          ptrdiff_t dest_stride_v) {
-  FilmGrain<bitdepth> film_grain(film_grain_params, is_monochrome,
-                                 color_matrix_is_identity, subsampling_x,
-                                 subsampling_y, width, height);
-  return film_grain.AddNoise(source_plane_y, source_stride_y, source_plane_u,
-                             source_stride_u, source_plane_v, source_stride_v,
-                             dest_plane_y, dest_stride_y, dest_plane_u,
-                             dest_stride_u, dest_plane_v, dest_stride_v);
-}
-
-void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
-  assert(dsp != nullptr);
-#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
-  dsp->film_grain_synthesis = FilmGrainSynthesis_C<8>;
-#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
-  static_cast<void>(dsp);
-#ifndef LIBGAV1_Dsp8bpp_FilmGrainSynthesis
-  dsp->film_grain_synthesis = FilmGrainSynthesis_C<8>;
-#endif
-#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
-}
-
-#if LIBGAV1_MAX_BITDEPTH >= 10
-void Init10bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
-  assert(dsp != nullptr);
-#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
-  dsp->film_grain_synthesis = FilmGrainSynthesis_C<10>;
-#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
-  static_cast<void>(dsp);
-#ifndef LIBGAV1_Dsp10bpp_FilmGrainSynthesis
-  dsp->film_grain_synthesis = FilmGrainSynthesis_C<10>;
-#endif
-#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
-}
-#endif
-
-}  // namespace
-
-void FilmGrainInit_C() {
-  Init8bpp();
-#if LIBGAV1_MAX_BITDEPTH >= 10
-  Init10bpp();
-#endif
-}
-
-// Static data member definitions.
-template <int bitdepth>
-constexpr int FilmGrain<bitdepth>::kLumaWidth;
-template <int bitdepth>
-constexpr int FilmGrain<bitdepth>::kLumaHeight;
-template <int bitdepth>
-constexpr int FilmGrain<bitdepth>::kMinChromaWidth;
-template <int bitdepth>
-constexpr int FilmGrain<bitdepth>::kMaxChromaWidth;
-template <int bitdepth>
-constexpr int FilmGrain<bitdepth>::kMinChromaHeight;
-template <int bitdepth>
-constexpr int FilmGrain<bitdepth>::kMaxChromaHeight;
-
-template <int bitdepth>
-FilmGrain<bitdepth>::FilmGrain(const FilmGrainParams& params,
-                               bool is_monochrome,
-                               bool color_matrix_is_identity, int subsampling_x,
-                               int subsampling_y, int width, int height)
-    : params_(params),
-      is_monochrome_(is_monochrome),
-      color_matrix_is_identity_(color_matrix_is_identity),
-      subsampling_x_(subsampling_x),
-      subsampling_y_(subsampling_y),
-      width_(width),
-      height_(height),
-      chroma_width_((subsampling_x != 0) ? kMinChromaWidth : kMaxChromaWidth),
-      chroma_height_((subsampling_y != 0) ? kMinChromaHeight
-                                          : kMaxChromaHeight) {
-  // bitdepth  grain_min_  grain_max_
-  // --------------------------------
-  //     8        -128         127
-  //    10        -512         511
-  //    12       -2048        2047
-  const int grain_center = 128 << (bitdepth - 8);
-  grain_min_ = -grain_center;
-  grain_max_ = grain_center - 1;
-}
-
-template <int bitdepth>
-bool FilmGrain<bitdepth>::Init() {
-  // Section 7.18.3.3. Generate grain process.
-  GenerateLumaGrain(params_, luma_grain_);
-  ApplyAutoRegressiveFilterToLumaGrain(params_, grain_min_, grain_max_,
-                                       luma_grain_);
-  if (!is_monochrome_) {
-    GenerateChromaGrains(params_, chroma_width_, chroma_height_, u_grain_,
-                         v_grain_);
-    ApplyAutoRegressiveFilterToChromaGrains(
-        params_, grain_min_, grain_max_, luma_grain_, subsampling_x_,
-        subsampling_y_, chroma_width_, chroma_height_, u_grain_, v_grain_);
-  }
-
-  // Section 7.18.3.4. Scaling lookup initialization process.
-  InitializeScalingLookupTable(params_.num_y_points, params_.point_y_value,
-                               params_.point_y_scaling, scaling_lut_y_);
-  if (!is_monochrome_) {
-    if (params_.chroma_scaling_from_luma) {
-      scaling_lut_u_ = scaling_lut_y_;
-      scaling_lut_v_ = scaling_lut_y_;
-    } else {
-      scaling_lut_chroma_buffer_.reset(new (std::nothrow) uint8_t[256 * 2]);
-      if (scaling_lut_chroma_buffer_ == nullptr) return false;
-      scaling_lut_u_ = &scaling_lut_chroma_buffer_[0];
-      scaling_lut_v_ = &scaling_lut_chroma_buffer_[256];
-      InitializeScalingLookupTable(params_.num_u_points, params_.point_u_value,
-                                   params_.point_u_scaling, scaling_lut_u_);
-      InitializeScalingLookupTable(params_.num_v_points, params_.point_v_value,
-                                   params_.point_v_scaling, scaling_lut_v_);
-    }
-  }
-  return true;
-}
-
-// Section 7.18.3.2.
-// |bits| is the number of random bits to return.
-template <int bitdepth>
-int FilmGrain<bitdepth>::GetRandomNumber(int bits, uint16_t* seed) {
-  uint16_t s = *seed;
-  uint16_t bit = (s ^ (s >> 1) ^ (s >> 3) ^ (s >> 12)) & 1;
-  s = (s >> 1) | (bit << 15);
-  *seed = s;
-  return s >> (16 - bits);
-}
-
-template <int bitdepth>
-void FilmGrain<bitdepth>::GenerateLumaGrain(const FilmGrainParams& params,
-                                            GrainType* luma_grain) {
-  const int shift = 12 - bitdepth + params.grain_scale_shift;
-  if (params.num_y_points == 0) {
-    memset(luma_grain, 0, kLumaHeight * kLumaWidth * sizeof(*luma_grain));
-  } else {
-    uint16_t seed = params.grain_seed;
-    GrainType* luma_grain_row = luma_grain;
-    for (int y = 0; y < kLumaHeight; ++y) {
-      for (int x = 0; x < kLumaWidth; ++x) {
-        luma_grain_row[x] = RightShiftWithRounding(
-            kGaussianSequence[GetRandomNumber(11, &seed)], shift);
-      }
-      luma_grain_row += kLumaWidth;
-    }
-  }
-}
-
-// Applies an auto-regressive filter to the white noise in luma_grain.
-template <int bitdepth>
-void FilmGrain<bitdepth>::ApplyAutoRegressiveFilterToLumaGrain(
-    const FilmGrainParams& params, int grain_min, int grain_max,
-    GrainType* luma_grain) {
-  assert(params.auto_regression_coeff_lag <= 3);
-  const int shift = params.auto_regression_shift;
-  for (int y = 3; y < kLumaHeight; ++y) {
-    for (int x = 3; x < kLumaWidth - 3; ++x) {
-      int sum = 0;
-      int pos = 0;
-      int delta_row = -params.auto_regression_coeff_lag;
-      do {
-        int delta_column = -params.auto_regression_coeff_lag;
-        do {
-          if (delta_row == 0 && delta_column == 0) {
-            break;
-          }
-          const int coeff = params.auto_regression_coeff_y[pos];
-          sum += luma_grain[(y + delta_row) * kLumaWidth + (x + delta_column)] *
-                 coeff;
-          ++pos;
-        } while (++delta_column <= params.auto_regression_coeff_lag);
-      } while (++delta_row <= 0);
-      luma_grain[y * kLumaWidth + x] = Clip3(
-          luma_grain[y * kLumaWidth + x] + RightShiftWithRounding(sum, shift),
-          grain_min, grain_max);
-    }
-  }
-}
-
-template <int bitdepth>
-void FilmGrain<bitdepth>::GenerateChromaGrains(const FilmGrainParams& params,
-                                               int chroma_width,
-                                               int chroma_height,
-                                               GrainType* u_grain,
-                                               GrainType* v_grain) {
-  const int shift = 12 - bitdepth + params.grain_scale_shift;
-  if (params.num_u_points == 0 && !params.chroma_scaling_from_luma) {
-    memset(u_grain, 0, chroma_height * chroma_width * sizeof(*u_grain));
-  } else {
-    uint16_t seed = params.grain_seed ^ 0xb524;
-    GrainType* u_grain_row = u_grain;
-    assert(chroma_width > 0);
-    assert(chroma_height > 0);
-    int y = 0;
-    do {
-      int x = 0;
-      do {
-        u_grain_row[x] = RightShiftWithRounding(
-            kGaussianSequence[GetRandomNumber(11, &seed)], shift);
-      } while (++x < chroma_width);
-
-      u_grain_row += chroma_width;
-    } while (++y < chroma_height);
-  }
-  if (params.num_v_points == 0 && !params.chroma_scaling_from_luma) {
-    memset(v_grain, 0, chroma_height * chroma_width * sizeof(*v_grain));
-  } else {
-    GrainType* v_grain_row = v_grain;
-    uint16_t seed = params.grain_seed ^ 0x49d8;
-    int y = 0;
-    do {
-      int x = 0;
-      do {
-        v_grain_row[x] = RightShiftWithRounding(
-            kGaussianSequence[GetRandomNumber(11, &seed)], shift);
-      } while (++x < chroma_width);
-
-      v_grain_row += chroma_width;
-    } while (++y < chroma_height);
-  }
-}
-
-template <int bitdepth>
-void FilmGrain<bitdepth>::ApplyAutoRegressiveFilterToChromaGrains(
-    const FilmGrainParams& params, int grain_min, int grain_max,
-    const GrainType* luma_grain, int subsampling_x, int subsampling_y,
-    int chroma_width, int chroma_height, GrainType* u_grain,
-    GrainType* v_grain) {
-  assert(params.auto_regression_coeff_lag <= 3);
-  const int shift = params.auto_regression_shift;
-  for (int y = 3; y < chroma_height; ++y) {
-    for (int x = 3; x < chroma_width - 3; ++x) {
-      int sum_u = 0;
-      int sum_v = 0;
-      int pos = 0;
-      int delta_row = -params.auto_regression_coeff_lag;
-      do {
-        int delta_column = -params.auto_regression_coeff_lag;
-        do {
-          const int coeff_u = params.auto_regression_coeff_u[pos];
-          const int coeff_v = params.auto_regression_coeff_v[pos];
-          if (delta_row == 0 && delta_column == 0) {
-            if (params.num_y_points > 0) {
-              int luma = 0;
-              const int luma_x = ((x - 3) << subsampling_x) + 3;
-              const int luma_y = ((y - 3) << subsampling_y) + 3;
-              int i = 0;
-              do {
-                int j = 0;
-                do {
-                  luma += luma_grain[(luma_y + i) * kLumaWidth + (luma_x + j)];
-                } while (++j <= subsampling_x);
-              } while (++i <= subsampling_y);
-              luma =
-                  RightShiftWithRounding(luma, subsampling_x + subsampling_y);
-              sum_u += luma * coeff_u;
-              sum_v += luma * coeff_v;
-            }
-            break;
-          }
-          sum_u +=
-              u_grain[(y + delta_row) * chroma_width + (x + delta_column)] *
-              coeff_u;
-          sum_v +=
-              v_grain[(y + delta_row) * chroma_width + (x + delta_column)] *
-              coeff_v;
-          ++pos;
-        } while (++delta_column <= params.auto_regression_coeff_lag);
-      } while (++delta_row <= 0);
-      u_grain[y * chroma_width + x] = Clip3(
-          u_grain[y * chroma_width + x] + RightShiftWithRounding(sum_u, shift),
-          grain_min, grain_max);
-      v_grain[y * chroma_width + x] = Clip3(
-          v_grain[y * chroma_width + x] + RightShiftWithRounding(sum_v, shift),
-          grain_min, grain_max);
-    }
-  }
-}
-
-template <int bitdepth>
-void FilmGrain<bitdepth>::InitializeScalingLookupTable(
+// Making this a template function prevents it from adding to code size when it
+// is not placed in the DSP table. Most functions in the dsp directory change
+// behavior by bitdepth, but because this one doesn't, it receives a dummy
+// parameter with one enforced value, ensuring only one copy is made.
+template <int singleton>
+void InitializeScalingLookupTable_C(
     int num_points, const uint8_t point_value[], const uint8_t point_scaling[],
-    uint8_t scaling_lut[256]) {
+    uint8_t scaling_lut[kScalingLookupTableSize]) {
+  static_assert(singleton == 0,
+                "Improper instantiation of InitializeScalingLookupTable_C. "
+                "There should be only one copy of this function.");
   if (num_points == 0) {
-    memset(scaling_lut, 0, sizeof(scaling_lut[0]) * 256);
+    memset(scaling_lut, 0, sizeof(scaling_lut[0]) * kScalingLookupTableSize);
     return;
   }
   static_assert(sizeof(scaling_lut[0]) == 1, "");
@@ -545,143 +64,223 @@
   }
   const uint8_t last_point_value = point_value[num_points - 1];
   memset(&scaling_lut[last_point_value], point_scaling[num_points - 1],
-         256 - last_point_value);
+         kScalingLookupTableSize - last_point_value);
 }
 
 // Section 7.18.3.5.
 // Performs a piecewise linear interpolation into the scaling table.
 template <int bitdepth>
-int ScaleLut(const uint8_t scaling_lut[256], int index) {
+int ScaleLut(const uint8_t scaling_lut[kScalingLookupTableSize], int index) {
   const int shift = bitdepth - 8;
   const int quotient = index >> shift;
   const int remainder = index - (quotient << shift);
-  if (bitdepth == 8 || quotient == 255) {
+  if (bitdepth == 8) {
+    assert(quotient < kScalingLookupTableSize);
     return scaling_lut[quotient];
   }
+  assert(quotient + 1 < kScalingLookupTableSize);
   const int start = scaling_lut[quotient];
   const int end = scaling_lut[quotient + 1];
   return start + RightShiftWithRounding((end - start) * remainder, shift);
 }
 
-template <int bitdepth>
-bool FilmGrain<bitdepth>::AllocateNoiseStripes() {
-  const int num_planes = is_monochrome_ ? kMaxPlanesMonochrome : kMaxPlanes;
-  const int half_height = DivideBy2(height_ + 1);
-  // ceil(half_height / 16.0)
-  const int max_luma_num = DivideBy16(half_height + 15);
-  if (!noise_stripe_.Reset(max_luma_num, num_planes,
-                           /*zero_initialize=*/false)) {
-    return false;
-  }
-  size_t noise_buffer_size = max_luma_num * 34 * width_;
-  if (!is_monochrome_) {
-    noise_buffer_size += max_luma_num * 2 * (34 >> subsampling_y_) *
-                         RightShiftWithRounding(width_, subsampling_x_);
-  }
-  noise_buffer_.reset(new (std::nothrow) GrainType[noise_buffer_size]);
-  if (noise_buffer_ == nullptr) return false;
-  GrainType* noise_block = noise_buffer_.get();
-  int luma_num = 0;
-  assert(half_height > 0);
-  int y = 0;
-  do {
-    noise_stripe_[luma_num][kPlaneY] = noise_block;
-    noise_block += 34 * width_;
-    if (!is_monochrome_) {
-      noise_stripe_[luma_num][kPlaneU] = noise_block;
-      noise_block += (34 >> subsampling_y_) *
-                     RightShiftWithRounding(width_, subsampling_x_);
-      noise_stripe_[luma_num][kPlaneV] = noise_block;
-      noise_block += (34 >> subsampling_y_) *
-                     RightShiftWithRounding(width_, subsampling_x_);
+// Applies an auto-regressive filter to the white noise in luma_grain.
+template <int bitdepth, typename GrainType>
+void ApplyAutoRegressiveFilterToLumaGrain_C(const FilmGrainParams& params,
+                                            void* luma_grain_buffer) {
+  auto* luma_grain = static_cast<GrainType*>(luma_grain_buffer);
+  const int grain_min = GetGrainMin<bitdepth>();
+  const int grain_max = GetGrainMax<bitdepth>();
+  const int auto_regression_coeff_lag = params.auto_regression_coeff_lag;
+  assert(auto_regression_coeff_lag > 0 && auto_regression_coeff_lag <= 3);
+  // A pictorial representation of the auto-regressive filter for various values
+  // of auto_regression_coeff_lag. The letter 'O' represents the current sample.
+  // (The filter always operates on the current sample with filter
+  // coefficient 1.) The letters 'X' represent the neighboring samples that the
+  // filter operates on.
+  //
+  // auto_regression_coeff_lag == 3:
+  //   X X X X X X X
+  //   X X X X X X X
+  //   X X X X X X X
+  //   X X X O
+  // auto_regression_coeff_lag == 2:
+  //     X X X X X
+  //     X X X X X
+  //     X X O
+  // auto_regression_coeff_lag == 1:
+  //       X X X
+  //       X O
+  // auto_regression_coeff_lag == 0:
+  //         O
+  //
+  // Note that if auto_regression_coeff_lag is 0, the filter is the identity
+  // filter and therefore can be skipped. This implementation assumes it is not
+  // called in that case.
+  const int shift = params.auto_regression_shift;
+  for (int y = kAutoRegressionBorder; y < kLumaHeight; ++y) {
+    for (int x = kAutoRegressionBorder; x < kLumaWidth - kAutoRegressionBorder;
+         ++x) {
+      int sum = 0;
+      int pos = 0;
+      int delta_row = -auto_regression_coeff_lag;
+      // The last iteration (delta_row == 0) is shorter and is handled
+      // separately.
+      do {
+        int delta_column = -auto_regression_coeff_lag;
+        do {
+          const int coeff = params.auto_regression_coeff_y[pos];
+          sum += luma_grain[(y + delta_row) * kLumaWidth + (x + delta_column)] *
+                 coeff;
+          ++pos;
+        } while (++delta_column <= auto_regression_coeff_lag);
+      } while (++delta_row < 0);
+      // Last iteration: delta_row == 0.
+      {
+        int delta_column = -auto_regression_coeff_lag;
+        do {
+          const int coeff = params.auto_regression_coeff_y[pos];
+          sum += luma_grain[y * kLumaWidth + (x + delta_column)] * coeff;
+          ++pos;
+        } while (++delta_column < 0);
+      }
+      luma_grain[y * kLumaWidth + x] = Clip3(
+          luma_grain[y * kLumaWidth + x] + RightShiftWithRounding(sum, shift),
+          grain_min, grain_max);
     }
-    ++luma_num;
-    y += 16;
-  } while (y < half_height);
-  assert(noise_block == noise_buffer_.get() + noise_buffer_size);
-  return true;
+  }
 }
 
-template <int bitdepth>
-void FilmGrain<bitdepth>::ConstructNoiseStripes() {
-  const int num_planes = is_monochrome_ ? kMaxPlanesMonochrome : kMaxPlanes;
-  const int half_width = DivideBy2(width_ + 1);
-  const int half_height = DivideBy2(height_ + 1);
-  int luma_num = 0;
-  assert(half_width > 0);
-  assert(half_height > 0);
-  int y = 0;
-  do {
-    uint16_t seed = params_.grain_seed;
-    seed ^= ((luma_num * 37 + 178) & 255) << 8;
-    seed ^= ((luma_num * 173 + 105) & 255);
-    int x = 0;
-    do {
-      const int rand = GetRandomNumber(8, &seed);
-      const int offset_x = rand >> 4;
-      const int offset_y = rand & 15;
-      for (int plane = kPlaneY; plane < num_planes; ++plane) {
-        const int plane_sub_x = (plane > kPlaneY) ? subsampling_x_ : 0;
-        const int plane_sub_y = (plane > kPlaneY) ? subsampling_y_ : 0;
-        const int plane_offset_x =
-            (plane_sub_x != 0) ? 6 + offset_x : 9 + offset_x * 2;
-        const int plane_offset_y =
-            (plane_sub_y != 0) ? 6 + offset_y : 9 + offset_y * 2;
-        GrainType* const noise_block = noise_stripe_[luma_num][plane];
-        const int noise_block_width = (width_ + plane_sub_x) >> plane_sub_x;
+template <int bitdepth, typename GrainType, int auto_regression_coeff_lag,
+          bool use_luma>
+void ApplyAutoRegressiveFilterToChromaGrains_C(const FilmGrainParams& params,
+                                               const void* luma_grain_buffer,
+                                               int subsampling_x,
+                                               int subsampling_y,
+                                               void* u_grain_buffer,
+                                               void* v_grain_buffer) {
+  static_assert(
+      auto_regression_coeff_lag >= 0 && auto_regression_coeff_lag <= 3,
+      "Unsupported autoregression lag for chroma.");
+  const auto* luma_grain = static_cast<const GrainType*>(luma_grain_buffer);
+  const int grain_min = GetGrainMin<bitdepth>();
+  const int grain_max = GetGrainMax<bitdepth>();
+  auto* u_grain = static_cast<GrainType*>(u_grain_buffer);
+  auto* v_grain = static_cast<GrainType*>(v_grain_buffer);
+  const int shift = params.auto_regression_shift;
+  const int chroma_height =
+      (subsampling_y == 0) ? kMaxChromaHeight : kMinChromaHeight;
+  const int chroma_width =
+      (subsampling_x == 0) ? kMaxChromaWidth : kMinChromaWidth;
+  for (int y = kAutoRegressionBorder; y < chroma_height; ++y) {
+    const int luma_y =
+        ((y - kAutoRegressionBorder) << subsampling_y) + kAutoRegressionBorder;
+    for (int x = kAutoRegressionBorder;
+         x < chroma_width - kAutoRegressionBorder; ++x) {
+      int sum_u = 0;
+      int sum_v = 0;
+      int pos = 0;
+      int delta_row = -auto_regression_coeff_lag;
+      do {
+        int delta_column = -auto_regression_coeff_lag;
+        do {
+          if (delta_row == 0 && delta_column == 0) {
+            break;
+          }
+          const int coeff_u = params.auto_regression_coeff_u[pos];
+          const int coeff_v = params.auto_regression_coeff_v[pos];
+          sum_u +=
+              u_grain[(y + delta_row) * chroma_width + (x + delta_column)] *
+              coeff_u;
+          sum_v +=
+              v_grain[(y + delta_row) * chroma_width + (x + delta_column)] *
+              coeff_v;
+          ++pos;
+        } while (++delta_column <= auto_regression_coeff_lag);
+      } while (++delta_row <= 0);
+      if (use_luma) {
+        int luma = 0;
+        const int luma_x = ((x - kAutoRegressionBorder) << subsampling_x) +
+                           kAutoRegressionBorder;
         int i = 0;
         do {
           int j = 0;
           do {
-            int grain;
-            if (plane == kPlaneY) {
-              grain = luma_grain_[(plane_offset_y + i) * kLumaWidth +
-                                  (plane_offset_x + j)];
-            } else if (plane == kPlaneU) {
-              grain = u_grain_[(plane_offset_y + i) * chroma_width_ +
-                               (plane_offset_x + j)];
-            } else {
-              grain = v_grain_[(plane_offset_y + i) * chroma_width_ +
-                               (plane_offset_x + j)];
-            }
-            // Section 7.18.3.5 says:
-            //   noiseStripe[ lumaNum ][ 0 ] is 34 samples high and w samples
-            //   wide (a few additional samples across are actually written to
-            //   the array, but these are never read) ...
-            //
-            // Note: The warning in the parentheses also applies to
-            // noiseStripe[ lumaNum ][ 1 ] and noiseStripe[ lumaNum ][ 2 ].
-            //
-            // The writes beyond the width of each row would happen below. To
-            // prevent those writes, we skip the write if the column index
-            // (x * 2 + j or x + j) is >= noise_block_width.
-            if (plane_sub_x == 0) {
-              if (x * 2 + j >= noise_block_width) continue;
-              if (j < 2 && params_.overlap_flag && x > 0) {
-                const int old =
-                    noise_block[i * noise_block_width + (x * 2 + j)];
-                if (j == 0) {
-                  grain = old * 27 + grain * 17;
-                } else {
-                  grain = old * 17 + grain * 27;
-                }
-                grain = Clip3(RightShiftWithRounding(grain, 5), grain_min_,
-                              grain_max_);
-              }
-              noise_block[i * noise_block_width + (x * 2 + j)] = grain;
-            } else {
-              if (x + j >= noise_block_width) continue;
-              if (j == 0 && params_.overlap_flag && x > 0) {
-                const int old = noise_block[i * noise_block_width + (x + j)];
-                grain = old * 23 + grain * 22;
-                grain = Clip3(RightShiftWithRounding(grain, 5), grain_min_,
-                              grain_max_);
-              }
-              noise_block[i * noise_block_width + (x + j)] = grain;
-            }
-          } while (++j < (34 >> plane_sub_x));
-        } while (++i < (34 >> plane_sub_y));
+            luma += luma_grain[(luma_y + i) * kLumaWidth + (luma_x + j)];
+          } while (++j <= subsampling_x);
+        } while (++i <= subsampling_y);
+        luma = RightShiftWithRounding(luma, subsampling_x + subsampling_y);
+        const int coeff_u = params.auto_regression_coeff_u[pos];
+        const int coeff_v = params.auto_regression_coeff_v[pos];
+        sum_u += luma * coeff_u;
+        sum_v += luma * coeff_v;
       }
+      u_grain[y * chroma_width + x] = Clip3(
+          u_grain[y * chroma_width + x] + RightShiftWithRounding(sum_u, shift),
+          grain_min, grain_max);
+      v_grain[y * chroma_width + x] = Clip3(
+          v_grain[y * chroma_width + x] + RightShiftWithRounding(sum_v, shift),
+          grain_min, grain_max);
+    }
+  }
+}
+
+// This implementation is for the condition overlap_flag == false.
+template <int bitdepth, typename GrainType>
+void ConstructNoiseStripes_C(const void* grain_buffer, int grain_seed,
+                             int width, int height, int subsampling_x,
+                             int subsampling_y, void* noise_stripes_buffer) {
+  auto* noise_stripes =
+      static_cast<Array2DView<GrainType>*>(noise_stripes_buffer);
+  const auto* grain = static_cast<const GrainType*>(grain_buffer);
+  const int half_width = DivideBy2(width + 1);
+  const int half_height = DivideBy2(height + 1);
+  assert(half_width > 0);
+  assert(half_height > 0);
+  static_assert(kLumaWidth == kMaxChromaWidth,
+                "kLumaWidth width should be equal to kMaxChromaWidth");
+  const int grain_width =
+      (subsampling_x == 0) ? kMaxChromaWidth : kMinChromaWidth;
+  const int plane_width = (width + subsampling_x) >> subsampling_x;
+  constexpr int kNoiseStripeHeight = 34;
+  int luma_num = 0;
+  int y = 0;
+  do {
+    GrainType* const noise_stripe = (*noise_stripes)[luma_num];
+    uint16_t seed = grain_seed;
+    seed ^= ((luma_num * 37 + 178) & 255) << 8;
+    seed ^= ((luma_num * 173 + 105) & 255);
+    int x = 0;
+    do {
+      const int rand = GetFilmGrainRandomNumber(8, &seed);
+      const int offset_x = rand >> 4;
+      const int offset_y = rand & 15;
+      const int plane_offset_x =
+          (subsampling_x != 0) ? 6 + offset_x : 9 + offset_x * 2;
+      const int plane_offset_y =
+          (subsampling_y != 0) ? 6 + offset_y : 9 + offset_y * 2;
+      int i = 0;
+      do {
+        // Section 7.18.3.5 says:
+        //   noiseStripe[ lumaNum ][ 0 ] is 34 samples high and w samples
+        //   wide (a few additional samples across are actually written to
+        //   the array, but these are never read) ...
+        //
+        // Note: The warning in the parentheses also applies to
+        // noiseStripe[ lumaNum ][ 1 ] and noiseStripe[ lumaNum ][ 2 ].
+        //
+        // Writes beyond the width of each row could happen below. To
+        // prevent those writes, we clip the number of pixels to copy against
+        // the remaining width.
+        // TODO(petersonab): Allocate aligned stripes with extra width to cover
+        // the size of the final stripe block, then remove this call to min.
+        const int copy_size =
+            std::min(kNoiseStripeHeight >> subsampling_x,
+                     plane_width - (x << (1 - subsampling_x)));
+        memcpy(&noise_stripe[i * plane_width + (x << (1 - subsampling_x))],
+               &grain[(plane_offset_y + i) * grain_width + plane_offset_x],
+               copy_size * sizeof(noise_stripe[0]));
+      } while (++i < (kNoiseStripeHeight >> subsampling_y));
       x += 16;
     } while (x < half_width);
 
@@ -690,115 +289,278 @@
   } while (y < half_height);
 }
 
-template <int bitdepth>
-bool FilmGrain<bitdepth>::AllocateNoiseImage() {
-  if (!noise_image_[kPlaneY].Reset(height_, width_,
-                                   /*zero_initialize=*/false)) {
-    return false;
-  }
-  if (!is_monochrome_) {
-    if (!noise_image_[kPlaneU].Reset(
-            (height_ + subsampling_y_) >> subsampling_y_,
-            (width_ + subsampling_x_) >> subsampling_x_,
-            /*zero_initialize=*/false)) {
-      return false;
-    }
-    if (!noise_image_[kPlaneV].Reset(
-            (height_ + subsampling_y_) >> subsampling_y_,
-            (width_ + subsampling_x_) >> subsampling_x_,
-            /*zero_initialize=*/false)) {
-      return false;
-    }
-  }
-  return true;
-}
-
-template <int bitdepth>
-void FilmGrain<bitdepth>::ConstructNoiseImage() {
-  const int num_planes = is_monochrome_ ? kMaxPlanesMonochrome : kMaxPlanes;
-  for (int plane = kPlaneY; plane < num_planes; ++plane) {
-    const int plane_sub_x = (plane > kPlaneY) ? subsampling_x_ : 0;
-    const int plane_sub_y = (plane > kPlaneY) ? subsampling_y_ : 0;
-    const int noise_block_width = (width_ + plane_sub_x) >> plane_sub_x;
-    int y = 0;
+// This implementation is for the condition overlap_flag == true.
+template <int bitdepth, typename GrainType>
+void ConstructNoiseStripesWithOverlap_C(const void* grain_buffer,
+                                        int grain_seed, int width, int height,
+                                        int subsampling_x, int subsampling_y,
+                                        void* noise_stripes_buffer) {
+  auto* noise_stripes =
+      static_cast<Array2DView<GrainType>*>(noise_stripes_buffer);
+  const auto* grain = static_cast<const GrainType*>(grain_buffer);
+  const int half_width = DivideBy2(width + 1);
+  const int half_height = DivideBy2(height + 1);
+  assert(half_width > 0);
+  assert(half_height > 0);
+  static_assert(kLumaWidth == kMaxChromaWidth,
+                "kLumaWidth width should be equal to kMaxChromaWidth");
+  const int grain_width =
+      (subsampling_x == 0) ? kMaxChromaWidth : kMinChromaWidth;
+  const int plane_width = (width + subsampling_x) >> subsampling_x;
+  constexpr int kNoiseStripeHeight = 34;
+  int luma_num = 0;
+  int y = 0;
+  do {
+    GrainType* const noise_stripe = (*noise_stripes)[luma_num];
+    uint16_t seed = grain_seed;
+    seed ^= ((luma_num * 37 + 178) & 255) << 8;
+    seed ^= ((luma_num * 173 + 105) & 255);
+    // Begin special iteration for x == 0.
+    const int rand = GetFilmGrainRandomNumber(8, &seed);
+    const int offset_x = rand >> 4;
+    const int offset_y = rand & 15;
+    const int plane_offset_x =
+        (subsampling_x != 0) ? 6 + offset_x : 9 + offset_x * 2;
+    const int plane_offset_y =
+        (subsampling_y != 0) ? 6 + offset_y : 9 + offset_y * 2;
+    // The overlap computation only occurs when x > 0, so it is omitted here.
+    int i = 0;
     do {
-      const int luma_num = y >> (5 - plane_sub_y);
-      const int i = y - (luma_num << (5 - plane_sub_y));
-      int x = 0;
+      // TODO(petersonab): Allocate aligned stripes with extra width to cover
+      // the size of the final stripe block, then remove this call to min.
+      const int copy_size =
+          std::min(kNoiseStripeHeight >> subsampling_x, plane_width);
+      memcpy(&noise_stripe[i * plane_width],
+             &grain[(plane_offset_y + i) * grain_width + plane_offset_x],
+             copy_size * sizeof(noise_stripe[0]));
+    } while (++i < (kNoiseStripeHeight >> subsampling_y));
+    // End special iteration for x == 0.
+    for (int x = 16; x < half_width; x += 16) {
+      const int rand = GetFilmGrainRandomNumber(8, &seed);
+      const int offset_x = rand >> 4;
+      const int offset_y = rand & 15;
+      const int plane_offset_x =
+          (subsampling_x != 0) ? 6 + offset_x : 9 + offset_x * 2;
+      const int plane_offset_y =
+          (subsampling_y != 0) ? 6 + offset_y : 9 + offset_y * 2;
+      int i = 0;
       do {
-        int grain = noise_stripe_[luma_num][plane][i * noise_block_width + x];
-        if (plane_sub_y == 0) {
-          if (i < 2 && luma_num > 0 && params_.overlap_flag) {
-            const int old = noise_stripe_[luma_num - 1][plane]
-                                         [(i + 32) * noise_block_width + x];
-            if (i == 0) {
-              grain = old * 27 + grain * 17;
-            } else {
-              grain = old * 17 + grain * 27;
-            }
-            grain =
-                Clip3(RightShiftWithRounding(grain, 5), grain_min_, grain_max_);
-          }
+        int j = 0;
+        int grain_sample =
+            grain[(plane_offset_y + i) * grain_width + plane_offset_x];
+        // The first pixel(s) of each segment of the noise_stripe are subject to
+        // the "overlap" computation.
+        if (subsampling_x == 0) {
+          // Corresponds to the line in the spec:
+          // if (j < 2 && x > 0)
+          // j = 0
+          int old = noise_stripe[i * plane_width + x * 2];
+          grain_sample = old * 27 + grain_sample * 17;
+          grain_sample =
+              Clip3(RightShiftWithRounding(grain_sample, 5),
+                    GetGrainMin<bitdepth>(), GetGrainMax<bitdepth>());
+          noise_stripe[i * plane_width + x * 2] = grain_sample;
+
+          // This check prevents overwriting for the iteration j = 1. The
+          // continue applies to the i-loop.
+          if (x * 2 + 1 >= plane_width) continue;
+          // j = 1
+          grain_sample =
+              grain[(plane_offset_y + i) * grain_width + plane_offset_x + 1];
+          old = noise_stripe[i * plane_width + x * 2 + 1];
+          grain_sample = old * 17 + grain_sample * 27;
+          grain_sample =
+              Clip3(RightShiftWithRounding(grain_sample, 5),
+                    GetGrainMin<bitdepth>(), GetGrainMax<bitdepth>());
+          noise_stripe[i * plane_width + x * 2 + 1] = grain_sample;
+          j = 2;
         } else {
-          if (i < 1 && luma_num > 0 && params_.overlap_flag) {
-            const int old = noise_stripe_[luma_num - 1][plane]
-                                         [(i + 16) * noise_block_width + x];
-            grain = old * 23 + grain * 22;
-            grain =
-                Clip3(RightShiftWithRounding(grain, 5), grain_min_, grain_max_);
-          }
+          // Corresponds to the line in the spec:
+          // if (j == 0 && x > 0)
+          const int old = noise_stripe[i * plane_width + x];
+          grain_sample = old * 23 + grain_sample * 22;
+          grain_sample =
+              Clip3(RightShiftWithRounding(grain_sample, 5),
+                    GetGrainMin<bitdepth>(), GetGrainMax<bitdepth>());
+          noise_stripe[i * plane_width + x] = grain_sample;
+          j = 1;
         }
-        noise_image_[plane][y][x] = grain;
-      } while (++x < noise_block_width);
-    } while (++y < ((height_ + plane_sub_y) >> plane_sub_y));
+        // The following covers the rest of the loop over j as described in the
+        // spec.
+        //
+        // Section 7.18.3.5 says:
+        //   noiseStripe[ lumaNum ][ 0 ] is 34 samples high and w samples
+        //   wide (a few additional samples across are actually written to
+        //   the array, but these are never read) ...
+        //
+        // Note: The warning in the parentheses also applies to
+        // noiseStripe[ lumaNum ][ 1 ] and noiseStripe[ lumaNum ][ 2 ].
+        //
+        // Writes beyond the width of each row could happen below. To
+        // prevent those writes, we clip the number of pixels to copy against
+        // the remaining width.
+        // TODO(petersonab): Allocate aligned stripes with extra width to cover
+        // the size of the final stripe block, then remove this call to min.
+        const int copy_size =
+            std::min(kNoiseStripeHeight >> subsampling_x,
+                     plane_width - (x << (1 - subsampling_x))) -
+            j;
+        memcpy(&noise_stripe[i * plane_width + (x << (1 - subsampling_x)) + j],
+               &grain[(plane_offset_y + i) * grain_width + plane_offset_x + j],
+               copy_size * sizeof(noise_stripe[0]));
+      } while (++i < (kNoiseStripeHeight >> subsampling_y));
+    }
+
+    ++luma_num;
+    y += 16;
+  } while (y < half_height);
+}
+
+template <int bitdepth, typename GrainType>
+inline void WriteOverlapLine_C(const GrainType* noise_stripe_row,
+                               const GrainType* noise_stripe_row_prev,
+                               int plane_width, int grain_coeff, int old_coeff,
+                               GrainType* noise_image_row) {
+  int x = 0;
+  do {
+    int grain = noise_stripe_row[x];
+    const int old = noise_stripe_row_prev[x];
+    grain = old * old_coeff + grain * grain_coeff;
+    grain = Clip3(RightShiftWithRounding(grain, 5), GetGrainMin<bitdepth>(),
+                  GetGrainMax<bitdepth>());
+    noise_image_row[x] = grain;
+  } while (++x < plane_width);
+}
+
+template <int bitdepth, typename GrainType>
+void ConstructNoiseImageOverlap_C(const void* noise_stripes_buffer, int width,
+                                  int height, int subsampling_x,
+                                  int subsampling_y, void* noise_image_buffer) {
+  const auto* noise_stripes =
+      static_cast<const Array2DView<GrainType>*>(noise_stripes_buffer);
+  auto* noise_image = static_cast<Array2D<GrainType>*>(noise_image_buffer);
+  const int plane_width = (width + subsampling_x) >> subsampling_x;
+  const int plane_height = (height + subsampling_y) >> subsampling_y;
+  const int stripe_height = 32 >> subsampling_y;
+  const int stripe_mask = stripe_height - 1;
+  int y = stripe_height;
+  int luma_num = 1;
+  if (subsampling_y == 0) {
+    // Begin complete stripes section. This is when we are guaranteed to have
+    // two overlap rows in each stripe.
+    for (; y < (plane_height & ~stripe_mask); ++luma_num, y += stripe_height) {
+      const GrainType* noise_stripe = (*noise_stripes)[luma_num];
+      const GrainType* noise_stripe_prev = (*noise_stripes)[luma_num - 1];
+      // First overlap row.
+      WriteOverlapLine_C<bitdepth>(noise_stripe,
+                                   &noise_stripe_prev[32 * plane_width],
+                                   plane_width, 17, 27, (*noise_image)[y]);
+      // Second overlap row.
+      WriteOverlapLine_C<bitdepth>(&noise_stripe[plane_width],
+                                   &noise_stripe_prev[(32 + 1) * plane_width],
+                                   plane_width, 27, 17, (*noise_image)[y + 1]);
+    }
+    // End complete stripes section.
+
+    const int remaining_height = plane_height - y;
+    // Either one partial stripe remains (remaining_height  > 0),
+    // OR image is less than one stripe high (remaining_height < 0),
+    // OR all stripes are completed (remaining_height == 0).
+    if (remaining_height <= 0) {
+      return;
+    }
+    const GrainType* noise_stripe = (*noise_stripes)[luma_num];
+    const GrainType* noise_stripe_prev = (*noise_stripes)[luma_num - 1];
+    WriteOverlapLine_C<bitdepth>(noise_stripe,
+                                 &noise_stripe_prev[32 * plane_width],
+                                 plane_width, 17, 27, (*noise_image)[y]);
+
+    // Check if second overlap row is in the image.
+    if (remaining_height > 1) {
+      WriteOverlapLine_C<bitdepth>(&noise_stripe[plane_width],
+                                   &noise_stripe_prev[(32 + 1) * plane_width],
+                                   plane_width, 27, 17, (*noise_image)[y + 1]);
+    }
+  } else {  // |subsampling_y| == 1
+    // No special checks needed for partial stripes, because if one exists, the
+    // first and only overlap row is guaranteed to exist.
+    for (; y < plane_height; ++luma_num, y += stripe_height) {
+      const GrainType* noise_stripe = (*noise_stripes)[luma_num];
+      const GrainType* noise_stripe_prev = (*noise_stripes)[luma_num - 1];
+      WriteOverlapLine_C<bitdepth>(noise_stripe,
+                                   &noise_stripe_prev[16 * plane_width],
+                                   plane_width, 22, 23, (*noise_image)[y]);
+    }
   }
 }
 
-template <int bitdepth>
-void FilmGrain<bitdepth>::BlendNoiseWithImage(
-    const void* source_plane_y, ptrdiff_t source_stride_y,
-    const void* source_plane_u, ptrdiff_t source_stride_u,
-    const void* source_plane_v, ptrdiff_t source_stride_v, void* dest_plane_y,
-    ptrdiff_t dest_stride_y, void* dest_plane_u, ptrdiff_t dest_stride_u,
-    void* dest_plane_v, ptrdiff_t dest_stride_v) const {
+template <int bitdepth, typename GrainType, typename Pixel>
+void BlendNoiseWithImageLuma_C(
+    const void* noise_image_ptr, int min_value, int max_luma, int scaling_shift,
+    int width, int height, int start_height,
+    const uint8_t scaling_lut_y[kScalingLookupTableSize],
+    const void* source_plane_y, ptrdiff_t source_stride_y, void* dest_plane_y,
+    ptrdiff_t dest_stride_y) {
+  const auto* noise_image =
+      static_cast<const Array2D<GrainType>*>(noise_image_ptr);
   const auto* in_y = static_cast<const Pixel*>(source_plane_y);
   source_stride_y /= sizeof(Pixel);
-  const auto* in_u = static_cast<const Pixel*>(source_plane_u);
-  source_stride_u /= sizeof(Pixel);
-  const auto* in_v = static_cast<const Pixel*>(source_plane_v);
-  source_stride_v /= sizeof(Pixel);
   auto* out_y = static_cast<Pixel*>(dest_plane_y);
   dest_stride_y /= sizeof(Pixel);
-  auto* out_u = static_cast<Pixel*>(dest_plane_u);
-  dest_stride_u /= sizeof(Pixel);
-  auto* out_v = static_cast<Pixel*>(dest_plane_v);
-  dest_stride_v /= sizeof(Pixel);
-  int min_value;
-  int max_luma;
-  int max_chroma;
-  if (params_.clip_to_restricted_range) {
-    min_value = 16 << (bitdepth - 8);
-    max_luma = 235 << (bitdepth - 8);
-    if (color_matrix_is_identity_) {
-      max_chroma = max_luma;
-    } else {
-      max_chroma = 240 << (bitdepth - 8);
-    }
-  } else {
-    min_value = 0;
-    max_luma = (256 << (bitdepth - 8)) - 1;
-    max_chroma = max_luma;
-  }
-  const int scaling_shift = params_.chroma_scaling;
+
   int y = 0;
   do {
     int x = 0;
     do {
-      const int luma_x = x << subsampling_x_;
-      const int luma_y = y << subsampling_y_;
-      const int luma_next_x = std::min(luma_x + 1, width_ - 1);
+      const int orig = in_y[y * source_stride_y + x];
+      int noise = noise_image[kPlaneY][y + start_height][x];
+      noise = RightShiftWithRounding(
+          ScaleLut<bitdepth>(scaling_lut_y, orig) * noise, scaling_shift);
+      out_y[y * dest_stride_y + x] = Clip3(orig + noise, min_value, max_luma);
+    } while (++x < width);
+  } while (++y < height);
+}
+
+// This function is for the case params_.chroma_scaling_from_luma == false.
+template <int bitdepth, typename GrainType, typename Pixel>
+void BlendNoiseWithImageChroma_C(
+    Plane plane, const FilmGrainParams& params, const void* noise_image_ptr,
+    int min_value, int max_chroma, int width, int height, int start_height,
+    int subsampling_x, int subsampling_y,
+    const uint8_t scaling_lut_uv[kScalingLookupTableSize],
+    const void* source_plane_y, ptrdiff_t source_stride_y,
+    const void* source_plane_uv, ptrdiff_t source_stride_uv,
+    void* dest_plane_uv, ptrdiff_t dest_stride_uv) {
+  const auto* noise_image =
+      static_cast<const Array2D<GrainType>*>(noise_image_ptr);
+
+  const int chroma_width = (width + subsampling_x) >> subsampling_x;
+  const int chroma_height = (height + subsampling_y) >> subsampling_y;
+
+  const auto* in_y = static_cast<const Pixel*>(source_plane_y);
+  source_stride_y /= sizeof(Pixel);
+  const auto* in_uv = static_cast<const Pixel*>(source_plane_uv);
+  source_stride_uv /= sizeof(Pixel);
+  auto* out_uv = static_cast<Pixel*>(dest_plane_uv);
+  dest_stride_uv /= sizeof(Pixel);
+
+  const int offset = (plane == kPlaneU) ? params.u_offset : params.v_offset;
+  const int luma_multiplier =
+      (plane == kPlaneU) ? params.u_luma_multiplier : params.v_luma_multiplier;
+  const int multiplier =
+      (plane == kPlaneU) ? params.u_multiplier : params.v_multiplier;
+
+  const int scaling_shift = params.chroma_scaling;
+  start_height >>= subsampling_y;
+  int y = 0;
+  do {
+    int x = 0;
+    do {
+      const int luma_x = x << subsampling_x;
+      const int luma_y = y << subsampling_y;
+      const int luma_next_x = std::min(luma_x + 1, width - 1);
       int average_luma;
-      if (subsampling_x_ != 0) {
+      if (subsampling_x != 0) {
         average_luma = RightShiftWithRounding(
             in_y[luma_y * source_stride_y + luma_x] +
                 in_y[luma_y * source_stride_y + luma_next_x],
@@ -806,111 +568,303 @@
       } else {
         average_luma = in_y[luma_y * source_stride_y + luma_x];
       }
-      if (params_.num_u_points > 0 || params_.chroma_scaling_from_luma) {
-        const int orig = in_u[y * source_stride_u + x];
-        int merged;
-        if (params_.chroma_scaling_from_luma) {
-          merged = average_luma;
-        } else {
-          const int combined =
-              average_luma * (params_.u_luma_multiplier - 128) +
-              orig * (params_.u_multiplier - 128);
-          merged = Clip3(
-              (combined >> 6) + LeftShift(params_.u_offset - 256, bitdepth - 8),
-              0, (1 << bitdepth) - 1);
-        }
-        int noise = noise_image_[kPlaneU][y][x];
-        noise = RightShiftWithRounding(
-            ScaleLut<bitdepth>(scaling_lut_u_, merged) * noise, scaling_shift);
-        out_u[y * dest_stride_u + x] =
-            Clip3(orig + noise, min_value, max_chroma);
-      } else {
-        out_u[y * dest_stride_u + x] = in_u[y * source_stride_u + x];
-      }
-      if (params_.num_v_points > 0 || params_.chroma_scaling_from_luma) {
-        const int orig = in_v[y * source_stride_v + x];
-        int merged;
-        if (params_.chroma_scaling_from_luma) {
-          merged = average_luma;
-        } else {
-          const int combined =
-              average_luma * (params_.v_luma_multiplier - 128) +
-              orig * (params_.v_multiplier - 128);
-          merged = Clip3(
-              (combined >> 6) + LeftShift(params_.v_offset - 256, bitdepth - 8),
-              0, (1 << bitdepth) - 1);
-        }
-        int noise = noise_image_[kPlaneV][y][x];
-        noise = RightShiftWithRounding(
-            ScaleLut<bitdepth>(scaling_lut_v_, merged) * noise, scaling_shift);
-        out_v[y * dest_stride_v + x] =
-            Clip3(orig + noise, min_value, max_chroma);
-      } else {
-        out_v[y * dest_stride_v + x] = in_v[y * source_stride_v + x];
-      }
-    } while (++x < ((width_ + subsampling_x_) >> subsampling_x_));
-  } while (++y < ((height_ + subsampling_y_) >> subsampling_y_));
-  if (params_.num_y_points > 0) {
-    int y = 0;
-    do {
-      int x = 0;
-      do {
-        const int orig = in_y[y * source_stride_y + x];
-        int noise = noise_image_[kPlaneY][y][x];
-        noise = RightShiftWithRounding(
-            ScaleLut<bitdepth>(scaling_lut_y_, orig) * noise, scaling_shift);
-        out_y[y * dest_stride_y + x] = Clip3(orig + noise, min_value, max_luma);
-      } while (++x < width_);
-    } while (++y < height_);
-  } else if (in_y != out_y) {  // If in_y and out_y point to the same buffer,
-                               // then do nothing.
-    const Pixel* in_y_row = in_y;
-    Pixel* out_y_row = out_y;
-    int y = 0;
-    do {
-      memcpy(out_y_row, in_y_row, width_ * sizeof(*out_y_row));
-      in_y_row += source_stride_y;
-      out_y_row += dest_stride_y;
-    } while (++y < height_);
-  }
+      const int orig = in_uv[y * source_stride_uv + x];
+      const int combined = average_luma * luma_multiplier + orig * multiplier;
+      const int merged =
+          Clip3((combined >> 6) + LeftShift(offset, bitdepth - 8), 0,
+                (1 << bitdepth) - 1);
+      int noise = noise_image[plane][y + start_height][x];
+      noise = RightShiftWithRounding(
+          ScaleLut<bitdepth>(scaling_lut_uv, merged) * noise, scaling_shift);
+      out_uv[y * dest_stride_uv + x] =
+          Clip3(orig + noise, min_value, max_chroma);
+    } while (++x < chroma_width);
+  } while (++y < chroma_height);
 }
 
-template <int bitdepth>
-bool FilmGrain<bitdepth>::AddNoise(
+// This function is for the case params_.chroma_scaling_from_luma == true.
+// This further implies that scaling_lut_u == scaling_lut_v == scaling_lut_y.
+template <int bitdepth, typename GrainType, typename Pixel>
+void BlendNoiseWithImageChromaWithCfl_C(
+    Plane plane, const FilmGrainParams& params, const void* noise_image_ptr,
+    int min_value, int max_chroma, int width, int height, int start_height,
+    int subsampling_x, int subsampling_y,
+    const uint8_t scaling_lut[kScalingLookupTableSize],
     const void* source_plane_y, ptrdiff_t source_stride_y,
-    const void* source_plane_u, ptrdiff_t source_stride_u,
-    const void* source_plane_v, ptrdiff_t source_stride_v, void* dest_plane_y,
-    ptrdiff_t dest_stride_y, void* dest_plane_u, ptrdiff_t dest_stride_u,
-    void* dest_plane_v, ptrdiff_t dest_stride_v) {
-  if (!Init()) {
-    LIBGAV1_DLOG(ERROR, "Init() failed.");
-    return false;
-  }
-  if (!AllocateNoiseStripes()) {
-    LIBGAV1_DLOG(ERROR, "AllocateNoiseStripes() failed.");
-    return false;
-  }
-  ConstructNoiseStripes();
+    const void* source_plane_uv, ptrdiff_t source_stride_uv,
+    void* dest_plane_uv, ptrdiff_t dest_stride_uv) {
+  const auto* noise_image =
+      static_cast<const Array2D<GrainType>*>(noise_image_ptr);
+  const auto* in_y = static_cast<const Pixel*>(source_plane_y);
+  source_stride_y /= sizeof(Pixel);
+  const auto* in_uv = static_cast<const Pixel*>(source_plane_uv);
+  source_stride_uv /= sizeof(Pixel);
+  auto* out_uv = static_cast<Pixel*>(dest_plane_uv);
+  dest_stride_uv /= sizeof(Pixel);
 
-  if (!AllocateNoiseImage()) {
-    LIBGAV1_DLOG(ERROR, "AllocateNoiseImage() failed.");
-    return false;
-  }
-  ConstructNoiseImage();
-
-  BlendNoiseWithImage(source_plane_y, source_stride_y, source_plane_u,
-                      source_stride_u, source_plane_v, source_stride_v,
-                      dest_plane_y, dest_stride_y, dest_plane_u, dest_stride_u,
-                      dest_plane_v, dest_stride_v);
-
-  return true;
+  const int chroma_width = (width + subsampling_x) >> subsampling_x;
+  const int chroma_height = (height + subsampling_y) >> subsampling_y;
+  const int scaling_shift = params.chroma_scaling;
+  start_height >>= subsampling_y;
+  int y = 0;
+  do {
+    int x = 0;
+    do {
+      const int luma_x = x << subsampling_x;
+      const int luma_y = y << subsampling_y;
+      const int luma_next_x = std::min(luma_x + 1, width - 1);
+      int average_luma;
+      if (subsampling_x != 0) {
+        average_luma = RightShiftWithRounding(
+            in_y[luma_y * source_stride_y + luma_x] +
+                in_y[luma_y * source_stride_y + luma_next_x],
+            1);
+      } else {
+        average_luma = in_y[luma_y * source_stride_y + luma_x];
+      }
+      const int orig_uv = in_uv[y * source_stride_uv + x];
+      int noise_uv = noise_image[plane][y + start_height][x];
+      noise_uv = RightShiftWithRounding(
+          ScaleLut<bitdepth>(scaling_lut, average_luma) * noise_uv,
+          scaling_shift);
+      out_uv[y * dest_stride_uv + x] =
+          Clip3(orig_uv + noise_uv, min_value, max_chroma);
+    } while (++x < chroma_width);
+  } while (++y < chroma_height);
 }
 
-// Explicit instantiations.
-template class FilmGrain<8>;
-#if LIBGAV1_MAX_BITDEPTH >= 10
-template class FilmGrain<10>;
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  // LumaAutoRegressionFunc
+  dsp->film_grain.luma_auto_regression[0] =
+      ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>;
+  dsp->film_grain.luma_auto_regression[1] =
+      ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>;
+  dsp->film_grain.luma_auto_regression[2] =
+      ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>;
+
+  // ChromaAutoRegressionFunc
+  // Chroma autoregression should never be called when lag is 0 and use_luma is
+  // false.
+  dsp->film_grain.chroma_auto_regression[0][0] = nullptr;
+  dsp->film_grain.chroma_auto_regression[0][1] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 1, false>;
+  dsp->film_grain.chroma_auto_regression[0][2] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 2, false>;
+  dsp->film_grain.chroma_auto_regression[0][3] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 3, false>;
+  dsp->film_grain.chroma_auto_regression[1][0] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 0, true>;
+  dsp->film_grain.chroma_auto_regression[1][1] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 1, true>;
+  dsp->film_grain.chroma_auto_regression[1][2] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 2, true>;
+  dsp->film_grain.chroma_auto_regression[1][3] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 3, true>;
+
+  // ConstructNoiseStripesFunc
+  dsp->film_grain.construct_noise_stripes[0] =
+      ConstructNoiseStripes_C<8, int8_t>;
+  dsp->film_grain.construct_noise_stripes[1] =
+      ConstructNoiseStripesWithOverlap_C<8, int8_t>;
+
+  // ConstructNoiseImageOverlapFunc
+  dsp->film_grain.construct_noise_image_overlap =
+      ConstructNoiseImageOverlap_C<8, int8_t>;
+
+  // InitializeScalingLutFunc
+  dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_C<0>;
+
+  // BlendNoiseWithImageLumaFunc
+  dsp->film_grain.blend_noise_luma =
+      BlendNoiseWithImageLuma_C<8, int8_t, uint8_t>;
+
+  // BlendNoiseWithImageChromaFunc
+  dsp->film_grain.blend_noise_chroma[0] =
+      BlendNoiseWithImageChroma_C<8, int8_t, uint8_t>;
+  dsp->film_grain.blend_noise_chroma[1] =
+      BlendNoiseWithImageChromaWithCfl_C<8, int8_t, uint8_t>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp8bpp_FilmGrainAutoregressionLuma
+  dsp->film_grain.luma_auto_regression[0] =
+      ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>;
+  dsp->film_grain.luma_auto_regression[1] =
+      ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>;
+  dsp->film_grain.luma_auto_regression[2] =
+      ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>;
 #endif
+#ifndef LIBGAV1_Dsp8bpp_FilmGrainAutoregressionChroma
+  // Chroma autoregression should never be called when lag is 0 and use_luma is
+  // false.
+  dsp->film_grain.chroma_auto_regression[0][0] = nullptr;
+  dsp->film_grain.chroma_auto_regression[0][1] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 1, false>;
+  dsp->film_grain.chroma_auto_regression[0][2] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 2, false>;
+  dsp->film_grain.chroma_auto_regression[0][3] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 3, false>;
+  dsp->film_grain.chroma_auto_regression[1][0] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 0, true>;
+  dsp->film_grain.chroma_auto_regression[1][1] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 1, true>;
+  dsp->film_grain.chroma_auto_regression[1][2] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 2, true>;
+  dsp->film_grain.chroma_auto_regression[1][3] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 3, true>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_FilmGrainConstructNoiseStripes
+  dsp->film_grain.construct_noise_stripes[0] =
+      ConstructNoiseStripes_C<8, int8_t>;
+  dsp->film_grain.construct_noise_stripes[1] =
+      ConstructNoiseStripesWithOverlap_C<8, int8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_FilmGrainConstructNoiseImageOverlap
+  dsp->film_grain.construct_noise_image_overlap =
+      ConstructNoiseImageOverlap_C<8, int8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_FilmGrainInitializeScalingLutFunc
+  dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_C<0>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseLuma
+  dsp->film_grain.blend_noise_luma =
+      BlendNoiseWithImageLuma_C<8, int8_t, uint8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChroma
+  dsp->film_grain.blend_noise_chroma[0] =
+      BlendNoiseWithImageChroma_C<8, int8_t, uint8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChromaWithCfl
+  dsp->film_grain.blend_noise_chroma[1] =
+      BlendNoiseWithImageChromaWithCfl_C<8, int8_t, uint8_t>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
+  assert(dsp != nullptr);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+
+  // LumaAutoRegressionFunc
+  dsp->film_grain.luma_auto_regression[0] =
+      ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>;
+  dsp->film_grain.luma_auto_regression[1] =
+      ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>;
+  dsp->film_grain.luma_auto_regression[2] =
+      ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>;
+
+  // ChromaAutoRegressionFunc
+  // Chroma autoregression should never be called when lag is 0 and use_luma is
+  // false.
+  dsp->film_grain.chroma_auto_regression[0][0] = nullptr;
+  dsp->film_grain.chroma_auto_regression[0][1] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 1, false>;
+  dsp->film_grain.chroma_auto_regression[0][2] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 2, false>;
+  dsp->film_grain.chroma_auto_regression[0][3] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 3, false>;
+  dsp->film_grain.chroma_auto_regression[1][0] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 0, true>;
+  dsp->film_grain.chroma_auto_regression[1][1] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 1, true>;
+  dsp->film_grain.chroma_auto_regression[1][2] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 2, true>;
+  dsp->film_grain.chroma_auto_regression[1][3] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 3, true>;
+
+  // ConstructNoiseStripesFunc
+  dsp->film_grain.construct_noise_stripes[0] =
+      ConstructNoiseStripes_C<10, int16_t>;
+  dsp->film_grain.construct_noise_stripes[1] =
+      ConstructNoiseStripesWithOverlap_C<10, int16_t>;
+
+  // ConstructNoiseImageOverlapFunc
+  dsp->film_grain.construct_noise_image_overlap =
+      ConstructNoiseImageOverlap_C<10, int16_t>;
+
+  // InitializeScalingLutFunc
+  dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_C<0>;
+
+  // BlendNoiseWithImageLumaFunc
+  dsp->film_grain.blend_noise_luma =
+      BlendNoiseWithImageLuma_C<10, int16_t, uint16_t>;
+
+  // BlendNoiseWithImageChromaFunc
+  dsp->film_grain.blend_noise_chroma[0] =
+      BlendNoiseWithImageChroma_C<10, int16_t, uint16_t>;
+  dsp->film_grain.blend_noise_chroma[1] =
+      BlendNoiseWithImageChromaWithCfl_C<10, int16_t, uint16_t>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp10bpp_FilmGrainAutoregressionLuma
+  dsp->film_grain.luma_auto_regression[0] =
+      ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>;
+  dsp->film_grain.luma_auto_regression[1] =
+      ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>;
+  dsp->film_grain.luma_auto_regression[2] =
+      ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_FilmGrainAutoregressionChroma
+  // Chroma autoregression should never be called when lag is 0 and use_luma is
+  // false.
+  dsp->film_grain.chroma_auto_regression[0][0] = nullptr;
+  dsp->film_grain.chroma_auto_regression[0][1] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 1, false>;
+  dsp->film_grain.chroma_auto_regression[0][2] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 2, false>;
+  dsp->film_grain.chroma_auto_regression[0][3] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 3, false>;
+  dsp->film_grain.chroma_auto_regression[1][0] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 0, true>;
+  dsp->film_grain.chroma_auto_regression[1][1] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 1, true>;
+  dsp->film_grain.chroma_auto_regression[1][2] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 2, true>;
+  dsp->film_grain.chroma_auto_regression[1][3] =
+      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 3, true>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_FilmGrainConstructNoiseStripes
+  dsp->film_grain.construct_noise_stripes[0] =
+      ConstructNoiseStripes_C<10, int16_t>;
+  dsp->film_grain.construct_noise_stripes[1] =
+      ConstructNoiseStripesWithOverlap_C<10, int16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_FilmGrainConstructNoiseImageOverlap
+  dsp->film_grain.construct_noise_image_overlap =
+      ConstructNoiseImageOverlap_C<10, int16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_FilmGrainInitializeScalingLutFunc
+  dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_C<0>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseLuma
+  dsp->film_grain.blend_noise_luma =
+      BlendNoiseWithImageLuma_C<10, int16_t, uint16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseChroma
+  dsp->film_grain.blend_noise_chroma[0] =
+      BlendNoiseWithImageChroma_C<10, int16_t, uint16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseChromaWithCfl
+  dsp->film_grain.blend_noise_chroma[1] =
+      BlendNoiseWithImageChromaWithCfl_C<10, int16_t, uint16_t>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+}
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+}  // namespace
+}  // namespace film_grain
+
+void FilmGrainInit_C() {
+  film_grain::Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  film_grain::Init10bpp();
+#endif
+}
 
 }  // namespace dsp
 }  // namespace libgav1
diff --git a/libgav1/src/dsp/film_grain.h b/libgav1/src/dsp/film_grain.h
index b8b28c4..fe93270 100644
--- a/libgav1/src/dsp/film_grain.h
+++ b/libgav1/src/dsp/film_grain.h
@@ -17,14 +17,15 @@
 #ifndef LIBGAV1_SRC_DSP_FILM_GRAIN_H_
 #define LIBGAV1_SRC_DSP_FILM_GRAIN_H_
 
-#include <cstddef>
-#include <cstdint>
-#include <memory>
-#include <type_traits>
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
 
-#include "src/dsp/common.h"
-#include "src/utils/array_2d.h"
-#include "src/utils/constants.h"
+// ARM:
+#include "src/dsp/arm/film_grain_neon.h"
+
+// IWYU pragma: end_exports
 
 namespace libgav1 {
 namespace dsp {
@@ -32,148 +33,6 @@
 // Initialize Dsp::film_grain_synthesis. This function is not thread-safe.
 void FilmGrainInit_C();
 
-// Section 7.18.3.5. Add noise synthesis process.
-template <int bitdepth>
-class FilmGrain {
- public:
-  // bitdepth  grain_min_  grain_max_
-  // --------------------------------
-  //     8        -128         127
-  //    10        -512         511
-  //    12       -2048        2047
-  //
-  // So int8_t is big enough for bitdepth 8, whereas bitdepths 10 and 12 need
-  // int16_t.
-  using GrainType =
-      typename std::conditional<bitdepth == 8, int8_t, int16_t>::type;
-
-  FilmGrain(const FilmGrainParams& params, bool is_monochrome,
-            bool color_matrix_is_identity, int subsampling_x, int subsampling_y,
-            int width, int height);
-
-  // Note: These static methods are declared public so that the unit tests can
-  // call them.
-
-  static int GetRandomNumber(int bits, uint16_t* seed);
-
-  static void GenerateLumaGrain(const FilmGrainParams& params,
-                                GrainType* luma_grain);
-
-  // Applies an auto-regressive filter to the white noise in luma_grain.
-  static void ApplyAutoRegressiveFilterToLumaGrain(
-      const FilmGrainParams& params, int grain_min, int grain_max,
-      GrainType* luma_grain);
-
-  // Generates white noise arrays u_grain and v_grain chroma_width samples wide
-  // and chroma_height samples high.
-  static void GenerateChromaGrains(const FilmGrainParams& params,
-                                   int chroma_width, int chroma_height,
-                                   GrainType* u_grain, GrainType* v_grain);
-
-  static void ApplyAutoRegressiveFilterToChromaGrains(
-      const FilmGrainParams& params, int grain_min, int grain_max,
-      const GrainType* luma_grain, int subsampling_x, int subsampling_y,
-      int chroma_width, int chroma_height, GrainType* u_grain,
-      GrainType* v_grain);
-
-  static void InitializeScalingLookupTable(int num_points,
-                                           const uint8_t point_value[],
-                                           const uint8_t point_scaling[],
-                                           uint8_t scaling_lut[256]);
-
-  // Combines the film grain with the image data.
-  bool AddNoise(const void* source_plane_y, ptrdiff_t source_stride_y,
-                const void* source_plane_u, ptrdiff_t source_stride_u,
-                const void* source_plane_v, ptrdiff_t source_stride_v,
-                void* dest_plane_y, ptrdiff_t dest_stride_y, void* dest_plane_u,
-                ptrdiff_t dest_stride_u, void* dest_plane_v,
-                ptrdiff_t dest_stride_v);
-
- private:
-  using Pixel =
-      typename std::conditional<bitdepth == 8, uint8_t, uint16_t>::type;
-
-  bool Init();
-
-  // Allocates noise_stripe_, which points to memory owned by noise_buffer_.
-  bool AllocateNoiseStripes();
-
-  void ConstructNoiseStripes();
-
-  bool AllocateNoiseImage();
-
-  // Blends the noise stripes together to form a noise image.
-  void ConstructNoiseImage();
-
-  // Blends the noise with the original image data.
-  void BlendNoiseWithImage(
-      const void* source_plane_y, ptrdiff_t source_stride_y,
-      const void* source_plane_u, ptrdiff_t source_stride_u,
-      const void* source_plane_v, ptrdiff_t source_stride_v, void* dest_plane_y,
-      ptrdiff_t dest_stride_y, void* dest_plane_u, ptrdiff_t dest_stride_u,
-      void* dest_plane_v, ptrdiff_t dest_stride_v) const;
-
-  // The width of the luma noise array.
-  static constexpr int kLumaWidth = 82;
-  // The height of the luma noise array.
-  static constexpr int kLumaHeight = 73;
-  // The two possible widths of the chroma noise array
-  static constexpr int kMinChromaWidth = 44;
-  static constexpr int kMaxChromaWidth = 82;
-  // The two possible heights of the chroma noise array.
-  static constexpr int kMinChromaHeight = 38;
-  static constexpr int kMaxChromaHeight = 73;
-
-  const FilmGrainParams& params_;
-  const bool is_monochrome_;
-  const bool color_matrix_is_identity_;
-  const int subsampling_x_;
-  const int subsampling_y_;
-  const int width_;
-  const int height_;
-  int grain_min_;
-  int grain_max_;
-  const int chroma_width_;
-  const int chroma_height_;
-  // The luma_grain array contains white noise generated for luma.
-  // The array size is fixed but subject to further optimization for SIMD.
-  GrainType luma_grain_[kLumaHeight * kLumaWidth];
-  // The maximum size of the u_grain and v_grain arrays is
-  // kMaxChromaHeight * kMaxChromaWidth. The actual size is
-  // chroma_height_ * chroma_width_.
-  GrainType u_grain_[kMaxChromaHeight * kMaxChromaWidth];
-  GrainType v_grain_[kMaxChromaHeight * kMaxChromaWidth];
-  // Scaling lookup tables.
-  uint8_t scaling_lut_y_[256];
-  uint8_t* scaling_lut_u_ = nullptr;
-  uint8_t* scaling_lut_v_ = nullptr;
-  // If allocated, this buffer is 256 * 2 bytes long and scaling_lut_u_ and
-  // scaling_lut_v_ point into this buffer. Otherwise, scaling_lut_u_ and
-  // scaling_lut_v_ point to scaling_lut_y_.
-  std::unique_ptr<uint8_t[]> scaling_lut_chroma_buffer_;
-
-  // A two-dimensional array of noise data. Generated for each 32 luma sample
-  // high stripe of the image. The first dimension is called luma_num. The
-  // second dimension is the plane.
-  //
-  // Each element of the noise_stripe_ array points to a conceptually
-  // two-dimensional array of int's. The two-dimensional array of int's is
-  // flattened into a one-dimensional buffer in this implementation.
-  //
-  // noise_stripe_[luma_num][kPlaneY] points to an array that has 34 rows and
-  // |width_| columns and contains noise for the luma component.
-  //
-  // noise_stripe_[luma_num][kPlaneU] or noise_stripe_[luma_num][kPlaneV]
-  // points to an array that has (34 >> subsampling_y_) rows and
-  // RightShiftWithRounding(width_, subsampling_x_) columns and contains noise
-  // for the chroma components.
-  Array2D<GrainType*> noise_stripe_;
-  // Owns the memory pointed to by the elements of noise_stripe_.
-  std::unique_ptr<GrainType[]> noise_buffer_;
-
-  Array2D<GrainType> noise_image_[kMaxPlanes];
-};
-
 }  // namespace dsp
 }  // namespace libgav1
 
diff --git a/libgav1/src/dsp/film_grain_common.h b/libgav1/src/dsp/film_grain_common.h
new file mode 100644
index 0000000..64e3e8e
--- /dev/null
+++ b/libgav1/src/dsp/film_grain_common.h
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_FILM_GRAIN_COMMON_H_
+#define LIBGAV1_SRC_DSP_FILM_GRAIN_COMMON_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <type_traits>
+
+#include "src/dsp/common.h"
+#include "src/utils/array_2d.h"
+#include "src/utils/constants.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+
+template <int bitdepth>
+int GetGrainMax() {
+  return (1 << (bitdepth - 1)) - 1;
+}
+
+template <int bitdepth>
+int GetGrainMin() {
+  return -(1 << (bitdepth - 1));
+}
+
+inline int GetFilmGrainRandomNumber(int bits, uint16_t* seed) {
+  uint16_t s = *seed;
+  uint16_t bit = (s ^ (s >> 1) ^ (s >> 3) ^ (s >> 12)) & 1;
+  s = (s >> 1) | (bit << 15);
+  *seed = s;
+  return s >> (16 - bits);
+}
+
+enum {
+  kAutoRegressionBorder = 3,
+  // The width of the luma noise array.
+  kLumaWidth = 82,
+  // The height of the luma noise array.
+  kLumaHeight = 73,
+  // The two possible widths of the chroma noise array.
+  kMinChromaWidth = 44,
+  kMaxChromaWidth = 82,
+  // The two possible heights of the chroma noise array.
+  kMinChromaHeight = 38,
+  kMaxChromaHeight = 73,
+  // The scaling lookup table maps bytes to bytes, so only uses 256 elements,
+  // plus one for overflow in 10bit lookups.
+  kScalingLookupTableSize = 257,
+  // Padding is added to the scaling lookup table to permit overwrites by
+  // InitializeScalingLookupTable_NEON.
+  kScalingLookupTablePadding = 6,
+  // Padding is added to each row of the noise image to permit overreads by
+  // BlendNoiseWithImageLuma_NEON and overwrites by WriteOverlapLine8bpp_NEON.
+  kNoiseImagePadding = 7,
+  // Padding is added to the end of the |noise_stripes_| buffer to permit
+  // overreads by WriteOverlapLine8bpp_NEON.
+  kNoiseStripePadding = 7,
+};  // anonymous enum
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_DSP_FILM_GRAIN_COMMON_H_
diff --git a/libgav1/src/dsp/intrapred.cc b/libgav1/src/dsp/intrapred.cc
index 8eeb037..4bcb580 100644
--- a/libgav1/src/dsp/intrapred.cc
+++ b/libgav1/src/dsp/intrapred.cc
@@ -757,7 +757,7 @@
   dsp->cfl_subsamplers[kTransformSize##W##x##H][kSubsamplingType422] = \
       CflSubsampler_C<W, H, BITDEPTH, PIXEL, 1, 0>;                    \
   dsp->cfl_subsamplers[kTransformSize##W##x##H][kSubsamplingType420] = \
-      CflSubsampler_C<W, H, BITDEPTH, PIXEL, 1, 1>;
+      CflSubsampler_C<W, H, BITDEPTH, PIXEL, 1, 1>
 
 #define INIT_CFL_INTRAPREDICTORS(BITDEPTH, PIXEL)       \
   INIT_CFL_INTRAPREDICTOR_WxH(4, 4, BITDEPTH, PIXEL);   \
diff --git a/libgav1/src/dsp/inverse_transform.cc b/libgav1/src/dsp/inverse_transform.cc
index 5f8548f..1c5a4a6 100644
--- a/libgav1/src/dsp/inverse_transform.cc
+++ b/libgav1/src/dsp/inverse_transform.cc
@@ -34,6 +34,10 @@
 
 constexpr uint8_t kTransformColumnShift = 4;
 
+#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION)
+#undef LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK
+#endif
+
 int32_t RangeCheckValue(int32_t value, int8_t range) {
 #if defined(LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK) && \
     LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK
@@ -69,6 +73,37 @@
 }
 
 template <typename Residual>
+void ButterflyRotationFirstIsZero_C(Residual* const dst, int a, int b,
+                                    int angle, bool flip, int8_t range) {
+  // Note that we multiply in 32 bits and then add/subtract the products in 64
+  // bits. The 32-bit multiplications do not overflow. Please see the comment
+  // and assert() in Cos128().
+  const auto x = static_cast<int64_t>(dst[b] * -Sin128(angle));
+  const auto y = static_cast<int64_t>(dst[b] * Cos128(angle));
+  // Section 7.13.2.1: It is a requirement of bitstream conformance that the
+  // values saved into the array T by this function are representable by a
+  // signed integer using |range| bits of precision.
+  dst[a] = RangeCheckValue(RightShiftWithRounding(flip ? y : x, 12), range);
+  dst[b] = RangeCheckValue(RightShiftWithRounding(flip ? x : y, 12), range);
+}
+
+template <typename Residual>
+void ButterflyRotationSecondIsZero_C(Residual* const dst, int a, int b,
+                                     int angle, bool flip, int8_t range) {
+  // Note that we multiply in 32 bits and then add/subtract the products in 64
+  // bits. The 32-bit multiplications do not overflow. Please see the comment
+  // and assert() in Cos128().
+  const auto x = static_cast<int64_t>(dst[a] * Cos128(angle));
+  const auto y = static_cast<int64_t>(dst[a] * Sin128(angle));
+
+  // Section 7.13.2.1: It is a requirement of bitstream conformance that the
+  // values saved into the array T by this function are representable by a
+  // signed integer using |range| bits of precision.
+  dst[a] = RangeCheckValue(RightShiftWithRounding(flip ? y : x, 12), range);
+  dst[b] = RangeCheckValue(RightShiftWithRounding(flip ? x : y, 12), range);
+}
+
+template <typename Residual>
 void HadamardRotation_C(Residual* const dst, int a, int b, bool flip,
                         int8_t range) {
   if (flip) std::swap(a, b);
@@ -83,6 +118,20 @@
   dst[b] = Clip3(y, min, max);
 }
 
+template <int bitdepth, typename Residual>
+void ClampIntermediate(Residual* const dst, int size) {
+  // If Residual is int16_t (which implies bitdepth is 8), we don't need to
+  // clip residual[i][j] to 16 bits.
+  if (sizeof(Residual) > 2) {
+    const Residual intermediate_clamp_max =
+        (1 << (std::max(bitdepth + 6, 16) - 1)) - 1;
+    const Residual intermediate_clamp_min = -intermediate_clamp_max - 1;
+    for (int j = 0; j < size; ++j) {
+      dst[j] = Clip3(dst[j], intermediate_clamp_min, intermediate_clamp_max);
+    }
+  }
+}
+
 //------------------------------------------------------------------------------
 // Discrete Cosine Transforms (DCT).
 
@@ -91,7 +140,7 @@
 // For e.g. index (2, 3) will be computed as follows:
 //   * bitreverse(3) = bitreverse(..000011) = 110000...
 //   * interpreting that as an integer with bit-length 2+2 = 4 will be 1100 = 12
-const uint8_t kBitReverseLookup[kNum1DTransformSizes][64] = {
+constexpr uint8_t kBitReverseLookup[kNum1DTransformSizes][64] = {
     {0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2,
      1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3,
      0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3},
@@ -346,6 +395,31 @@
   }
 }
 
+template <int bitdepth, typename Residual, int size_log2>
+void DctDcOnly_C(void* dest, const void* source, int8_t range,
+                 bool should_round, int row_shift, bool is_row) {
+  auto* const dst = static_cast<Residual*>(dest);
+  const auto* const src = static_cast<const Residual*>(source);
+
+  dst[0] = src[0];
+  if (is_row && should_round) {
+    dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12);
+  }
+
+  ButterflyRotationSecondIsZero_C(dst, 0, 1, 32, true, range);
+
+  if (is_row && row_shift > 0) {
+    dst[0] = RightShiftWithRounding(dst[0], row_shift);
+  }
+
+  ClampIntermediate<bitdepth, Residual>(dst, 1);
+
+  const int size = 1 << size_log2;
+  for (int i = 1; i < size; ++i) {
+    dst[i] = dst[0];
+  }
+}
+
 //------------------------------------------------------------------------------
 // Asymmetric Discrete Sine Transforms (ADST).
 
@@ -415,6 +489,57 @@
   dst[3] = dst_3;
 }
 
+template <int bitdepth, typename Residual>
+void Adst4DcOnly_C(void* dest, const void* source, int8_t range,
+                   bool should_round, int row_shift, bool is_row) {
+  auto* const dst = static_cast<Residual*>(dest);
+  const auto* const src = static_cast<const Residual*>(source);
+
+  dst[0] = src[0];
+  if (is_row && should_round) {
+    dst[0] = RightShiftWithRounding(src[0] * kTransformRowMultiplier, 12);
+  }
+
+  // stage 1.
+  // Section 7.13.2.6: It is a requirement of bitstream conformance that all
+  // values stored in the s and x arrays by this process are representable by
+  // a signed integer using range + 12 bits of precision.
+  int32_t s[3];
+  s[0] = RangeCheckValue(kAdst4Multiplier[0] * dst[0], range + 12);
+  s[1] = RangeCheckValue(kAdst4Multiplier[1] * dst[0], range + 12);
+  s[2] = RangeCheckValue(kAdst4Multiplier[2] * dst[0], range + 12);
+  // stage 3.
+  // stage 4.
+  // stages 5 and 6.
+  int32_t dst_0 = RightShiftWithRounding(s[0], 12);
+  int32_t dst_1 = RightShiftWithRounding(s[1], 12);
+  int32_t dst_2 = RightShiftWithRounding(s[2], 12);
+  int32_t dst_3 =
+      RightShiftWithRounding(RangeCheckValue(s[0] + s[1], range + 12), 12);
+  if (sizeof(Residual) == 2) {
+    // If the first argument to RightShiftWithRounding(..., 12) is only
+    // slightly smaller than 2^27 - 1 (e.g., 0x7fffe4e), adding 2^11 to it
+    // in RightShiftWithRounding(..., 12) will cause the function to return
+    // 0x8000, which cannot be represented as an int16_t. Change it to 0x7fff.
+    dst_0 -= (dst_0 == 0x8000);
+    dst_1 -= (dst_1 == 0x8000);
+    dst_3 -= (dst_3 == 0x8000);
+  }
+  dst[0] = dst_0;
+  dst[1] = dst_1;
+  dst[2] = dst_2;
+  dst[3] = dst_3;
+
+  const int size = 4;
+  if (is_row && row_shift > 0) {
+    for (int j = 0; j < size; ++j) {
+      dst[j] = RightShiftWithRounding(dst[j], row_shift);
+    }
+  }
+
+  ClampIntermediate<bitdepth, Residual>(dst, 4);
+}
+
 template <typename Residual>
 void AdstInputPermutation(int32_t* const dst, const Residual* const src,
                           int n) {
@@ -480,6 +605,54 @@
   AdstOutputPermutation(dst, temp, 8);
 }
 
+template <int bitdepth, typename Residual>
+void Adst8DcOnly_C(void* dest, const void* source, int8_t range,
+                   bool should_round, int row_shift, bool is_row) {
+  auto* const dst = static_cast<Residual*>(dest);
+  const auto* const src = static_cast<const Residual*>(source);
+
+  // stage 1.
+  int32_t temp[8];
+  // After the permutation, the dc value is in temp[1]. The remaining are zero.
+  AdstInputPermutation(temp, src, 8);
+
+  if (is_row && should_round) {
+    temp[1] = RightShiftWithRounding(temp[1] * kTransformRowMultiplier, 12);
+  }
+
+  // stage 2.
+  ButterflyRotationFirstIsZero_C(temp, 0, 1, 60, true, range);
+
+  // stage 3.
+  temp[4] = temp[0];
+  temp[5] = temp[1];
+
+  // stage 4.
+  ButterflyRotation_C(temp, 4, 5, 48, true, range);
+
+  // stage 5.
+  temp[2] = temp[0];
+  temp[3] = temp[1];
+  temp[6] = temp[4];
+  temp[7] = temp[5];
+
+  // stage 6.
+  ButterflyRotation_C(temp, 2, 3, 32, true, range);
+  ButterflyRotation_C(temp, 6, 7, 32, true, range);
+
+  // stage 7.
+  AdstOutputPermutation(dst, temp, 8);
+
+  const int size = 8;
+  if (is_row && row_shift > 0) {
+    for (int j = 0; j < size; ++j) {
+      dst[j] = RightShiftWithRounding(dst[j], row_shift);
+    }
+  }
+
+  ClampIntermediate<bitdepth, Residual>(dst, 8);
+}
+
 template <typename Residual>
 void Adst16_C(void* dest, const void* source, int8_t range) {
   auto* const dst = static_cast<Residual*>(dest);
@@ -533,6 +706,71 @@
   AdstOutputPermutation(dst, temp, 16);
 }
 
+template <int bitdepth, typename Residual>
+void Adst16DcOnly_C(void* dest, const void* source, int8_t range,
+                    bool should_round, int row_shift, bool is_row) {
+  auto* const dst = static_cast<Residual*>(dest);
+  const auto* const src = static_cast<const Residual*>(source);
+
+  // stage 1.
+  int32_t temp[16];
+  // After the permutation, the dc value is in temp[1].  The remaining are zero.
+  AdstInputPermutation(temp, src, 16);
+
+  if (is_row && should_round) {
+    temp[1] = RightShiftWithRounding(temp[1] * kTransformRowMultiplier, 12);
+  }
+
+  // stage 2.
+  ButterflyRotationFirstIsZero_C(temp, 0, 1, 62, true, range);
+
+  // stage 3.
+  temp[8] = temp[0];
+  temp[9] = temp[1];
+
+  // stage 4.
+  ButterflyRotation_C(temp, 8, 9, 56, true, range);
+
+  // stage 5.
+  temp[4] = temp[0];
+  temp[5] = temp[1];
+  temp[12] = temp[8];
+  temp[13] = temp[9];
+
+  // stage 6.
+  ButterflyRotation_C(temp, 4, 5, 48, true, range);
+  ButterflyRotation_C(temp, 12, 13, 48, true, range);
+
+  // stage 7.
+  temp[2] = temp[0];
+  temp[3] = temp[1];
+  temp[10] = temp[8];
+  temp[11] = temp[9];
+
+  temp[6] = temp[4];
+  temp[7] = temp[5];
+  temp[14] = temp[12];
+  temp[15] = temp[13];
+
+  // stage 8.
+  for (int i = 0; i < 4; ++i) {
+    ButterflyRotation_C(temp, MultiplyBy4(i) + 2, MultiplyBy4(i) + 3, 32, true,
+                        range);
+  }
+
+  // stage 9.
+  AdstOutputPermutation(dst, temp, 16);
+
+  const int size = 16;
+  if (is_row && row_shift > 0) {
+    for (int j = 0; j < size; ++j) {
+      dst[j] = RightShiftWithRounding(dst[j], row_shift);
+    }
+  }
+
+  ClampIntermediate<bitdepth, Residual>(dst, 16);
+}
+
 //------------------------------------------------------------------------------
 // Identity Transforms.
 //
@@ -648,6 +886,35 @@
   }
 }
 
+template <int bitdepth, typename Residual>
+void Identity4DcOnly_C(void* dest, const void* source, int8_t /*range*/,
+                       bool should_round, int row_shift, bool is_row) {
+  auto* const dst = static_cast<Residual*>(dest);
+  const auto* const src = static_cast<const Residual*>(source);
+
+  if (is_row) {
+    dst[0] = src[0];
+    if (should_round) {
+      dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12);
+    }
+
+    const int32_t rounding = (1 + (row_shift << 1)) << 11;
+    int32_t dst_i =
+        (dst[0] * kIdentity4Multiplier + rounding) >> (12 + row_shift);
+    if (sizeof(Residual) == 2) {
+      dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX);
+    }
+    dst[0] = static_cast<Residual>(dst_i);
+
+    ClampIntermediate<bitdepth, Residual>(dst, 1);
+    return;
+  }
+
+  const int32_t rounding = (1 + (1 << kTransformColumnShift)) << 11;
+  dst[0] = static_cast<Residual>((src[0] * kIdentity4Multiplier + rounding) >>
+                                 (12 + kTransformColumnShift));
+}
+
 template <typename Residual>
 void Identity8Row_C(void* dest, const void* source, int8_t shift) {
   assert(shift == 0 || shift == 1 || shift == 2);
@@ -672,6 +939,39 @@
   }
 }
 
+template <int bitdepth, typename Residual>
+void Identity8DcOnly_C(void* dest, const void* source, int8_t /*range*/,
+                       bool should_round, int row_shift, bool is_row) {
+  auto* const dst = static_cast<Residual*>(dest);
+  const auto* const src = static_cast<const Residual*>(source);
+
+  if (is_row) {
+    dst[0] = src[0];
+    if (should_round) {
+      dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12);
+    }
+
+    int32_t dst_i = RightShiftWithRounding(MultiplyBy2(dst[0]), row_shift);
+    if (sizeof(Residual) == 2) {
+      dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX);
+    }
+    dst[0] = static_cast<Residual>(dst_i);
+
+    // If Residual is int16_t (which implies bitdepth is 8), we don't need to
+    // clip residual[i][j] to 16 bits.
+    if (sizeof(Residual) > 2) {
+      const Residual intermediate_clamp_max =
+          (1 << (std::max(bitdepth + 6, 16) - 1)) - 1;
+      const Residual intermediate_clamp_min = -intermediate_clamp_max - 1;
+      dst[0] = Clip3(dst[0], intermediate_clamp_min, intermediate_clamp_max);
+    }
+    return;
+  }
+
+  dst[0] = static_cast<Residual>(
+      RightShiftWithRounding(src[0], kTransformColumnShift - 1));
+}
+
 template <typename Residual>
 void Identity16Row_C(void* dest, const void* source, int8_t shift) {
   assert(shift == 1 || shift == 2);
@@ -705,6 +1005,35 @@
   }
 }
 
+template <int bitdepth, typename Residual>
+void Identity16DcOnly_C(void* dest, const void* source, int8_t /*range*/,
+                        bool should_round, int row_shift, bool is_row) {
+  auto* const dst = static_cast<Residual*>(dest);
+  const auto* const src = static_cast<const Residual*>(source);
+
+  if (is_row) {
+    dst[0] = src[0];
+    if (should_round) {
+      dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12);
+    }
+
+    const int32_t rounding = (1 + (1 << row_shift)) << 11;
+    int32_t dst_i =
+        (dst[0] * kIdentity16Multiplier + rounding) >> (12 + row_shift);
+    if (sizeof(Residual) == 2) {
+      dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX);
+    }
+    dst[0] = static_cast<Residual>(dst_i);
+
+    ClampIntermediate<bitdepth, Residual>(dst, 1);
+    return;
+  }
+
+  const int32_t rounding = (1 + (1 << kTransformColumnShift)) << 11;
+  dst[0] = static_cast<Residual>((src[0] * kIdentity16Multiplier + rounding) >>
+                                 (12 + kTransformColumnShift));
+}
+
 template <typename Residual>
 void Identity32Row_C(void* dest, const void* source, int8_t shift) {
   assert(shift == 1 || shift == 2);
@@ -729,6 +1058,32 @@
   }
 }
 
+template <int bitdepth, typename Residual>
+void Identity32DcOnly_C(void* dest, const void* source, int8_t /*range*/,
+                        bool should_round, int row_shift, bool is_row) {
+  auto* const dst = static_cast<Residual*>(dest);
+  const auto* const src = static_cast<const Residual*>(source);
+
+  if (is_row) {
+    dst[0] = src[0];
+    if (should_round) {
+      dst[0] = RightShiftWithRounding(dst[0] * kTransformRowMultiplier, 12);
+    }
+
+    int32_t dst_i = RightShiftWithRounding(MultiplyBy4(dst[0]), row_shift);
+    if (sizeof(Residual) == 2) {
+      dst_i = Clip3(dst_i, INT16_MIN, INT16_MAX);
+    }
+    dst[0] = static_cast<Residual>(dst_i);
+
+    ClampIntermediate<bitdepth, Residual>(dst, 1);
+    return;
+  }
+
+  dst[0] = static_cast<Residual>(
+      RightShiftWithRounding(src[0], kTransformColumnShift - 2));
+}
+
 //------------------------------------------------------------------------------
 // Walsh Hadamard Transform.
 
@@ -751,14 +1106,36 @@
   dst[3] = temp[3] + dst[2];
 }
 
+template <int bitdepth, typename Residual>
+void Wht4DcOnly_C(void* dest, const void* source, int8_t range,
+                  bool /*should_round*/, int /*row_shift*/, bool /*is_row*/) {
+  auto* const dst = static_cast<Residual*>(dest);
+  const auto* const src = static_cast<const Residual*>(source);
+  const int shift = range;
+
+  Residual temp = src[0] >> shift;
+  // This signed right shift must be an arithmetic shift.
+  Residual e = temp >> 1;
+  dst[0] = temp - e;
+  dst[1] = e;
+  dst[2] = e;
+  dst[3] = e;
+
+  ClampIntermediate<bitdepth, Residual>(dst, 4);
+}
+
 //------------------------------------------------------------------------------
 // row/column transform loop
 
 using InverseTransform1DFunc = void (*)(void* dst, const void* src,
                                         int8_t range);
+using InverseTransformDcOnlyFunc = void (*)(void* dest, const void* source,
+                                            int8_t range, bool should_round,
+                                            int row_shift, bool is_row);
 
 template <int bitdepth, typename Residual, typename Pixel,
           Transform1D transform1d_type,
+          InverseTransformDcOnlyFunc dconly_transform1d,
           InverseTransform1DFunc row_transform1d_func,
           InverseTransform1DFunc column_transform1d_func = row_transform1d_func>
 void TransformLoop_C(TransformType tx_type, TransformSize tx_size,
@@ -773,7 +1150,7 @@
   const int tx_height = lossless ? 4 : kTransformHeight[tx_size];
   const int tx_width_log2 = kTransformWidthLog2[tx_size];
   const int tx_height_log2 = kTransformHeightLog2[tx_size];
-  auto* frame = reinterpret_cast<Array2DView<Pixel>*>(dst_frame);
+  auto* frame = static_cast<Array2DView<Pixel>*>(dst_frame);
 
   // Initially this points to the dequantized values. After the transforms are
   // applied, this buffer contains the residual.
@@ -781,13 +1158,6 @@
                                  static_cast<Residual*>(src_buffer));
 
   if (is_row) {
-    // Row transforms need to be done only up to 32 because the rest of the rows
-    // are always all zero if |tx_height| is 64.  Otherwise, only process the
-    // rows that have a non zero coefficients.
-    // TODO(slavarnway): Expand to include other possible non_zero_coeff_count
-    // values.
-    const int num_rows =
-        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
     // Row transform.
     const uint8_t row_shift = lossless ? 0 : kTransformRowShift[tx_size];
     // This is the |range| parameter of the InverseTransform1DFunc.  For lossy
@@ -798,6 +1168,18 @@
     // the fraction 2896 / 2^12.
     const bool should_round = std::abs(tx_width_log2 - tx_height_log2) == 1;
 
+    if (non_zero_coeff_count == 1) {
+      dconly_transform1d(residual[0], residual[0], row_clamp_range,
+                         should_round, row_shift, true);
+      return;
+    }
+
+    // Row transforms need to be done only up to 32 because the rest of the rows
+    // are always all zero if |tx_height| is 64.  Otherwise, only process the
+    // rows that have a non zero coefficients.
+    // TODO(slavarnway): Expand to include other possible non_zero_coeff_count
+    // values.
+    const int num_rows = std::min(tx_height, 32);
     for (int i = 0; i < num_rows; ++i) {
       // If lossless, the transform size is 4x4, so should_round is false.
       if (!lossless && should_round) {
@@ -817,17 +1199,8 @@
           residual[i][j] = RightShiftWithRounding(residual[i][j], row_shift);
         }
       }
-      // If Residual is int16_t (which implies bitdepth is 8), we don't need to
-      // clip residual[i][j] to 16 bits.
-      if (sizeof(Residual) > 2) {
-        const Residual intermediate_clamp_max =
-            (1 << (std::max(bitdepth + 6, 16) - 1)) - 1;
-        const Residual intermediate_clamp_min = -intermediate_clamp_max - 1;
-        for (int j = 0; j < tx_width; ++j) {
-          residual[i][j] = Clip3(residual[i][j], intermediate_clamp_min,
-                                 intermediate_clamp_max);
-        }
-      }
+
+      ClampIntermediate<bitdepth, Residual>(residual[i], tx_width);
     }
     return;
   }
@@ -851,10 +1224,15 @@
     for (int i = 0; i < tx_height; ++i) {
       tx_buffer[i] = residual[i][flipped_j];
     }
-    // For identity transform, |column_transform1d_func| also performs the
-    // Round2(T[i], colShift) call in the spec.
-    column_transform1d_func(tx_buffer, tx_buffer,
-                            is_identity ? column_shift : column_clamp_range);
+    if (non_zero_coeff_count == 1) {
+      dconly_transform1d(tx_buffer, tx_buffer, column_clamp_range, false, 0,
+                         false);
+    } else {
+      // For identity transform, |column_transform1d_func| also performs the
+      // Round2(T[i], colShift) call in the spec.
+      column_transform1d_func(tx_buffer, tx_buffer,
+                              is_identity ? column_shift : column_clamp_range);
+    }
     const int x = start_x + j;
     for (int i = 0; i < tx_height; ++i) {
       const int y = start_y + i;
@@ -876,49 +1254,53 @@
   // Maximum transform size for Dct is 64.
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformDct] =
       TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
-                      Dct_C<Residual, 2>>;
+                      DctDcOnly_C<bitdepth, Residual, 2>, Dct_C<Residual, 2>>;
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformDct] =
       TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
-                      Dct_C<Residual, 3>>;
+                      DctDcOnly_C<bitdepth, Residual, 3>, Dct_C<Residual, 3>>;
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformDct] =
       TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
-                      Dct_C<Residual, 4>>;
+                      DctDcOnly_C<bitdepth, Residual, 4>, Dct_C<Residual, 4>>;
   dsp->inverse_transforms[k1DTransformSize32][k1DTransformDct] =
       TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
-                      Dct_C<Residual, 5>>;
+                      DctDcOnly_C<bitdepth, Residual, 5>, Dct_C<Residual, 5>>;
   dsp->inverse_transforms[k1DTransformSize64][k1DTransformDct] =
       TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
-                      Dct_C<Residual, 6>>;
+                      DctDcOnly_C<bitdepth, Residual, 6>, Dct_C<Residual, 6>>;
 
   // Maximum transform size for Adst is 16.
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformAdst] =
       TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst,
-                      Adst4_C<Residual>>;
+                      Adst4DcOnly_C<bitdepth, Residual>, Adst4_C<Residual>>;
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformAdst] =
       TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst,
-                      Adst8_C<Residual>>;
+                      Adst8DcOnly_C<bitdepth, Residual>, Adst8_C<Residual>>;
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformAdst] =
       TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst,
-                      Adst16_C<Residual>>;
+                      Adst16DcOnly_C<bitdepth, Residual>, Adst16_C<Residual>>;
 
   // Maximum transform size for Identity transform is 32.
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformIdentity] =
       TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
+                      Identity4DcOnly_C<bitdepth, Residual>,
                       Identity4Row_C<Residual>, Identity4Column_C<Residual>>;
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformIdentity] =
       TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
+                      Identity8DcOnly_C<bitdepth, Residual>,
                       Identity8Row_C<Residual>, Identity8Column_C<Residual>>;
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformIdentity] =
       TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
+                      Identity16DcOnly_C<bitdepth, Residual>,
                       Identity16Row_C<Residual>, Identity16Column_C<Residual>>;
   dsp->inverse_transforms[k1DTransformSize32][k1DTransformIdentity] =
       TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
+                      Identity32DcOnly_C<bitdepth, Residual>,
                       Identity32Row_C<Residual>, Identity32Column_C<Residual>>;
 
   // Maximum transform size for Wht is 4.
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformWht] =
       TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformWht,
-                      Wht4_C<Residual>>;
+                      Wht4DcOnly_C<bitdepth, Residual>, Wht4_C<Residual>>;
 }
 
 void Init8bpp() {
@@ -934,59 +1316,72 @@
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformDct] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 2>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+                      DctDcOnly_C<8, int16_t, 2>, Dct_C<int16_t, 2>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformDct] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 3>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+                      DctDcOnly_C<8, int16_t, 3>, Dct_C<int16_t, 3>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformDct] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 4>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+                      DctDcOnly_C<8, int16_t, 4>, Dct_C<int16_t, 4>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize32][k1DTransformDct] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 5>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+                      DctDcOnly_C<8, int16_t, 5>, Dct_C<int16_t, 5>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize64][k1DTransformDct] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct, Dct_C<int16_t, 6>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+                      DctDcOnly_C<8, int16_t, 6>, Dct_C<int16_t, 6>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformAdst] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst, Adst4_C<int16_t>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst,
+                      Adst4DcOnly_C<8, int16_t>, Adst4_C<int16_t>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformAdst] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst, Adst8_C<int16_t>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst,
+                      Adst8DcOnly_C<8, int16_t>, Adst8_C<int16_t>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformAdst] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst, Adst16_C<int16_t>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst,
+                      Adst16DcOnly_C<8, int16_t>, Adst16_C<int16_t>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformIdentity] =
       TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
-                      Identity4Row_C<int16_t>, Identity4Column_C<int16_t>>;
+                      Identity4DcOnly_C<8, int16_t>, Identity4Row_C<int16_t>,
+                      Identity4Column_C<int16_t>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformIdentity] =
       TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
-                      Identity8Row_C<int16_t>, Identity8Column_C<int16_t>>;
+                      Identity8DcOnly_C<8, int16_t>, Identity8Row_C<int16_t>,
+                      Identity8Column_C<int16_t>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformIdentity] =
       TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
-                      Identity16Row_C<int16_t>, Identity16Column_C<int16_t>>;
+                      Identity16DcOnly_C<8, int16_t>, Identity16Row_C<int16_t>,
+                      Identity16Column_C<int16_t>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize32][k1DTransformIdentity] =
       TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
-                      Identity32Row_C<int16_t>, Identity32Column_C<int16_t>>;
+                      Identity32DcOnly_C<8, int16_t>, Identity32Row_C<int16_t>,
+                      Identity32Column_C<int16_t>>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformWht] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformWht, Wht4_C<int16_t>>;
+      TransformLoop_C<8, int16_t, uint8_t, k1DTransformWht,
+                      Wht4DcOnly_C<8, int16_t>, Wht4_C<int16_t>>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
@@ -1006,66 +1401,71 @@
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformDct] =
       TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
-                      Dct_C<int32_t, 2>>;
+                      DctDcOnly_C<10, int32_t, 2>, Dct_C<int32_t, 2>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformDct] =
       TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
-                      Dct_C<int32_t, 3>>;
+                      DctDcOnly_C<10, int32_t, 3>, Dct_C<int32_t, 3>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformDct] =
       TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
-                      Dct_C<int32_t, 4>>;
+                      DctDcOnly_C<10, int32_t, 4>, Dct_C<int32_t, 4>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize32_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize32][k1DTransformDct] =
       TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
-                      Dct_C<int32_t, 5>>;
+                      DctDcOnly_C<10, int32_t, 5>, Dct_C<int32_t, 5>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize64_1DTransformDct
   dsp->inverse_transforms[k1DTransformSize64][k1DTransformDct] =
       TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
-                      Dct_C<int32_t, 6>>;
+                      DctDcOnly_C<10, int32_t, 6>, Dct_C<int32_t, 6>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformAdst
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformAdst] =
       TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst,
-                      Adst4_C<int32_t>>;
+                      Adst4DcOnly_C<10, int32_t>, Adst4_C<int32_t>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformAdst
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformAdst] =
       TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst,
-                      Adst8_C<int32_t>>;
+                      Adst8DcOnly_C<10, int32_t>, Adst8_C<int32_t>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformAdst
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformAdst] =
       TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst,
-                      Adst16_C<int32_t>>;
+                      Adst16DcOnly_C<10, int32_t>, Adst16_C<int32_t>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformIdentity] =
       TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
-                      Identity4Row_C<int32_t>, Identity4Column_C<int32_t>>;
+                      Identity4DcOnly_C<10, int32_t>, Identity4Row_C<int32_t>,
+                      Identity4Column_C<int32_t>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize8][k1DTransformIdentity] =
       TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
-                      Identity8Row_C<int32_t>, Identity8Column_C<int32_t>>;
+                      Identity8DcOnly_C<10, int32_t>, Identity8Row_C<int32_t>,
+                      Identity8Column_C<int32_t>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize16][k1DTransformIdentity] =
       TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
-                      Identity16Row_C<int32_t>, Identity16Column_C<int32_t>>;
+                      Identity16DcOnly_C<10, int32_t>, Identity16Row_C<int32_t>,
+                      Identity16Column_C<int32_t>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize32_1DTransformIdentity
   dsp->inverse_transforms[k1DTransformSize32][k1DTransformIdentity] =
       TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
-                      Identity32Row_C<int32_t>, Identity32Column_C<int32_t>>;
+                      Identity32DcOnly_C<10, int32_t>, Identity32Row_C<int32_t>,
+                      Identity32Column_C<int32_t>>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformWht
   dsp->inverse_transforms[k1DTransformSize4][k1DTransformWht] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformWht, Wht4_C<int32_t>>;
+      TransformLoop_C<10, int32_t, uint16_t, k1DTransformWht,
+                      Wht4DcOnly_C<10, int32_t>, Wht4_C<int32_t>>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
diff --git a/libgav1/src/dsp/inverse_transform.inc b/libgav1/src/dsp/inverse_transform.inc
index 55e68b6..1893884 100644
--- a/libgav1/src/dsp/inverse_transform.inc
+++ b/libgav1/src/dsp/inverse_transform.inc
@@ -46,6 +46,84 @@
 
 inline int16_t Sin128(int angle) { return Cos128(angle - 64); }
 
+template <int tx_width>
+LIBGAV1_ALWAYS_INLINE int GetNumRows(TransformType tx_type, int tx_height,
+                                     int non_zero_coeff_count) {
+  const TransformClass tx_class = GetTransformClass(tx_type);
+  // The transform loops process either 4 or a multiple of 8 rows.  Use tx_class
+  // to determine the scan order.  Then return the number of rows based on the
+  // non_zero_coeff_count.
+  if (tx_height > 4) {
+    if (tx_class == kTransformClass2D) {
+      if (tx_width == 4) {
+        if (non_zero_coeff_count <= 10) return 4;
+        if (non_zero_coeff_count <= 29) return 8;
+        return tx_height;
+      }
+      if (tx_width == 8) {
+        if (non_zero_coeff_count <= 10) return 4;
+        if (non_zero_coeff_count <= 43) return 8;
+        if ((non_zero_coeff_count <= 107) & (tx_height > 16)) return 16;
+        if ((non_zero_coeff_count <= 171) & (tx_height > 16)) return 24;
+        return tx_height;
+      }
+      if (tx_width == 16) {
+        if (non_zero_coeff_count <= 10) return 4;
+        if (non_zero_coeff_count <= 36) return 8;
+        if ((non_zero_coeff_count <= 151) & (tx_height > 16)) return 16;
+        if ((non_zero_coeff_count <= 279) & (tx_height > 16)) return 24;
+        return tx_height;
+      }
+      if (tx_width == 32) {
+        if (non_zero_coeff_count <= 10) return 4;
+        if (non_zero_coeff_count <= 36) return 8;
+        if ((non_zero_coeff_count <= 136) & (tx_height > 16)) return 16;
+        if ((non_zero_coeff_count <= 300) & (tx_height > 16)) return 24;
+        return tx_height;
+      }
+    }
+
+    if (tx_class == kTransformClassHorizontal) {
+      if (non_zero_coeff_count <= 4) return 4;
+      if (non_zero_coeff_count <= 8) return 8;
+      if ((non_zero_coeff_count <= 16) & (tx_height > 16)) return 16;
+      if ((non_zero_coeff_count <= 24) & (tx_height > 16)) return 24;
+      return tx_height;
+    }
+
+    if (tx_class == kTransformClassVertical) {
+      if (tx_width == 4) {
+        if (non_zero_coeff_count <= 16) return 4;
+        if (non_zero_coeff_count <= 32) return 8;
+        return tx_height;
+      }
+      if (tx_width == 8) {
+        if (non_zero_coeff_count <= 32) return 4;
+        if (non_zero_coeff_count <= 64) return 8;
+        if ((non_zero_coeff_count <= 128) & (tx_height > 16)) return 16;
+        if ((non_zero_coeff_count <= 192) & (tx_height > 16)) return 24;
+        return tx_height;
+      }
+
+      if (tx_width == 16) {
+        if (non_zero_coeff_count <= 64) return 4;
+        if (non_zero_coeff_count <= 128) return 8;
+        if ((non_zero_coeff_count <= 256) & (tx_height > 16)) return 16;
+        if ((non_zero_coeff_count <= 384) & (tx_height > 16)) return 24;
+        return tx_height;
+      }
+      if (tx_width == 32) {
+        if (non_zero_coeff_count <= 128) return 4;
+        if (non_zero_coeff_count <= 256) return 8;
+        if ((non_zero_coeff_count <= 512) & (tx_height > 16)) return 16;
+        if ((non_zero_coeff_count <= 768) & (tx_height > 16)) return 24;
+        return tx_height;
+      }
+    }
+  }
+  return tx_height;
+}
+
 // The value for index i is derived as:
 // round(sqrt(2) * sin(i * pi / 9) * 2 / 3 * (1 << 12)).
 constexpr int16_t kAdst4Multiplier[4] = {1321, 2482, 3344, 3803};
diff --git a/libgav1/src/dsp/libgav1_dsp.cmake b/libgav1/src/dsp/libgav1_dsp.cmake
new file mode 100644
index 0000000..00574fa
--- /dev/null
+++ b/libgav1/src/dsp/libgav1_dsp.cmake
@@ -0,0 +1,165 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_SRC_DSP_LIBGAV1_DSP_CMAKE_)
+  return()
+endif() # LIBGAV1_SRC_DSP_LIBGAV1_DSP_CMAKE_
+set(LIBGAV1_SRC_DSP_LIBGAV1_DSP_CMAKE_ 1)
+
+include("${libgav1_root}/cmake/libgav1_targets.cmake")
+
+list(APPEND libgav1_dsp_sources
+            "${libgav1_source}/dsp/average_blend.cc"
+            "${libgav1_source}/dsp/average_blend.h"
+            "${libgav1_source}/dsp/cdef.cc"
+            "${libgav1_source}/dsp/cdef.h"
+            "${libgav1_source}/dsp/cdef.inc"
+            "${libgav1_source}/dsp/common.h"
+            "${libgav1_source}/dsp/constants.cc"
+            "${libgav1_source}/dsp/constants.h"
+            "${libgav1_source}/dsp/convolve.cc"
+            "${libgav1_source}/dsp/convolve.h"
+            "${libgav1_source}/dsp/distance_weighted_blend.cc"
+            "${libgav1_source}/dsp/distance_weighted_blend.h"
+            "${libgav1_source}/dsp/dsp.cc"
+            "${libgav1_source}/dsp/dsp.h"
+            "${libgav1_source}/dsp/film_grain.cc"
+            "${libgav1_source}/dsp/film_grain.h"
+            "${libgav1_source}/dsp/film_grain_common.h"
+            "${libgav1_source}/dsp/intra_edge.cc"
+            "${libgav1_source}/dsp/intra_edge.h"
+            "${libgav1_source}/dsp/intrapred.cc"
+            "${libgav1_source}/dsp/intrapred.h"
+            "${libgav1_source}/dsp/inverse_transform.cc"
+            "${libgav1_source}/dsp/inverse_transform.h"
+            "${libgav1_source}/dsp/inverse_transform.inc"
+            "${libgav1_source}/dsp/loop_filter.cc"
+            "${libgav1_source}/dsp/loop_filter.h"
+            "${libgav1_source}/dsp/loop_restoration.cc"
+            "${libgav1_source}/dsp/loop_restoration.h"
+            "${libgav1_source}/dsp/mask_blend.cc"
+            "${libgav1_source}/dsp/mask_blend.h"
+            "${libgav1_source}/dsp/motion_field_projection.cc"
+            "${libgav1_source}/dsp/motion_field_projection.h"
+            "${libgav1_source}/dsp/motion_vector_search.cc"
+            "${libgav1_source}/dsp/motion_vector_search.h"
+            "${libgav1_source}/dsp/obmc.cc"
+            "${libgav1_source}/dsp/obmc.h"
+            "${libgav1_source}/dsp/obmc.inc"
+            "${libgav1_source}/dsp/super_res.cc"
+            "${libgav1_source}/dsp/super_res.h"
+            "${libgav1_source}/dsp/warp.cc"
+            "${libgav1_source}/dsp/warp.h"
+            "${libgav1_source}/dsp/weight_mask.cc"
+            "${libgav1_source}/dsp/weight_mask.h")
+
+list(APPEND libgav1_dsp_sources_neon
+            ${libgav1_dsp_sources_neon}
+            "${libgav1_source}/dsp/arm/average_blend_neon.cc"
+            "${libgav1_source}/dsp/arm/average_blend_neon.h"
+            "${libgav1_source}/dsp/arm/cdef_neon.cc"
+            "${libgav1_source}/dsp/arm/cdef_neon.h"
+            "${libgav1_source}/dsp/arm/common_neon.h"
+            "${libgav1_source}/dsp/arm/convolve_neon.cc"
+            "${libgav1_source}/dsp/arm/convolve_neon.h"
+            "${libgav1_source}/dsp/arm/distance_weighted_blend_neon.cc"
+            "${libgav1_source}/dsp/arm/distance_weighted_blend_neon.h"
+            "${libgav1_source}/dsp/arm/film_grain_neon.cc"
+            "${libgav1_source}/dsp/arm/film_grain_neon.h"
+            "${libgav1_source}/dsp/arm/intra_edge_neon.cc"
+            "${libgav1_source}/dsp/arm/intra_edge_neon.h"
+            "${libgav1_source}/dsp/arm/intrapred_cfl_neon.cc"
+            "${libgav1_source}/dsp/arm/intrapred_directional_neon.cc"
+            "${libgav1_source}/dsp/arm/intrapred_filter_intra_neon.cc"
+            "${libgav1_source}/dsp/arm/intrapred_neon.cc"
+            "${libgav1_source}/dsp/arm/intrapred_neon.h"
+            "${libgav1_source}/dsp/arm/intrapred_smooth_neon.cc"
+            "${libgav1_source}/dsp/arm/inverse_transform_neon.cc"
+            "${libgav1_source}/dsp/arm/inverse_transform_neon.h"
+            "${libgav1_source}/dsp/arm/loop_filter_neon.cc"
+            "${libgav1_source}/dsp/arm/loop_filter_neon.h"
+            "${libgav1_source}/dsp/arm/loop_restoration_neon.cc"
+            "${libgav1_source}/dsp/arm/loop_restoration_neon.h"
+            "${libgav1_source}/dsp/arm/mask_blend_neon.cc"
+            "${libgav1_source}/dsp/arm/mask_blend_neon.h"
+            "${libgav1_source}/dsp/arm/motion_field_projection_neon.cc"
+            "${libgav1_source}/dsp/arm/motion_field_projection_neon.h"
+            "${libgav1_source}/dsp/arm/motion_vector_search_neon.cc"
+            "${libgav1_source}/dsp/arm/motion_vector_search_neon.h"
+            "${libgav1_source}/dsp/arm/obmc_neon.cc"
+            "${libgav1_source}/dsp/arm/obmc_neon.h"
+            "${libgav1_source}/dsp/arm/super_res_neon.cc"
+            "${libgav1_source}/dsp/arm/super_res_neon.h"
+            "${libgav1_source}/dsp/arm/warp_neon.cc"
+            "${libgav1_source}/dsp/arm/warp_neon.h"
+            "${libgav1_source}/dsp/arm/weight_mask_neon.cc"
+            "${libgav1_source}/dsp/arm/weight_mask_neon.h")
+
+list(APPEND libgav1_dsp_sources_sse4
+            ${libgav1_dsp_sources_sse4}
+            "${libgav1_source}/dsp/x86/average_blend_sse4.cc"
+            "${libgav1_source}/dsp/x86/average_blend_sse4.h"
+            "${libgav1_source}/dsp/x86/common_sse4.h"
+            "${libgav1_source}/dsp/x86/cdef_sse4.cc"
+            "${libgav1_source}/dsp/x86/cdef_sse4.h"
+            "${libgav1_source}/dsp/x86/convolve_sse4.cc"
+            "${libgav1_source}/dsp/x86/convolve_sse4.h"
+            "${libgav1_source}/dsp/x86/distance_weighted_blend_sse4.cc"
+            "${libgav1_source}/dsp/x86/distance_weighted_blend_sse4.h"
+            "${libgav1_source}/dsp/x86/intra_edge_sse4.cc"
+            "${libgav1_source}/dsp/x86/intra_edge_sse4.h"
+            "${libgav1_source}/dsp/x86/intrapred_sse4.cc"
+            "${libgav1_source}/dsp/x86/intrapred_sse4.h"
+            "${libgav1_source}/dsp/x86/intrapred_cfl_sse4.cc"
+            "${libgav1_source}/dsp/x86/intrapred_smooth_sse4.cc"
+            "${libgav1_source}/dsp/x86/inverse_transform_sse4.cc"
+            "${libgav1_source}/dsp/x86/inverse_transform_sse4.h"
+            "${libgav1_source}/dsp/x86/loop_filter_sse4.cc"
+            "${libgav1_source}/dsp/x86/loop_filter_sse4.h"
+            "${libgav1_source}/dsp/x86/loop_restoration_sse4.cc"
+            "${libgav1_source}/dsp/x86/loop_restoration_sse4.h"
+            "${libgav1_source}/dsp/x86/mask_blend_sse4.cc"
+            "${libgav1_source}/dsp/x86/mask_blend_sse4.h"
+            "${libgav1_source}/dsp/x86/motion_field_projection_sse4.cc"
+            "${libgav1_source}/dsp/x86/motion_field_projection_sse4.h"
+            "${libgav1_source}/dsp/x86/motion_vector_search_sse4.cc"
+            "${libgav1_source}/dsp/x86/motion_vector_search_sse4.h"
+            "${libgav1_source}/dsp/x86/obmc_sse4.cc"
+            "${libgav1_source}/dsp/x86/obmc_sse4.h"
+            "${libgav1_source}/dsp/x86/super_res_sse4.cc"
+            "${libgav1_source}/dsp/x86/super_res_sse4.h"
+            "${libgav1_source}/dsp/x86/transpose_sse4.h"
+            "${libgav1_source}/dsp/x86/warp_sse4.cc"
+            "${libgav1_source}/dsp/x86/warp_sse4.h"
+            "${libgav1_source}/dsp/x86/weight_mask_sse4.cc"
+            "${libgav1_source}/dsp/x86/weight_mask_sse4.h"
+            )
+
+macro(libgav1_add_dsp_targets)
+  unset(dsp_sources)
+  list(APPEND dsp_sources ${libgav1_dsp_sources} ${libgav1_dsp_sources_neon}
+              ${libgav1_dsp_sources_sse4})
+
+  libgav1_add_library(NAME
+                      libgav1_dsp
+                      TYPE
+                      OBJECT
+                      SOURCES
+                      ${dsp_sources}
+                      DEFINES
+                      ${libgav1_defines}
+                      $<$<CONFIG:Debug>:LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS>
+                      INCLUDES
+                      ${libgav1_include_paths})
+endmacro()
diff --git a/libgav1/src/dsp/loop_filter.cc b/libgav1/src/dsp/loop_filter.cc
index 946952b..6cad97d 100644
--- a/libgav1/src/dsp/loop_filter.cc
+++ b/libgav1/src/dsp/loop_filter.cc
@@ -31,10 +31,10 @@
 struct LoopFilterFuncs_C {
   LoopFilterFuncs_C() = delete;
 
-  static const int kMaxPixel = (1 << bitdepth) - 1;
-  static const int kMinSignedPixel = -(1 << (bitdepth - 1));
-  static const int kMaxSignedPixel = (1 << (bitdepth - 1)) - 1;
-  static const int kFlatThresh = 1 << (bitdepth - 8);
+  static constexpr int kMaxPixel = (1 << bitdepth) - 1;
+  static constexpr int kMinSignedPixel = -(1 << (bitdepth - 1));
+  static constexpr int kMaxSignedPixel = (1 << (bitdepth - 1)) - 1;
+  static constexpr int kFlatThresh = 1 << (bitdepth - 8);
 
   static void Vertical4(void* dest, ptrdiff_t stride, int outer_thresh,
                         int inner_thresh, int hev_thresh);
diff --git a/libgav1/src/dsp/loop_restoration.cc b/libgav1/src/dsp/loop_restoration.cc
index 5e6f9a0..fce54f2 100644
--- a/libgav1/src/dsp/loop_restoration.cc
+++ b/libgav1/src/dsp/loop_restoration.cc
@@ -14,7 +14,7 @@
 
 #include "src/dsp/loop_restoration.h"
 
-#include <algorithm>  // std::max
+#include <algorithm>
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
@@ -26,15 +26,6 @@
 
 namespace libgav1 {
 namespace dsp {
-namespace {
-
-// Precision of a division table (mtable)
-constexpr int kSgrProjScaleBits = 20;
-constexpr int kSgrProjReciprocalBits = 12;
-// Core selfguided restoration precision bits.
-constexpr int kSgrProjSgrBits = 8;
-// Precision bits of generated values higher than source before projection.
-constexpr int kSgrProjRestoreBits = 4;
 
 // Section 7.17.3.
 // a2: range [1, 256].
@@ -44,75 +35,85 @@
 //   a2 = 1;
 // else
 //   a2 = ((z << kSgrProjSgrBits) + (z >> 1)) / (z + 1);
-constexpr int kXByXPlus1[256] = {
-    1,   128, 171, 192, 205, 213, 219, 224, 228, 230, 233, 235, 236, 238, 239,
-    240, 241, 242, 243, 243, 244, 244, 245, 245, 246, 246, 247, 247, 247, 247,
-    248, 248, 248, 248, 249, 249, 249, 249, 249, 250, 250, 250, 250, 250, 250,
-    250, 251, 251, 251, 251, 251, 251, 251, 251, 251, 251, 252, 252, 252, 252,
-    252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 253, 253,
-    253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253,
-    253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 254, 254, 254,
-    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
-    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
-    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
-    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
-    254, 254, 254, 254, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    256};
+// ma = 256 - a2;
+const uint8_t kSgrMaLookup[256] = {
+    255, 128, 85, 64, 51, 43, 37, 32, 28, 26, 23, 21, 20, 18, 17, 16, 15, 14,
+    13,  13,  12, 12, 11, 11, 10, 10, 9,  9,  9,  9,  8,  8,  8,  8,  7,  7,
+    7,   7,   7,  6,  6,  6,  6,  6,  6,  6,  5,  5,  5,  5,  5,  5,  5,  5,
+    5,   5,   4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
+    4,   3,   3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
+    3,   3,   3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  2,
+    2,   2,   2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
+    2,   2,   2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
+    2,   2,   2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
+    2,   2,   2,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,   1,   1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,   1,   1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,   1,   1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,   1,   1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,   1,   1,  0};
 
-constexpr int kOneByX[25] = {
-    4096, 2048, 1365, 1024, 819, 683, 585, 512, 455, 410, 372, 341, 315,
-    293,  273,  256,  241,  228, 216, 205, 195, 186, 178, 171, 164,
-};
+namespace {
 
 template <int bitdepth, typename Pixel>
-struct LoopRestorationFuncs_C {
-  LoopRestorationFuncs_C() = delete;
+inline void WienerHorizontal(const Pixel* source, const ptrdiff_t source_stride,
+                             const int width, const int height,
+                             const int16_t* const filter,
+                             const int number_zero_coefficients,
+                             int16_t** wiener_buffer) {
+  constexpr int kCenterTap = kWienerFilterTaps / 2;
+  constexpr int kRoundBitsHorizontal = (bitdepth == 12)
+                                           ? kInterRoundBitsHorizontal12bpp
+                                           : kInterRoundBitsHorizontal;
+  constexpr int offset =
+      1 << (bitdepth + kWienerFilterBits - kRoundBitsHorizontal - 1);
+  constexpr int limit = (offset << 2) - 1;
+  int y = height;
+  do {
+    int x = 0;
+    do {
+      // sum fits into 16 bits only when bitdepth = 8.
+      int sum = 0;
+      for (int k = number_zero_coefficients; k < kCenterTap; ++k) {
+        sum +=
+            filter[k] * (source[x + k] + source[x + kWienerFilterTaps - 1 - k]);
+      }
+      sum += filter[kCenterTap] * source[x + kCenterTap];
+      const int rounded_sum = RightShiftWithRounding(sum, kRoundBitsHorizontal);
+      (*wiener_buffer)[x] = Clip3(rounded_sum, -offset, limit - offset);
+    } while (++x != width);
+    source += source_stride;
+    *wiener_buffer += width;
+  } while (--y != 0);
+}
 
-  // |stride| for SelfGuidedFilter and WienerFilter is given in bytes.
-  static void SelfGuidedFilter(const void* source, void* dest,
-                               const RestorationUnitInfo& restoration_info,
-                               ptrdiff_t source_stride, ptrdiff_t dest_stride,
-                               int width, int height,
-                               RestorationBuffer* buffer);
-  static void WienerFilter(const void* source, void* dest,
-                           const RestorationUnitInfo& restoration_info,
-                           ptrdiff_t source_stride, ptrdiff_t dest_stride,
-                           int width, int height, RestorationBuffer* buffer);
-  // |stride| for box filter processing is in Pixels.
-  static void BoxFilterPreProcess(const RestorationUnitInfo& restoration_info,
-                                  const Pixel* src, ptrdiff_t stride, int width,
-                                  int height, int pass,
-                                  RestorationBuffer* buffer);
-  static void BoxFilterProcess(const RestorationUnitInfo& restoration_info,
-                               const Pixel* src, ptrdiff_t stride, int width,
-                               int height, RestorationBuffer* buffer);
-};
-
-// Note: range of wiener filter coefficients.
-// Wiener filter coefficients are symmetric, and their sum is 1 (128).
-// The range of each coefficient:
-// filter[0] = filter[6], 4 bits, min = -5, max = 10.
-// filter[1] = filter[5], 5 bits, min = -23, max = 8.
-// filter[2] = filter[4], 6 bits, min = -17, max = 46.
-// filter[3] = 128 - (filter[0] + filter[1] + filter[2]) * 2.
-// The difference from libaom is that in libaom:
-// filter[3] = 0 - (filter[0] + filter[1] + filter[2]) * 2.
-// Thus in libaom's computation, an offset of 128 is needed for filter[3].
-inline void PopulateWienerCoefficients(
-    const RestorationUnitInfo& restoration_info, int direction,
-    int16_t* const filter) {
-  filter[3] = 128;
-  for (int i = 0; i < 3; ++i) {
-    const int16_t coeff = restoration_info.wiener_info.filter[direction][i];
-    filter[i] = coeff;
-    filter[6 - i] = coeff;
-    filter[3] -= MultiplyBy2(coeff);
-  }
+template <int bitdepth, typename Pixel>
+inline void WienerVertical(const int16_t* wiener_buffer, const int width,
+                           const int height, const int16_t* const filter,
+                           const int number_zero_coefficients, void* const dest,
+                           const ptrdiff_t dest_stride) {
+  constexpr int kCenterTap = kWienerFilterTaps / 2;
+  constexpr int kRoundBitsVertical =
+      (bitdepth == 12) ? kInterRoundBitsVertical12bpp : kInterRoundBitsVertical;
+  auto* dst = static_cast<Pixel*>(dest);
+  int y = height;
+  do {
+    int x = 0;
+    do {
+      // sum needs 32 bits.
+      int sum = 0;
+      for (int k = number_zero_coefficients; k < kCenterTap; ++k) {
+        sum += filter[k] *
+               (wiener_buffer[k * width + x] +
+                wiener_buffer[(kWienerFilterTaps - 1 - k) * width + x]);
+      }
+      sum += filter[kCenterTap] * wiener_buffer[kCenterTap * width + x];
+      const int rounded_sum = RightShiftWithRounding(sum, kRoundBitsVertical);
+      dst[x] = static_cast<Pixel>(Clip3(rounded_sum, 0, (1 << bitdepth) - 1));
+    } while (++x != width);
+    wiener_buffer += width;
+    dst += dest_stride;
+  } while (--y != 0);
 }
 
 // Note: bit range for wiener filter.
@@ -131,264 +132,621 @@
 // than 16 bit and smaller than 32 bits.
 // The accumulator of the vertical filter is larger than 16 bits and smaller
 // than 32 bits.
+// Note: range of wiener filter coefficients.
+// Wiener filter coefficients are symmetric, and their sum is 1 (128).
+// The range of each coefficient:
+// filter[0] = filter[6], 4 bits, min = -5, max = 10.
+// filter[1] = filter[5], 5 bits, min = -23, max = 8.
+// filter[2] = filter[4], 6 bits, min = -17, max = 46.
+// filter[3] = 128 - 2 * (filter[0] + filter[1] + filter[2]).
+// The difference from libaom is that in libaom:
+// filter[3] = 0 - 2 * (filter[0] + filter[1] + filter[2]).
+// Thus in libaom's computation, an offset of 128 is needed for filter[3].
 template <int bitdepth, typename Pixel>
-void LoopRestorationFuncs_C<bitdepth, Pixel>::WienerFilter(
-    const void* const source, void* const dest,
-    const RestorationUnitInfo& restoration_info, ptrdiff_t source_stride,
-    ptrdiff_t dest_stride, int width, int height,
-    RestorationBuffer* const buffer) {
-  constexpr int kRoundBitsHorizontal = (bitdepth == 12)
-                                           ? kInterRoundBitsHorizontal12bpp
-                                           : kInterRoundBitsHorizontal;
-  constexpr int kRoundBitsVertical =
-      (bitdepth == 12) ? kInterRoundBitsVertical12bpp : kInterRoundBitsVertical;
-  int16_t filter[kSubPixelTaps - 1];
-  const int limit =
-      (1 << (bitdepth + 1 + kWienerFilterBits - kRoundBitsHorizontal)) - 1;
-  const auto* src = static_cast<const Pixel*>(source);
-  auto* dst = static_cast<Pixel*>(dest);
-  source_stride /= sizeof(Pixel);
-  dest_stride /= sizeof(Pixel);
-  const ptrdiff_t buffer_stride = buffer->wiener_buffer_stride;
-  auto* wiener_buffer = buffer->wiener_buffer;
+void WienerFilter_C(const void* const source, void* const dest,
+                    const RestorationUnitInfo& restoration_info,
+                    ptrdiff_t source_stride, ptrdiff_t dest_stride, int width,
+                    int height, RestorationBuffer* const restoration_buffer) {
+  constexpr int kCenterTap = kWienerFilterTaps / 2;
+  const int16_t* const number_leading_zero_coefficients =
+      restoration_info.wiener_info.number_leading_zero_coefficients;
+  const int number_rows_to_skip = std::max(
+      static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]),
+      1);
+  int16_t* const wiener_buffer_org = restoration_buffer->wiener_buffer;
+
   // horizontal filtering.
-  PopulateWienerCoefficients(restoration_info, WienerInfo::kHorizontal, filter);
-  const int center_tap = 3;
-  src -= center_tap * source_stride + center_tap;
-  const int horizontal_rounding = 1 << (bitdepth + kWienerFilterBits - 1);
-  for (int y = 0; y < height + kSubPixelTaps - 2; ++y) {
-    for (int x = 0; x < width; ++x) {
-      // sum fits into 16 bits only when bitdepth = 8.
-      int sum = horizontal_rounding;
-      for (int k = 0; k < kSubPixelTaps - 1; ++k) {
-        sum += filter[k] * src[x + k];
-      }
-      const int rounded_sum = RightShiftWithRounding(sum, kRoundBitsHorizontal);
-      wiener_buffer[x] = static_cast<uint16_t>(Clip3(rounded_sum, 0, limit));
-    }
-    src += source_stride;
-    wiener_buffer += buffer_stride;
+  const int height_horizontal =
+      height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip;
+  const int16_t* const filter_horizontal =
+      restoration_info.wiener_info.filter[WienerInfo::kHorizontal];
+  const auto* src = static_cast<const Pixel*>(source);
+  src -= (kCenterTap - number_rows_to_skip) * source_stride + kCenterTap;
+  auto* wiener_buffer = wiener_buffer_org + number_rows_to_skip * width;
+
+  if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) {
+    WienerHorizontal<bitdepth, Pixel>(src, source_stride, width,
+                                      height_horizontal, filter_horizontal, 0,
+                                      &wiener_buffer);
+  } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) {
+    WienerHorizontal<bitdepth, Pixel>(src, source_stride, width,
+                                      height_horizontal, filter_horizontal, 1,
+                                      &wiener_buffer);
+  } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) {
+    WienerHorizontal<bitdepth, Pixel>(src, source_stride, width,
+                                      height_horizontal, filter_horizontal, 2,
+                                      &wiener_buffer);
+  } else {
+    assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3);
+    WienerHorizontal<bitdepth, Pixel>(src, source_stride, width,
+                                      height_horizontal, filter_horizontal, 3,
+                                      &wiener_buffer);
   }
-  wiener_buffer = buffer->wiener_buffer;
+
   // vertical filtering.
-  PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical, filter);
-  const int vertical_rounding = -(1 << (bitdepth + kRoundBitsVertical - 1));
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
-      // sum needs 32 bits.
-      int sum = vertical_rounding;
-      for (int k = 0; k < kSubPixelTaps - 1; ++k) {
-        sum += filter[k] * wiener_buffer[k * buffer_stride + x];
-      }
-      const int rounded_sum = RightShiftWithRounding(sum, kRoundBitsVertical);
-      dst[x] = static_cast<Pixel>(Clip3(rounded_sum, 0, (1 << bitdepth) - 1));
+  const int16_t* const filter_vertical =
+      restoration_info.wiener_info.filter[WienerInfo::kVertical];
+  if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) {
+    // Because the top row of |source| is a duplicate of the second row, and the
+    // bottom row of |source| is a duplicate of its above row, we can duplicate
+    // the top and bottom row of |wiener_buffer| accordingly.
+    memcpy(wiener_buffer, wiener_buffer - width,
+           sizeof(*wiener_buffer) * width);
+    memcpy(wiener_buffer_org, wiener_buffer_org + width,
+           sizeof(*wiener_buffer) * width);
+    WienerVertical<bitdepth, Pixel>(wiener_buffer_org, width, height,
+                                    filter_vertical, 0, dest, dest_stride);
+  } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) {
+    WienerVertical<bitdepth, Pixel>(wiener_buffer_org, width, height,
+                                    filter_vertical, 1, dest, dest_stride);
+  } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) {
+    WienerVertical<bitdepth, Pixel>(wiener_buffer_org, width, height,
+                                    filter_vertical, 2, dest, dest_stride);
+  } else {
+    assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3);
+    WienerVertical<bitdepth, Pixel>(wiener_buffer_org, width, height,
+                                    filter_vertical, 3, dest, dest_stride);
+  }
+}
+
+//------------------------------------------------------------------------------
+// SGR
+
+template <typename Pixel, int size>
+LIBGAV1_ALWAYS_INLINE void BoxSum(const Pixel* src, const ptrdiff_t src_stride,
+                                  const int height, const int width,
+                                  uint16_t* sums, uint32_t* square_sums,
+                                  const ptrdiff_t sum_stride) {
+  int y = height;
+  do {
+    uint32_t sum = 0;
+    uint32_t square_sum = 0;
+    for (int dx = 0; dx < size; ++dx) {
+      const Pixel source = src[dx];
+      sum += source;
+      square_sum += source * source;
     }
-    dst += dest_stride;
-    wiener_buffer += buffer_stride;
+    sums[0] = sum;
+    square_sums[0] = square_sum;
+    int x = 1;
+    do {
+      const Pixel source0 = src[x - 1];
+      const Pixel source1 = src[x - 1 + size];
+      sum -= source0;
+      sum += source1;
+      square_sum -= source0 * source0;
+      square_sum += source1 * source1;
+      sums[x] = sum;
+      square_sums[x] = square_sum;
+    } while (++x != width);
+    src += src_stride;
+    sums += sum_stride;
+    square_sums += sum_stride;
+  } while (--y != 0);
+}
+
+template <typename Pixel>
+LIBGAV1_ALWAYS_INLINE void BoxSum(const Pixel* src, const ptrdiff_t src_stride,
+                                  const int height, const int width,
+                                  uint16_t* sum3, uint16_t* sum5,
+                                  uint32_t* square_sum3, uint32_t* square_sum5,
+                                  const ptrdiff_t sum_stride) {
+  int y = height;
+  do {
+    uint32_t sum = 0;
+    uint32_t square_sum = 0;
+    for (int dx = 0; dx < 4; ++dx) {
+      const Pixel source = src[dx];
+      sum += source;
+      square_sum += source * source;
+    }
+    int x = 0;
+    do {
+      const Pixel source0 = src[x];
+      const Pixel source1 = src[x + 4];
+      sum -= source0;
+      square_sum -= source0 * source0;
+      sum3[x] = sum;
+      square_sum3[x] = square_sum;
+      sum += source1;
+      square_sum += source1 * source1;
+      sum5[x] = sum + source0;
+      square_sum5[x] = square_sum + source0 * source0;
+    } while (++x != width);
+    src += src_stride;
+    sum3 += sum_stride;
+    sum5 += sum_stride;
+    square_sum3 += sum_stride;
+    square_sum5 += sum_stride;
+  } while (--y != 0);
+}
+
+template <int bitdepth, int n>
+inline void CalculateIntermediate(const uint32_t s, uint32_t a,
+                                  const uint32_t b, uint8_t* const ma_ptr,
+                                  uint32_t* const b_ptr) {
+  // a: before shift, max is 25 * (2^(bitdepth) - 1) * (2^(bitdepth) - 1).
+  // since max bitdepth = 12, max < 2^31.
+  // after shift, a < 2^16 * n < 2^22 regardless of bitdepth
+  a = RightShiftWithRounding(a, (bitdepth - 8) << 1);
+  // b: max is 25 * (2^(bitdepth) - 1). If bitdepth = 12, max < 2^19.
+  // d < 2^8 * n < 2^14 regardless of bitdepth
+  const uint32_t d = RightShiftWithRounding(b, bitdepth - 8);
+  // p: Each term in calculating p = a * n - b * b is < 2^16 * n^2 < 2^28,
+  // and p itself satisfies p < 2^14 * n^2 < 2^26.
+  // This bound on p is due to:
+  // https://en.wikipedia.org/wiki/Popoviciu's_inequality_on_variances
+  // Note: Sometimes, in high bitdepth, we can end up with a*n < b*b.
+  // This is an artifact of rounding, and can only happen if all pixels
+  // are (almost) identical, so in this case we saturate to p=0.
+  const uint32_t p = (a * n < d * d) ? 0 : a * n - d * d;
+  // p * s < (2^14 * n^2) * round(2^20 / (n^2 * scale)) < 2^34 / scale <
+  // 2^32 as long as scale >= 4. So p * s fits into a uint32_t, and z < 2^12
+  // (this holds even after accounting for the rounding in s)
+  const uint32_t z = RightShiftWithRounding(p * s, kSgrProjScaleBits);
+  // ma: range [0, 255].
+  const uint32_t ma = kSgrMaLookup[std::min(z, 255u)];
+  const uint32_t one_over_n = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n;
+  // ma < 2^8, b < 2^(bitdepth) * n,
+  // one_over_n = round(2^12 / n)
+  // => the product here is < 2^(20 + bitdepth) <= 2^32,
+  // and b is set to a value < 2^(8 + bitdepth).
+  // This holds even with the rounding in one_over_n and in the overall result,
+  // as long as ma is strictly less than 2^8.
+  const uint32_t b2 = ma * b * one_over_n;
+  *ma_ptr = ma;
+  *b_ptr = RightShiftWithRounding(b2, kSgrProjReciprocalBits);
+}
+
+template <typename T>
+inline uint32_t Sum343(const T* const src) {
+  return 3 * (src[0] + src[2]) + 4 * src[1];
+}
+
+template <typename T>
+inline uint32_t Sum444(const T* const src) {
+  return 4 * (src[0] + src[1] + src[2]);
+}
+
+template <typename T>
+inline uint32_t Sum565(const T* const src) {
+  return 5 * (src[0] + src[2]) + 6 * src[1];
+}
+
+template <int bitdepth>
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5(
+    const uint16_t* const sum5[5], const uint32_t* const square_sum5[5],
+    const int width, const uint32_t s, SgrBuffer* const sgr_buffer,
+    uint16_t* const ma565, uint32_t* const b565) {
+  int x = 0;
+  do {
+    uint32_t a = 0;
+    uint32_t b = 0;
+    for (int dy = 0; dy < 5; ++dy) {
+      a += square_sum5[dy][x];
+      b += sum5[dy][x];
+    }
+    CalculateIntermediate<bitdepth, 25>(s, a, b, sgr_buffer->ma + x,
+                                        sgr_buffer->b + x);
+  } while (++x != width + 2);
+  x = 0;
+  do {
+    ma565[x] = Sum565(sgr_buffer->ma + x);
+    b565[x] = Sum565(sgr_buffer->b + x);
+  } while (++x != width);
+}
+
+template <int bitdepth>
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3(
+    const uint16_t* const sum3[3], const uint32_t* const square_sum3[3],
+    const int width, const uint32_t s, const bool calculate444,
+    SgrBuffer* const sgr_buffer, uint16_t* const ma343, uint32_t* const b343,
+    uint16_t* const ma444, uint32_t* const b444) {
+  int x = 0;
+  do {
+    uint32_t a = 0;
+    uint32_t b = 0;
+    for (int dy = 0; dy < 3; ++dy) {
+      a += square_sum3[dy][x];
+      b += sum3[dy][x];
+    }
+    CalculateIntermediate<bitdepth, 9>(s, a, b, sgr_buffer->ma + x,
+                                       sgr_buffer->b + x);
+  } while (++x != width + 2);
+  x = 0;
+  do {
+    ma343[x] = Sum343(sgr_buffer->ma + x);
+    b343[x] = Sum343(sgr_buffer->b + x);
+  } while (++x != width);
+  if (calculate444) {
+    x = 0;
+    do {
+      ma444[x] = Sum444(sgr_buffer->ma + x);
+      b444[x] = Sum444(sgr_buffer->b + x);
+    } while (++x != width);
+  }
+}
+
+template <typename Pixel>
+inline int CalculateFilteredOutput(const Pixel src, const uint32_t ma,
+                                   const uint32_t b, const int shift) {
+  const int32_t v = b - ma * src;
+  return RightShiftWithRounding(v,
+                                kSgrProjSgrBits + shift - kSgrProjRestoreBits);
+}
+
+template <typename Pixel>
+inline void BoxFilterPass(const Pixel src0, const Pixel src1,
+                          const uint16_t* const ma565[2],
+                          const uint32_t* const b565[2], const ptrdiff_t x,
+                          int p[2]) {
+  p[0] = CalculateFilteredOutput<Pixel>(src0, ma565[0][x] + ma565[1][x],
+                                        b565[0][x] + b565[1][x], 5);
+  p[1] = CalculateFilteredOutput<Pixel>(src1, ma565[1][x], b565[1][x], 4);
+}
+
+template <typename Pixel>
+inline int BoxFilterPass2(const Pixel src, const uint16_t* const ma343[3],
+                          const uint16_t* const ma444,
+                          const uint32_t* const b343[3],
+                          const uint32_t* const b444, const ptrdiff_t x) {
+  const uint32_t ma = ma343[0][x] + ma444[x] + ma343[2][x];
+  const uint32_t b = b343[0][x] + b444[x] + b343[2][x];
+  return CalculateFilteredOutput<Pixel>(src, ma, b, 5);
+}
+
+template <int bitdepth, typename Pixel>
+inline Pixel SelfGuidedFinal(const int src, const int v) {
+  // if radius_pass_0 == 0 and radius_pass_1 == 0, the range of v is:
+  // bits(u) + bits(w0/w1/w2) + 2 = bitdepth + 13.
+  // Then, range of s is bitdepth + 2. This is a rough estimation, taking the
+  // maximum value of each element.
+  const int s = src + RightShiftWithRounding(
+                          v, kSgrProjRestoreBits + kSgrProjPrecisionBits);
+  return static_cast<Pixel>(Clip3(s, 0, (1 << bitdepth) - 1));
+}
+
+template <int bitdepth, typename Pixel>
+inline Pixel SelfGuidedDoubleMultiplier(const int src, const int filter0,
+                                        const int filter1, const int16_t w0,
+                                        const int16_t w2) {
+  const int v = w0 * filter0 + w2 * filter1;
+  return SelfGuidedFinal<bitdepth, Pixel>(src, v);
+}
+
+template <int bitdepth, typename Pixel>
+inline Pixel SelfGuidedSingleMultiplier(const int src, const int filter,
+                                        const int16_t w0) {
+  const int v = w0 * filter;
+  return SelfGuidedFinal<bitdepth, Pixel>(src, v);
+}
+
+template <typename T>
+void Circulate3PointersBy1(T* p[3]) {
+  T* const p0 = p[0];
+  p[0] = p[1];
+  p[1] = p[2];
+  p[2] = p0;
+}
+
+template <typename T>
+void Circulate4PointersBy2(T* p[4]) {
+  std::swap(p[0], p[2]);
+  std::swap(p[1], p[3]);
+}
+
+template <typename T>
+void Circulate5PointersBy2(T* p[5]) {
+  T* const p0 = p[0];
+  T* const p1 = p[1];
+  p[0] = p[2];
+  p[1] = p[3];
+  p[2] = p[4];
+  p[3] = p0;
+  p[4] = p1;
+}
+
+template <int bitdepth, typename Pixel>
+inline void BoxFilterProcess(const RestorationUnitInfo& restoration_info,
+                             const Pixel* src, const ptrdiff_t src_stride,
+                             const int width, const int height,
+                             SgrBuffer* const sgr_buffer, Pixel* dst,
+                             const ptrdiff_t dst_stride) {
+  const auto temp_stride = Align<ptrdiff_t>(width, 8);
+  const ptrdiff_t sum_stride = temp_stride + 8;
+  const int sgr_proj_index = restoration_info.sgr_proj_info.index;
+  const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index];  // < 2^12.
+  const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
+  const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
+  const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
+  uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2];
+  uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2];
+  sum3[0] = sgr_buffer->sum3;
+  square_sum3[0] = sgr_buffer->square_sum3;
+  ma343[0] = sgr_buffer->ma343;
+  b343[0] = sgr_buffer->b343;
+  for (int i = 1; i <= 3; ++i) {
+    sum3[i] = sum3[i - 1] + sum_stride;
+    square_sum3[i] = square_sum3[i - 1] + sum_stride;
+    ma343[i] = ma343[i - 1] + temp_stride;
+    b343[i] = b343[i - 1] + temp_stride;
+  }
+  sum5[0] = sgr_buffer->sum5;
+  square_sum5[0] = sgr_buffer->square_sum5;
+  for (int i = 1; i <= 4; ++i) {
+    sum5[i] = sum5[i - 1] + sum_stride;
+    square_sum5[i] = square_sum5[i - 1] + sum_stride;
+  }
+  ma444[0] = sgr_buffer->ma444;
+  b444[0] = sgr_buffer->b444;
+  for (int i = 1; i <= 2; ++i) {
+    ma444[i] = ma444[i - 1] + temp_stride;
+    b444[i] = b444[i - 1] + temp_stride;
+  }
+  ma565[0] = sgr_buffer->ma565;
+  ma565[1] = ma565[0] + temp_stride;
+  b565[0] = sgr_buffer->b565;
+  b565[1] = b565[0] + temp_stride;
+  assert(scales[0] != 0);
+  assert(scales[1] != 0);
+  BoxSum<Pixel>(src - 2 * src_stride - 3, src_stride, 4, width + 2, sum3[0],
+                sum5[1], square_sum3[0], square_sum5[1], sum_stride);
+  memcpy(sum5[0], sum5[1], sizeof(**sum5) * sum_stride);
+  memcpy(square_sum5[0], square_sum5[1], sizeof(**square_sum5) * sum_stride);
+  BoxFilterPreProcess5<bitdepth>(sum5, square_sum5, width, scales[0],
+                                 sgr_buffer, ma565[0], b565[0]);
+  BoxFilterPreProcess3<bitdepth>(sum3, square_sum3, width, scales[1], false,
+                                 sgr_buffer, ma343[0], b343[0], nullptr,
+                                 nullptr);
+  BoxFilterPreProcess3<bitdepth>(sum3 + 1, square_sum3 + 1, width, scales[1],
+                                 true, sgr_buffer, ma343[1], b343[1], ma444[0],
+                                 b444[0]);
+  for (int y = height >> 1; y != 0; --y) {
+    Circulate4PointersBy2<uint16_t>(sum3);
+    Circulate4PointersBy2<uint32_t>(square_sum3);
+    Circulate5PointersBy2<uint16_t>(sum5);
+    Circulate5PointersBy2<uint32_t>(square_sum5);
+    BoxSum<Pixel>(src + 2 * src_stride - 3, src_stride, 1, width + 2, sum3[2],
+                  sum5[3], square_sum3[2], square_sum5[3], sum_stride);
+    BoxSum<Pixel>(src + 3 * src_stride - 3, src_stride, 1, width + 2, sum3[3],
+                  sum5[4], square_sum3[3], square_sum5[4], sum_stride);
+    BoxFilterPreProcess5<bitdepth>(sum5, square_sum5, width, scales[0],
+                                   sgr_buffer, ma565[1], b565[1]);
+    BoxFilterPreProcess3<bitdepth>(sum3, square_sum3, width, scales[1], true,
+                                   sgr_buffer, ma343[2], b343[2], ma444[1],
+                                   b444[1]);
+    BoxFilterPreProcess3<bitdepth>(sum3 + 1, square_sum3 + 1, width, scales[1],
+                                   true, sgr_buffer, ma343[3], b343[3],
+                                   ma444[2], b444[2]);
+    int x = 0;
+    do {
+      int p[2][2];
+      BoxFilterPass<Pixel>(src[x], src[src_stride + x], ma565, b565, x, p[0]);
+      p[1][0] =
+          BoxFilterPass2<Pixel>(src[x], ma343, ma444[0], b343, b444[0], x);
+      p[1][1] = BoxFilterPass2<Pixel>(src[src_stride + x], ma343 + 1, ma444[1],
+                                      b343 + 1, b444[1], x);
+      dst[x] = SelfGuidedDoubleMultiplier<bitdepth, Pixel>(src[x], p[0][0],
+                                                           p[1][0], w0, w2);
+      dst[dst_stride + x] = SelfGuidedDoubleMultiplier<bitdepth, Pixel>(
+          src[src_stride + x], p[0][1], p[1][1], w0, w2);
+    } while (++x != width);
+    src += 2 * src_stride;
+    dst += 2 * dst_stride;
+    Circulate4PointersBy2<uint16_t>(ma343);
+    Circulate4PointersBy2<uint32_t>(b343);
+    std::swap(ma444[0], ma444[2]);
+    std::swap(b444[0], b444[2]);
+    std::swap(ma565[0], ma565[1]);
+    std::swap(b565[0], b565[1]);
+  }
+  if ((height & 1) != 0) {
+    Circulate4PointersBy2<uint16_t>(sum3);
+    Circulate4PointersBy2<uint32_t>(square_sum3);
+    Circulate5PointersBy2<uint16_t>(sum5);
+    Circulate5PointersBy2<uint32_t>(square_sum5);
+    BoxSum<Pixel>(src + 2 * src_stride - 3, src_stride, 1, width + 2, sum3[2],
+                  sum5[3], square_sum3[2], square_sum5[3], sum_stride);
+    memcpy(sum5[4], sum5[3], sizeof(**sum5) * sum_stride);
+    memcpy(square_sum5[4], square_sum5[3], sizeof(**square_sum5) * sum_stride);
+    BoxFilterPreProcess5<bitdepth>(sum5, square_sum5, width, scales[0],
+                                   sgr_buffer, ma565[1], b565[1]);
+    BoxFilterPreProcess3<bitdepth>(sum3, square_sum3, width, scales[1], false,
+                                   sgr_buffer, ma343[2], b343[2], nullptr,
+                                   nullptr);
+    int x = 0;
+    do {
+      const int p0 = CalculateFilteredOutput<Pixel>(
+          src[x], ma565[0][x] + ma565[1][x], b565[0][x] + b565[1][x], 5);
+      const int p1 =
+          BoxFilterPass2<Pixel>(src[x], ma343, ma444[0], b343, b444[0], x);
+      dst[x] =
+          SelfGuidedDoubleMultiplier<bitdepth, Pixel>(src[x], p0, p1, w0, w2);
+    } while (++x != width);
   }
 }
 
 template <int bitdepth, typename Pixel>
-void LoopRestorationFuncs_C<bitdepth, Pixel>::BoxFilterPreProcess(
-    const RestorationUnitInfo& restoration_info, const Pixel* const src,
-    ptrdiff_t stride, int width, int height, int pass,
-    RestorationBuffer* const buffer) {
+inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info,
+                                  const Pixel* src, const ptrdiff_t src_stride,
+                                  const int width, const int height,
+                                  SgrBuffer* const sgr_buffer, Pixel* dst,
+                                  const ptrdiff_t dst_stride) {
+  const auto temp_stride = Align<ptrdiff_t>(width, 8);
+  const ptrdiff_t sum_stride = temp_stride + 8;
   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
-  const uint8_t radius = kSgrProjParams[sgr_proj_index][pass * 2];
-  assert(radius != 0);
-  const uint32_t n = (2 * radius + 1) * (2 * radius + 1);
-  // const uint8_t scale = kSgrProjParams[sgr_proj_index][pass * 2 + 1];
-  // n2_with_scale: max value < 2^16. min value is 4.
-  // const uint32_t n2_with_scale = n * n * scale;
-  // s: max value < 2^12.
-  // const uint32_t s =
-  // ((1 << kSgrProjScaleBits) + (n2_with_scale >> 1)) / n2_with_scale;
-  const uint32_t s = kSgrScaleParameter[sgr_proj_index][pass];
+  const uint32_t s = kSgrScaleParameter[sgr_proj_index][0];  // s < 2^12.
+  const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
+  uint16_t *sum5[5], *ma565[2];
+  uint32_t *square_sum5[5], *b565[2];
+  sum5[0] = sgr_buffer->sum5;
+  square_sum5[0] = sgr_buffer->square_sum5;
+  for (int i = 1; i <= 4; ++i) {
+    sum5[i] = sum5[i - 1] + sum_stride;
+    square_sum5[i] = square_sum5[i - 1] + sum_stride;
+  }
+  ma565[0] = sgr_buffer->ma565;
+  ma565[1] = ma565[0] + temp_stride;
+  b565[0] = sgr_buffer->b565;
+  b565[1] = b565[0] + temp_stride;
   assert(s != 0);
-  const ptrdiff_t array_stride = buffer->box_filter_process_intermediate_stride;
-  // The size of the intermediate result buffer is the size of the filter area
-  // plus horizontal (3) and vertical (3) padding. The processing start point
-  // is the filter area start point -1 row and -1 column. Therefore we need to
-  // set offset and use the intermediate_result as the start point for
-  // processing.
-  const ptrdiff_t intermediate_buffer_offset =
-      kRestorationBorder * array_stride + kRestorationBorder;
-  uint32_t* intermediate_result[2] = {
-      buffer->box_filter_process_intermediate[0] + intermediate_buffer_offset -
-          array_stride,
-      buffer->box_filter_process_intermediate[1] + intermediate_buffer_offset -
-          array_stride};
-  // Calculate intermediate results, including one-pixel border, for example,
-  // if unit size is 64x64, we calculate 66x66 pixels.
-  for (int y = -1; y <= height; ++y) {
-    if (pass == 0 && ((y & 1) == 0)) {
-      intermediate_result[0] += array_stride;
-      intermediate_result[1] += array_stride;
-      continue;
-    }
-    for (int x = -1; x <= width; ++x) {
-      uint32_t a = 0;
-      uint32_t b = 0;
-      for (int dy = -radius; dy <= radius; ++dy) {
-        for (int dx = -radius; dx <= radius; ++dx) {
-          const Pixel source = src[(y + dy) * stride + (x + dx)];
-          // TODO(chengchen): Use boxsum for fast calculation.
-          a += source * source;
-          b += source;
-        }
-      }
-      // a: before shift, max is 25 * (2^(bitdepth) - 1) * (2^(bitdepth) - 1).
-      // since max bitdepth = 12, max < 2^31.
-      // after shift, a < 2^16 * n < 2^22 regardless of bitdepth
-      a = RightShiftWithRounding(a, (bitdepth - 8) << 1);
-      // b: max is 25 * (2^(bitdepth) - 1). If bitdepth = 12, max < 2^19.
-      // d < 2^8 * n < 2^14 regardless of bitdepth
-      const uint32_t d = RightShiftWithRounding(b, bitdepth - 8);
-      // p: Each term in calculating p = a * n - b * b is < 2^16 * n^2 < 2^28,
-      // and p itself satisfies p < 2^14 * n^2 < 2^26.
-      // This bound on p is due to:
-      // https://en.wikipedia.org/wiki/Popoviciu's_inequality_on_variances
-      // Note: Sometimes, in high bitdepth, we can end up with a*n < b*b.
-      // This is an artifact of rounding, and can only happen if all pixels
-      // are (almost) identical, so in this case we saturate to p=0.
-      const uint32_t p = (a * n < d * d) ? 0 : a * n - d * d;
-      // p * s < (2^14 * n^2) * round(2^20 / (n^2 * scale)) < 2^34 / scale <
-      // 2^32 as long as scale >= 4. So p * s fits into a uint32_t, and z < 2^12
-      // (this holds even after accounting for the rounding in s)
-      const uint32_t z = RightShiftWithRounding(p * s, kSgrProjScaleBits);
-      // a2: range [1, 256].
-      uint32_t a2 = kXByXPlus1[std::min(z, 255u)];
-      const uint32_t one_over_n = kOneByX[n - 1];
-      // (kSgrProjSgrBits - a2) < 2^8, b < 2^(bitdepth) * n,
-      // one_over_n = round(2^12 / n)
-      // => the product here is < 2^(20 + bitdepth) <= 2^32,
-      // and b is set to a value < 2^(8 + bitdepth).
-      // This holds even with the rounding in one_over_n and in the overall
-      // result, as long as (kSgrProjSgrBits - a2) is strictly less than 2^8.
-      const uint32_t b2 = ((1 << kSgrProjSgrBits) - a2) * b * one_over_n;
-      intermediate_result[0][x] = a2;
-      intermediate_result[1][x] =
-          RightShiftWithRounding(b2, kSgrProjReciprocalBits);
-    }
-    intermediate_result[0] += array_stride;
-    intermediate_result[1] += array_stride;
+  BoxSum<Pixel, 5>(src - 2 * src_stride - 3, src_stride, 4, width + 2, sum5[1],
+                   square_sum5[1], sum_stride);
+  memcpy(sum5[0], sum5[1], sizeof(**sum5) * sum_stride);
+  memcpy(square_sum5[0], square_sum5[1], sizeof(**square_sum5) * sum_stride);
+  BoxFilterPreProcess5<bitdepth>(sum5, square_sum5, width, s, sgr_buffer,
+                                 ma565[0], b565[0]);
+  for (int y = height >> 1; y != 0; --y) {
+    Circulate5PointersBy2<uint16_t>(sum5);
+    Circulate5PointersBy2<uint32_t>(square_sum5);
+    BoxSum<Pixel, 5>(src + 2 * src_stride - 3, src_stride, 1, width + 2,
+                     sum5[3], square_sum5[3], sum_stride);
+    BoxSum<Pixel, 5>(src + 3 * src_stride - 3, src_stride, 1, width + 2,
+                     sum5[4], square_sum5[4], sum_stride);
+    BoxFilterPreProcess5<bitdepth>(sum5, square_sum5, width, s, sgr_buffer,
+                                   ma565[1], b565[1]);
+    int x = 0;
+    do {
+      int p[2];
+      BoxFilterPass<Pixel>(src[x], src[src_stride + x], ma565, b565, x, p);
+      dst[x] = SelfGuidedSingleMultiplier<bitdepth, Pixel>(src[x], p[0], w0);
+      dst[dst_stride + x] = SelfGuidedSingleMultiplier<bitdepth, Pixel>(
+          src[src_stride + x], p[1], w0);
+    } while (++x != width);
+    src += 2 * src_stride;
+    dst += 2 * dst_stride;
+    std::swap(ma565[0], ma565[1]);
+    std::swap(b565[0], b565[1]);
+  }
+  if ((height & 1) != 0) {
+    Circulate5PointersBy2<uint16_t>(sum5);
+    Circulate5PointersBy2<uint32_t>(square_sum5);
+    BoxSum<Pixel, 5>(src + 2 * src_stride - 3, src_stride, 1, width + 2,
+                     sum5[3], square_sum5[3], sum_stride);
+    memcpy(sum5[4], sum5[3], sizeof(**sum5) * sum_stride);
+    memcpy(square_sum5[4], square_sum5[3], sizeof(**square_sum5) * sum_stride);
+    BoxFilterPreProcess5<bitdepth>(sum5, square_sum5, width, s, sgr_buffer,
+                                   ma565[1], b565[1]);
+    int x = 0;
+    do {
+      const int p = CalculateFilteredOutput<Pixel>(
+          src[x], ma565[0][x] + ma565[1][x], b565[0][x] + b565[1][x], 5);
+      dst[x] = SelfGuidedSingleMultiplier<bitdepth, Pixel>(src[x], p, w0);
+    } while (++x != width);
   }
 }
 
 template <int bitdepth, typename Pixel>
-void LoopRestorationFuncs_C<bitdepth, Pixel>::BoxFilterProcess(
-    const RestorationUnitInfo& restoration_info, const Pixel* src,
-    ptrdiff_t stride, int width, int height, RestorationBuffer* const buffer) {
+inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info,
+                                  const Pixel* src, const ptrdiff_t src_stride,
+                                  const int width, const int height,
+                                  SgrBuffer* const sgr_buffer, Pixel* dst,
+                                  const ptrdiff_t dst_stride) {
+  assert(restoration_info.sgr_proj_info.multiplier[0] == 0);
+  const auto temp_stride = Align<ptrdiff_t>(width, 8);
+  const ptrdiff_t sum_stride = temp_stride + 8;
+  const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
+  const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1;
   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
-  for (int pass = 0; pass < 2; ++pass) {
-    const uint8_t radius = kSgrProjParams[sgr_proj_index][pass * 2];
-    const Pixel* src_ptr = src;
-    if (radius == 0) continue;
-    LoopRestorationFuncs_C<bitdepth, Pixel>::BoxFilterPreProcess(
-        restoration_info, src_ptr, stride, width, height, pass, buffer);
-
-    int* filtered_output = buffer->box_filter_process_output[pass];
-    const ptrdiff_t filtered_output_stride =
-        buffer->box_filter_process_output_stride;
-    const ptrdiff_t intermediate_stride =
-        buffer->box_filter_process_intermediate_stride;
-    // Set intermediate buffer start point to the actual start point of
-    // filtering.
-    const ptrdiff_t intermediate_buffer_offset =
-        kRestorationBorder * intermediate_stride + kRestorationBorder;
-    for (int y = 0; y < height; ++y) {
-      const int shift = (pass == 0 && (y & 1) != 0) ? 4 : 5;
-      uint32_t* const array_start[2] = {
-          buffer->box_filter_process_intermediate[0] +
-              intermediate_buffer_offset + y * intermediate_stride,
-          buffer->box_filter_process_intermediate[1] +
-              intermediate_buffer_offset + y * intermediate_stride};
-      for (int x = 0; x < width; ++x) {
-        uint32_t a = 0;
-        uint32_t b = 0;
-        uint32_t* intermediate_result[2] = {
-            array_start[0] - intermediate_stride,
-            array_start[1] - intermediate_stride};
-        for (int dy = -1; dy <= 1; ++dy) {
-          for (int dx = -1; dx <= 1; ++dx) {
-            int weight;
-            if (pass == 0) {
-              if (((y + dy) & 1) != 0) {
-                weight = (dx == 0) ? 6 : 5;
-              } else {
-                continue;
-              }
-            } else {
-              weight = ((dx & dy) == 0) ? 4 : 3;
-            }
-            // intermediate_result[0]: range [1, 256].
-            // intermediate_result[1] < 2^20.
-            a += weight * intermediate_result[0][x + dx];
-            b += weight * intermediate_result[1][x + dx];
-          }
-          intermediate_result[0] += intermediate_stride;
-          intermediate_result[1] += intermediate_stride;
-        }
-        // v < 2^32. All intermediate calculations are positive.
-        const uint32_t v = a * src_ptr[x] + b;
-        filtered_output[x] = RightShiftWithRounding(
-            v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
-      }
-      src_ptr += stride;
-      filtered_output += filtered_output_stride;
-    }
+  const uint32_t s = kSgrScaleParameter[sgr_proj_index][1];  // s < 2^12.
+  uint16_t *sum3[3], *ma343[3], *ma444[2];
+  uint32_t *square_sum3[3], *b343[3], *b444[2];
+  sum3[0] = sgr_buffer->sum3;
+  square_sum3[0] = sgr_buffer->square_sum3;
+  ma343[0] = sgr_buffer->ma343;
+  b343[0] = sgr_buffer->b343;
+  for (int i = 1; i <= 2; ++i) {
+    sum3[i] = sum3[i - 1] + sum_stride;
+    square_sum3[i] = square_sum3[i - 1] + sum_stride;
+    ma343[i] = ma343[i - 1] + temp_stride;
+    b343[i] = b343[i - 1] + temp_stride;
   }
+  ma444[0] = sgr_buffer->ma444;
+  ma444[1] = ma444[0] + temp_stride;
+  b444[0] = sgr_buffer->b444;
+  b444[1] = b444[0] + temp_stride;
+  assert(s != 0);
+  BoxSum<Pixel, 3>(src - 2 * src_stride - 2, src_stride, 3, width + 2, sum3[0],
+                   square_sum3[0], sum_stride);
+  BoxFilterPreProcess3<bitdepth>(sum3, square_sum3, width, s, false, sgr_buffer,
+                                 ma343[0], b343[0], nullptr, nullptr);
+  Circulate3PointersBy1<uint16_t>(sum3);
+  Circulate3PointersBy1<uint32_t>(square_sum3);
+  BoxSum<Pixel, 3>(src + src_stride - 2, src_stride, 1, width + 2, sum3[2],
+                   square_sum3[2], sum_stride);
+  BoxFilterPreProcess3<bitdepth>(sum3, square_sum3, width, s, true, sgr_buffer,
+                                 ma343[1], b343[1], ma444[0], b444[0]);
+  int y = height;
+  do {
+    Circulate3PointersBy1<uint16_t>(sum3);
+    Circulate3PointersBy1<uint32_t>(square_sum3);
+    BoxSum<Pixel, 3>(src + 2 * src_stride - 2, src_stride, 1, width + 2,
+                     sum3[2], square_sum3[2], sum_stride);
+    BoxFilterPreProcess3<bitdepth>(sum3, square_sum3, width, s, true,
+                                   sgr_buffer, ma343[2], b343[2], ma444[1],
+                                   b444[1]);
+    int x = 0;
+    do {
+      const int p =
+          BoxFilterPass2<Pixel>(src[x], ma343, ma444[0], b343, b444[0], x);
+      dst[x] = SelfGuidedSingleMultiplier<bitdepth, Pixel>(src[x], p, w0);
+    } while (++x != width);
+    src += src_stride;
+    dst += dst_stride;
+    Circulate3PointersBy1<uint16_t>(ma343);
+    Circulate3PointersBy1<uint32_t>(b343);
+    std::swap(ma444[0], ma444[1]);
+    std::swap(b444[0], b444[1]);
+  } while (--y != 0);
 }
 
-// Assume box_filter_process_output[2] are allocated before calling
-// this function. Their sizes are width * height, stride equals width.
 template <int bitdepth, typename Pixel>
-void LoopRestorationFuncs_C<bitdepth, Pixel>::SelfGuidedFilter(
-    const void* const source, void* const dest,
-    const RestorationUnitInfo& restoration_info, ptrdiff_t source_stride,
-    ptrdiff_t dest_stride, int width, int height,
-    RestorationBuffer* const buffer) {
-  const int w0 = restoration_info.sgr_proj_info.multiplier[0];
-  const int w1 = restoration_info.sgr_proj_info.multiplier[1];
-  const int w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
+void SelfGuidedFilter_C(const void* const source, void* const dest,
+                        const RestorationUnitInfo& restoration_info,
+                        ptrdiff_t source_stride, ptrdiff_t dest_stride,
+                        int width, int height,
+                        RestorationBuffer* const restoration_buffer) {
   const int index = restoration_info.sgr_proj_info.index;
-  const uint8_t r0 = kSgrProjParams[index][0];
-  const uint8_t r1 = kSgrProjParams[index][2];
-  const ptrdiff_t array_stride = buffer->box_filter_process_output_stride;
-  int* box_filter_process_output[2] = {buffer->box_filter_process_output[0],
-                                       buffer->box_filter_process_output[1]};
+  const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
+  const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
   const auto* src = static_cast<const Pixel*>(source);
   auto* dst = static_cast<Pixel*>(dest);
-  source_stride /= sizeof(Pixel);
-  dest_stride /= sizeof(Pixel);
-  LoopRestorationFuncs_C<bitdepth, Pixel>::BoxFilterProcess(
-      restoration_info, src, source_stride, width, height, buffer);
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
-      const int u = src[x] << kSgrProjRestoreBits;
-      int v = w1 * u;
-      if (r0 != 0) {
-        v += w0 * box_filter_process_output[0][x];
-      } else {
-        v += w0 * u;
-      }
-      if (r1 != 0) {
-        v += w2 * box_filter_process_output[1][x];
-      } else {
-        v += w2 * u;
-      }
-      // if r0 == 0 and r1 == 0, the range of v is:
-      // bits(u) + bits(w0/w1/w2) + 2 = bitdepth + 13.
-      // Then, range of s is bitdepth + 2. This is a rough estimation, taking
-      // the maximum value of each element.
-      const int s = RightShiftWithRounding(
-          v, kSgrProjRestoreBits + kSgrProjPrecisionBits);
-      dst[x] = static_cast<Pixel>(Clip3(s, 0, (1 << bitdepth) - 1));
-    }
-    src += source_stride;
-    dst += dest_stride;
-    box_filter_process_output[0] += array_stride;
-    box_filter_process_output[1] += array_stride;
+  SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer;
+  if (radius_pass_1 == 0) {
+    // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the
+    // following assertion.
+    assert(radius_pass_0 != 0);
+    BoxFilterProcessPass1<bitdepth, Pixel>(restoration_info, src, source_stride,
+                                           width, height, sgr_buffer, dst,
+                                           dest_stride);
+  } else if (radius_pass_0 == 0) {
+    BoxFilterProcessPass2<bitdepth, Pixel>(restoration_info, src, source_stride,
+                                           width, height, sgr_buffer, dst,
+                                           dest_stride);
+  } else {
+    BoxFilterProcess<bitdepth, Pixel>(restoration_info, src, source_stride,
+                                      width, height, sgr_buffer, dst,
+                                      dest_stride);
   }
 }
 
@@ -396,17 +754,15 @@
   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
   assert(dsp != nullptr);
 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
-  dsp->loop_restorations[0] = LoopRestorationFuncs_C<8, uint8_t>::WienerFilter;
-  dsp->loop_restorations[1] =
-      LoopRestorationFuncs_C<8, uint8_t>::SelfGuidedFilter;
+  dsp->loop_restorations[0] = WienerFilter_C<8, uint8_t>;
+  dsp->loop_restorations[1] = SelfGuidedFilter_C<8, uint8_t>;
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   static_cast<void>(dsp);
 #ifndef LIBGAV1_Dsp8bpp_WienerFilter
-  dsp->loop_restorations[0] = LoopRestorationFuncs_C<8, uint8_t>::WienerFilter;
+  dsp->loop_restorations[0] = WienerFilter_C<8, uint8_t>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_SelfGuidedFilter
-  dsp->loop_restorations[1] =
-      LoopRestorationFuncs_C<8, uint8_t>::SelfGuidedFilter;
+  dsp->loop_restorations[1] = SelfGuidedFilter_C<8, uint8_t>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
@@ -417,19 +773,15 @@
   Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
   assert(dsp != nullptr);
 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
-  dsp->loop_restorations[0] =
-      LoopRestorationFuncs_C<10, uint16_t>::WienerFilter;
-  dsp->loop_restorations[1] =
-      LoopRestorationFuncs_C<10, uint16_t>::SelfGuidedFilter;
+  dsp->loop_restorations[0] = WienerFilter_C<10, uint16_t>;
+  dsp->loop_restorations[1] = SelfGuidedFilter_C<10, uint16_t>;
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   static_cast<void>(dsp);
 #ifndef LIBGAV1_Dsp10bpp_WienerFilter
-  dsp->loop_restorations[0] =
-      LoopRestorationFuncs_C<10, uint16_t>::WienerFilter;
+  dsp->loop_restorations[0] = WienerFilter_C<10, uint16_t>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_SelfGuidedFilter
-  dsp->loop_restorations[1] =
-      LoopRestorationFuncs_C<10, uint16_t>::SelfGuidedFilter;
+  dsp->loop_restorations[1] = SelfGuidedFilter_C<10, uint16_t>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
@@ -442,9 +794,6 @@
 #if LIBGAV1_MAX_BITDEPTH >= 10
   Init10bpp();
 #endif
-  // Local functions that may be unused depending on the optimizations
-  // available.
-  static_cast<void>(PopulateWienerCoefficients);
 }
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/loop_restoration.h b/libgav1/src/dsp/loop_restoration.h
index 663639c..a902e9b 100644
--- a/libgav1/src/dsp/loop_restoration.h
+++ b/libgav1/src/dsp/loop_restoration.h
@@ -38,6 +38,18 @@
 namespace libgav1 {
 namespace dsp {
 
+enum {
+  // Precision of a division table (mtable)
+  kSgrProjScaleBits = 20,
+  kSgrProjReciprocalBits = 12,
+  // Core self-guided restoration precision bits.
+  kSgrProjSgrBits = 8,
+  // Precision bits of generated values higher than source before projection.
+  kSgrProjRestoreBits = 4
+};  // anonymous enum
+
+extern const uint8_t kSgrMaLookup[256];
+
 // Initializes Dsp::loop_restorations. This function is not thread-safe.
 void LoopRestorationInit_C();
 
diff --git a/libgav1/src/dsp/mask_blend.cc b/libgav1/src/dsp/mask_blend.cc
index b011a4b..101c410 100644
--- a/libgav1/src/dsp/mask_blend.cc
+++ b/libgav1/src/dsp/mask_blend.cc
@@ -25,62 +25,52 @@
 namespace dsp {
 namespace {
 
+template <int subsampling_x, int subsampling_y>
+uint8_t GetMaskValue(const uint8_t* mask, const uint8_t* mask_next_row, int x) {
+  if ((subsampling_x | subsampling_y) == 0) {
+    return mask[x];
+  }
+  if (subsampling_x == 1 && subsampling_y == 0) {
+    return static_cast<uint8_t>(RightShiftWithRounding(
+        mask[MultiplyBy2(x)] + mask[MultiplyBy2(x) + 1], 1));
+  }
+  assert(subsampling_x == 1 && subsampling_y == 1);
+  return static_cast<uint8_t>(RightShiftWithRounding(
+      mask[MultiplyBy2(x)] + mask[MultiplyBy2(x) + 1] +
+          mask_next_row[MultiplyBy2(x)] + mask_next_row[MultiplyBy2(x) + 1],
+      2));
+}
+
 template <int bitdepth, typename Pixel, bool is_inter_intra, int subsampling_x,
           int subsampling_y>
-void MaskBlend_C(const uint16_t* prediction_0,
-                 const ptrdiff_t prediction_stride_0,
-                 const uint16_t* prediction_1,
+void MaskBlend_C(const void* prediction_0, const void* prediction_1,
                  const ptrdiff_t prediction_stride_1, const uint8_t* mask,
                  const ptrdiff_t mask_stride, const int width, const int height,
                  void* dest, const ptrdiff_t dest_stride) {
+  static_assert(!(bitdepth == 8 && is_inter_intra), "");
   assert(mask != nullptr);
+  using PredType =
+      typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type;
+  const auto* pred_0 = static_cast<const PredType*>(prediction_0);
+  const auto* pred_1 = static_cast<const PredType*>(prediction_1);
   auto* dst = static_cast<Pixel*>(dest);
   const ptrdiff_t dst_stride = dest_stride / sizeof(Pixel);
   constexpr int step_y = subsampling_y ? 2 : 1;
   const uint8_t* mask_next_row = mask + mask_stride;
-  // An offset to cancel offsets used in single predictor generation that
-  // make intermediate computations non negative.
-  const int single_round_offset = (1 << bitdepth) + (1 << (bitdepth - 1));
-  // An offset to cancel offsets used in compound predictor generation that
-  // make intermediate computations non negative.
-  const int compound_round_offset =
-      (1 << (bitdepth + 4)) + (1 << (bitdepth + 3));
   // 7.11.3.2 Rounding variables derivation process
   //   2 * FILTER_BITS(7) - (InterRound0(3|5) + InterRound1(7))
   constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
   for (int y = 0; y < height; ++y) {
     for (int x = 0; x < width; ++x) {
-      uint8_t mask_value;
-      if ((subsampling_x | subsampling_y) == 0) {
-        mask_value = mask[x];
-      } else if (subsampling_x == 1 && subsampling_y == 0) {
-        mask_value = static_cast<uint8_t>(RightShiftWithRounding(
-            mask[MultiplyBy2(x)] + mask[MultiplyBy2(x) + 1], 1));
-      } else {
-        assert(subsampling_x == 1 && subsampling_y == 1);
-        mask_value = static_cast<uint8_t>(RightShiftWithRounding(
-            mask[MultiplyBy2(x)] + mask[MultiplyBy2(x) + 1] +
-                mask_next_row[MultiplyBy2(x)] +
-                mask_next_row[MultiplyBy2(x) + 1],
-            2));
-      }
-
+      const uint8_t mask_value =
+          GetMaskValue<subsampling_x, subsampling_y>(mask, mask_next_row, x);
       if (is_inter_intra) {
-        // In inter intra prediction mode, the intra prediction (prediction_1)
-        // values are valid pixel values: [0, (1 << bitdepth) - 1].
-        // While the inter prediction values come from subpixel prediction
-        // from another frame, which involves interpolation and rounding.
-        // Therefore prediction_0 has to be clipped.
         dst[x] = static_cast<Pixel>(RightShiftWithRounding(
-            mask_value * prediction_1[x] +
-                (64 - mask_value) * Clip3(prediction_0[x] - single_round_offset,
-                                          0, (1 << bitdepth) - 1),
-            6));
+            mask_value * pred_1[x] + (64 - mask_value) * pred_0[x], 6));
       } else {
-        int res = (mask_value * prediction_0[x] +
-                   (64 - mask_value) * prediction_1[x]) >>
-                  6;
-        res -= compound_round_offset;
+        assert(prediction_stride_1 == width);
+        int res = (mask_value * pred_0[x] + (64 - mask_value) * pred_1[x]) >> 6;
+        res -= (bitdepth == 8) ? 0 : kCompoundOffset;
         dst[x] = static_cast<Pixel>(
             Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
                   (1 << bitdepth) - 1));
@@ -89,7 +79,31 @@
     dst += dst_stride;
     mask += mask_stride * step_y;
     mask_next_row += mask_stride * step_y;
-    prediction_0 += prediction_stride_0;
+    pred_0 += width;
+    pred_1 += prediction_stride_1;
+  }
+}
+
+template <int subsampling_x, int subsampling_y>
+void InterIntraMaskBlend8bpp_C(const uint8_t* prediction_0,
+                               uint8_t* prediction_1,
+                               const ptrdiff_t prediction_stride_1,
+                               const uint8_t* mask, const ptrdiff_t mask_stride,
+                               const int width, const int height) {
+  assert(mask != nullptr);
+  constexpr int step_y = subsampling_y ? 2 : 1;
+  const uint8_t* mask_next_row = mask + mask_stride;
+  for (int y = 0; y < height; ++y) {
+    for (int x = 0; x < width; ++x) {
+      const uint8_t mask_value =
+          GetMaskValue<subsampling_x, subsampling_y>(mask, mask_next_row, x);
+      prediction_1[x] = static_cast<uint8_t>(RightShiftWithRounding(
+          mask_value * prediction_1[x] + (64 - mask_value) * prediction_0[x],
+          6));
+    }
+    mask += mask_stride * step_y;
+    mask_next_row += mask_stride * step_y;
+    prediction_0 += width;
     prediction_1 += prediction_stride_1;
   }
 }
@@ -101,9 +115,14 @@
   dsp->mask_blend[0][0] = MaskBlend_C<8, uint8_t, false, 0, 0>;
   dsp->mask_blend[1][0] = MaskBlend_C<8, uint8_t, false, 1, 0>;
   dsp->mask_blend[2][0] = MaskBlend_C<8, uint8_t, false, 1, 1>;
-  dsp->mask_blend[0][1] = MaskBlend_C<8, uint8_t, true, 0, 0>;
-  dsp->mask_blend[1][1] = MaskBlend_C<8, uint8_t, true, 1, 0>;
-  dsp->mask_blend[2][1] = MaskBlend_C<8, uint8_t, true, 1, 1>;
+  // The is_inter_intra index of mask_blend[][] is replaced by
+  // inter_intra_mask_blend_8bpp[] in 8-bit.
+  dsp->mask_blend[0][1] = nullptr;
+  dsp->mask_blend[1][1] = nullptr;
+  dsp->mask_blend[2][1] = nullptr;
+  dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_C<0, 0>;
+  dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_C<1, 0>;
+  dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_C<1, 1>;
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   static_cast<void>(dsp);
 #ifndef LIBGAV1_Dsp8bpp_MaskBlend444
@@ -115,14 +134,19 @@
 #ifndef LIBGAV1_Dsp8bpp_MaskBlend420
   dsp->mask_blend[2][0] = MaskBlend_C<8, uint8_t, false, 1, 1>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_MaskBlendInterIntra444
-  dsp->mask_blend[0][1] = MaskBlend_C<8, uint8_t, true, 0, 0>;
+  // The is_inter_intra index of mask_blend[][] is replaced by
+  // inter_intra_mask_blend_8bpp[] in 8-bit.
+  dsp->mask_blend[0][1] = nullptr;
+  dsp->mask_blend[1][1] = nullptr;
+  dsp->mask_blend[2][1] = nullptr;
+#ifndef LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp444
+  dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_C<0, 0>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_MaskBlendInterIntra422
-  dsp->mask_blend[1][1] = MaskBlend_C<8, uint8_t, true, 1, 0>;
+#ifndef LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp422
+  dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_C<1, 0>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_MaskBlendInterIntra420
-  dsp->mask_blend[2][1] = MaskBlend_C<8, uint8_t, true, 1, 1>;
+#ifndef LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp420
+  dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_C<1, 1>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
@@ -138,6 +162,10 @@
   dsp->mask_blend[0][1] = MaskBlend_C<10, uint16_t, true, 0, 0>;
   dsp->mask_blend[1][1] = MaskBlend_C<10, uint16_t, true, 1, 0>;
   dsp->mask_blend[2][1] = MaskBlend_C<10, uint16_t, true, 1, 1>;
+  // These are only used with 8-bit.
+  dsp->inter_intra_mask_blend_8bpp[0] = nullptr;
+  dsp->inter_intra_mask_blend_8bpp[1] = nullptr;
+  dsp->inter_intra_mask_blend_8bpp[2] = nullptr;
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   static_cast<void>(dsp);
 #ifndef LIBGAV1_Dsp10bpp_MaskBlend444
@@ -158,6 +186,10 @@
 #ifndef LIBGAV1_Dsp10bpp_MaskBlendInterIntra420
   dsp->mask_blend[2][1] = MaskBlend_C<10, uint16_t, true, 1, 1>;
 #endif
+  // These are only used with 8-bit.
+  dsp->inter_intra_mask_blend_8bpp[0] = nullptr;
+  dsp->inter_intra_mask_blend_8bpp[1] = nullptr;
+  dsp->inter_intra_mask_blend_8bpp[2] = nullptr;
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
 #endif
diff --git a/libgav1/src/dsp/mask_blend.h b/libgav1/src/dsp/mask_blend.h
index c0e77dd..41f5e5b 100644
--- a/libgav1/src/dsp/mask_blend.h
+++ b/libgav1/src/dsp/mask_blend.h
@@ -25,12 +25,22 @@
 // ARM:
 #include "src/dsp/arm/mask_blend_neon.h"
 
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
+// SSE4_1
+#include "src/dsp/x86/mask_blend_sse4.h"
+// clang-format on
+
 // IWYU pragma: end_exports
 
 namespace libgav1 {
 namespace dsp {
 
-// Initializes Dsp::mask_blend. This function is not thread-safe.
+// Initializes Dsp::mask_blend and Dsp::inter_intra_mask_blend_8bpp. This
+// function is not thread-safe.
 void MaskBlendInit_C();
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/motion_field_projection.cc b/libgav1/src/dsp/motion_field_projection.cc
new file mode 100644
index 0000000..b51ec8f
--- /dev/null
+++ b/libgav1/src/dsp/motion_field_projection.cc
@@ -0,0 +1,138 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/motion_field_projection.h"
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+#include "src/utils/reference_info.h"
+#include "src/utils/types.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+// Silence unused function warnings when MotionFieldProjectionKernel_C is
+// not used.
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS ||                      \
+    !defined(LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel) || \
+    (LIBGAV1_MAX_BITDEPTH >= 10 &&                           \
+     !defined(LIBGAV1_Dsp10bpp_MotionFieldProjectionKernel))
+
+// 7.9.2.
+void MotionFieldProjectionKernel_C(const ReferenceInfo& reference_info,
+                                   int reference_to_current_with_sign,
+                                   int dst_sign, int y8_start, int y8_end,
+                                   int x8_start, int x8_end,
+                                   TemporalMotionField* motion_field) {
+  const ptrdiff_t stride = motion_field->mv.columns();
+  // The column range has to be offset by kProjectionMvMaxHorizontalOffset since
+  // coordinates in that range could end up being position_x8 because of
+  // projection.
+  const int adjusted_x8_start =
+      std::max(x8_start - kProjectionMvMaxHorizontalOffset, 0);
+  const int adjusted_x8_end = std::min(
+      x8_end + kProjectionMvMaxHorizontalOffset, static_cast<int>(stride));
+  const int8_t* const reference_offsets =
+      reference_info.relative_distance_to.data();
+  const bool* const skip_references = reference_info.skip_references.data();
+  const int16_t* const projection_divisions =
+      reference_info.projection_divisions.data();
+  const ReferenceFrameType* source_reference_types =
+      &reference_info.motion_field_reference_frame[y8_start][0];
+  const MotionVector* mv = &reference_info.motion_field_mv[y8_start][0];
+  int8_t* dst_reference_offset = motion_field->reference_offset[y8_start];
+  MotionVector* dst_mv = motion_field->mv[y8_start];
+  assert(stride == motion_field->reference_offset.columns());
+  assert((y8_start & 7) == 0);
+
+  int y8 = y8_start;
+  do {
+    const int y8_floor = (y8 & ~7) - y8;
+    const int y8_ceiling = std::min(y8_end - y8, y8_floor + 8);
+    int x8 = adjusted_x8_start;
+    do {
+      const int source_reference_type = source_reference_types[x8];
+      if (skip_references[source_reference_type]) continue;
+      MotionVector projection_mv;
+      // reference_to_current_with_sign could be 0.
+      GetMvProjection(mv[x8], reference_to_current_with_sign,
+                      projection_divisions[source_reference_type],
+                      &projection_mv);
+      // Do not update the motion vector if the block position is not valid or
+      // if position_x8 is outside the current range of x8_start and x8_end.
+      // Note that position_y8 will always be within the range of y8_start and
+      // y8_end.
+      const int position_y8 = Project(0, projection_mv.mv[0], dst_sign);
+      if (position_y8 < y8_floor || position_y8 >= y8_ceiling) continue;
+      const int x8_base = x8 & ~7;
+      const int x8_floor =
+          std::max(x8_start, x8_base - kProjectionMvMaxHorizontalOffset);
+      const int x8_ceiling =
+          std::min(x8_end, x8_base + 8 + kProjectionMvMaxHorizontalOffset);
+      const int position_x8 = Project(x8, projection_mv.mv[1], dst_sign);
+      if (position_x8 < x8_floor || position_x8 >= x8_ceiling) continue;
+      dst_mv[position_y8 * stride + position_x8] = mv[x8];
+      dst_reference_offset[position_y8 * stride + position_x8] =
+          reference_offsets[source_reference_type];
+    } while (++x8 < adjusted_x8_end);
+    source_reference_types += stride;
+    mv += stride;
+    dst_reference_offset += stride;
+    dst_mv += stride;
+  } while (++y8 < y8_end);
+}
+
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS ||
+        // !defined(LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel) ||
+        // (LIBGAV1_MAX_BITDEPTH >= 10 &&
+        //  !defined(LIBGAV1_Dsp10bpp_MotionFieldProjectionKernel))
+
+void Init8bpp() {
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \
+    !defined(LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel)
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_C;
+#endif
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+void Init10bpp() {
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \
+    !defined(LIBGAV1_Dsp10bpp_MotionFieldProjectionKernel)
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
+  assert(dsp != nullptr);
+  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_C;
+#endif
+}
+#endif
+
+}  // namespace
+
+void MotionFieldProjectionInit_C() {
+  Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  Init10bpp();
+#endif
+}
+
+}  // namespace dsp
+}  // namespace libgav1
diff --git a/libgav1/src/dsp/motion_field_projection.h b/libgav1/src/dsp/motion_field_projection.h
new file mode 100644
index 0000000..36de459
--- /dev/null
+++ b/libgav1/src/dsp/motion_field_projection.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_MOTION_FIELD_PROJECTION_H_
+#define LIBGAV1_SRC_DSP_MOTION_FIELD_PROJECTION_H_
+
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
+
+// ARM:
+#include "src/dsp/arm/motion_field_projection_neon.h"
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
+// SSE4_1
+#include "src/dsp/x86/motion_field_projection_sse4.h"
+// clang-format on
+
+// IWYU pragma: end_exports
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::motion_field_projection_kernel. This function is not
+// thread-safe.
+void MotionFieldProjectionInit_C();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_DSP_MOTION_FIELD_PROJECTION_H_
diff --git a/libgav1/src/dsp/motion_vector_search.cc b/libgav1/src/dsp/motion_vector_search.cc
new file mode 100644
index 0000000..9402302
--- /dev/null
+++ b/libgav1/src/dsp/motion_vector_search.cc
@@ -0,0 +1,211 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/motion_vector_search.h"
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+#include "src/utils/types.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+// Silence unused function warnings when the C functions are not used.
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS ||             \
+    !defined(LIBGAV1_Dsp8bpp_MotionVectorSearch) || \
+    (LIBGAV1_MAX_BITDEPTH >= 10 &&                  \
+     !defined(LIBGAV1_Dsp10bpp_MotionVectorSearch))
+
+void MvProjectionCompoundLowPrecision_C(
+    const MotionVector* const temporal_mvs,
+    const int8_t* const temporal_reference_offsets,
+    const int reference_offsets[2], const int count,
+    CompoundMotionVector* const candidate_mvs) {
+  // To facilitate the compilers, make a local copy of |reference_offsets|.
+  const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
+  int index = 0;
+  do {
+    candidate_mvs[index].mv64 = 0;
+    for (int i = 0; i < 2; ++i) {
+      // |offsets| non-zero check usually equals true and could be ignored.
+      if (offsets[i] != 0) {
+        GetMvProjection(
+            temporal_mvs[index], offsets[i],
+            kProjectionMvDivisionLookup[temporal_reference_offsets[index]],
+            &candidate_mvs[index].mv[i]);
+        for (auto& mv : candidate_mvs[index].mv[i].mv) {
+          // The next line is equivalent to:
+          // if ((mv & 1) != 0) mv += (mv > 0) ? -1 : 1;
+          mv = (mv - (mv >> 15)) & ~1;
+        }
+      }
+    }
+  } while (++index < count);
+}
+
+void MvProjectionCompoundForceInteger_C(
+    const MotionVector* const temporal_mvs,
+    const int8_t* const temporal_reference_offsets,
+    const int reference_offsets[2], const int count,
+    CompoundMotionVector* const candidate_mvs) {
+  // To facilitate the compilers, make a local copy of |reference_offsets|.
+  const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
+  int index = 0;
+  do {
+    candidate_mvs[index].mv64 = 0;
+    for (int i = 0; i < 2; ++i) {
+      // |offsets| non-zero check usually equals true and could be ignored.
+      if (offsets[i] != 0) {
+        GetMvProjection(
+            temporal_mvs[index], offsets[i],
+            kProjectionMvDivisionLookup[temporal_reference_offsets[index]],
+            &candidate_mvs[index].mv[i]);
+        for (auto& mv : candidate_mvs[index].mv[i].mv) {
+          // The next line is equivalent to:
+          // const int value = (std::abs(static_cast<int>(mv)) + 3) & ~7;
+          // const int sign = mv >> 15;
+          // mv = ApplySign(value, sign);
+          mv = (mv + 3 - (mv >> 15)) & ~7;
+        }
+      }
+    }
+  } while (++index < count);
+}
+
+void MvProjectionCompoundHighPrecision_C(
+    const MotionVector* const temporal_mvs,
+    const int8_t* const temporal_reference_offsets,
+    const int reference_offsets[2], const int count,
+    CompoundMotionVector* const candidate_mvs) {
+  // To facilitate the compilers, make a local copy of |reference_offsets|.
+  const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
+  int index = 0;
+  do {
+    candidate_mvs[index].mv64 = 0;
+    for (int i = 0; i < 2; ++i) {
+      // |offsets| non-zero check usually equals true and could be ignored.
+      if (offsets[i] != 0) {
+        GetMvProjection(
+            temporal_mvs[index], offsets[i],
+            kProjectionMvDivisionLookup[temporal_reference_offsets[index]],
+            &candidate_mvs[index].mv[i]);
+      }
+    }
+  } while (++index < count);
+}
+
+void MvProjectionSingleLowPrecision_C(
+    const MotionVector* const temporal_mvs,
+    const int8_t* const temporal_reference_offsets, const int reference_offset,
+    const int count, MotionVector* const candidate_mvs) {
+  int index = 0;
+  do {
+    GetMvProjection(
+        temporal_mvs[index], reference_offset,
+        kProjectionMvDivisionLookup[temporal_reference_offsets[index]],
+        &candidate_mvs[index]);
+    for (auto& mv : candidate_mvs[index].mv) {
+      // The next line is equivalent to:
+      // if ((mv & 1) != 0) mv += (mv > 0) ? -1 : 1;
+      mv = (mv - (mv >> 15)) & ~1;
+    }
+  } while (++index < count);
+}
+
+void MvProjectionSingleForceInteger_C(
+    const MotionVector* const temporal_mvs,
+    const int8_t* const temporal_reference_offsets, const int reference_offset,
+    const int count, MotionVector* const candidate_mvs) {
+  int index = 0;
+  do {
+    GetMvProjection(
+        temporal_mvs[index], reference_offset,
+        kProjectionMvDivisionLookup[temporal_reference_offsets[index]],
+        &candidate_mvs[index]);
+    for (auto& mv : candidate_mvs[index].mv) {
+      // The next line is equivalent to:
+      // const int value = (std::abs(static_cast<int>(mv)) + 3) & ~7;
+      // const int sign = mv >> 15;
+      // mv = ApplySign(value, sign);
+      mv = (mv + 3 - (mv >> 15)) & ~7;
+    }
+  } while (++index < count);
+}
+
+void MvProjectionSingleHighPrecision_C(
+    const MotionVector* const temporal_mvs,
+    const int8_t* const temporal_reference_offsets, const int reference_offset,
+    const int count, MotionVector* const candidate_mvs) {
+  int index = 0;
+  do {
+    GetMvProjection(
+        temporal_mvs[index], reference_offset,
+        kProjectionMvDivisionLookup[temporal_reference_offsets[index]],
+        &candidate_mvs[index]);
+  } while (++index < count);
+}
+
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS ||
+        // !defined(LIBGAV1_Dsp8bpp_MotionVectorSearch) ||
+        // (LIBGAV1_MAX_BITDEPTH >= 10 &&
+        //  !defined(LIBGAV1_Dsp10bpp_MotionVectorSearch))
+
+void Init8bpp() {
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \
+    !defined(LIBGAV1_Dsp8bpp_MotionVectorSearch)
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+  dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_C;
+  dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_C;
+  dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_C;
+  dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_C;
+  dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_C;
+  dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_C;
+#endif
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+void Init10bpp() {
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \
+    !defined(LIBGAV1_Dsp10bpp_MotionVectorSearch)
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
+  assert(dsp != nullptr);
+  dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_C;
+  dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_C;
+  dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_C;
+  dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_C;
+  dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_C;
+  dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_C;
+#endif
+}
+#endif
+
+}  // namespace
+
+void MotionVectorSearchInit_C() {
+  Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  Init10bpp();
+#endif
+}
+
+}  // namespace dsp
+}  // namespace libgav1
diff --git a/libgav1/src/dsp/motion_vector_search.h b/libgav1/src/dsp/motion_vector_search.h
new file mode 100644
index 0000000..ae16726
--- /dev/null
+++ b/libgav1/src/dsp/motion_vector_search.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_MOTION_VECTOR_SEARCH_H_
+#define LIBGAV1_SRC_DSP_MOTION_VECTOR_SEARCH_H_
+
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
+
+// ARM:
+#include "src/dsp/arm/motion_vector_search_neon.h"
+
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
+// SSE4_1
+#include "src/dsp/x86/motion_vector_search_sse4.h"
+// clang-format on
+
+// IWYU pragma: end_exports
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::mv_projection_compound and Dsp::mv_projection_single. This
+// function is not thread-safe.
+void MotionVectorSearchInit_C();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_DSP_MOTION_VECTOR_SEARCH_H_
diff --git a/libgav1/src/dsp/super_res.cc b/libgav1/src/dsp/super_res.cc
new file mode 100644
index 0000000..9379f46
--- /dev/null
+++ b/libgav1/src/dsp/super_res.cc
@@ -0,0 +1,98 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/super_res.h"
+
+#include <cassert>
+
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+template <int bitdepth, typename Pixel>
+void ComputeSuperRes(const void* source, const int upscaled_width,
+                     const int initial_subpixel_x, const int step,
+                     void* const dest) {
+  // If (original) upscaled_width is <= 9, the downscaled_width may be
+  // upscaled_width - 1 (i.e. 8, 9), and become the same (i.e. 4) when
+  // subsampled via RightShiftWithRounding. This leads to an edge case where
+  // |step| == 1 << 14.
+  assert(step <= kSuperResScaleMask || upscaled_width <= 4);
+  const auto* src = static_cast<const Pixel*>(source);
+  auto* dst = static_cast<Pixel*>(dest);
+  src -= DivideBy2(kSuperResFilterTaps);
+  int subpixel_x = initial_subpixel_x;
+  for (int x = 0; x < upscaled_width; ++x) {
+    int sum = 0;
+    const Pixel* const src_x = &src[subpixel_x >> kSuperResScaleBits];
+    const int src_x_subpixel =
+        (subpixel_x & kSuperResScaleMask) >> kSuperResExtraBits;
+    // The sign of each tap is: - + - + + - + -
+    sum -= src_x[0] * kUpscaleFilterUnsigned[src_x_subpixel][0];
+    sum += src_x[1] * kUpscaleFilterUnsigned[src_x_subpixel][1];
+    sum -= src_x[2] * kUpscaleFilterUnsigned[src_x_subpixel][2];
+    sum += src_x[3] * kUpscaleFilterUnsigned[src_x_subpixel][3];
+    sum += src_x[4] * kUpscaleFilterUnsigned[src_x_subpixel][4];
+    sum -= src_x[5] * kUpscaleFilterUnsigned[src_x_subpixel][5];
+    sum += src_x[6] * kUpscaleFilterUnsigned[src_x_subpixel][6];
+    sum -= src_x[7] * kUpscaleFilterUnsigned[src_x_subpixel][7];
+    dst[x] =
+        Clip3(RightShiftWithRounding(sum, kFilterBits), 0, (1 << bitdepth) - 1);
+    subpixel_x += step;
+  }
+}
+
+void Init8bpp() {
+  Dsp* dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  dsp->super_res_row = ComputeSuperRes<8, uint8_t>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp8bpp_SuperRes
+  dsp->super_res_row = ComputeSuperRes<8, uint8_t>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+void Init10bpp() {
+  Dsp* dsp = dsp_internal::GetWritableDspTable(10);
+  assert(dsp != nullptr);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  dsp->super_res_row = ComputeSuperRes<10, uint16_t>;
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp10bpp_SuperRes
+  dsp->super_res_row = ComputeSuperRes<10, uint16_t>;
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+}
+#endif
+
+}  // namespace
+
+void SuperResInit_C() {
+  Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  Init10bpp();
+#endif
+}
+
+}  // namespace dsp
+}  // namespace libgav1
diff --git a/libgav1/src/dsp/super_res.h b/libgav1/src/dsp/super_res.h
new file mode 100644
index 0000000..cd69474
--- /dev/null
+++ b/libgav1/src/dsp/super_res.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_SUPER_RES_H_
+#define LIBGAV1_SRC_DSP_SUPER_RES_H_
+
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
+
+// ARM:
+#include "src/dsp/arm/super_res_neon.h"
+
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
+#include "src/dsp/x86/super_res_sse4.h"
+// clang-format on
+
+// IWYU pragma: end_exports
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::super_res_row. This function is not thread-safe.
+void SuperResInit_C();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_DSP_SUPER_RES_H_
diff --git a/libgav1/src/dsp/warp.cc b/libgav1/src/dsp/warp.cc
index aae3be1..fbde65a 100644
--- a/libgav1/src/dsp/warp.cc
+++ b/libgav1/src/dsp/warp.cc
@@ -19,6 +19,7 @@
 #include <cstddef>
 #include <cstdint>
 #include <cstdlib>
+#include <type_traits>
 
 #include "src/dsp/constants.h"
 #include "src/dsp/dsp.h"
@@ -33,26 +34,72 @@
 // Number of extra bits of precision in warped filtering.
 constexpr int kWarpedDiffPrecisionBits = 10;
 
-template <int bitdepth, typename Pixel>
+// Warp prediction output ranges from WarpTest.ShowRange.
+// Bitdepth:  8 Input range:            [       0,      255]
+//   8bpp intermediate offset: 16384.
+//   intermediate range:                [    4399,    61009]
+//   first pass output range:           [     550,     7626]
+//   8bpp intermediate offset removal: 262144.
+//   intermediate range:                [ -620566,  1072406]
+//   second pass output range:          [       0,      255]
+//   compound second pass output range: [   -4848,     8378]
+//
+// Bitdepth: 10 Input range:            [       0,     1023]
+//   intermediate range:                [  -48081,   179025]
+//   first pass output range:           [   -6010,    22378]
+//   intermediate range:                [-2103516,  4198620]
+//   second pass output range:          [       0,     1023]
+//   compound second pass output range: [    8142,    57378]
+//
+// Bitdepth: 12 Input range:            [       0,     4095]
+//   intermediate range:                [ -192465,   716625]
+//   first pass output range:           [   -6015,    22395]
+//   intermediate range:                [-2105190,  4201830]
+//   second pass output range:          [       0,     4095]
+//   compound second pass output range: [    8129,    57403]
+
+template <bool is_compound, int bitdepth, typename Pixel>
 void Warp_C(const void* const source, ptrdiff_t source_stride,
             const int source_width, const int source_height,
             const int* const warp_params, const int subsampling_x,
-            const int subsampling_y, const int inter_round_bits_vertical,
-            const int block_start_x, const int block_start_y,
-            const int block_width, const int block_height, const int16_t alpha,
-            const int16_t beta, const int16_t gamma, const int16_t delta,
-            uint16_t* dest, const ptrdiff_t dest_stride) {
+            const int subsampling_y, const int block_start_x,
+            const int block_start_y, const int block_width,
+            const int block_height, const int16_t alpha, const int16_t beta,
+            const int16_t gamma, const int16_t delta, void* dest,
+            ptrdiff_t dest_stride) {
+  assert(block_width >= 8 && block_height >= 8);
+  if (is_compound) {
+    assert(dest_stride == block_width);
+  }
   constexpr int kRoundBitsHorizontal = (bitdepth == 12)
                                            ? kInterRoundBitsHorizontal12bpp
                                            : kInterRoundBitsHorizontal;
-  // Intermediate_result is the output of the horizontal filtering and rounding.
-  // The range is within 16 bits (unsigned).
-  uint16_t intermediate_result[15][8];  // 15 rows, 8 columns.
-  const int horizontal_offset = 1 << (bitdepth + kFilterBits - 1);
-  const int vertical_offset =
-      1 << (bitdepth + 2 * kFilterBits - kRoundBitsHorizontal);
+  constexpr int kRoundBitsVertical =
+      is_compound        ? kInterRoundBitsCompoundVertical
+      : (bitdepth == 12) ? kInterRoundBitsVertical12bpp
+                         : kInterRoundBitsVertical;
+
+  // Only used for 8bpp. Allows for keeping the first pass intermediates within
+  // uint16_t. With 10/12bpp the intermediate value will always require int32_t.
+  constexpr int first_pass_offset = (bitdepth == 8) ? 1 << 14 : 0;
+  constexpr int offset_removal =
+      (first_pass_offset >> kRoundBitsHorizontal) * 128;
+
+  constexpr int kMaxPixel = (1 << bitdepth) - 1;
+  union {
+    // |intermediate_result| is the output of the horizontal filtering and
+    // rounding. The range is within int16_t.
+    int16_t intermediate_result[15][8];  // 15 rows, 8 columns.
+    // In the simple special cases where the samples in each row are all the
+    // same, store one sample per row in a column vector.
+    int16_t intermediate_result_column[15];
+  };
   const auto* const src = static_cast<const Pixel*>(source);
   source_stride /= sizeof(Pixel);
+  using DestType =
+      typename std::conditional<is_compound, uint16_t, Pixel>::type;
+  auto* dst = static_cast<DestType*>(dest);
+  if (!is_compound) dest_stride /= sizeof(dst[0]);
 
   assert(block_width >= 8);
   assert(block_height >= 8);
@@ -73,81 +120,253 @@
       const int ix4 = x4 >> kWarpedModelPrecisionBits;
       const int iy4 = y4 >> kWarpedModelPrecisionBits;
 
-      // Horizontal filter.
-      int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7;
-      for (int y = -7; y < 8; ++y) {
-        // TODO(chengchen):
-        // Because of warping, the index could be out of frame boundary. Thus
-        // clip is needed. However, can we remove or reduce usage of clip?
-        // Besides, special cases exist, for example,
-        // if iy4 - 7 >= source_height or iy4 + 7 < 0, there's no need to do the
-        // filtering.
-        const int row = Clip3(iy4 + y, 0, source_height - 1);
-        const Pixel* const src_row = src + row * source_stride;
-        // Check for two simple special cases.
-        if (ix4 - 7 >= source_width - 1) {
-          // Every sample is equal to src_row[source_width - 1]. Since the sum
-          // of the warped filter coefficients is 128 (= 2^7), the filtering is
-          // equivalent to multiplying src_row[source_width - 1] by 128.
-          const int s =
-              (horizontal_offset >> kInterRoundBitsHorizontal) +
-              (src_row[source_width - 1] << (7 - kInterRoundBitsHorizontal));
-          Memset(intermediate_result[y + 7], s, 8);
-          sx4 += beta;
-          continue;
-        }
-        if (ix4 + 7 <= 0) {
-          // Every sample is equal to src_row[0]. Since the sum of the warped
-          // filter coefficients is 128 (= 2^7), the filtering is equivalent to
-          // multiplying src_row[0] by 128.
-          const int s = (horizontal_offset >> kInterRoundBitsHorizontal) +
-                        (src_row[0] << (7 - kInterRoundBitsHorizontal));
-          Memset(intermediate_result[y + 7], s, 8);
-          sx4 += beta;
-          continue;
-        }
-        // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0.
-        // It follows that -6 <= ix4 <= source_width + 5. This inequality is
-        // used below.
-        int sx = sx4 - MultiplyBy4(alpha);
-        for (int x = -4; x < 4; ++x) {
-          const int offset =
-              RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) +
-              kWarpedPixelPrecisionShifts;
-          // Since alpha and beta have been validated by SetupShear(), one can
-          // prove that 0 <= offset <= 3 * 2^6.
-          assert(offset >= 0);
-          assert(offset < 3 * kWarpedPixelPrecisionShifts + 1);
-          // For SIMD optimization:
-          // For 8 bit, the range of sum is within uint16_t, if we add an
-          // horizontal offset:
-          int sum = horizontal_offset;
-          // Horizontal_offset guarantees sum is non negative.
-          // If horizontal_offset is used, intermediate_result needs to be
-          // uint16_t.
-          // For 10/12 bit, the range of sum is within 32 bits.
-          for (int k = 0; k < 8; ++k) {
-            // We assume the source frame has left and right borders of at
-            // least 13 pixels that extend the frame boundary pixels.
-            //
-            // Since -4 <= x <= 3 and 0 <= k <= 7, using the inequality on ix4
-            // above, we have -13 <= ix4 + x + k - 3 <= source_width + 12, or
-            // -13 <= column <= (source_width - 1) + 13. Therefore we may
-            // over-read up to 13 pixels before the source row, or up to 13
-            // pixels after the source row.
-            const int column = ix4 + x + k - 3;
-            sum += kWarpedFilters[offset][k] * src_row[column];
+      // A prediction block may fall outside the frame's boundaries. If a
+      // prediction block is calculated using only samples outside the frame's
+      // boundary, the filtering can be simplified. We can divide the plane
+      // into several regions and handle them differently.
+      //
+      //                |           |
+      //            1   |     3     |   1
+      //                |           |
+      //         -------+-----------+-------
+      //                |***********|
+      //            2   |*****4*****|   2
+      //                |***********|
+      //         -------+-----------+-------
+      //                |           |
+      //            1   |     3     |   1
+      //                |           |
+      //
+      // At the center, region 4 represents the frame and is the general case.
+      //
+      // In regions 1 and 2, the prediction block is outside the frame's
+      // boundary horizontally. Therefore the horizontal filtering can be
+      // simplified. Furthermore, in the region 1 (at the four corners), the
+      // prediction is outside the frame's boundary both horizontally and
+      // vertically, so we get a constant prediction block.
+      //
+      // In region 3, the prediction block is outside the frame's boundary
+      // vertically. Unfortunately because we apply the horizontal filters
+      // first, by the time we apply the vertical filters, they no longer see
+      // simple inputs. So the only simplification is that all the rows are
+      // the same, but we still need to apply all the horizontal and vertical
+      // filters.
+
+      // Check for two simple special cases, where the horizontal filter can
+      // be significantly simplified.
+      //
+      // In general, for each row, the horizontal filter is calculated as
+      // follows:
+      //   for (int x = -4; x < 4; ++x) {
+      //     const int offset = ...;
+      //     int sum = first_pass_offset;
+      //     for (int k = 0; k < 8; ++k) {
+      //       const int column = Clip3(ix4 + x + k - 3, 0, source_width - 1);
+      //       sum += kWarpedFilters[offset][k] * src_row[column];
+      //     }
+      //     ...
+      //   }
+      // The column index before clipping, ix4 + x + k - 3, varies in the range
+      // ix4 - 7 <= ix4 + x + k - 3 <= ix4 + 7. If ix4 - 7 >= source_width - 1
+      // or ix4 + 7 <= 0, then all the column indexes are clipped to the same
+      // border index (source_width - 1 or 0, respectively). Then for each x,
+      // the inner for loop of the horizontal filter is reduced to multiplying
+      // the border pixel by the sum of the filter coefficients.
+      if (ix4 - 7 >= source_width - 1 || ix4 + 7 <= 0) {
+        // Regions 1 and 2.
+        // Points to the left or right border of the first row of |src|.
+        const Pixel* first_row_border =
+            (ix4 + 7 <= 0) ? src : src + source_width - 1;
+        // In general, for y in [-7, 8), the row number iy4 + y is clipped:
+        //   const int row = Clip3(iy4 + y, 0, source_height - 1);
+        // In two special cases, iy4 + y is clipped to either 0 or
+        // source_height - 1 for all y. In the rest of the cases, iy4 + y is
+        // bounded and we can avoid clipping iy4 + y by relying on a reference
+        // frame's boundary extension on the top and bottom.
+        if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) {
+          // Region 1.
+          // Every sample used to calculate the prediction block has the same
+          // value. So the whole prediction block has the same value.
+          const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1;
+          const Pixel row_border_pixel = first_row_border[row * source_stride];
+          DestType* dst_row = dst + start_x - block_start_x;
+          if (is_compound) {
+            int sum = row_border_pixel
+                      << ((14 - kRoundBitsHorizontal) - kRoundBitsVertical);
+            sum += (bitdepth == 8) ? 0 : kCompoundOffset;
+            Memset(dst_row, sum, 8);
+          } else {
+            Memset(dst_row, row_border_pixel, 8);
           }
-          assert(sum >= 0 && sum < (horizontal_offset << 2));
-          intermediate_result[y + 7][x + 4] = static_cast<uint16_t>(
-              RightShiftWithRounding(sum, kRoundBitsHorizontal));
-          sx += alpha;
+          const DestType* const first_dst_row = dst_row;
+          dst_row += dest_stride;
+          for (int y = 1; y < 8; ++y) {
+            memcpy(dst_row, first_dst_row, 8 * sizeof(*dst_row));
+            dst_row += dest_stride;
+          }
+          // End of region 1. Continue the |start_x| for loop.
+          continue;
         }
-        sx4 += beta;
+
+        // Region 2.
+        // Horizontal filter.
+        // The input values in this region are generated by extending the border
+        // which makes them identical in the horizontal direction. This
+        // computation could be inlined in the vertical pass but most
+        // implementations will need a transpose of some sort.
+        // It is not necessary to use the offset values here because the
+        // horizontal pass is a simple shift and the vertical pass will always
+        // require using 32 bits.
+        for (int y = -7; y < 8; ++y) {
+          // We may over-read up to 13 pixels above the top source row, or up
+          // to 13 pixels below the bottom source row. This is proved below.
+          const int row = iy4 + y;
+          int sum = first_row_border[row * source_stride];
+          sum <<= kFilterBits - kRoundBitsHorizontal;
+          intermediate_result_column[y + 7] = sum;
+        }
+        // Vertical filter.
+        DestType* dst_row = dst + start_x - block_start_x;
+        int sy4 =
+            (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta);
+        for (int y = 0; y < 8; ++y) {
+          int sy = sy4 - MultiplyBy4(gamma);
+          for (int x = 0; x < 8; ++x) {
+            const int offset =
+                RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
+                kWarpedPixelPrecisionShifts;
+            assert(offset >= 0);
+            assert(offset < 3 * kWarpedPixelPrecisionShifts + 1);
+            int sum = 0;
+            for (int k = 0; k < 8; ++k) {
+              sum +=
+                  kWarpedFilters[offset][k] * intermediate_result_column[y + k];
+            }
+            sum = RightShiftWithRounding(sum, kRoundBitsVertical);
+            if (is_compound) {
+              sum += (bitdepth == 8) ? 0 : kCompoundOffset;
+              dst_row[x] = static_cast<DestType>(sum);
+            } else {
+              dst_row[x] = static_cast<DestType>(Clip3(sum, 0, kMaxPixel));
+            }
+            sy += gamma;
+          }
+          dst_row += dest_stride;
+          sy4 += delta;
+        }
+        // End of region 2. Continue the |start_x| for loop.
+        continue;
       }
 
+      // Regions 3 and 4.
+      // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0.
+      // It follows that -6 <= ix4 <= source_width + 5. This inequality is
+      // used below.
+
+      // In general, for y in [-7, 8), the row number iy4 + y is clipped:
+      //   const int row = Clip3(iy4 + y, 0, source_height - 1);
+      // In two special cases, iy4 + y is clipped to either 0 or
+      // source_height - 1 for all y. In the rest of the cases, iy4 + y is
+      // bounded and we can avoid clipping iy4 + y by relying on a reference
+      // frame's boundary extension on the top and bottom.
+      if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) {
+        // Region 3.
+        // Horizontal filter.
+        const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1;
+        const Pixel* const src_row = src + row * source_stride;
+        int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7;
+        for (int y = -7; y < 8; ++y) {
+          int sx = sx4 - MultiplyBy4(alpha);
+          for (int x = -4; x < 4; ++x) {
+            const int offset =
+                RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) +
+                kWarpedPixelPrecisionShifts;
+            // Since alpha and beta have been validated by SetupShear(), one
+            // can prove that 0 <= offset <= 3 * 2^6.
+            assert(offset >= 0);
+            assert(offset < 3 * kWarpedPixelPrecisionShifts + 1);
+            // For SIMD optimization:
+            // |first_pass_offset| guarantees the sum fits in uint16_t for 8bpp.
+            // For 10/12 bit, the range of sum requires 32 bits.
+            int sum = first_pass_offset;
+            for (int k = 0; k < 8; ++k) {
+              // We assume the source frame has left and right borders of at
+              // least 13 pixels that extend the frame boundary pixels.
+              //
+              // Since -4 <= x <= 3 and 0 <= k <= 7, using the inequality on
+              // ix4 above, we have
+              //   -13 <= ix4 + x + k - 3 <= source_width + 12,
+              // or
+              //   -13 <= column <= (source_width - 1) + 13.
+              // Therefore we may over-read up to 13 pixels before the source
+              // row, or up to 13 pixels after the source row.
+              const int column = ix4 + x + k - 3;
+              sum += kWarpedFilters[offset][k] * src_row[column];
+            }
+            intermediate_result[y + 7][x + 4] =
+                RightShiftWithRounding(sum, kRoundBitsHorizontal);
+            sx += alpha;
+          }
+          sx4 += beta;
+        }
+      } else {
+        // Region 4.
+        // Horizontal filter.
+        // At this point, we know iy4 - 7 < source_height - 1 and iy4 + 7 > 0.
+        // It follows that -6 <= iy4 <= source_height + 5. This inequality is
+        // used below.
+        int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7;
+        for (int y = -7; y < 8; ++y) {
+          // We assume the source frame has top and bottom borders of at least
+          // 13 pixels that extend the frame boundary pixels.
+          //
+          // Since -7 <= y <= 7, using the inequality on iy4 above, we have
+          //   -13 <= iy4 + y <= source_height + 12,
+          // or
+          //   -13 <= row <= (source_height - 1) + 13.
+          // Therefore we may over-read up to 13 pixels above the top source
+          // row, or up to 13 pixels below the bottom source row.
+          const int row = iy4 + y;
+          const Pixel* const src_row = src + row * source_stride;
+          int sx = sx4 - MultiplyBy4(alpha);
+          for (int x = -4; x < 4; ++x) {
+            const int offset =
+                RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) +
+                kWarpedPixelPrecisionShifts;
+            // Since alpha and beta have been validated by SetupShear(), one
+            // can prove that 0 <= offset <= 3 * 2^6.
+            assert(offset >= 0);
+            assert(offset < 3 * kWarpedPixelPrecisionShifts + 1);
+            // For SIMD optimization:
+            // |first_pass_offset| guarantees the sum fits in uint16_t for 8bpp.
+            // For 10/12 bit, the range of sum requires 32 bits.
+            int sum = first_pass_offset;
+            for (int k = 0; k < 8; ++k) {
+              // We assume the source frame has left and right borders of at
+              // least 13 pixels that extend the frame boundary pixels.
+              //
+              // Since -4 <= x <= 3 and 0 <= k <= 7, using the inequality on
+              // ix4 above, we have
+              //   -13 <= ix4 + x + k - 3 <= source_width + 12,
+              // or
+              //   -13 <= column <= (source_width - 1) + 13.
+              // Therefore we may over-read up to 13 pixels before the source
+              // row, or up to 13 pixels after the source row.
+              const int column = ix4 + x + k - 3;
+              sum += kWarpedFilters[offset][k] * src_row[column];
+            }
+            intermediate_result[y + 7][x + 4] =
+                RightShiftWithRounding(sum, kRoundBitsHorizontal) -
+                offset_removal;
+            sx += alpha;
+          }
+          sx4 += beta;
+        }
+      }
+
+      // Regions 3 and 4.
       // Vertical filter.
-      uint16_t* dst_row = dest + start_x - block_start_x;
+      DestType* dst_row = dst + start_x - block_start_x;
       int sy4 =
           (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta);
       // The spec says we should use the following loop condition:
@@ -156,22 +375,28 @@
       // implies std::min(4, block_start_y + block_height - start_y - 4) = 4.
       // So the loop condition is simply y < 4.
       //
-      // Proof:
-      //    start_y < block_start_y + block_height
-      // => block_start_y + block_height - start_y > 0
-      // => block_height - (start_y - block_start_y) > 0
+      //   Proof:
+      //      start_y < block_start_y + block_height
+      //   => block_start_y + block_height - start_y > 0
+      //   => block_height - (start_y - block_start_y) > 0
       //
-      // Since block_height >= 8 and is a power of 2, it follows that
-      // block_height is a multiple of 8. start_y - block_start_y is also a
-      // multiple of 8. Therefore their difference is a multiple of 8. Since
-      // their difference is > 0, their difference must be >= 8.
-      for (int y = -4; y < 4; ++y) {
+      //   Since block_height >= 8 and is a power of 2, it follows that
+      //   block_height is a multiple of 8. start_y - block_start_y is also a
+      //   multiple of 8. Therefore their difference is a multiple of 8. Since
+      //   their difference is > 0, their difference must be >= 8.
+      //
+      // We then add an offset of 4 to y so that the loop starts with y = 0
+      // and continues if y < 8.
+      for (int y = 0; y < 8; ++y) {
         int sy = sy4 - MultiplyBy4(gamma);
         // The spec says we should use the following loop condition:
         //   x < std::min(4, block_start_x + block_width - start_x - 4);
         // Similar to the above, we can prove that the loop condition can be
         // simplified to x < 4.
-        for (int x = -4; x < 4; ++x) {
+        //
+        // We then add an offset of 4 to x so that the loop starts with x = 0
+        // and continues if x < 8.
+        for (int x = 0; x < 8; ++x) {
           const int offset =
               RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
               kWarpedPixelPrecisionShifts;
@@ -179,26 +404,25 @@
           // prove that 0 <= offset <= 3 * 2^6.
           assert(offset >= 0);
           assert(offset < 3 * kWarpedPixelPrecisionShifts + 1);
-          // Similar to horizontal_offset, vertical_offset guarantees sum
-          // before shifting is non negative:
-          int sum = vertical_offset;
+          int sum = 0;
           for (int k = 0; k < 8; ++k) {
-            sum += kWarpedFilters[offset][k] *
-                   intermediate_result[y + 4 + k][x + 4];
+            sum += kWarpedFilters[offset][k] * intermediate_result[y + k][x];
           }
-          assert(sum >= 0 && sum < (vertical_offset << 2));
-          sum = RightShiftWithRounding(sum, inter_round_bits_vertical);
-          // Warp output is a predictor, whose type is uint16_t.
-          // Do not clip it here. The clipping is applied at the stage of
-          // final pixel value output.
-          dst_row[x + 4] = static_cast<uint16_t>(sum);
+          sum -= offset_removal;
+          sum = RightShiftWithRounding(sum, kRoundBitsVertical);
+          if (is_compound) {
+            sum += (bitdepth == 8) ? 0 : kCompoundOffset;
+            dst_row[x] = static_cast<DestType>(sum);
+          } else {
+            dst_row[x] = static_cast<DestType>(Clip3(sum, 0, kMaxPixel));
+          }
           sy += gamma;
         }
         dst_row += dest_stride;
         sy4 += delta;
       }
     }
-    dest += 8 * dest_stride;
+    dst += 8 * dest_stride;
   }
 }
 
@@ -206,11 +430,15 @@
   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
   assert(dsp != nullptr);
 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
-  dsp->warp = Warp_C<8, uint8_t>;
+  dsp->warp = Warp_C</*is_compound=*/false, 8, uint8_t>;
+  dsp->warp_compound = Warp_C</*is_compound=*/true, 8, uint8_t>;
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   static_cast<void>(dsp);
 #ifndef LIBGAV1_Dsp8bpp_Warp
-  dsp->warp = Warp_C<8, uint8_t>;
+  dsp->warp = Warp_C</*is_compound=*/false, 8, uint8_t>;
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WarpCompound
+  dsp->warp_compound = Warp_C</*is_compound=*/true, 8, uint8_t>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
@@ -220,11 +448,15 @@
   Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
   assert(dsp != nullptr);
 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
-  dsp->warp = Warp_C<10, uint16_t>;
+  dsp->warp = Warp_C</*is_compound=*/false, 10, uint16_t>;
+  dsp->warp_compound = Warp_C</*is_compound=*/true, 10, uint16_t>;
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   static_cast<void>(dsp);
 #ifndef LIBGAV1_Dsp10bpp_Warp
-  dsp->warp = Warp_C<10, uint16_t>;
+  dsp->warp = Warp_C</*is_compound=*/false, 10, uint16_t>;
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WarpCompound
+  dsp->warp_compound = Warp_C</*is_compound=*/true, 10, uint16_t>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
diff --git a/libgav1/src/dsp/warp.h b/libgav1/src/dsp/warp.h
index 3e5b9e0..7367a9b 100644
--- a/libgav1/src/dsp/warp.h
+++ b/libgav1/src/dsp/warp.h
@@ -25,6 +25,14 @@
 // ARM:
 #include "src/dsp/arm/warp_neon.h"
 
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
+#include "src/dsp/x86/warp_sse4.h"
+// clang-format on
+
 // IWYU pragma: end_exports
 
 namespace libgav1 {
diff --git a/libgav1/src/dsp/weight_mask.cc b/libgav1/src/dsp/weight_mask.cc
new file mode 100644
index 0000000..15d6bc6
--- /dev/null
+++ b/libgav1/src/dsp/weight_mask.cc
@@ -0,0 +1,227 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/weight_mask.h"
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <string>
+#include <type_traits>
+
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+template <int width, int height, int bitdepth, bool mask_is_inverse>
+void WeightMask_C(const void* prediction_0, const void* prediction_1,
+                  uint8_t* mask, ptrdiff_t mask_stride) {
+  using PredType =
+      typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type;
+  const auto* pred_0 = static_cast<const PredType*>(prediction_0);
+  const auto* pred_1 = static_cast<const PredType*>(prediction_1);
+  static_assert(width >= 8, "");
+  static_assert(height >= 8, "");
+  constexpr int rounding_bits = bitdepth - 8 + ((bitdepth == 12) ? 2 : 4);
+  for (int y = 0; y < height; ++y) {
+    for (int x = 0; x < width; ++x) {
+      const int difference = RightShiftWithRounding(
+          std::abs(pred_0[x] - pred_1[x]), rounding_bits);
+      const auto mask_value =
+          static_cast<uint8_t>(std::min(DivideBy16(difference) + 38, 64));
+      mask[x] = mask_is_inverse ? 64 - mask_value : mask_value;
+    }
+    pred_0 += width;
+    pred_1 += width;
+    mask += mask_stride;
+  }
+}
+
+#define INIT_WEIGHT_MASK(width, height, bitdepth, w_index, h_index) \
+  dsp->weight_mask[w_index][h_index][0] =                           \
+      WeightMask_C<width, height, bitdepth, 0>;                     \
+  dsp->weight_mask[w_index][h_index][1] =                           \
+      WeightMask_C<width, height, bitdepth, 1>
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  INIT_WEIGHT_MASK(8, 8, 8, 0, 0);
+  INIT_WEIGHT_MASK(8, 16, 8, 0, 1);
+  INIT_WEIGHT_MASK(8, 32, 8, 0, 2);
+  INIT_WEIGHT_MASK(16, 8, 8, 1, 0);
+  INIT_WEIGHT_MASK(16, 16, 8, 1, 1);
+  INIT_WEIGHT_MASK(16, 32, 8, 1, 2);
+  INIT_WEIGHT_MASK(16, 64, 8, 1, 3);
+  INIT_WEIGHT_MASK(32, 8, 8, 2, 0);
+  INIT_WEIGHT_MASK(32, 16, 8, 2, 1);
+  INIT_WEIGHT_MASK(32, 32, 8, 2, 2);
+  INIT_WEIGHT_MASK(32, 64, 8, 2, 3);
+  INIT_WEIGHT_MASK(64, 16, 8, 3, 1);
+  INIT_WEIGHT_MASK(64, 32, 8, 3, 2);
+  INIT_WEIGHT_MASK(64, 64, 8, 3, 3);
+  INIT_WEIGHT_MASK(64, 128, 8, 3, 4);
+  INIT_WEIGHT_MASK(128, 64, 8, 4, 3);
+  INIT_WEIGHT_MASK(128, 128, 8, 4, 4);
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_8x8
+  INIT_WEIGHT_MASK(8, 8, 8, 0, 0);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_8x16
+  INIT_WEIGHT_MASK(8, 16, 8, 0, 1);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_8x32
+  INIT_WEIGHT_MASK(8, 32, 8, 0, 2);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_16x8
+  INIT_WEIGHT_MASK(16, 8, 8, 1, 0);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_16x16
+  INIT_WEIGHT_MASK(16, 16, 8, 1, 1);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_16x32
+  INIT_WEIGHT_MASK(16, 32, 8, 1, 2);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_16x64
+  INIT_WEIGHT_MASK(16, 64, 8, 1, 3);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_32x8
+  INIT_WEIGHT_MASK(32, 8, 8, 2, 0);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_32x16
+  INIT_WEIGHT_MASK(32, 16, 8, 2, 1);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_32x32
+  INIT_WEIGHT_MASK(32, 32, 8, 2, 2);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_32x64
+  INIT_WEIGHT_MASK(32, 64, 8, 2, 3);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_64x16
+  INIT_WEIGHT_MASK(64, 16, 8, 3, 1);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_64x32
+  INIT_WEIGHT_MASK(64, 32, 8, 3, 2);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_64x64
+  INIT_WEIGHT_MASK(64, 64, 8, 3, 3);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_64x128
+  INIT_WEIGHT_MASK(64, 128, 8, 3, 4);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_128x64
+  INIT_WEIGHT_MASK(128, 64, 8, 4, 3);
+#endif
+#ifndef LIBGAV1_Dsp8bpp_WeightMask_128x128
+  INIT_WEIGHT_MASK(128, 128, 8, 4, 4);
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
+  assert(dsp != nullptr);
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  INIT_WEIGHT_MASK(8, 8, 10, 0, 0);
+  INIT_WEIGHT_MASK(8, 16, 10, 0, 1);
+  INIT_WEIGHT_MASK(8, 32, 10, 0, 2);
+  INIT_WEIGHT_MASK(16, 8, 10, 1, 0);
+  INIT_WEIGHT_MASK(16, 16, 10, 1, 1);
+  INIT_WEIGHT_MASK(16, 32, 10, 1, 2);
+  INIT_WEIGHT_MASK(16, 64, 10, 1, 3);
+  INIT_WEIGHT_MASK(32, 8, 10, 2, 0);
+  INIT_WEIGHT_MASK(32, 16, 10, 2, 1);
+  INIT_WEIGHT_MASK(32, 32, 10, 2, 2);
+  INIT_WEIGHT_MASK(32, 64, 10, 2, 3);
+  INIT_WEIGHT_MASK(64, 16, 10, 3, 1);
+  INIT_WEIGHT_MASK(64, 32, 10, 3, 2);
+  INIT_WEIGHT_MASK(64, 64, 10, 3, 3);
+  INIT_WEIGHT_MASK(64, 128, 10, 3, 4);
+  INIT_WEIGHT_MASK(128, 64, 10, 4, 3);
+  INIT_WEIGHT_MASK(128, 128, 10, 4, 4);
+#else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+  static_cast<void>(dsp);
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_8x8
+  INIT_WEIGHT_MASK(8, 8, 10, 0, 0);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_8x16
+  INIT_WEIGHT_MASK(8, 16, 10, 0, 1);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_8x32
+  INIT_WEIGHT_MASK(8, 32, 10, 0, 2);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_16x8
+  INIT_WEIGHT_MASK(16, 8, 10, 1, 0);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_16x16
+  INIT_WEIGHT_MASK(16, 16, 10, 1, 1);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_16x32
+  INIT_WEIGHT_MASK(16, 32, 10, 1, 2);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_16x64
+  INIT_WEIGHT_MASK(16, 64, 10, 1, 3);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_32x8
+  INIT_WEIGHT_MASK(32, 8, 10, 2, 0);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_32x16
+  INIT_WEIGHT_MASK(32, 16, 10, 2, 1);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_32x32
+  INIT_WEIGHT_MASK(32, 32, 10, 2, 2);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_32x64
+  INIT_WEIGHT_MASK(32, 64, 10, 2, 3);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_64x16
+  INIT_WEIGHT_MASK(64, 16, 10, 3, 1);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_64x32
+  INIT_WEIGHT_MASK(64, 32, 10, 3, 2);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_64x64
+  INIT_WEIGHT_MASK(64, 64, 10, 3, 3);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_64x128
+  INIT_WEIGHT_MASK(64, 128, 10, 3, 4);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_128x64
+  INIT_WEIGHT_MASK(128, 64, 10, 4, 3);
+#endif
+#ifndef LIBGAV1_Dsp10bpp_WeightMask_128x128
+  INIT_WEIGHT_MASK(128, 128, 10, 4, 4);
+#endif
+#endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
+}
+#endif
+
+}  // namespace
+
+void WeightMaskInit_C() {
+  Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  Init10bpp();
+#endif
+}
+
+}  // namespace dsp
+}  // namespace libgav1
diff --git a/libgav1/src/dsp/weight_mask.h b/libgav1/src/dsp/weight_mask.h
new file mode 100644
index 0000000..43bef05
--- /dev/null
+++ b/libgav1/src/dsp/weight_mask.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_WEIGHT_MASK_H_
+#define LIBGAV1_SRC_DSP_WEIGHT_MASK_H_
+
+// Pull in LIBGAV1_DspXXX defines representing the implementation status
+// of each function. The resulting value of each can be used by each module to
+// determine whether an implementation is needed at compile time.
+// IWYU pragma: begin_exports
+
+// ARM:
+#include "src/dsp/arm/weight_mask_neon.h"
+
+// x86:
+// Note includes should be sorted in logical order avx2/avx/sse4, etc.
+// The order of includes is important as each tests for a superior version
+// before setting the base.
+// clang-format off
+#include "src/dsp/x86/weight_mask_sse4.h"
+// clang-format on
+
+// IWYU pragma: end_exports
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::weight_mask. This function is not thread-safe.
+void WeightMaskInit_C();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_DSP_WEIGHT_MASK_H_
diff --git a/libgav1/src/dsp/x86/average_blend_sse4.cc b/libgav1/src/dsp/x86/average_blend_sse4.cc
index 264ed02..6c37658 100644
--- a/libgav1/src/dsp/x86/average_blend_sse4.cc
+++ b/libgav1/src/dsp/x86/average_blend_sse4.cc
@@ -13,7 +13,7 @@
 // limitations under the License.
 
 #include "src/dsp/average_blend.h"
-#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
@@ -23,6 +23,8 @@
 #include <cstddef>
 #include <cstdint>
 
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/dsp/x86/common_sse4.h"
 #include "src/utils/common.h"
 
@@ -30,73 +32,65 @@
 namespace dsp {
 namespace {
 
-constexpr int kBitdepth8 = 8;
 constexpr int kInterPostRoundBit = 4;
-// An offset to cancel offsets used in compound predictor generation that
-// make intermediate computations non negative.
-const __m128i kCompoundRoundOffset =
-    _mm_set1_epi16((2 << (kBitdepth8 + 4)) + (2 << (kBitdepth8 + 3)));
 
-inline void AverageBlend4Row(const uint16_t* prediction_0,
-                             const uint16_t* prediction_1, uint8_t* dest) {
+inline void AverageBlend4Row(const int16_t* prediction_0,
+                             const int16_t* prediction_1, uint8_t* dest) {
   const __m128i pred_0 = LoadLo8(prediction_0);
   const __m128i pred_1 = LoadLo8(prediction_1);
   __m128i res = _mm_add_epi16(pred_0, pred_1);
-  res = _mm_sub_epi16(res, kCompoundRoundOffset);
   res = RightShiftWithRounding_S16(res, kInterPostRoundBit + 1);
   Store4(dest, _mm_packus_epi16(res, res));
 }
 
-inline void AverageBlend8Row(const uint16_t* prediction_0,
-                             const uint16_t* prediction_1, uint8_t* dest) {
-  const __m128i pred_0 = LoadUnaligned16(prediction_0);
-  const __m128i pred_1 = LoadUnaligned16(prediction_1);
+inline void AverageBlend8Row(const int16_t* prediction_0,
+                             const int16_t* prediction_1, uint8_t* dest) {
+  const __m128i pred_0 = LoadAligned16(prediction_0);
+  const __m128i pred_1 = LoadAligned16(prediction_1);
   __m128i res = _mm_add_epi16(pred_0, pred_1);
-  res = _mm_sub_epi16(res, kCompoundRoundOffset);
   res = RightShiftWithRounding_S16(res, kInterPostRoundBit + 1);
   StoreLo8(dest, _mm_packus_epi16(res, res));
 }
 
-inline void AverageBlendLargeRow(const uint16_t* prediction_0,
-                                 const uint16_t* prediction_1, const int width,
+inline void AverageBlendLargeRow(const int16_t* prediction_0,
+                                 const int16_t* prediction_1, const int width,
                                  uint8_t* dest) {
   int x = 0;
   do {
-    const __m128i pred_00 = LoadUnaligned16(&prediction_0[x]);
-    const __m128i pred_01 = LoadUnaligned16(&prediction_1[x]);
+    const __m128i pred_00 = LoadAligned16(&prediction_0[x]);
+    const __m128i pred_01 = LoadAligned16(&prediction_1[x]);
     __m128i res0 = _mm_add_epi16(pred_00, pred_01);
-    res0 = _mm_sub_epi16(res0, kCompoundRoundOffset);
     res0 = RightShiftWithRounding_S16(res0, kInterPostRoundBit + 1);
-    const __m128i pred_10 = LoadUnaligned16(&prediction_0[x + 8]);
-    const __m128i pred_11 = LoadUnaligned16(&prediction_1[x + 8]);
+    const __m128i pred_10 = LoadAligned16(&prediction_0[x + 8]);
+    const __m128i pred_11 = LoadAligned16(&prediction_1[x + 8]);
     __m128i res1 = _mm_add_epi16(pred_10, pred_11);
-    res1 = _mm_sub_epi16(res1, kCompoundRoundOffset);
     res1 = RightShiftWithRounding_S16(res1, kInterPostRoundBit + 1);
     StoreUnaligned16(dest + x, _mm_packus_epi16(res0, res1));
     x += 16;
   } while (x < width);
 }
 
-void AverageBlend_SSE4_1(const uint16_t* prediction_0,
-                         const ptrdiff_t prediction_stride_0,
-                         const uint16_t* prediction_1,
-                         const ptrdiff_t prediction_stride_1, const int width,
-                         const int height, void* const dest,
+void AverageBlend_SSE4_1(const void* prediction_0, const void* prediction_1,
+                         const int width, const int height, void* const dest,
                          const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y = height;
 
   if (width == 4) {
     do {
-      AverageBlend4Row(prediction_0, prediction_1, dst);
+      // TODO(b/150326556): |prediction_[01]| values are packed. It is possible
+      // to load 8 values at a time.
+      AverageBlend4Row(pred_0, pred_1, dst);
       dst += dest_stride;
-      prediction_0 += prediction_stride_0;
-      prediction_1 += prediction_stride_1;
+      pred_0 += width;
+      pred_1 += width;
 
-      AverageBlend4Row(prediction_0, prediction_1, dst);
+      AverageBlend4Row(pred_0, pred_1, dst);
       dst += dest_stride;
-      prediction_0 += prediction_stride_0;
-      prediction_1 += prediction_stride_1;
+      pred_0 += width;
+      pred_1 += width;
 
       y -= 2;
     } while (y != 0);
@@ -105,15 +99,15 @@
 
   if (width == 8) {
     do {
-      AverageBlend8Row(prediction_0, prediction_1, dst);
+      AverageBlend8Row(pred_0, pred_1, dst);
       dst += dest_stride;
-      prediction_0 += prediction_stride_0;
-      prediction_1 += prediction_stride_1;
+      pred_0 += width;
+      pred_1 += width;
 
-      AverageBlend8Row(prediction_0, prediction_1, dst);
+      AverageBlend8Row(pred_0, pred_1, dst);
       dst += dest_stride;
-      prediction_0 += prediction_stride_0;
-      prediction_1 += prediction_stride_1;
+      pred_0 += width;
+      pred_1 += width;
 
       y -= 2;
     } while (y != 0);
@@ -121,22 +115,22 @@
   }
 
   do {
-    AverageBlendLargeRow(prediction_0, prediction_1, width, dst);
+    AverageBlendLargeRow(pred_0, pred_1, width, dst);
     dst += dest_stride;
-    prediction_0 += prediction_stride_0;
-    prediction_1 += prediction_stride_1;
+    pred_0 += width;
+    pred_1 += width;
 
-    AverageBlendLargeRow(prediction_0, prediction_1, width, dst);
+    AverageBlendLargeRow(pred_0, pred_1, width, dst);
     dst += dest_stride;
-    prediction_0 += prediction_stride_0;
-    prediction_1 += prediction_stride_1;
+    pred_0 += width;
+    pred_1 += width;
 
     y -= 2;
   } while (y != 0);
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
 #if DSP_ENABLED_8BPP_SSE4_1(AverageBlend)
   dsp->average_blend = AverageBlend_SSE4_1;
@@ -150,7 +144,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_SSE4_1
+#else  // !LIBGAV1_ENABLE_SSE4_1
 
 namespace libgav1 {
 namespace dsp {
diff --git a/libgav1/src/dsp/x86/average_blend_sse4.h b/libgav1/src/dsp/x86/average_blend_sse4.h
index ba4fac6..e205c2b 100644
--- a/libgav1/src/dsp/x86/average_blend_sse4.h
+++ b/libgav1/src/dsp/x86/average_blend_sse4.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_X86_AVERAGE_BLEND_SSE4_H_
 #define LIBGAV1_SRC_DSP_X86_AVERAGE_BLEND_SSE4_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -33,7 +33,7 @@
 // optimization being enabled, signal the sse4 implementation should be used.
 #if LIBGAV1_ENABLE_SSE4_1
 #ifndef LIBGAV1_Dsp8bpp_AverageBlend
-#define LIBGAV1_Dsp8bpp_AverageBlend LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_AverageBlend LIBGAV1_CPU_SSE4_1
 #endif
 
 #endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/cdef_sse4.cc b/libgav1/src/dsp/x86/cdef_sse4.cc
new file mode 100644
index 0000000..4478bde
--- /dev/null
+++ b/libgav1/src/dsp/x86/cdef_sse4.cc
@@ -0,0 +1,728 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/cdef.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_SSE4_1
+
+#include <emmintrin.h>
+#include <tmmintrin.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/x86/common_sse4.h"
+#include "src/dsp/x86/transpose_sse4.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+#include "src/dsp/cdef.inc"
+
+// Used when calculating odd |cost[x]| values.
+// Holds elements 1 3 5 7 7 7 7 7
+alignas(16) constexpr uint32_t kCdefDivisionTableOddPadded[] = {
+    420, 210, 140, 105, 105, 105, 105, 105};
+
+// ----------------------------------------------------------------------------
+// Refer to CdefDirection_C().
+//
+// int32_t partial[8][15] = {};
+// for (int i = 0; i < 8; ++i) {
+//   for (int j = 0; j < 8; ++j) {
+//     const int x = 1;
+//     partial[0][i + j] += x;
+//     partial[1][i + j / 2] += x;
+//     partial[2][i] += x;
+//     partial[3][3 + i - j / 2] += x;
+//     partial[4][7 + i - j] += x;
+//     partial[5][3 - i / 2 + j] += x;
+//     partial[6][j] += x;
+//     partial[7][i / 2 + j] += x;
+//   }
+// }
+//
+// Using the code above, generate the position count for partial[8][15].
+//
+// partial[0]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1
+// partial[1]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
+// partial[2]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0
+// partial[3]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
+// partial[4]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1
+// partial[5]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
+// partial[6]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0
+// partial[7]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
+//
+// The SIMD code shifts the input horizontally, then adds vertically to get the
+// correct partial value for the given position.
+// ----------------------------------------------------------------------------
+
+// ----------------------------------------------------------------------------
+// partial[0][i + j] += x;
+//
+// 00 01 02 03 04 05 06 07  00 00 00 00 00 00 00
+// 00 10 11 12 13 14 15 16  17 00 00 00 00 00 00
+// 00 00 20 21 22 23 24 25  26 27 00 00 00 00 00
+// 00 00 00 30 31 32 33 34  35 36 37 00 00 00 00
+// 00 00 00 00 40 41 42 43  44 45 46 47 00 00 00
+// 00 00 00 00 00 50 51 52  53 54 55 56 57 00 00
+// 00 00 00 00 00 00 60 61  62 63 64 65 66 67 00
+// 00 00 00 00 00 00 00 70  71 72 73 74 75 76 77
+//
+// partial[4] is the same except the source is reversed.
+LIBGAV1_ALWAYS_INLINE void AddPartial_D0_D4(__m128i* v_src_16,
+                                            __m128i* partial_lo,
+                                            __m128i* partial_hi) {
+  // 00 01 02 03 04 05 06 07
+  *partial_lo = v_src_16[0];
+  // 00 00 00 00 00 00 00 00
+  *partial_hi = _mm_setzero_si128();
+
+  // 00 10 11 12 13 14 15 16
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_src_16[1], 2));
+  // 17 00 00 00 00 00 00 00
+  *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_src_16[1], 14));
+
+  // 00 00 20 21 22 23 24 25
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_src_16[2], 4));
+  // 26 27 00 00 00 00 00 00
+  *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_src_16[2], 12));
+
+  // 00 00 00 30 31 32 33 34
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_src_16[3], 6));
+  // 35 36 37 00 00 00 00 00
+  *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_src_16[3], 10));
+
+  // 00 00 00 00 40 41 42 43
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_src_16[4], 8));
+  // 44 45 46 47 00 00 00 00
+  *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_src_16[4], 8));
+
+  // 00 00 00 00 00 50 51 52
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_src_16[5], 10));
+  // 53 54 55 56 57 00 00 00
+  *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_src_16[5], 6));
+
+  // 00 00 00 00 00 00 60 61
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_src_16[6], 12));
+  // 62 63 64 65 66 67 00 00
+  *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_src_16[6], 4));
+
+  // 00 00 00 00 00 00 00 70
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_src_16[7], 14));
+  // 71 72 73 74 75 76 77 00
+  *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_src_16[7], 2));
+}
+
+// ----------------------------------------------------------------------------
+// partial[1][i + j / 2] += x;
+//
+// A0 = src[0] + src[1], A1 = src[2] + src[3], ...
+//
+// A0 A1 A2 A3 00 00 00 00  00 00 00 00 00 00 00
+// 00 B0 B1 B2 B3 00 00 00  00 00 00 00 00 00 00
+// 00 00 C0 C1 C2 C3 00 00  00 00 00 00 00 00 00
+// 00 00 00 D0 D1 D2 D3 00  00 00 00 00 00 00 00
+// 00 00 00 00 E0 E1 E2 E3  00 00 00 00 00 00 00
+// 00 00 00 00 00 F0 F1 F2  F3 00 00 00 00 00 00
+// 00 00 00 00 00 00 G0 G1  G2 G3 00 00 00 00 00
+// 00 00 00 00 00 00 00 H0  H1 H2 H3 00 00 00 00
+//
+// partial[3] is the same except the source is reversed.
+LIBGAV1_ALWAYS_INLINE void AddPartial_D1_D3(__m128i* v_src_16,
+                                            __m128i* partial_lo,
+                                            __m128i* partial_hi) {
+  __m128i v_d1_temp[8];
+  const __m128i v_zero = _mm_setzero_si128();
+
+  for (int i = 0; i < 8; ++i) {
+    v_d1_temp[i] = _mm_hadd_epi16(v_src_16[i], v_zero);
+  }
+
+  *partial_lo = *partial_hi = v_zero;
+  // A0 A1 A2 A3 00 00 00 00
+  *partial_lo = _mm_add_epi16(*partial_lo, v_d1_temp[0]);
+
+  // 00 B0 B1 B2 B3 00 00 00
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_d1_temp[1], 2));
+
+  // 00 00 C0 C1 C2 C3 00 00
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_d1_temp[2], 4));
+  // 00 00 00 D0 D1 D2 D3 00
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_d1_temp[3], 6));
+  // 00 00 00 00 E0 E1 E2 E3
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_d1_temp[4], 8));
+
+  // 00 00 00 00 00 F0 F1 F2
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_d1_temp[5], 10));
+  // F3 00 00 00 00 00 00 00
+  *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_d1_temp[5], 6));
+
+  // 00 00 00 00 00 00 G0 G1
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_d1_temp[6], 12));
+  // G2 G3 00 00 00 00 00 00
+  *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_d1_temp[6], 4));
+
+  // 00 00 00 00 00 00 00 H0
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_d1_temp[7], 14));
+  // H1 H2 H3 00 00 00 00 00
+  *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_d1_temp[7], 2));
+}
+
+// ----------------------------------------------------------------------------
+// partial[7][i / 2 + j] += x;
+//
+// 00 01 02 03 04 05 06 07  00 00 00 00 00 00 00
+// 10 11 12 13 14 15 16 17  00 00 00 00 00 00 00
+// 00 20 21 22 23 24 25 26  27 00 00 00 00 00 00
+// 00 30 31 32 33 34 35 36  37 00 00 00 00 00 00
+// 00 00 40 41 42 43 44 45  46 47 00 00 00 00 00
+// 00 00 50 51 52 53 54 55  56 57 00 00 00 00 00
+// 00 00 00 60 61 62 63 64  65 66 67 00 00 00 00
+// 00 00 00 70 71 72 73 74  75 76 77 00 00 00 00
+//
+// partial[5] is the same except the source is reversed.
+LIBGAV1_ALWAYS_INLINE void AddPartial_D5_D7(__m128i* v_src, __m128i* partial_lo,
+                                            __m128i* partial_hi) {
+  __m128i v_pair_add[4];
+  // Add vertical source pairs.
+  v_pair_add[0] = _mm_add_epi16(v_src[0], v_src[1]);
+  v_pair_add[1] = _mm_add_epi16(v_src[2], v_src[3]);
+  v_pair_add[2] = _mm_add_epi16(v_src[4], v_src[5]);
+  v_pair_add[3] = _mm_add_epi16(v_src[6], v_src[7]);
+
+  // 00 01 02 03 04 05 06 07
+  // 10 11 12 13 14 15 16 17
+  *partial_lo = v_pair_add[0];
+  // 00 00 00 00 00 00 00 00
+  // 00 00 00 00 00 00 00 00
+  *partial_hi = _mm_setzero_si128();
+
+  // 00 20 21 22 23 24 25 26
+  // 00 30 31 32 33 34 35 36
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_pair_add[1], 2));
+  // 27 00 00 00 00 00 00 00
+  // 37 00 00 00 00 00 00 00
+  *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_pair_add[1], 14));
+
+  // 00 00 40 41 42 43 44 45
+  // 00 00 50 51 52 53 54 55
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_pair_add[2], 4));
+  // 46 47 00 00 00 00 00 00
+  // 56 57 00 00 00 00 00 00
+  *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_pair_add[2], 12));
+
+  // 00 00 00 60 61 62 63 64
+  // 00 00 00 70 71 72 73 74
+  *partial_lo = _mm_add_epi16(*partial_lo, _mm_slli_si128(v_pair_add[3], 6));
+  // 65 66 67 00 00 00 00 00
+  // 75 76 77 00 00 00 00 00
+  *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_pair_add[3], 10));
+}
+
+LIBGAV1_ALWAYS_INLINE void AddPartial(const uint8_t* src, ptrdiff_t stride,
+                                      __m128i* partial_lo,
+                                      __m128i* partial_hi) {
+  // 8x8 input
+  // 00 01 02 03 04 05 06 07
+  // 10 11 12 13 14 15 16 17
+  // 20 21 22 23 24 25 26 27
+  // 30 31 32 33 34 35 36 37
+  // 40 41 42 43 44 45 46 47
+  // 50 51 52 53 54 55 56 57
+  // 60 61 62 63 64 65 66 67
+  // 70 71 72 73 74 75 76 77
+  __m128i v_src[8];
+  for (auto& i : v_src) {
+    i = LoadLo8(src);
+    src += stride;
+  }
+
+  const __m128i v_zero = _mm_setzero_si128();
+  // partial for direction 2
+  // --------------------------------------------------------------------------
+  // partial[2][i] += x;
+  // 00 10 20 30 40 50 60 70  00 00 00 00 00 00 00 00
+  // 01 11 21 33 41 51 61 71  00 00 00 00 00 00 00 00
+  // 02 12 22 33 42 52 62 72  00 00 00 00 00 00 00 00
+  // 03 13 23 33 43 53 63 73  00 00 00 00 00 00 00 00
+  // 04 14 24 34 44 54 64 74  00 00 00 00 00 00 00 00
+  // 05 15 25 35 45 55 65 75  00 00 00 00 00 00 00 00
+  // 06 16 26 36 46 56 66 76  00 00 00 00 00 00 00 00
+  // 07 17 27 37 47 57 67 77  00 00 00 00 00 00 00 00
+  const __m128i v_src_4_0 = _mm_unpacklo_epi64(v_src[0], v_src[4]);
+  const __m128i v_src_5_1 = _mm_unpacklo_epi64(v_src[1], v_src[5]);
+  const __m128i v_src_6_2 = _mm_unpacklo_epi64(v_src[2], v_src[6]);
+  const __m128i v_src_7_3 = _mm_unpacklo_epi64(v_src[3], v_src[7]);
+  const __m128i v_hsum_4_0 = _mm_sad_epu8(v_src_4_0, v_zero);
+  const __m128i v_hsum_5_1 = _mm_sad_epu8(v_src_5_1, v_zero);
+  const __m128i v_hsum_6_2 = _mm_sad_epu8(v_src_6_2, v_zero);
+  const __m128i v_hsum_7_3 = _mm_sad_epu8(v_src_7_3, v_zero);
+  const __m128i v_hsum_1_0 = _mm_unpacklo_epi16(v_hsum_4_0, v_hsum_5_1);
+  const __m128i v_hsum_3_2 = _mm_unpacklo_epi16(v_hsum_6_2, v_hsum_7_3);
+  const __m128i v_hsum_5_4 = _mm_unpackhi_epi16(v_hsum_4_0, v_hsum_5_1);
+  const __m128i v_hsum_7_6 = _mm_unpackhi_epi16(v_hsum_6_2, v_hsum_7_3);
+  partial_lo[2] =
+      _mm_unpacklo_epi64(_mm_unpacklo_epi32(v_hsum_1_0, v_hsum_3_2),
+                         _mm_unpacklo_epi32(v_hsum_5_4, v_hsum_7_6));
+
+  __m128i v_src_16[8];
+  for (int i = 0; i < 8; ++i) {
+    v_src_16[i] = _mm_cvtepu8_epi16(v_src[i]);
+  }
+
+  // partial for direction 6
+  // --------------------------------------------------------------------------
+  // partial[6][j] += x;
+  // 00 01 02 03 04 05 06 07  00 00 00 00 00 00 00 00
+  // 10 11 12 13 14 15 16 17  00 00 00 00 00 00 00 00
+  // 20 21 22 23 24 25 26 27  00 00 00 00 00 00 00 00
+  // 30 31 32 33 34 35 36 37  00 00 00 00 00 00 00 00
+  // 40 41 42 43 44 45 46 47  00 00 00 00 00 00 00 00
+  // 50 51 52 53 54 55 56 57  00 00 00 00 00 00 00 00
+  // 60 61 62 63 64 65 66 67  00 00 00 00 00 00 00 00
+  // 70 71 72 73 74 75 76 77  00 00 00 00 00 00 00 00
+  partial_lo[6] = v_src_16[0];
+  for (int i = 1; i < 8; ++i) {
+    partial_lo[6] = _mm_add_epi16(partial_lo[6], v_src_16[i]);
+  }
+
+  // partial for direction 0
+  AddPartial_D0_D4(v_src_16, &partial_lo[0], &partial_hi[0]);
+
+  // partial for direction 1
+  AddPartial_D1_D3(v_src_16, &partial_lo[1], &partial_hi[1]);
+
+  // partial for direction 7
+  AddPartial_D5_D7(v_src_16, &partial_lo[7], &partial_hi[7]);
+
+  __m128i v_src_reverse[8];
+  const __m128i reverser =
+      _mm_set_epi32(0x01000302, 0x05040706, 0x09080b0a, 0x0d0c0f0e);
+  for (int i = 0; i < 8; ++i) {
+    v_src_reverse[i] = _mm_shuffle_epi8(v_src_16[i], reverser);
+  }
+
+  // partial for direction 4
+  AddPartial_D0_D4(v_src_reverse, &partial_lo[4], &partial_hi[4]);
+
+  // partial for direction 3
+  AddPartial_D1_D3(v_src_reverse, &partial_lo[3], &partial_hi[3]);
+
+  // partial for direction 5
+  AddPartial_D5_D7(v_src_reverse, &partial_lo[5], &partial_hi[5]);
+}
+
+inline uint32_t SumVector_S32(__m128i a) {
+  a = _mm_hadd_epi32(a, a);
+  a = _mm_add_epi32(a, _mm_srli_si128(a, 4));
+  return _mm_cvtsi128_si32(a);
+}
+
+// |cost[0]| and |cost[4]| square the input and sum with the corresponding
+// element from the other end of the vector:
+// |kCdefDivisionTable[]| element:
+// cost[0] += (Square(partial[0][i]) + Square(partial[0][14 - i])) *
+//             kCdefDivisionTable[i + 1];
+// cost[0] += Square(partial[0][7]) * kCdefDivisionTable[8];
+inline uint32_t Cost0Or4(const __m128i a, const __m128i b,
+                         const __m128i division_table[2]) {
+  // Reverse and clear upper 2 bytes.
+  const __m128i reverser =
+      _mm_set_epi32(0x80800100, 0x03020504, 0x07060908, 0x0b0a0d0c);
+  // 14 13 12 11 10 09 08 ZZ
+  const __m128i b_reversed = _mm_shuffle_epi8(b, reverser);
+  // 00 14 01 13 02 12 03 11
+  const __m128i ab_lo = _mm_unpacklo_epi16(a, b_reversed);
+  // 04 10 05 09 06 08 07 ZZ
+  const __m128i ab_hi = _mm_unpackhi_epi16(a, b_reversed);
+
+  // Square(partial[0][i]) + Square(partial[0][14 - i])
+  const __m128i square_lo = _mm_madd_epi16(ab_lo, ab_lo);
+  const __m128i square_hi = _mm_madd_epi16(ab_hi, ab_hi);
+
+  const __m128i c = _mm_mullo_epi32(square_lo, division_table[0]);
+  const __m128i d = _mm_mullo_epi32(square_hi, division_table[1]);
+  return SumVector_S32(_mm_add_epi32(c, d));
+}
+
+inline uint32_t CostOdd(const __m128i a, const __m128i b,
+                        const __m128i division_table[2]) {
+  // Reverse and clear upper 10 bytes.
+  const __m128i reverser =
+      _mm_set_epi32(0x80808080, 0x80808080, 0x80800100, 0x03020504);
+  // 10 09 08 ZZ ZZ ZZ ZZ ZZ
+  const __m128i b_reversed = _mm_shuffle_epi8(b, reverser);
+  // 00 10 01 09 02 08 03 ZZ
+  const __m128i ab_lo = _mm_unpacklo_epi16(a, b_reversed);
+  // 04 ZZ 05 ZZ 06 ZZ 07 ZZ
+  const __m128i ab_hi = _mm_unpackhi_epi16(a, b_reversed);
+
+  // Square(partial[0][i]) + Square(partial[0][10 - i])
+  const __m128i square_lo = _mm_madd_epi16(ab_lo, ab_lo);
+  const __m128i square_hi = _mm_madd_epi16(ab_hi, ab_hi);
+
+  const __m128i c = _mm_mullo_epi32(square_lo, division_table[0]);
+  const __m128i d = _mm_mullo_epi32(square_hi, division_table[1]);
+  return SumVector_S32(_mm_add_epi32(c, d));
+}
+
+// Sum of squared elements.
+inline uint32_t SquareSum_S16(const __m128i a) {
+  const __m128i square = _mm_madd_epi16(a, a);
+  return SumVector_S32(square);
+}
+
+void CdefDirection_SSE4_1(const void* const source, ptrdiff_t stride,
+                          int* const direction, int* const variance) {
+  assert(direction != nullptr);
+  assert(variance != nullptr);
+  const auto* src = static_cast<const uint8_t*>(source);
+  uint32_t cost[8];
+  __m128i partial_lo[8], partial_hi[8];
+
+  AddPartial(src, stride, partial_lo, partial_hi);
+
+  cost[2] = kCdefDivisionTable[7] * SquareSum_S16(partial_lo[2]);
+  cost[6] = kCdefDivisionTable[7] * SquareSum_S16(partial_lo[6]);
+
+  const __m128i division_table[2] = {LoadUnaligned16(kCdefDivisionTable),
+                                     LoadUnaligned16(kCdefDivisionTable + 4)};
+
+  cost[0] = Cost0Or4(partial_lo[0], partial_hi[0], division_table);
+  cost[4] = Cost0Or4(partial_lo[4], partial_hi[4], division_table);
+
+  const __m128i division_table_odd[2] = {
+      LoadUnaligned16(kCdefDivisionTableOddPadded),
+      LoadUnaligned16(kCdefDivisionTableOddPadded + 4)};
+
+  cost[1] = CostOdd(partial_lo[1], partial_hi[1], division_table_odd);
+  cost[3] = CostOdd(partial_lo[3], partial_hi[3], division_table_odd);
+  cost[5] = CostOdd(partial_lo[5], partial_hi[5], division_table_odd);
+  cost[7] = CostOdd(partial_lo[7], partial_hi[7], division_table_odd);
+
+  uint32_t best_cost = 0;
+  *direction = 0;
+  for (int i = 0; i < 8; ++i) {
+    if (cost[i] > best_cost) {
+      best_cost = cost[i];
+      *direction = i;
+    }
+  }
+  *variance = (best_cost - cost[(*direction + 4) & 7]) >> 10;
+}
+
+// -------------------------------------------------------------------------
+// CdefFilter
+
+// Load 4 vectors based on the given |direction|.
+inline void LoadDirection(const uint16_t* const src, const ptrdiff_t stride,
+                          __m128i* output, const int direction) {
+  // Each |direction| describes a different set of source values. Expand this
+  // set by negating each set. For |direction| == 0 this gives a diagonal line
+  // from top right to bottom left. The first value is y, the second x. Negative
+  // y values move up.
+  //    a       b         c       d
+  // {-1, 1}, {1, -1}, {-2, 2}, {2, -2}
+  //         c
+  //       a
+  //     0
+  //   b
+  // d
+  const int y_0 = kCdefDirections[direction][0][0];
+  const int x_0 = kCdefDirections[direction][0][1];
+  const int y_1 = kCdefDirections[direction][1][0];
+  const int x_1 = kCdefDirections[direction][1][1];
+  output[0] = LoadUnaligned16(src - y_0 * stride - x_0);
+  output[1] = LoadUnaligned16(src + y_0 * stride + x_0);
+  output[2] = LoadUnaligned16(src - y_1 * stride - x_1);
+  output[3] = LoadUnaligned16(src + y_1 * stride + x_1);
+}
+
+// Load 4 vectors based on the given |direction|. Use when |block_width| == 4 to
+// do 2 rows at a time.
+void LoadDirection4(const uint16_t* const src, const ptrdiff_t stride,
+                    __m128i* output, const int direction) {
+  const int y_0 = kCdefDirections[direction][0][0];
+  const int x_0 = kCdefDirections[direction][0][1];
+  const int y_1 = kCdefDirections[direction][1][0];
+  const int x_1 = kCdefDirections[direction][1][1];
+  output[0] = LoadHi8(LoadLo8(src - y_0 * stride - x_0),
+                      src - y_0 * stride + stride - x_0);
+  output[1] = LoadHi8(LoadLo8(src + y_0 * stride + x_0),
+                      src + y_0 * stride + stride + x_0);
+  output[2] = LoadHi8(LoadLo8(src - y_1 * stride - x_1),
+                      src - y_1 * stride + stride - x_1);
+  output[3] = LoadHi8(LoadLo8(src + y_1 * stride + x_1),
+                      src + y_1 * stride + stride + x_1);
+}
+
+inline __m128i Constrain(const __m128i& pixel, const __m128i& reference,
+                         const __m128i& damping, const __m128i& threshold) {
+  const __m128i diff = _mm_sub_epi16(pixel, reference);
+  const __m128i abs_diff = _mm_abs_epi16(diff);
+  // sign(diff) * Clip3(threshold - (std::abs(diff) >> damping),
+  //                    0, std::abs(diff))
+  const __m128i shifted_diff = _mm_srl_epi16(abs_diff, damping);
+  // For bitdepth == 8, the threshold range is [0, 15] and the damping range is
+  // [3, 6]. If pixel == kCdefLargeValue(0x4000), shifted_diff will always be
+  // larger than threshold. Subtract using saturation will return 0 when pixel
+  // == kCdefLargeValue.
+  static_assert(kCdefLargeValue == 0x4000, "Invalid kCdefLargeValue");
+  const __m128i thresh_minus_shifted_diff =
+      _mm_subs_epu16(threshold, shifted_diff);
+  const __m128i clamp_abs_diff =
+      _mm_min_epi16(thresh_minus_shifted_diff, abs_diff);
+  // Restore the sign.
+  return _mm_sign_epi16(clamp_abs_diff, diff);
+}
+
+inline __m128i ApplyConstrainAndTap(const __m128i& pixel, const __m128i& val,
+                                    const __m128i& tap, const __m128i& damping,
+                                    const __m128i& threshold) {
+  const __m128i constrained = Constrain(val, pixel, damping, threshold);
+  return _mm_mullo_epi16(constrained, tap);
+}
+
+template <int width, bool enable_primary = true, bool enable_secondary = true>
+void CdefFilter_SSE4_1(const uint16_t* src, const ptrdiff_t src_stride,
+                       const int height, const int primary_strength,
+                       const int secondary_strength, const int damping,
+                       const int direction, void* dest,
+                       const ptrdiff_t dst_stride) {
+  static_assert(width == 8 || width == 4, "Invalid CDEF width.");
+  static_assert(enable_primary || enable_secondary, "");
+  constexpr bool clipping_required = enable_primary && enable_secondary;
+  auto* dst = static_cast<uint8_t*>(dest);
+  __m128i primary_damping_shift, secondary_damping_shift;
+
+  // FloorLog2() requires input to be > 0.
+  // 8-bit damping range: Y: [3, 6], UV: [2, 5].
+  if (enable_primary) {
+    // primary_strength: [0, 15] -> FloorLog2: [0, 3] so a clamp is necessary
+    // for UV filtering.
+    primary_damping_shift =
+        _mm_cvtsi32_si128(std::max(0, damping - FloorLog2(primary_strength)));
+  }
+  if (enable_secondary) {
+    // secondary_strength: [0, 4] -> FloorLog2: [0, 2] so no clamp to 0 is
+    // necessary.
+    assert(damping - FloorLog2(secondary_strength) >= 0);
+    secondary_damping_shift =
+        _mm_cvtsi32_si128(damping - FloorLog2(secondary_strength));
+  }
+
+  const __m128i primary_tap_0 =
+      _mm_set1_epi16(kCdefPrimaryTaps[primary_strength & 1][0]);
+  const __m128i primary_tap_1 =
+      _mm_set1_epi16(kCdefPrimaryTaps[primary_strength & 1][1]);
+  const __m128i secondary_tap_0 = _mm_set1_epi16(kCdefSecondaryTap0);
+  const __m128i secondary_tap_1 = _mm_set1_epi16(kCdefSecondaryTap1);
+  const __m128i cdef_large_value_mask =
+      _mm_set1_epi16(static_cast<int16_t>(~kCdefLargeValue));
+  const __m128i primary_threshold = _mm_set1_epi16(primary_strength);
+  const __m128i secondary_threshold = _mm_set1_epi16(secondary_strength);
+
+  int y = height;
+  do {
+    __m128i pixel;
+    if (width == 8) {
+      pixel = LoadUnaligned16(src);
+    } else {
+      pixel = LoadHi8(LoadLo8(src), src + src_stride);
+    }
+
+    __m128i min = pixel;
+    __m128i max = pixel;
+    __m128i sum;
+
+    if (enable_primary) {
+      // Primary |direction|.
+      __m128i primary_val[4];
+      if (width == 8) {
+        LoadDirection(src, src_stride, primary_val, direction);
+      } else {
+        LoadDirection4(src, src_stride, primary_val, direction);
+      }
+
+      if (clipping_required) {
+        min = _mm_min_epu16(min, primary_val[0]);
+        min = _mm_min_epu16(min, primary_val[1]);
+        min = _mm_min_epu16(min, primary_val[2]);
+        min = _mm_min_epu16(min, primary_val[3]);
+
+        // The source is 16 bits, however, we only really care about the lower
+        // 8 bits.  The upper 8 bits contain the "large" flag.  After the final
+        // primary max has been calculated, zero out the upper 8 bits.  Use this
+        // to find the "16 bit" max.
+        const __m128i max_p01 = _mm_max_epu8(primary_val[0], primary_val[1]);
+        const __m128i max_p23 = _mm_max_epu8(primary_val[2], primary_val[3]);
+        const __m128i max_p = _mm_max_epu8(max_p01, max_p23);
+        max = _mm_max_epu16(max, _mm_and_si128(max_p, cdef_large_value_mask));
+      }
+
+      sum = ApplyConstrainAndTap(pixel, primary_val[0], primary_tap_0,
+                                 primary_damping_shift, primary_threshold);
+      sum = _mm_add_epi16(
+          sum, ApplyConstrainAndTap(pixel, primary_val[1], primary_tap_0,
+                                    primary_damping_shift, primary_threshold));
+      sum = _mm_add_epi16(
+          sum, ApplyConstrainAndTap(pixel, primary_val[2], primary_tap_1,
+                                    primary_damping_shift, primary_threshold));
+      sum = _mm_add_epi16(
+          sum, ApplyConstrainAndTap(pixel, primary_val[3], primary_tap_1,
+                                    primary_damping_shift, primary_threshold));
+    } else {
+      sum = _mm_setzero_si128();
+    }
+
+    if (enable_secondary) {
+      // Secondary |direction| values (+/- 2). Clamp |direction|.
+      __m128i secondary_val[8];
+      if (width == 8) {
+        LoadDirection(src, src_stride, secondary_val, direction + 2);
+        LoadDirection(src, src_stride, secondary_val + 4, direction - 2);
+      } else {
+        LoadDirection4(src, src_stride, secondary_val, direction + 2);
+        LoadDirection4(src, src_stride, secondary_val + 4, direction - 2);
+      }
+
+      if (clipping_required) {
+        min = _mm_min_epu16(min, secondary_val[0]);
+        min = _mm_min_epu16(min, secondary_val[1]);
+        min = _mm_min_epu16(min, secondary_val[2]);
+        min = _mm_min_epu16(min, secondary_val[3]);
+        min = _mm_min_epu16(min, secondary_val[4]);
+        min = _mm_min_epu16(min, secondary_val[5]);
+        min = _mm_min_epu16(min, secondary_val[6]);
+        min = _mm_min_epu16(min, secondary_val[7]);
+
+        const __m128i max_s01 =
+            _mm_max_epu8(secondary_val[0], secondary_val[1]);
+        const __m128i max_s23 =
+            _mm_max_epu8(secondary_val[2], secondary_val[3]);
+        const __m128i max_s45 =
+            _mm_max_epu8(secondary_val[4], secondary_val[5]);
+        const __m128i max_s67 =
+            _mm_max_epu8(secondary_val[6], secondary_val[7]);
+        const __m128i max_s = _mm_max_epu8(_mm_max_epu8(max_s01, max_s23),
+                                           _mm_max_epu8(max_s45, max_s67));
+        max = _mm_max_epu16(max, _mm_and_si128(max_s, cdef_large_value_mask));
+      }
+
+      sum = _mm_add_epi16(
+          sum,
+          ApplyConstrainAndTap(pixel, secondary_val[0], secondary_tap_0,
+                               secondary_damping_shift, secondary_threshold));
+      sum = _mm_add_epi16(
+          sum,
+          ApplyConstrainAndTap(pixel, secondary_val[1], secondary_tap_0,
+                               secondary_damping_shift, secondary_threshold));
+      sum = _mm_add_epi16(
+          sum,
+          ApplyConstrainAndTap(pixel, secondary_val[2], secondary_tap_1,
+                               secondary_damping_shift, secondary_threshold));
+      sum = _mm_add_epi16(
+          sum,
+          ApplyConstrainAndTap(pixel, secondary_val[3], secondary_tap_1,
+                               secondary_damping_shift, secondary_threshold));
+      sum = _mm_add_epi16(
+          sum,
+          ApplyConstrainAndTap(pixel, secondary_val[4], secondary_tap_0,
+                               secondary_damping_shift, secondary_threshold));
+      sum = _mm_add_epi16(
+          sum,
+          ApplyConstrainAndTap(pixel, secondary_val[5], secondary_tap_0,
+                               secondary_damping_shift, secondary_threshold));
+      sum = _mm_add_epi16(
+          sum,
+          ApplyConstrainAndTap(pixel, secondary_val[6], secondary_tap_1,
+                               secondary_damping_shift, secondary_threshold));
+      sum = _mm_add_epi16(
+          sum,
+          ApplyConstrainAndTap(pixel, secondary_val[7], secondary_tap_1,
+                               secondary_damping_shift, secondary_threshold));
+    }
+    // Clip3(pixel + ((8 + sum - (sum < 0)) >> 4), min, max))
+    const __m128i sum_lt_0 = _mm_srai_epi16(sum, 15);
+    // 8 + sum
+    sum = _mm_add_epi16(sum, _mm_set1_epi16(8));
+    // (... - (sum < 0)) >> 4
+    sum = _mm_add_epi16(sum, sum_lt_0);
+    sum = _mm_srai_epi16(sum, 4);
+    // pixel + ...
+    sum = _mm_add_epi16(sum, pixel);
+    if (clipping_required) {
+      // Clip3
+      sum = _mm_min_epi16(sum, max);
+      sum = _mm_max_epi16(sum, min);
+    }
+
+    const __m128i result = _mm_packus_epi16(sum, sum);
+    if (width == 8) {
+      src += src_stride;
+      StoreLo8(dst, result);
+      dst += dst_stride;
+      --y;
+    } else {
+      src += src_stride << 1;
+      Store4(dst, result);
+      dst += dst_stride;
+      Store4(dst, _mm_srli_si128(result, 4));
+      dst += dst_stride;
+      y -= 2;
+    }
+  } while (y != 0);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  assert(dsp != nullptr);
+  dsp->cdef_direction = CdefDirection_SSE4_1;
+  dsp->cdef_filters[0][0] = CdefFilter_SSE4_1<4>;
+  dsp->cdef_filters[0][1] =
+      CdefFilter_SSE4_1<4, /*enable_primary=*/true, /*enable_secondary=*/false>;
+  dsp->cdef_filters[0][2] = CdefFilter_SSE4_1<4, /*enable_primary=*/false>;
+  dsp->cdef_filters[1][0] = CdefFilter_SSE4_1<8>;
+  dsp->cdef_filters[1][1] =
+      CdefFilter_SSE4_1<8, /*enable_primary=*/true, /*enable_secondary=*/false>;
+  dsp->cdef_filters[1][2] = CdefFilter_SSE4_1<8, /*enable_primary=*/false>;
+}
+
+}  // namespace
+}  // namespace low_bitdepth
+
+void CdefInit_SSE4_1() { low_bitdepth::Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+#else  // !LIBGAV1_ENABLE_SSE4_1
+namespace libgav1 {
+namespace dsp {
+
+void CdefInit_SSE4_1() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/cdef_sse4.h b/libgav1/src/dsp/x86/cdef_sse4.h
new file mode 100644
index 0000000..2593c72
--- /dev/null
+++ b/libgav1/src/dsp/x86/cdef_sse4.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_X86_CDEF_SSE4_H_
+#define LIBGAV1_SRC_DSP_X86_CDEF_SSE4_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::cdef_direction and Dsp::cdef_filters. This function is not
+// thread-safe.
+void CdefInit_SSE4_1();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_SSE4_1
+#define LIBGAV1_Dsp8bpp_CdefDirection LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_CdefFilters LIBGAV1_CPU_SSE4_1
+#endif  // LIBGAV1_ENABLE_SSE4_1
+
+#endif  // LIBGAV1_SRC_DSP_X86_CDEF_SSE4_H_
diff --git a/libgav1/src/dsp/x86/common_sse4.h b/libgav1/src/dsp/x86/common_sse4.h
index 8ff1ecb..24c801f 100644
--- a/libgav1/src/dsp/x86/common_sse4.h
+++ b/libgav1/src/dsp/x86/common_sse4.h
@@ -17,7 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_X86_COMMON_SSE4_H_
 #define LIBGAV1_SRC_DSP_X86_COMMON_SSE4_H_
 
-#include "src/dsp/dsp.h"
+#include "src/utils/compiler_attributes.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
@@ -26,6 +27,7 @@
 
 #include <cassert>
 #include <cstdint>
+#include <cstdlib>
 #include <cstring>
 
 #if 0
@@ -90,6 +92,14 @@
   return _mm_cvtsi32_si128(val1 | (val2 << 16));
 }
 
+// Load 2 uint8_t values into |lane| * 2 and |lane| * 2 + 1.
+template <int lane>
+inline __m128i Load2(const void* const buf, __m128i val) {
+  uint16_t temp;
+  memcpy(&temp, buf, 2);
+  return _mm_insert_epi16(val, temp, lane);
+}
+
 inline __m128i Load4(const void* src) {
   // With new compilers such as clang 8.0.0 we can use the new _mm_loadu_si32
   // intrinsic. Both _mm_loadu_si32(src) and the code here are compiled into a
@@ -102,6 +112,19 @@
   return _mm_cvtsi32_si128(val);
 }
 
+inline __m128i Load4x2(const void* src1, const void* src2) {
+  // With new compilers such as clang 8.0.0 we can use the new _mm_loadu_si32
+  // intrinsic. Both _mm_loadu_si32(src) and the code here are compiled into a
+  // movss instruction.
+  //
+  // Until compiler support of _mm_loadu_si32 is widespread, use of
+  // _mm_loadu_si32 is banned.
+  int val1, val2;
+  memcpy(&val1, src1, sizeof(val1));
+  memcpy(&val2, src2, sizeof(val2));
+  return _mm_insert_epi32(_mm_cvtsi32_si128(val1), val2, 1);
+}
+
 inline __m128i LoadLo8(const void* a) {
   return _mm_loadl_epi64(static_cast<const __m128i*>(a));
 }
@@ -116,6 +139,46 @@
   return _mm_loadu_si128(static_cast<const __m128i*>(a));
 }
 
+inline __m128i LoadAligned16(const void* a) {
+  assert((reinterpret_cast<uintptr_t>(a) & 0xf) == 0);
+  return _mm_load_si128(static_cast<const __m128i*>(a));
+}
+
+//------------------------------------------------------------------------------
+// Load functions to avoid MemorySanitizer's use-of-uninitialized-value warning.
+
+inline __m128i MaskOverreads(const __m128i source,
+                             const int over_read_in_bytes) {
+  __m128i dst = source;
+#if LIBGAV1_MSAN
+  if (over_read_in_bytes > 0) {
+    __m128i mask = _mm_set1_epi8(-1);
+    for (int i = 0; i < over_read_in_bytes; ++i) {
+      mask = _mm_srli_si128(mask, 1);
+    }
+    dst = _mm_and_si128(dst, mask);
+  }
+#else
+  static_cast<void>(over_read_in_bytes);
+#endif
+  return dst;
+}
+
+inline __m128i LoadLo8Msan(const void* const source,
+                           const int over_read_in_bytes) {
+  return MaskOverreads(LoadLo8(source), over_read_in_bytes + 8);
+}
+
+inline __m128i LoadAligned16Msan(const void* const source,
+                                 const int over_read_in_bytes) {
+  return MaskOverreads(LoadAligned16(source), over_read_in_bytes);
+}
+
+inline __m128i LoadUnaligned16Msan(const void* const source,
+                                   const int over_read_in_bytes) {
+  return MaskOverreads(LoadUnaligned16(source), over_read_in_bytes);
+}
+
 //------------------------------------------------------------------------------
 // Store functions.
 
@@ -137,6 +200,10 @@
   _mm_storeh_pi(static_cast<__m64*>(a), _mm_castsi128_ps(v));
 }
 
+inline void StoreAligned16(void* a, const __m128i v) {
+  _mm_store_si128(static_cast<__m128i*>(a), v);
+}
+
 inline void StoreUnaligned16(void* a, const __m128i v) {
   _mm_storeu_si128(static_cast<__m128i*>(a), v);
 }
@@ -175,13 +242,13 @@
 //------------------------------------------------------------------------------
 // Masking utilities
 inline __m128i MaskHighNBytes(int n) {
-  const uint8_t lu_table[32] = {
+  static constexpr uint8_t kMask[32] = {
       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
       0,   0,   0,   0,   0,   255, 255, 255, 255, 255, 255,
       255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
   };
 
-  return LoadUnaligned16(lu_table + n);
+  return LoadUnaligned16(kMask + n);
 }
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/x86/convolve_sse4.cc b/libgav1/src/dsp/x86/convolve_sse4.cc
index a22df1b..ff9a373 100644
--- a/libgav1/src/dsp/x86/convolve_sse4.cc
+++ b/libgav1/src/dsp/x86/convolve_sse4.cc
@@ -13,18 +13,2300 @@
 // limitations under the License.
 
 #include "src/dsp/convolve.h"
-#include "src/dsp/dsp.h"
+#include "src/utils/constants.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
+#include <smmintrin.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+#include <cstring>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/x86/common_sse4.h"
+#include "src/utils/common.h"
+
 namespace libgav1 {
 namespace dsp {
+namespace low_bitdepth {
+namespace {
 
-void ConvolveInit_SSE4_1() {}
+// TODO(slavarnway): Move to common neon/sse4 file.
+int GetNumTapsInFilter(const int filter_index) {
+  if (filter_index < 2) {
+    // Despite the names these only use 6 taps.
+    // kInterpolationFilterEightTap
+    // kInterpolationFilterEightTapSmooth
+    return 6;
+  }
+
+  if (filter_index == 2) {
+    // kInterpolationFilterEightTapSharp
+    return 8;
+  }
+
+  if (filter_index == 3) {
+    // kInterpolationFilterBilinear
+    return 2;
+  }
+
+  assert(filter_index > 3);
+  // For small sizes (width/height <= 4) the large filters are replaced with 4
+  // tap options.
+  // If the original filters were |kInterpolationFilterEightTap| or
+  // |kInterpolationFilterEightTapSharp| then it becomes
+  // |kInterpolationFilterSwitchable|.
+  // If it was |kInterpolationFilterEightTapSmooth| then it becomes an unnamed 4
+  // tap filter.
+  return 4;
+}
+
+constexpr int kIntermediateStride = kMaxSuperBlockSizeInPixels;
+constexpr int kHorizontalOffset = 3;
+constexpr int kFilterIndexShift = 6;
+
+// Multiply every entry in |src[]| by the corresponding entry in |taps[]| and
+// sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final
+// sum from outranging int16_t.
+template <int filter_index>
+__m128i SumOnePassTaps(const __m128i* const src, const __m128i* const taps) {
+  __m128i sum;
+  if (filter_index < 2) {
+    // 6 taps.
+    const __m128i v_madd_21 = _mm_maddubs_epi16(src[0], taps[0]);  // k2k1
+    const __m128i v_madd_43 = _mm_maddubs_epi16(src[1], taps[1]);  // k4k3
+    const __m128i v_madd_65 = _mm_maddubs_epi16(src[2], taps[2]);  // k6k5
+    sum = _mm_add_epi16(v_madd_21, v_madd_43);
+    sum = _mm_add_epi16(sum, v_madd_65);
+  } else if (filter_index == 2) {
+    // 8 taps.
+    const __m128i v_madd_10 = _mm_maddubs_epi16(src[0], taps[0]);  // k1k0
+    const __m128i v_madd_32 = _mm_maddubs_epi16(src[1], taps[1]);  // k3k2
+    const __m128i v_madd_54 = _mm_maddubs_epi16(src[2], taps[2]);  // k5k4
+    const __m128i v_madd_76 = _mm_maddubs_epi16(src[3], taps[3]);  // k7k6
+    const __m128i v_sum_3210 = _mm_add_epi16(v_madd_10, v_madd_32);
+    const __m128i v_sum_7654 = _mm_add_epi16(v_madd_54, v_madd_76);
+    sum = _mm_add_epi16(v_sum_7654, v_sum_3210);
+  } else if (filter_index == 3) {
+    // 2 taps.
+    sum = _mm_maddubs_epi16(src[0], taps[0]);  // k4k3
+  } else {
+    // 4 taps.
+    const __m128i v_madd_32 = _mm_maddubs_epi16(src[0], taps[0]);  // k3k2
+    const __m128i v_madd_54 = _mm_maddubs_epi16(src[1], taps[1]);  // k5k4
+    sum = _mm_add_epi16(v_madd_32, v_madd_54);
+  }
+  return sum;
+}
+
+template <int filter_index>
+__m128i SumHorizontalTaps(const uint8_t* const src,
+                          const __m128i* const v_tap) {
+  __m128i v_src[4];
+  const __m128i src_long = LoadUnaligned16(src);
+  const __m128i src_long_dup_lo = _mm_unpacklo_epi8(src_long, src_long);
+  const __m128i src_long_dup_hi = _mm_unpackhi_epi8(src_long, src_long);
+
+  if (filter_index < 2) {
+    // 6 taps.
+    v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 3);   // _21
+    v_src[1] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 7);   // _43
+    v_src[2] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 11);  // _65
+  } else if (filter_index == 2) {
+    // 8 taps.
+    v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 1);   // _10
+    v_src[1] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 5);   // _32
+    v_src[2] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 9);   // _54
+    v_src[3] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 13);  // _76
+  } else if (filter_index == 3) {
+    // 2 taps.
+    v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 7);  // _43
+  } else if (filter_index > 3) {
+    // 4 taps.
+    v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 5);  // _32
+    v_src[1] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 9);  // _54
+  }
+  const __m128i sum = SumOnePassTaps<filter_index>(v_src, v_tap);
+  return sum;
+}
+
+template <int filter_index>
+__m128i SimpleHorizontalTaps(const uint8_t* const src,
+                             const __m128i* const v_tap) {
+  __m128i sum = SumHorizontalTaps<filter_index>(src, v_tap);
+
+  // Normally the Horizontal pass does the downshift in two passes:
+  // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
+  // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them
+  // requires adding the rounding offset from the skipped shift.
+  constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2);
+
+  sum = _mm_add_epi16(sum, _mm_set1_epi16(first_shift_rounding_bit));
+  sum = RightShiftWithRounding_S16(sum, kFilterBits - 1);
+  return _mm_packus_epi16(sum, sum);
+}
+
+template <int filter_index>
+__m128i HorizontalTaps8To16(const uint8_t* const src,
+                            const __m128i* const v_tap) {
+  const __m128i sum = SumHorizontalTaps<filter_index>(src, v_tap);
+
+  return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1);
+}
+
+template <int filter_index>
+__m128i SumHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride,
+                             const __m128i* const v_tap) {
+  const __m128i input0 = LoadLo8(&src[2]);
+  const __m128i input1 = LoadLo8(&src[2 + src_stride]);
+
+  if (filter_index == 3) {
+    // 03 04 04 05 05 06 06 07 ....
+    const __m128i input0_dup =
+        _mm_srli_si128(_mm_unpacklo_epi8(input0, input0), 3);
+    // 13 14 14 15 15 16 16 17 ....
+    const __m128i input1_dup =
+        _mm_srli_si128(_mm_unpacklo_epi8(input1, input1), 3);
+    const __m128i v_src_43 = _mm_unpacklo_epi64(input0_dup, input1_dup);
+    const __m128i v_sum_43 = _mm_maddubs_epi16(v_src_43, v_tap[0]);  // k4k3
+    return v_sum_43;
+  }
+
+  // 02 03 03 04 04 05 05 06 06 07 ....
+  const __m128i input0_dup =
+      _mm_srli_si128(_mm_unpacklo_epi8(input0, input0), 1);
+  // 12 13 13 14 14 15 15 16 16 17 ....
+  const __m128i input1_dup =
+      _mm_srli_si128(_mm_unpacklo_epi8(input1, input1), 1);
+  // 04 05 05 06 06 07 07 08 ...
+  const __m128i input0_dup_54 = _mm_srli_si128(input0_dup, 4);
+  // 14 15 15 16 16 17 17 18 ...
+  const __m128i input1_dup_54 = _mm_srli_si128(input1_dup, 4);
+  const __m128i v_src_32 = _mm_unpacklo_epi64(input0_dup, input1_dup);
+  const __m128i v_src_54 = _mm_unpacklo_epi64(input0_dup_54, input1_dup_54);
+  const __m128i v_madd_32 = _mm_maddubs_epi16(v_src_32, v_tap[0]);  // k3k2
+  const __m128i v_madd_54 = _mm_maddubs_epi16(v_src_54, v_tap[1]);  // k5k4
+  const __m128i v_sum_5432 = _mm_add_epi16(v_madd_54, v_madd_32);
+  return v_sum_5432;
+}
+
+template <int filter_index>
+__m128i SimpleHorizontalTaps2x2(const uint8_t* src, const ptrdiff_t src_stride,
+                                const __m128i* const v_tap) {
+  __m128i sum = SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap);
+
+  // Normally the Horizontal pass does the downshift in two passes:
+  // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
+  // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them
+  // requires adding the rounding offset from the skipped shift.
+  constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2);
+
+  sum = _mm_add_epi16(sum, _mm_set1_epi16(first_shift_rounding_bit));
+  sum = RightShiftWithRounding_S16(sum, kFilterBits - 1);
+  return _mm_packus_epi16(sum, sum);
+}
+
+template <int filter_index>
+__m128i HorizontalTaps8To16_2x2(const uint8_t* src, const ptrdiff_t src_stride,
+                                const __m128i* const v_tap) {
+  const __m128i sum =
+      SumHorizontalTaps2x2<filter_index>(src, src_stride, v_tap);
+
+  return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1);
+}
+
+template <int num_taps, int step, int filter_index, bool is_2d = false,
+          bool is_compound = false>
+void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride,
+                      void* const dest, const ptrdiff_t pred_stride,
+                      const int width, const int height,
+                      const __m128i* const v_tap) {
+  auto* dest8 = static_cast<uint8_t*>(dest);
+  auto* dest16 = static_cast<uint16_t*>(dest);
+
+  // 4 tap filters are never used when width > 4.
+  if (num_taps != 4 && width > 4) {
+    int y = 0;
+    do {
+      int x = 0;
+      do {
+        if (is_2d || is_compound) {
+          const __m128i v_sum =
+              HorizontalTaps8To16<filter_index>(&src[x], v_tap);
+          if (is_2d) {
+            StoreAligned16(&dest16[x], v_sum);
+          } else {
+            StoreUnaligned16(&dest16[x], v_sum);
+          }
+        } else {
+          const __m128i result =
+              SimpleHorizontalTaps<filter_index>(&src[x], v_tap);
+          StoreLo8(&dest8[x], result);
+        }
+        x += step;
+      } while (x < width);
+      src += src_stride;
+      dest8 += pred_stride;
+      dest16 += pred_stride;
+    } while (++y < height);
+    return;
+  }
+
+  // Horizontal passes only needs to account for |num_taps| 2 and 4 when
+  // |width| <= 4.
+  assert(width <= 4);
+  assert(num_taps <= 4);
+  if (num_taps <= 4) {
+    if (width == 4) {
+      int y = 0;
+      do {
+        if (is_2d || is_compound) {
+          const __m128i v_sum = HorizontalTaps8To16<filter_index>(src, v_tap);
+          StoreLo8(dest16, v_sum);
+        } else {
+          const __m128i result = SimpleHorizontalTaps<filter_index>(src, v_tap);
+          Store4(&dest8[0], result);
+        }
+        src += src_stride;
+        dest8 += pred_stride;
+        dest16 += pred_stride;
+      } while (++y < height);
+      return;
+    }
+
+    if (!is_compound) {
+      int y = 0;
+      do {
+        if (is_2d) {
+          const __m128i sum =
+              HorizontalTaps8To16_2x2<filter_index>(src, src_stride, v_tap);
+          Store4(&dest16[0], sum);
+          dest16 += pred_stride;
+          Store4(&dest16[0], _mm_srli_si128(sum, 8));
+          dest16 += pred_stride;
+        } else {
+          const __m128i sum =
+              SimpleHorizontalTaps2x2<filter_index>(src, src_stride, v_tap);
+          Store2(dest8, sum);
+          dest8 += pred_stride;
+          Store2(dest8, _mm_srli_si128(sum, 4));
+          dest8 += pred_stride;
+        }
+
+        src += src_stride << 1;
+        y += 2;
+      } while (y < height - 1);
+
+      // The 2d filters have an odd |height| because the horizontal pass
+      // generates context for the vertical pass.
+      if (is_2d) {
+        assert(height % 2 == 1);
+        __m128i sum;
+        const __m128i input = LoadLo8(&src[2]);
+        if (filter_index == 3) {
+          // 03 04 04 05 05 06 06 07 ....
+          const __m128i v_src_43 =
+              _mm_srli_si128(_mm_unpacklo_epi8(input, input), 3);
+          sum = _mm_maddubs_epi16(v_src_43, v_tap[0]);  // k4k3
+        } else {
+          // 02 03 03 04 04 05 05 06 06 07 ....
+          const __m128i v_src_32 =
+              _mm_srli_si128(_mm_unpacklo_epi8(input, input), 1);
+          // 04 05 05 06 06 07 07 08 ...
+          const __m128i v_src_54 = _mm_srli_si128(v_src_32, 4);
+          const __m128i v_madd_32 =
+              _mm_maddubs_epi16(v_src_32, v_tap[0]);  // k3k2
+          const __m128i v_madd_54 =
+              _mm_maddubs_epi16(v_src_54, v_tap[1]);  // k5k4
+          sum = _mm_add_epi16(v_madd_54, v_madd_32);
+        }
+        sum = RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1);
+        Store4(dest16, sum);
+      }
+    }
+  }
+}
+
+template <int num_taps, bool is_2d_vertical = false>
+LIBGAV1_ALWAYS_INLINE void SetupTaps(const __m128i* const filter,
+                                     __m128i* v_tap) {
+  if (num_taps == 8) {
+    v_tap[0] = _mm_shufflelo_epi16(*filter, 0x0);   // k1k0
+    v_tap[1] = _mm_shufflelo_epi16(*filter, 0x55);  // k3k2
+    v_tap[2] = _mm_shufflelo_epi16(*filter, 0xaa);  // k5k4
+    v_tap[3] = _mm_shufflelo_epi16(*filter, 0xff);  // k7k6
+    if (is_2d_vertical) {
+      v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]);
+      v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]);
+      v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]);
+      v_tap[3] = _mm_cvtepi8_epi16(v_tap[3]);
+    } else {
+      v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]);
+      v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]);
+      v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]);
+      v_tap[3] = _mm_unpacklo_epi64(v_tap[3], v_tap[3]);
+    }
+  } else if (num_taps == 6) {
+    const __m128i adjusted_filter = _mm_srli_si128(*filter, 1);
+    v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x0);   // k2k1
+    v_tap[1] = _mm_shufflelo_epi16(adjusted_filter, 0x55);  // k4k3
+    v_tap[2] = _mm_shufflelo_epi16(adjusted_filter, 0xaa);  // k6k5
+    if (is_2d_vertical) {
+      v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]);
+      v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]);
+      v_tap[2] = _mm_cvtepi8_epi16(v_tap[2]);
+    } else {
+      v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]);
+      v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]);
+      v_tap[2] = _mm_unpacklo_epi64(v_tap[2], v_tap[2]);
+    }
+  } else if (num_taps == 4) {
+    v_tap[0] = _mm_shufflelo_epi16(*filter, 0x55);  // k3k2
+    v_tap[1] = _mm_shufflelo_epi16(*filter, 0xaa);  // k5k4
+    if (is_2d_vertical) {
+      v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]);
+      v_tap[1] = _mm_cvtepi8_epi16(v_tap[1]);
+    } else {
+      v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]);
+      v_tap[1] = _mm_unpacklo_epi64(v_tap[1], v_tap[1]);
+    }
+  } else {  // num_taps == 2
+    const __m128i adjusted_filter = _mm_srli_si128(*filter, 1);
+    v_tap[0] = _mm_shufflelo_epi16(adjusted_filter, 0x55);  // k4k3
+    if (is_2d_vertical) {
+      v_tap[0] = _mm_cvtepi8_epi16(v_tap[0]);
+    } else {
+      v_tap[0] = _mm_unpacklo_epi64(v_tap[0], v_tap[0]);
+    }
+  }
+}
+
+template <int num_taps, bool is_compound>
+__m128i SimpleSum2DVerticalTaps(const __m128i* const src,
+                                const __m128i* const taps) {
+  __m128i sum_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[0], src[1]), taps[0]);
+  __m128i sum_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[0], src[1]), taps[0]);
+  if (num_taps >= 4) {
+    __m128i madd_lo =
+        _mm_madd_epi16(_mm_unpacklo_epi16(src[2], src[3]), taps[1]);
+    __m128i madd_hi =
+        _mm_madd_epi16(_mm_unpackhi_epi16(src[2], src[3]), taps[1]);
+    sum_lo = _mm_add_epi32(sum_lo, madd_lo);
+    sum_hi = _mm_add_epi32(sum_hi, madd_hi);
+    if (num_taps >= 6) {
+      madd_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[4], src[5]), taps[2]);
+      madd_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[4], src[5]), taps[2]);
+      sum_lo = _mm_add_epi32(sum_lo, madd_lo);
+      sum_hi = _mm_add_epi32(sum_hi, madd_hi);
+      if (num_taps == 8) {
+        madd_lo = _mm_madd_epi16(_mm_unpacklo_epi16(src[6], src[7]), taps[3]);
+        madd_hi = _mm_madd_epi16(_mm_unpackhi_epi16(src[6], src[7]), taps[3]);
+        sum_lo = _mm_add_epi32(sum_lo, madd_lo);
+        sum_hi = _mm_add_epi32(sum_hi, madd_hi);
+      }
+    }
+  }
+
+  if (is_compound) {
+    return _mm_packs_epi32(
+        RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1),
+        RightShiftWithRounding_S32(sum_hi,
+                                   kInterRoundBitsCompoundVertical - 1));
+  }
+
+  return _mm_packs_epi32(
+      RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1),
+      RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1));
+}
+
+template <int num_taps, bool is_compound = false>
+void Filter2DVertical(const uint16_t* src, void* const dst,
+                      const ptrdiff_t dst_stride, const int width,
+                      const int height, const __m128i* const taps) {
+  assert(width >= 8);
+  constexpr int next_row = num_taps - 1;
+  // The Horizontal pass uses |width| as |stride| for the intermediate buffer.
+  const ptrdiff_t src_stride = width;
+
+  auto* dst8 = static_cast<uint8_t*>(dst);
+  auto* dst16 = static_cast<uint16_t*>(dst);
+
+  int x = 0;
+  do {
+    __m128i srcs[8];
+    const uint16_t* src_x = src + x;
+    srcs[0] = LoadAligned16(src_x);
+    src_x += src_stride;
+    if (num_taps >= 4) {
+      srcs[1] = LoadAligned16(src_x);
+      src_x += src_stride;
+      srcs[2] = LoadAligned16(src_x);
+      src_x += src_stride;
+      if (num_taps >= 6) {
+        srcs[3] = LoadAligned16(src_x);
+        src_x += src_stride;
+        srcs[4] = LoadAligned16(src_x);
+        src_x += src_stride;
+        if (num_taps == 8) {
+          srcs[5] = LoadAligned16(src_x);
+          src_x += src_stride;
+          srcs[6] = LoadAligned16(src_x);
+          src_x += src_stride;
+        }
+      }
+    }
+
+    int y = 0;
+    do {
+      srcs[next_row] = LoadAligned16(src_x);
+      src_x += src_stride;
+
+      const __m128i sum =
+          SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps);
+      if (is_compound) {
+        StoreUnaligned16(dst16 + x + y * dst_stride, sum);
+      } else {
+        StoreLo8(dst8 + x + y * dst_stride, _mm_packus_epi16(sum, sum));
+      }
+
+      srcs[0] = srcs[1];
+      if (num_taps >= 4) {
+        srcs[1] = srcs[2];
+        srcs[2] = srcs[3];
+        if (num_taps >= 6) {
+          srcs[3] = srcs[4];
+          srcs[4] = srcs[5];
+          if (num_taps == 8) {
+            srcs[5] = srcs[6];
+            srcs[6] = srcs[7];
+          }
+        }
+      }
+    } while (++y < height);
+    x += 8;
+  } while (x < width);
+}
+
+// Take advantage of |src_stride| == |width| to process two rows at a time.
+template <int num_taps, bool is_compound = false>
+void Filter2DVertical4xH(const uint16_t* src, void* const dst,
+                         const ptrdiff_t dst_stride, const int height,
+                         const __m128i* const taps) {
+  auto* dst8 = static_cast<uint8_t*>(dst);
+  auto* dst16 = static_cast<uint16_t*>(dst);
+
+  __m128i srcs[9];
+  srcs[0] = LoadAligned16(src);
+  src += 8;
+  if (num_taps >= 4) {
+    srcs[2] = LoadAligned16(src);
+    src += 8;
+    srcs[1] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[0], 8), srcs[2]);
+    if (num_taps >= 6) {
+      srcs[4] = LoadAligned16(src);
+      src += 8;
+      srcs[3] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[2], 8), srcs[4]);
+      if (num_taps == 8) {
+        srcs[6] = LoadAligned16(src);
+        src += 8;
+        srcs[5] = _mm_unpacklo_epi64(_mm_srli_si128(srcs[4], 8), srcs[6]);
+      }
+    }
+  }
+
+  int y = 0;
+  do {
+    srcs[num_taps] = LoadAligned16(src);
+    src += 8;
+    srcs[num_taps - 1] = _mm_unpacklo_epi64(
+        _mm_srli_si128(srcs[num_taps - 2], 8), srcs[num_taps]);
+
+    const __m128i sum =
+        SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps);
+    if (is_compound) {
+      StoreUnaligned16(dst16, sum);
+      dst16 += 4 << 1;
+    } else {
+      const __m128i results = _mm_packus_epi16(sum, sum);
+      Store4(dst8, results);
+      dst8 += dst_stride;
+      Store4(dst8, _mm_srli_si128(results, 4));
+      dst8 += dst_stride;
+    }
+
+    srcs[0] = srcs[2];
+    if (num_taps >= 4) {
+      srcs[1] = srcs[3];
+      srcs[2] = srcs[4];
+      if (num_taps >= 6) {
+        srcs[3] = srcs[5];
+        srcs[4] = srcs[6];
+        if (num_taps == 8) {
+          srcs[5] = srcs[7];
+          srcs[6] = srcs[8];
+        }
+      }
+    }
+    y += 2;
+  } while (y < height);
+}
+
+// Take advantage of |src_stride| == |width| to process four rows at a time.
+template <int num_taps>
+void Filter2DVertical2xH(const uint16_t* src, void* const dst,
+                         const ptrdiff_t dst_stride, const int height,
+                         const __m128i* const taps) {
+  constexpr int next_row = (num_taps < 6) ? 4 : 8;
+
+  auto* dst8 = static_cast<uint8_t*>(dst);
+
+  __m128i srcs[9];
+  srcs[0] = LoadAligned16(src);
+  src += 8;
+  if (num_taps >= 6) {
+    srcs[4] = LoadAligned16(src);
+    src += 8;
+    srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4);
+    if (num_taps == 8) {
+      srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8);
+      srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12);
+    }
+  }
+
+  int y = 0;
+  do {
+    srcs[next_row] = LoadAligned16(src);
+    src += 8;
+    if (num_taps == 2) {
+      srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4);
+    } else if (num_taps == 4) {
+      srcs[1] = _mm_alignr_epi8(srcs[4], srcs[0], 4);
+      srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8);
+      srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12);
+    } else if (num_taps == 6) {
+      srcs[2] = _mm_alignr_epi8(srcs[4], srcs[0], 8);
+      srcs[3] = _mm_alignr_epi8(srcs[4], srcs[0], 12);
+      srcs[5] = _mm_alignr_epi8(srcs[8], srcs[4], 4);
+    } else if (num_taps == 8) {
+      srcs[5] = _mm_alignr_epi8(srcs[8], srcs[4], 4);
+      srcs[6] = _mm_alignr_epi8(srcs[8], srcs[4], 8);
+      srcs[7] = _mm_alignr_epi8(srcs[8], srcs[4], 12);
+    }
+
+    const __m128i sum =
+        SimpleSum2DVerticalTaps<num_taps, /*is_compound=*/false>(srcs, taps);
+    const __m128i results = _mm_packus_epi16(sum, sum);
+
+    Store2(dst8, results);
+    dst8 += dst_stride;
+    Store2(dst8, _mm_srli_si128(results, 2));
+    // When |height| <= 4 the taps are restricted to 2 and 4 tap variants.
+    // Therefore we don't need to check this condition when |height| > 4.
+    if (num_taps <= 4 && height == 2) return;
+    dst8 += dst_stride;
+    Store2(dst8, _mm_srli_si128(results, 4));
+    dst8 += dst_stride;
+    Store2(dst8, _mm_srli_si128(results, 6));
+    dst8 += dst_stride;
+
+    srcs[0] = srcs[4];
+    if (num_taps == 6) {
+      srcs[1] = srcs[5];
+      srcs[4] = srcs[8];
+    } else if (num_taps == 8) {
+      srcs[1] = srcs[5];
+      srcs[2] = srcs[6];
+      srcs[3] = srcs[7];
+      srcs[4] = srcs[8];
+    }
+
+    y += 4;
+  } while (y < height);
+}
+
+template <bool is_2d = false, bool is_compound = false>
+LIBGAV1_ALWAYS_INLINE void DoHorizontalPass(
+    const uint8_t* const src, const ptrdiff_t src_stride, void* const dst,
+    const ptrdiff_t dst_stride, const int width, const int height,
+    const int subpixel, const int filter_index) {
+  const int filter_id = (subpixel >> 6) & kSubPixelMask;
+  assert(filter_id != 0);
+  __m128i v_tap[4];
+  const __m128i v_horizontal_filter =
+      LoadLo8(kHalfSubPixelFilters[filter_index][filter_id]);
+
+  if (filter_index == 2) {  // 8 tap.
+    SetupTaps<8>(&v_horizontal_filter, v_tap);
+    FilterHorizontal<8, 8, 2, is_2d, is_compound>(
+        src, src_stride, dst, dst_stride, width, height, v_tap);
+  } else if (filter_index == 1) {  // 6 tap.
+    SetupTaps<6>(&v_horizontal_filter, v_tap);
+    FilterHorizontal<6, 8, 1, is_2d, is_compound>(
+        src, src_stride, dst, dst_stride, width, height, v_tap);
+  } else if (filter_index == 0) {  // 6 tap.
+    SetupTaps<6>(&v_horizontal_filter, v_tap);
+    FilterHorizontal<6, 8, 0, is_2d, is_compound>(
+        src, src_stride, dst, dst_stride, width, height, v_tap);
+  } else if (filter_index == 4) {  // 4 tap.
+    SetupTaps<4>(&v_horizontal_filter, v_tap);
+    FilterHorizontal<4, 8, 4, is_2d, is_compound>(
+        src, src_stride, dst, dst_stride, width, height, v_tap);
+  } else if (filter_index == 5) {  // 4 tap.
+    SetupTaps<4>(&v_horizontal_filter, v_tap);
+    FilterHorizontal<4, 8, 5, is_2d, is_compound>(
+        src, src_stride, dst, dst_stride, width, height, v_tap);
+  } else {  // 2 tap.
+    SetupTaps<2>(&v_horizontal_filter, v_tap);
+    FilterHorizontal<2, 8, 3, is_2d, is_compound>(
+        src, src_stride, dst, dst_stride, width, height, v_tap);
+  }
+}
+
+void Convolve2D_SSE4_1(const void* const reference,
+                       const ptrdiff_t reference_stride,
+                       const int horizontal_filter_index,
+                       const int vertical_filter_index, const int subpixel_x,
+                       const int subpixel_y, const int width, const int height,
+                       void* prediction, const ptrdiff_t pred_stride) {
+  const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
+  const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
+  const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
+
+  // The output of the horizontal filter is guaranteed to fit in 16 bits.
+  alignas(16) uint16_t
+      intermediate_result[kMaxSuperBlockSizeInPixels *
+                          (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
+  const int intermediate_height = height + vertical_taps - 1;
+
+  const ptrdiff_t src_stride = reference_stride;
+  const auto* src = static_cast<const uint8_t*>(reference) -
+                    (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset;
+
+  DoHorizontalPass</*is_2d=*/true>(src, src_stride, intermediate_result, width,
+                                   width, intermediate_height, subpixel_x,
+                                   horiz_filter_index);
+
+  // Vertical filter.
+  auto* dest = static_cast<uint8_t*>(prediction);
+  const ptrdiff_t dest_stride = pred_stride;
+  const int filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
+  assert(filter_id != 0);
+
+  __m128i taps[4];
+  const __m128i v_filter =
+      LoadLo8(kHalfSubPixelFilters[vert_filter_index][filter_id]);
+
+  if (vertical_taps == 8) {
+    SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter, taps);
+    if (width == 2) {
+      Filter2DVertical2xH<8>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else if (width == 4) {
+      Filter2DVertical4xH<8>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else {
+      Filter2DVertical<8>(intermediate_result, dest, dest_stride, width, height,
+                          taps);
+    }
+  } else if (vertical_taps == 6) {
+    SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter, taps);
+    if (width == 2) {
+      Filter2DVertical2xH<6>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else if (width == 4) {
+      Filter2DVertical4xH<6>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else {
+      Filter2DVertical<6>(intermediate_result, dest, dest_stride, width, height,
+                          taps);
+    }
+  } else if (vertical_taps == 4) {
+    SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter, taps);
+    if (width == 2) {
+      Filter2DVertical2xH<4>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else if (width == 4) {
+      Filter2DVertical4xH<4>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else {
+      Filter2DVertical<4>(intermediate_result, dest, dest_stride, width, height,
+                          taps);
+    }
+  } else {  // |vertical_taps| == 2
+    SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter, taps);
+    if (width == 2) {
+      Filter2DVertical2xH<2>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else if (width == 4) {
+      Filter2DVertical4xH<2>(intermediate_result, dest, dest_stride, height,
+                             taps);
+    } else {
+      Filter2DVertical<2>(intermediate_result, dest, dest_stride, width, height,
+                          taps);
+    }
+  }
+}
+
+// The 1D compound shift is always |kInterRoundBitsHorizontal|, even for 1D
+// Vertical calculations.
+__m128i Compound1DShift(const __m128i sum) {
+  return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1);
+}
+
+template <int filter_index>
+__m128i SumVerticalTaps(const __m128i* const srcs, const __m128i* const v_tap) {
+  __m128i v_src[4];
+
+  if (filter_index < 2) {
+    // 6 taps.
+    v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]);
+    v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]);
+    v_src[2] = _mm_unpacklo_epi8(srcs[4], srcs[5]);
+  } else if (filter_index == 2) {
+    // 8 taps.
+    v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]);
+    v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]);
+    v_src[2] = _mm_unpacklo_epi8(srcs[4], srcs[5]);
+    v_src[3] = _mm_unpacklo_epi8(srcs[6], srcs[7]);
+  } else if (filter_index == 3) {
+    // 2 taps.
+    v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]);
+  } else if (filter_index > 3) {
+    // 4 taps.
+    v_src[0] = _mm_unpacklo_epi8(srcs[0], srcs[1]);
+    v_src[1] = _mm_unpacklo_epi8(srcs[2], srcs[3]);
+  }
+  const __m128i sum = SumOnePassTaps<filter_index>(v_src, v_tap);
+  return sum;
+}
+
+template <int filter_index, bool is_compound = false>
+void FilterVertical(const uint8_t* src, const ptrdiff_t src_stride,
+                    void* const dst, const ptrdiff_t dst_stride,
+                    const int width, const int height,
+                    const __m128i* const v_tap) {
+  const int num_taps = GetNumTapsInFilter(filter_index);
+  const int next_row = num_taps - 1;
+  auto* dst8 = static_cast<uint8_t*>(dst);
+  auto* dst16 = static_cast<uint16_t*>(dst);
+  assert(width >= 8);
+
+  int x = 0;
+  do {
+    const uint8_t* src_x = src + x;
+    __m128i srcs[8];
+    srcs[0] = LoadLo8(src_x);
+    src_x += src_stride;
+    if (num_taps >= 4) {
+      srcs[1] = LoadLo8(src_x);
+      src_x += src_stride;
+      srcs[2] = LoadLo8(src_x);
+      src_x += src_stride;
+      if (num_taps >= 6) {
+        srcs[3] = LoadLo8(src_x);
+        src_x += src_stride;
+        srcs[4] = LoadLo8(src_x);
+        src_x += src_stride;
+        if (num_taps == 8) {
+          srcs[5] = LoadLo8(src_x);
+          src_x += src_stride;
+          srcs[6] = LoadLo8(src_x);
+          src_x += src_stride;
+        }
+      }
+    }
+
+    int y = 0;
+    do {
+      srcs[next_row] = LoadLo8(src_x);
+      src_x += src_stride;
+
+      const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
+      if (is_compound) {
+        const __m128i results = Compound1DShift(sums);
+        StoreUnaligned16(dst16 + x + y * dst_stride, results);
+      } else {
+        const __m128i results =
+            RightShiftWithRounding_S16(sums, kFilterBits - 1);
+        StoreLo8(dst8 + x + y * dst_stride, _mm_packus_epi16(results, results));
+      }
+
+      srcs[0] = srcs[1];
+      if (num_taps >= 4) {
+        srcs[1] = srcs[2];
+        srcs[2] = srcs[3];
+        if (num_taps >= 6) {
+          srcs[3] = srcs[4];
+          srcs[4] = srcs[5];
+          if (num_taps == 8) {
+            srcs[5] = srcs[6];
+            srcs[6] = srcs[7];
+          }
+        }
+      }
+    } while (++y < height);
+    x += 8;
+  } while (x < width);
+}
+
+template <int filter_index, bool is_compound = false>
+void FilterVertical4xH(const uint8_t* src, const ptrdiff_t src_stride,
+                       void* const dst, const ptrdiff_t dst_stride,
+                       const int height, const __m128i* const v_tap) {
+  const int num_taps = GetNumTapsInFilter(filter_index);
+  auto* dst8 = static_cast<uint8_t*>(dst);
+  auto* dst16 = static_cast<uint16_t*>(dst);
+
+  __m128i srcs[9];
+
+  if (num_taps == 2) {
+    srcs[2] = _mm_setzero_si128();
+    // 00 01 02 03
+    srcs[0] = Load4(src);
+    src += src_stride;
+
+    int y = 0;
+    do {
+      // 10 11 12 13
+      const __m128i a = Load4(src);
+      // 00 01 02 03 10 11 12 13
+      srcs[0] = _mm_unpacklo_epi32(srcs[0], a);
+      src += src_stride;
+      // 20 21 22 23
+      srcs[2] = Load4(src);
+      src += src_stride;
+      // 10 11 12 13 20 21 22 23
+      srcs[1] = _mm_unpacklo_epi32(a, srcs[2]);
+
+      const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
+      if (is_compound) {
+        const __m128i results = Compound1DShift(sums);
+        StoreUnaligned16(dst16, results);
+        dst16 += 4 << 1;
+      } else {
+        const __m128i results_16 =
+            RightShiftWithRounding_S16(sums, kFilterBits - 1);
+        const __m128i results = _mm_packus_epi16(results_16, results_16);
+        Store4(dst8, results);
+        dst8 += dst_stride;
+        Store4(dst8, _mm_srli_si128(results, 4));
+        dst8 += dst_stride;
+      }
+
+      srcs[0] = srcs[2];
+      y += 2;
+    } while (y < height);
+  } else if (num_taps == 4) {
+    srcs[4] = _mm_setzero_si128();
+    // 00 01 02 03
+    srcs[0] = Load4(src);
+    src += src_stride;
+    // 10 11 12 13
+    const __m128i a = Load4(src);
+    // 00 01 02 03 10 11 12 13
+    srcs[0] = _mm_unpacklo_epi32(srcs[0], a);
+    src += src_stride;
+    // 20 21 22 23
+    srcs[2] = Load4(src);
+    src += src_stride;
+    // 10 11 12 13 20 21 22 23
+    srcs[1] = _mm_unpacklo_epi32(a, srcs[2]);
+
+    int y = 0;
+    do {
+      // 30 31 32 33
+      const __m128i b = Load4(src);
+      // 20 21 22 23 30 31 32 33
+      srcs[2] = _mm_unpacklo_epi32(srcs[2], b);
+      src += src_stride;
+      // 40 41 42 43
+      srcs[4] = Load4(src);
+      src += src_stride;
+      // 30 31 32 33 40 41 42 43
+      srcs[3] = _mm_unpacklo_epi32(b, srcs[4]);
+
+      const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
+      if (is_compound) {
+        const __m128i results = Compound1DShift(sums);
+        StoreUnaligned16(dst16, results);
+        dst16 += 4 << 1;
+      } else {
+        const __m128i results_16 =
+            RightShiftWithRounding_S16(sums, kFilterBits - 1);
+        const __m128i results = _mm_packus_epi16(results_16, results_16);
+        Store4(dst8, results);
+        dst8 += dst_stride;
+        Store4(dst8, _mm_srli_si128(results, 4));
+        dst8 += dst_stride;
+      }
+
+      srcs[0] = srcs[2];
+      srcs[1] = srcs[3];
+      srcs[2] = srcs[4];
+      y += 2;
+    } while (y < height);
+  } else if (num_taps == 6) {
+    srcs[6] = _mm_setzero_si128();
+    // 00 01 02 03
+    srcs[0] = Load4(src);
+    src += src_stride;
+    // 10 11 12 13
+    const __m128i a = Load4(src);
+    // 00 01 02 03 10 11 12 13
+    srcs[0] = _mm_unpacklo_epi32(srcs[0], a);
+    src += src_stride;
+    // 20 21 22 23
+    srcs[2] = Load4(src);
+    src += src_stride;
+    // 10 11 12 13 20 21 22 23
+    srcs[1] = _mm_unpacklo_epi32(a, srcs[2]);
+    // 30 31 32 33
+    const __m128i b = Load4(src);
+    // 20 21 22 23 30 31 32 33
+    srcs[2] = _mm_unpacklo_epi32(srcs[2], b);
+    src += src_stride;
+    // 40 41 42 43
+    srcs[4] = Load4(src);
+    src += src_stride;
+    // 30 31 32 33 40 41 42 43
+    srcs[3] = _mm_unpacklo_epi32(b, srcs[4]);
+
+    int y = 0;
+    do {
+      // 50 51 52 53
+      const __m128i c = Load4(src);
+      // 40 41 42 43 50 51 52 53
+      srcs[4] = _mm_unpacklo_epi32(srcs[4], c);
+      src += src_stride;
+      // 60 61 62 63
+      srcs[6] = Load4(src);
+      src += src_stride;
+      // 50 51 52 53 60 61 62 63
+      srcs[5] = _mm_unpacklo_epi32(c, srcs[6]);
+
+      const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
+      if (is_compound) {
+        const __m128i results = Compound1DShift(sums);
+        StoreUnaligned16(dst16, results);
+        dst16 += 4 << 1;
+      } else {
+        const __m128i results_16 =
+            RightShiftWithRounding_S16(sums, kFilterBits - 1);
+        const __m128i results = _mm_packus_epi16(results_16, results_16);
+        Store4(dst8, results);
+        dst8 += dst_stride;
+        Store4(dst8, _mm_srli_si128(results, 4));
+        dst8 += dst_stride;
+      }
+
+      srcs[0] = srcs[2];
+      srcs[1] = srcs[3];
+      srcs[2] = srcs[4];
+      srcs[3] = srcs[5];
+      srcs[4] = srcs[6];
+      y += 2;
+    } while (y < height);
+  } else if (num_taps == 8) {
+    srcs[8] = _mm_setzero_si128();
+    // 00 01 02 03
+    srcs[0] = Load4(src);
+    src += src_stride;
+    // 10 11 12 13
+    const __m128i a = Load4(src);
+    // 00 01 02 03 10 11 12 13
+    srcs[0] = _mm_unpacklo_epi32(srcs[0], a);
+    src += src_stride;
+    // 20 21 22 23
+    srcs[2] = Load4(src);
+    src += src_stride;
+    // 10 11 12 13 20 21 22 23
+    srcs[1] = _mm_unpacklo_epi32(a, srcs[2]);
+    // 30 31 32 33
+    const __m128i b = Load4(src);
+    // 20 21 22 23 30 31 32 33
+    srcs[2] = _mm_unpacklo_epi32(srcs[2], b);
+    src += src_stride;
+    // 40 41 42 43
+    srcs[4] = Load4(src);
+    src += src_stride;
+    // 30 31 32 33 40 41 42 43
+    srcs[3] = _mm_unpacklo_epi32(b, srcs[4]);
+    // 50 51 52 53
+    const __m128i c = Load4(src);
+    // 40 41 42 43 50 51 52 53
+    srcs[4] = _mm_unpacklo_epi32(srcs[4], c);
+    src += src_stride;
+    // 60 61 62 63
+    srcs[6] = Load4(src);
+    src += src_stride;
+    // 50 51 52 53 60 61 62 63
+    srcs[5] = _mm_unpacklo_epi32(c, srcs[6]);
+
+    int y = 0;
+    do {
+      // 70 71 72 73
+      const __m128i d = Load4(src);
+      // 60 61 62 63 70 71 72 73
+      srcs[6] = _mm_unpacklo_epi32(srcs[6], d);
+      src += src_stride;
+      // 80 81 82 83
+      srcs[8] = Load4(src);
+      src += src_stride;
+      // 70 71 72 73 80 81 82 83
+      srcs[7] = _mm_unpacklo_epi32(d, srcs[8]);
+
+      const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
+      if (is_compound) {
+        const __m128i results = Compound1DShift(sums);
+        StoreUnaligned16(dst16, results);
+        dst16 += 4 << 1;
+      } else {
+        const __m128i results_16 =
+            RightShiftWithRounding_S16(sums, kFilterBits - 1);
+        const __m128i results = _mm_packus_epi16(results_16, results_16);
+        Store4(dst8, results);
+        dst8 += dst_stride;
+        Store4(dst8, _mm_srli_si128(results, 4));
+        dst8 += dst_stride;
+      }
+
+      srcs[0] = srcs[2];
+      srcs[1] = srcs[3];
+      srcs[2] = srcs[4];
+      srcs[3] = srcs[5];
+      srcs[4] = srcs[6];
+      srcs[5] = srcs[7];
+      srcs[6] = srcs[8];
+      y += 2;
+    } while (y < height);
+  }
+}
+
+template <int filter_index, bool negative_outside_taps = false>
+void FilterVertical2xH(const uint8_t* src, const ptrdiff_t src_stride,
+                       void* const dst, const ptrdiff_t dst_stride,
+                       const int height, const __m128i* const v_tap) {
+  const int num_taps = GetNumTapsInFilter(filter_index);
+  auto* dst8 = static_cast<uint8_t*>(dst);
+
+  __m128i srcs[9];
+
+  if (num_taps == 2) {
+    srcs[2] = _mm_setzero_si128();
+    // 00 01
+    srcs[0] = Load2(src);
+    src += src_stride;
+
+    int y = 0;
+    do {
+      // 00 01 10 11
+      srcs[0] = Load2<1>(src, srcs[0]);
+      src += src_stride;
+      // 00 01 10 11 20 21
+      srcs[0] = Load2<2>(src, srcs[0]);
+      src += src_stride;
+      // 00 01 10 11 20 21 30 31
+      srcs[0] = Load2<3>(src, srcs[0]);
+      src += src_stride;
+      // 40 41
+      srcs[2] = Load2<0>(src, srcs[2]);
+      src += src_stride;
+      // 00 01 10 11 20 21 30 31 40 41
+      const __m128i srcs_0_2 = _mm_unpacklo_epi64(srcs[0], srcs[2]);
+      // 10 11 20 21 30 31 40 41
+      srcs[1] = _mm_srli_si128(srcs_0_2, 2);
+      // This uses srcs[0]..srcs[1].
+      const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
+      const __m128i results_16 =
+          RightShiftWithRounding_S16(sums, kFilterBits - 1);
+      const __m128i results = _mm_packus_epi16(results_16, results_16);
+
+      Store2(dst8, results);
+      dst8 += dst_stride;
+      Store2(dst8, _mm_srli_si128(results, 2));
+      if (height == 2) return;
+      dst8 += dst_stride;
+      Store2(dst8, _mm_srli_si128(results, 4));
+      dst8 += dst_stride;
+      Store2(dst8, _mm_srli_si128(results, 6));
+      dst8 += dst_stride;
+
+      srcs[0] = srcs[2];
+      y += 4;
+    } while (y < height);
+  } else if (num_taps == 4) {
+    srcs[4] = _mm_setzero_si128();
+
+    // 00 01
+    srcs[0] = Load2(src);
+    src += src_stride;
+    // 00 01 10 11
+    srcs[0] = Load2<1>(src, srcs[0]);
+    src += src_stride;
+    // 00 01 10 11 20 21
+    srcs[0] = Load2<2>(src, srcs[0]);
+    src += src_stride;
+
+    int y = 0;
+    do {
+      // 00 01 10 11 20 21 30 31
+      srcs[0] = Load2<3>(src, srcs[0]);
+      src += src_stride;
+      // 40 41
+      srcs[4] = Load2<0>(src, srcs[4]);
+      src += src_stride;
+      // 40 41 50 51
+      srcs[4] = Load2<1>(src, srcs[4]);
+      src += src_stride;
+      // 40 41 50 51 60 61
+      srcs[4] = Load2<2>(src, srcs[4]);
+      src += src_stride;
+      // 00 01 10 11 20 21 30 31 40 41 50 51 60 61
+      const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]);
+      // 10 11 20 21 30 31 40 41
+      srcs[1] = _mm_srli_si128(srcs_0_4, 2);
+      // 20 21 30 31 40 41 50 51
+      srcs[2] = _mm_srli_si128(srcs_0_4, 4);
+      // 30 31 40 41 50 51 60 61
+      srcs[3] = _mm_srli_si128(srcs_0_4, 6);
+
+      // This uses srcs[0]..srcs[3].
+      const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
+      const __m128i results_16 =
+          RightShiftWithRounding_S16(sums, kFilterBits - 1);
+      const __m128i results = _mm_packus_epi16(results_16, results_16);
+
+      Store2(dst8, results);
+      dst8 += dst_stride;
+      Store2(dst8, _mm_srli_si128(results, 2));
+      if (height == 2) return;
+      dst8 += dst_stride;
+      Store2(dst8, _mm_srli_si128(results, 4));
+      dst8 += dst_stride;
+      Store2(dst8, _mm_srli_si128(results, 6));
+      dst8 += dst_stride;
+
+      srcs[0] = srcs[4];
+      y += 4;
+    } while (y < height);
+  } else if (num_taps == 6) {
+    // During the vertical pass the number of taps is restricted when
+    // |height| <= 4.
+    assert(height > 4);
+    srcs[8] = _mm_setzero_si128();
+
+    // 00 01
+    srcs[0] = Load2(src);
+    src += src_stride;
+    // 00 01 10 11
+    srcs[0] = Load2<1>(src, srcs[0]);
+    src += src_stride;
+    // 00 01 10 11 20 21
+    srcs[0] = Load2<2>(src, srcs[0]);
+    src += src_stride;
+    // 00 01 10 11 20 21 30 31
+    srcs[0] = Load2<3>(src, srcs[0]);
+    src += src_stride;
+    // 40 41
+    srcs[4] = Load2(src);
+    src += src_stride;
+    // 00 01 10 11 20 21 30 31 40 41 50 51 60 61
+    const __m128i srcs_0_4x = _mm_unpacklo_epi64(srcs[0], srcs[4]);
+    // 10 11 20 21 30 31 40 41
+    srcs[1] = _mm_srli_si128(srcs_0_4x, 2);
+
+    int y = 0;
+    do {
+      // 40 41 50 51
+      srcs[4] = Load2<1>(src, srcs[4]);
+      src += src_stride;
+      // 40 41 50 51 60 61
+      srcs[4] = Load2<2>(src, srcs[4]);
+      src += src_stride;
+      // 40 41 50 51 60 61 70 71
+      srcs[4] = Load2<3>(src, srcs[4]);
+      src += src_stride;
+      // 80 81
+      srcs[8] = Load2<0>(src, srcs[8]);
+      src += src_stride;
+      // 00 01 10 11 20 21 30 31 40 41 50 51 60 61
+      const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]);
+      // 20 21 30 31 40 41 50 51
+      srcs[2] = _mm_srli_si128(srcs_0_4, 4);
+      // 30 31 40 41 50 51 60 61
+      srcs[3] = _mm_srli_si128(srcs_0_4, 6);
+      const __m128i srcs_4_8 = _mm_unpacklo_epi64(srcs[4], srcs[8]);
+      // 50 51 60 61 70 71 80 81
+      srcs[5] = _mm_srli_si128(srcs_4_8, 2);
+
+      // This uses srcs[0]..srcs[5].
+      const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
+      const __m128i results_16 =
+          RightShiftWithRounding_S16(sums, kFilterBits - 1);
+      const __m128i results = _mm_packus_epi16(results_16, results_16);
+
+      Store2(dst8, results);
+      dst8 += dst_stride;
+      Store2(dst8, _mm_srli_si128(results, 2));
+      dst8 += dst_stride;
+      Store2(dst8, _mm_srli_si128(results, 4));
+      dst8 += dst_stride;
+      Store2(dst8, _mm_srli_si128(results, 6));
+      dst8 += dst_stride;
+
+      srcs[0] = srcs[4];
+      srcs[1] = srcs[5];
+      srcs[4] = srcs[8];
+      y += 4;
+    } while (y < height);
+  } else if (num_taps == 8) {
+    // During the vertical pass the number of taps is restricted when
+    // |height| <= 4.
+    assert(height > 4);
+    srcs[8] = _mm_setzero_si128();
+    // 00 01
+    srcs[0] = Load2(src);
+    src += src_stride;
+    // 00 01 10 11
+    srcs[0] = Load2<1>(src, srcs[0]);
+    src += src_stride;
+    // 00 01 10 11 20 21
+    srcs[0] = Load2<2>(src, srcs[0]);
+    src += src_stride;
+    // 00 01 10 11 20 21 30 31
+    srcs[0] = Load2<3>(src, srcs[0]);
+    src += src_stride;
+    // 40 41
+    srcs[4] = Load2(src);
+    src += src_stride;
+    // 40 41 50 51
+    srcs[4] = Load2<1>(src, srcs[4]);
+    src += src_stride;
+    // 40 41 50 51 60 61
+    srcs[4] = Load2<2>(src, srcs[4]);
+    src += src_stride;
+
+    // 00 01 10 11 20 21 30 31 40 41 50 51 60 61
+    const __m128i srcs_0_4 = _mm_unpacklo_epi64(srcs[0], srcs[4]);
+    // 10 11 20 21 30 31 40 41
+    srcs[1] = _mm_srli_si128(srcs_0_4, 2);
+    // 20 21 30 31 40 41 50 51
+    srcs[2] = _mm_srli_si128(srcs_0_4, 4);
+    // 30 31 40 41 50 51 60 61
+    srcs[3] = _mm_srli_si128(srcs_0_4, 6);
+
+    int y = 0;
+    do {
+      // 40 41 50 51 60 61 70 71
+      srcs[4] = Load2<3>(src, srcs[4]);
+      src += src_stride;
+      // 80 81
+      srcs[8] = Load2<0>(src, srcs[8]);
+      src += src_stride;
+      // 80 81 90 91
+      srcs[8] = Load2<1>(src, srcs[8]);
+      src += src_stride;
+      // 80 81 90 91 a0 a1
+      srcs[8] = Load2<2>(src, srcs[8]);
+      src += src_stride;
+
+      // 40 41 50 51 60 61 70 71 80 81 90 91 a0 a1
+      const __m128i srcs_4_8 = _mm_unpacklo_epi64(srcs[4], srcs[8]);
+      // 50 51 60 61 70 71 80 81
+      srcs[5] = _mm_srli_si128(srcs_4_8, 2);
+      // 60 61 70 71 80 81 90 91
+      srcs[6] = _mm_srli_si128(srcs_4_8, 4);
+      // 70 71 80 81 90 91 a0 a1
+      srcs[7] = _mm_srli_si128(srcs_4_8, 6);
+
+      // This uses srcs[0]..srcs[7].
+      const __m128i sums = SumVerticalTaps<filter_index>(srcs, v_tap);
+      const __m128i results_16 =
+          RightShiftWithRounding_S16(sums, kFilterBits - 1);
+      const __m128i results = _mm_packus_epi16(results_16, results_16);
+
+      Store2(dst8, results);
+      dst8 += dst_stride;
+      Store2(dst8, _mm_srli_si128(results, 2));
+      dst8 += dst_stride;
+      Store2(dst8, _mm_srli_si128(results, 4));
+      dst8 += dst_stride;
+      Store2(dst8, _mm_srli_si128(results, 6));
+      dst8 += dst_stride;
+
+      srcs[0] = srcs[4];
+      srcs[1] = srcs[5];
+      srcs[2] = srcs[6];
+      srcs[3] = srcs[7];
+      srcs[4] = srcs[8];
+      y += 4;
+    } while (y < height);
+  }
+}
+
+void ConvolveVertical_SSE4_1(const void* const reference,
+                             const ptrdiff_t reference_stride,
+                             const int /*horizontal_filter_index*/,
+                             const int vertical_filter_index,
+                             const int /*subpixel_x*/, const int subpixel_y,
+                             const int width, const int height,
+                             void* prediction, const ptrdiff_t pred_stride) {
+  const int filter_index = GetFilterIndex(vertical_filter_index, height);
+  const int vertical_taps = GetNumTapsInFilter(filter_index);
+  const ptrdiff_t src_stride = reference_stride;
+  const auto* src = static_cast<const uint8_t*>(reference) -
+                    (vertical_taps / 2 - 1) * src_stride;
+  auto* dest = static_cast<uint8_t*>(prediction);
+  const ptrdiff_t dest_stride = pred_stride;
+  const int filter_id = (subpixel_y >> 6) & kSubPixelMask;
+  assert(filter_id != 0);
+
+  __m128i taps[4];
+  const __m128i v_filter =
+      LoadLo8(kHalfSubPixelFilters[filter_index][filter_id]);
+
+  if (filter_index < 2) {  // 6 tap.
+    SetupTaps<6>(&v_filter, taps);
+    if (width == 2) {
+      FilterVertical2xH<0>(src, src_stride, dest, dest_stride, height, taps);
+    } else if (width == 4) {
+      FilterVertical4xH<0>(src, src_stride, dest, dest_stride, height, taps);
+    } else {
+      FilterVertical<0>(src, src_stride, dest, dest_stride, width, height,
+                        taps);
+    }
+  } else if (filter_index == 2) {  // 8 tap.
+    SetupTaps<8>(&v_filter, taps);
+    if (width == 2) {
+      FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps);
+    } else if (width == 4) {
+      FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps);
+    } else {
+      FilterVertical<2>(src, src_stride, dest, dest_stride, width, height,
+                        taps);
+    }
+  } else if (filter_index == 3) {  // 2 tap.
+    SetupTaps<2>(&v_filter, taps);
+    if (width == 2) {
+      FilterVertical2xH<3>(src, src_stride, dest, dest_stride, height, taps);
+    } else if (width == 4) {
+      FilterVertical4xH<3>(src, src_stride, dest, dest_stride, height, taps);
+    } else {
+      FilterVertical<3>(src, src_stride, dest, dest_stride, width, height,
+                        taps);
+    }
+  } else if (filter_index == 4) {  // 4 tap.
+    SetupTaps<4>(&v_filter, taps);
+    if (width == 2) {
+      FilterVertical2xH<4>(src, src_stride, dest, dest_stride, height, taps);
+    } else if (width == 4) {
+      FilterVertical4xH<4>(src, src_stride, dest, dest_stride, height, taps);
+    } else {
+      FilterVertical<4>(src, src_stride, dest, dest_stride, width, height,
+                        taps);
+    }
+  } else {
+    // TODO(slavarnway): Investigate adding |filter_index| == 1 special cases.
+    // See convolve_neon.cc
+    SetupTaps<4>(&v_filter, taps);
+
+    if (width == 2) {
+      FilterVertical2xH<5>(src, src_stride, dest, dest_stride, height, taps);
+    } else if (width == 4) {
+      FilterVertical4xH<5>(src, src_stride, dest, dest_stride, height, taps);
+    } else {
+      FilterVertical<5>(src, src_stride, dest, dest_stride, width, height,
+                        taps);
+    }
+  }
+}
+
+void ConvolveCompoundCopy_SSE4(
+    const void* const reference, const ptrdiff_t reference_stride,
+    const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
+    const int /*subpixel_x*/, const int /*subpixel_y*/, const int width,
+    const int height, void* prediction, const ptrdiff_t pred_stride) {
+  const auto* src = static_cast<const uint8_t*>(reference);
+  const ptrdiff_t src_stride = reference_stride;
+  auto* dest = static_cast<uint16_t*>(prediction);
+  constexpr int kRoundBitsVertical =
+      kInterRoundBitsVertical - kInterRoundBitsCompoundVertical;
+  if (width >= 16) {
+    int y = height;
+    do {
+      int x = 0;
+      do {
+        const __m128i v_src = LoadUnaligned16(&src[x]);
+        const __m128i v_src_ext_lo = _mm_cvtepu8_epi16(v_src);
+        const __m128i v_src_ext_hi =
+            _mm_cvtepu8_epi16(_mm_srli_si128(v_src, 8));
+        const __m128i v_dest_lo =
+            _mm_slli_epi16(v_src_ext_lo, kRoundBitsVertical);
+        const __m128i v_dest_hi =
+            _mm_slli_epi16(v_src_ext_hi, kRoundBitsVertical);
+        // TODO(slavarnway): Investigate using aligned stores.
+        StoreUnaligned16(&dest[x], v_dest_lo);
+        StoreUnaligned16(&dest[x + 8], v_dest_hi);
+        x += 16;
+      } while (x < width);
+      src += src_stride;
+      dest += pred_stride;
+    } while (--y != 0);
+  } else if (width == 8) {
+    int y = height;
+    do {
+      const __m128i v_src = LoadLo8(&src[0]);
+      const __m128i v_src_ext = _mm_cvtepu8_epi16(v_src);
+      const __m128i v_dest = _mm_slli_epi16(v_src_ext, kRoundBitsVertical);
+      StoreUnaligned16(&dest[0], v_dest);
+      src += src_stride;
+      dest += pred_stride;
+    } while (--y != 0);
+  } else { /* width == 4 */
+    int y = height;
+    do {
+      const __m128i v_src0 = Load4(&src[0]);
+      const __m128i v_src1 = Load4(&src[src_stride]);
+      const __m128i v_src = _mm_unpacklo_epi32(v_src0, v_src1);
+      const __m128i v_src_ext = _mm_cvtepu8_epi16(v_src);
+      const __m128i v_dest = _mm_slli_epi16(v_src_ext, kRoundBitsVertical);
+      StoreLo8(&dest[0], v_dest);
+      StoreHi8(&dest[pred_stride], v_dest);
+      src += src_stride * 2;
+      dest += pred_stride * 2;
+      y -= 2;
+    } while (y != 0);
+  }
+}
+
+void ConvolveCompoundVertical_SSE4_1(
+    const void* const reference, const ptrdiff_t reference_stride,
+    const int /*horizontal_filter_index*/, const int vertical_filter_index,
+    const int /*subpixel_x*/, const int subpixel_y, const int width,
+    const int height, void* prediction, const ptrdiff_t /*pred_stride*/) {
+  const int filter_index = GetFilterIndex(vertical_filter_index, height);
+  const int vertical_taps = GetNumTapsInFilter(filter_index);
+  const ptrdiff_t src_stride = reference_stride;
+  const auto* src = static_cast<const uint8_t*>(reference) -
+                    (vertical_taps / 2 - 1) * src_stride;
+  auto* dest = static_cast<uint16_t*>(prediction);
+  const int filter_id = (subpixel_y >> 6) & kSubPixelMask;
+  assert(filter_id != 0);
+
+  __m128i taps[4];
+  const __m128i v_filter =
+      LoadLo8(kHalfSubPixelFilters[filter_index][filter_id]);
+
+  if (filter_index < 2) {  // 6 tap.
+    SetupTaps<6>(&v_filter, taps);
+    if (width == 4) {
+      FilterVertical4xH<0, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps);
+    } else {
+      FilterVertical<0, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps);
+    }
+  } else if (filter_index == 2) {  // 8 tap.
+    SetupTaps<8>(&v_filter, taps);
+
+    if (width == 4) {
+      FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps);
+    } else {
+      FilterVertical<2, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps);
+    }
+  } else if (filter_index == 3) {  // 2 tap.
+    SetupTaps<2>(&v_filter, taps);
+
+    if (width == 4) {
+      FilterVertical4xH<3, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps);
+    } else {
+      FilterVertical<3, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps);
+    }
+  } else if (filter_index == 4) {  // 4 tap.
+    SetupTaps<4>(&v_filter, taps);
+
+    if (width == 4) {
+      FilterVertical4xH<4, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps);
+    } else {
+      FilterVertical<4, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps);
+    }
+  } else {
+    SetupTaps<4>(&v_filter, taps);
+
+    if (width == 4) {
+      FilterVertical4xH<5, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps);
+    } else {
+      FilterVertical<5, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps);
+    }
+  }
+}
+
+void ConvolveHorizontal_SSE4_1(const void* const reference,
+                               const ptrdiff_t reference_stride,
+                               const int horizontal_filter_index,
+                               const int /*vertical_filter_index*/,
+                               const int subpixel_x, const int /*subpixel_y*/,
+                               const int width, const int height,
+                               void* prediction, const ptrdiff_t pred_stride) {
+  const int filter_index = GetFilterIndex(horizontal_filter_index, width);
+  // Set |src| to the outermost tap.
+  const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
+  auto* dest = static_cast<uint8_t*>(prediction);
+
+  DoHorizontalPass(src, reference_stride, dest, pred_stride, width, height,
+                   subpixel_x, filter_index);
+}
+
+void ConvolveCompoundHorizontal_SSE4_1(
+    const void* const reference, const ptrdiff_t reference_stride,
+    const int horizontal_filter_index, const int /*vertical_filter_index*/,
+    const int subpixel_x, const int /*subpixel_y*/, const int width,
+    const int height, void* prediction, const ptrdiff_t /*pred_stride*/) {
+  const int filter_index = GetFilterIndex(horizontal_filter_index, width);
+  const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
+  auto* dest = static_cast<uint16_t*>(prediction);
+
+  DoHorizontalPass</*is_2d=*/false, /*is_compound=*/true>(
+      src, reference_stride, dest, width, width, height, subpixel_x,
+      filter_index);
+}
+
+void ConvolveCompound2D_SSE4_1(
+    const void* const reference, const ptrdiff_t reference_stride,
+    const int horizontal_filter_index, const int vertical_filter_index,
+    const int subpixel_x, const int subpixel_y, const int width,
+    const int height, void* prediction, const ptrdiff_t /*pred_stride*/) {
+  // The output of the horizontal filter, i.e. the intermediate_result, is
+  // guaranteed to fit in int16_t.
+  alignas(16) uint16_t
+      intermediate_result[kMaxSuperBlockSizeInPixels *
+                          (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
+
+  // Horizontal filter.
+  // Filter types used for width <= 4 are different from those for width > 4.
+  // When width > 4, the valid filter index range is always [0, 3].
+  // When width <= 4, the valid filter index range is always [4, 5].
+  // Similarly for height.
+  const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
+  const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
+  const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
+  const int intermediate_height = height + vertical_taps - 1;
+  const ptrdiff_t src_stride = reference_stride;
+  const auto* const src = static_cast<const uint8_t*>(reference) -
+                          (vertical_taps / 2 - 1) * src_stride -
+                          kHorizontalOffset;
+
+  DoHorizontalPass</*is_2d=*/true, /*is_compound=*/true>(
+      src, src_stride, intermediate_result, width, width, intermediate_height,
+      subpixel_x, horiz_filter_index);
+
+  // Vertical filter.
+  auto* dest = static_cast<uint16_t*>(prediction);
+  const int filter_id = ((subpixel_y & 1023) >> 6) & kSubPixelMask;
+  assert(filter_id != 0);
+
+  const ptrdiff_t dest_stride = width;
+  __m128i taps[4];
+  const __m128i v_filter =
+      LoadLo8(kHalfSubPixelFilters[vert_filter_index][filter_id]);
+
+  if (vertical_taps == 8) {
+    SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter, taps);
+    if (width == 4) {
+      Filter2DVertical4xH<8, /*is_compound=*/true>(intermediate_result, dest,
+                                                   dest_stride, height, taps);
+    } else {
+      Filter2DVertical<8, /*is_compound=*/true>(
+          intermediate_result, dest, dest_stride, width, height, taps);
+    }
+  } else if (vertical_taps == 6) {
+    SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter, taps);
+    if (width == 4) {
+      Filter2DVertical4xH<6, /*is_compound=*/true>(intermediate_result, dest,
+                                                   dest_stride, height, taps);
+    } else {
+      Filter2DVertical<6, /*is_compound=*/true>(
+          intermediate_result, dest, dest_stride, width, height, taps);
+    }
+  } else if (vertical_taps == 4) {
+    SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter, taps);
+    if (width == 4) {
+      Filter2DVertical4xH<4, /*is_compound=*/true>(intermediate_result, dest,
+                                                   dest_stride, height, taps);
+    } else {
+      Filter2DVertical<4, /*is_compound=*/true>(
+          intermediate_result, dest, dest_stride, width, height, taps);
+    }
+  } else {  // |vertical_taps| == 2
+    SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter, taps);
+    if (width == 4) {
+      Filter2DVertical4xH<2, /*is_compound=*/true>(intermediate_result, dest,
+                                                   dest_stride, height, taps);
+    } else {
+      Filter2DVertical<2, /*is_compound=*/true>(
+          intermediate_result, dest, dest_stride, width, height, taps);
+    }
+  }
+}
+
+// Pre-transposed filters.
+template <int filter_index>
+inline void GetHalfSubPixelFilter(__m128i* output) {
+  // Filter 0
+  alignas(
+      16) static constexpr int8_t kHalfSubPixel6TapSignedFilterColumns[6][16] =
+      {{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0},
+       {0, -3, -5, -6, -7, -7, -8, -7, -7, -6, -6, -6, -5, -4, -2, -1},
+       {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
+       {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
+       {0, -1, -2, -4, -5, -6, -6, -6, -7, -7, -8, -7, -7, -6, -5, -3},
+       {0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}};
+  // Filter 1
+  alignas(16) static constexpr int8_t
+      kHalfSubPixel6TapMixedSignedFilterColumns[6][16] = {
+          {0, 1, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0},
+          {0, 14, 13, 11, 10, 9, 8, 8, 7, 6, 5, 4, 3, 2, 2, 1},
+          {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
+          {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
+          {0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9, 10, 11, 13, 14},
+          {0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 1}};
+  // Filter 2
+  alignas(
+      16) static constexpr int8_t kHalfSubPixel8TapSignedFilterColumns[8][16] =
+      {{0, -1, -1, -1, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, 0},
+       {0, 1, 3, 4, 5, 5, 5, 5, 6, 5, 4, 4, 3, 3, 2, 1},
+       {0, -3, -6, -9, -11, -11, -12, -12, -12, -11, -10, -9, -7, -5, -3, -1},
+       {64, 63, 62, 60, 58, 54, 50, 45, 40, 35, 30, 24, 19, 13, 8, 4},
+       {0, 4, 8, 13, 19, 24, 30, 35, 40, 45, 50, 54, 58, 60, 62, 63},
+       {0, -1, -3, -5, -7, -9, -10, -11, -12, -12, -12, -11, -11, -9, -6, -3},
+       {0, 1, 2, 3, 3, 4, 4, 5, 6, 5, 5, 5, 5, 4, 3, 1},
+       {0, 0, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -1, -1, -1}};
+  // Filter 3
+  alignas(16) static constexpr uint8_t kHalfSubPixel2TapFilterColumns[2][16] = {
+      {64, 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4},
+      {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}};
+  // Filter 4
+  alignas(
+      16) static constexpr int8_t kHalfSubPixel4TapSignedFilterColumns[4][16] =
+      {{0, -2, -4, -5, -6, -6, -7, -6, -6, -5, -5, -5, -4, -3, -2, -1},
+       {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
+       {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
+       {0, -1, -2, -3, -4, -5, -5, -5, -6, -6, -7, -6, -6, -5, -4, -2}};
+  // Filter 5
+  alignas(
+      16) static constexpr uint8_t kSubPixel4TapPositiveFilterColumns[4][16] = {
+      {0, 15, 13, 11, 10, 9, 8, 7, 6, 6, 5, 4, 3, 2, 2, 1},
+      {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
+      {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
+      {0, 1, 2, 2, 3, 4, 5, 6, 6, 7, 8, 9, 10, 11, 13, 15}};
+  switch (filter_index) {
+    case 0:
+      output[0] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[0]);
+      output[1] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[1]);
+      output[2] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[2]);
+      output[3] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[3]);
+      output[4] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[4]);
+      output[5] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[5]);
+      break;
+    case 1:
+      // The term "mixed" refers to the fact that the outer taps have a mix of
+      // negative and positive values.
+      output[0] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[0]);
+      output[1] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[1]);
+      output[2] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[2]);
+      output[3] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[3]);
+      output[4] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[4]);
+      output[5] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[5]);
+      break;
+    case 2:
+      output[0] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[0]);
+      output[1] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[1]);
+      output[2] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[2]);
+      output[3] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[3]);
+      output[4] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[4]);
+      output[5] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[5]);
+      output[6] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[6]);
+      output[7] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[7]);
+      break;
+    case 3:
+      output[0] = LoadAligned16(kHalfSubPixel2TapFilterColumns[0]);
+      output[1] = LoadAligned16(kHalfSubPixel2TapFilterColumns[1]);
+      break;
+    case 4:
+      output[0] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[0]);
+      output[1] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[1]);
+      output[2] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[2]);
+      output[3] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[3]);
+      break;
+    default:
+      assert(filter_index == 5);
+      output[0] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[0]);
+      output[1] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[1]);
+      output[2] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[2]);
+      output[3] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[3]);
+      break;
+  }
+}
+
+// There are many opportunities for overreading in scaled convolve, because
+// the range of starting points for filter windows is anywhere from 0 to 16
+// for 8 destination pixels, and the window sizes range from 2 to 8. To
+// accommodate this range concisely, we use |grade_x| to mean the most steps
+// in src that can be traversed in a single |step_x| increment, i.e. 1 or 2.
+// More importantly, |grade_x| answers the question "how many vector loads are
+// needed to cover the source values?"
+// When |grade_x| == 1, the maximum number of source values needed is 8 separate
+// starting positions plus 7 more to cover taps, all fitting into 16 bytes.
+// When |grade_x| > 1, we are guaranteed to exceed 8 whole steps in src for
+// every 8 |step_x| increments, on top of 8 possible taps. The first load covers
+// the starting sources for each kernel, while the final load covers the taps.
+// Since the offset value of src_x cannot exceed 8 and |num_taps| does not
+// exceed 4 when width <= 4, |grade_x| is set to 1 regardless of the value of
+// |step_x|.
+template <int num_taps, int grade_x>
+inline void PrepareSourceVectors(const uint8_t* src, const __m128i src_indices,
+                                 __m128i* const source /*[num_taps >> 1]*/) {
+  const __m128i src_vals = LoadUnaligned16(src);
+  source[0] = _mm_shuffle_epi8(src_vals, src_indices);
+  if (grade_x == 1) {
+    if (num_taps > 2) {
+      source[1] = _mm_shuffle_epi8(_mm_srli_si128(src_vals, 2), src_indices);
+    }
+    if (num_taps > 4) {
+      source[2] = _mm_shuffle_epi8(_mm_srli_si128(src_vals, 4), src_indices);
+    }
+    if (num_taps > 6) {
+      source[3] = _mm_shuffle_epi8(_mm_srli_si128(src_vals, 6), src_indices);
+    }
+  } else {
+    assert(grade_x > 1);
+    assert(num_taps != 4);
+    // grade_x > 1 also means width >= 8 && num_taps != 4
+    const __m128i src_vals_ext = LoadLo8(src + 16);
+    if (num_taps > 2) {
+      source[1] = _mm_shuffle_epi8(_mm_alignr_epi8(src_vals_ext, src_vals, 2),
+                                   src_indices);
+      source[2] = _mm_shuffle_epi8(_mm_alignr_epi8(src_vals_ext, src_vals, 4),
+                                   src_indices);
+    }
+    if (num_taps > 6) {
+      source[3] = _mm_shuffle_epi8(_mm_alignr_epi8(src_vals_ext, src_vals, 6),
+                                   src_indices);
+    }
+  }
+}
+
+template <int num_taps>
+inline void PrepareHorizontalTaps(const __m128i subpel_indices,
+                                  const __m128i* filter_taps,
+                                  __m128i* out_taps) {
+  const __m128i scale_index_offsets =
+      _mm_srli_epi16(subpel_indices, kFilterIndexShift);
+  const __m128i filter_index_mask = _mm_set1_epi8(kSubPixelMask);
+  const __m128i filter_indices =
+      _mm_and_si128(_mm_packus_epi16(scale_index_offsets, scale_index_offsets),
+                    filter_index_mask);
+  // Line up taps for maddubs_epi16.
+  // The unpack is also assumed to be lighter than shift+alignr.
+  for (int k = 0; k < (num_taps >> 1); ++k) {
+    const __m128i taps0 = _mm_shuffle_epi8(filter_taps[2 * k], filter_indices);
+    const __m128i taps1 =
+        _mm_shuffle_epi8(filter_taps[2 * k + 1], filter_indices);
+    out_taps[k] = _mm_unpacklo_epi8(taps0, taps1);
+  }
+}
+
+inline __m128i HorizontalScaleIndices(const __m128i subpel_indices) {
+  const __m128i src_indices16 =
+      _mm_srli_epi16(subpel_indices, kScaleSubPixelBits);
+  const __m128i src_indices = _mm_packus_epi16(src_indices16, src_indices16);
+  return _mm_unpacklo_epi8(src_indices,
+                           _mm_add_epi8(src_indices, _mm_set1_epi8(1)));
+}
+
+template <int grade_x, int filter_index, int num_taps>
+inline void ConvolveHorizontalScale(const uint8_t* src, ptrdiff_t src_stride,
+                                    int width, int subpixel_x, int step_x,
+                                    int intermediate_height,
+                                    int16_t* intermediate) {
+  // Account for the 0-taps that precede the 2 nonzero taps.
+  const int kernel_offset = (8 - num_taps) >> 1;
+  const int ref_x = subpixel_x >> kScaleSubPixelBits;
+  const int step_x8 = step_x << 3;
+  __m128i filter_taps[num_taps];
+  GetHalfSubPixelFilter<filter_index>(filter_taps);
+  const __m128i index_steps =
+      _mm_mullo_epi16(_mm_set_epi16(7, 6, 5, 4, 3, 2, 1, 0),
+                      _mm_set1_epi16(static_cast<int16_t>(step_x)));
+
+  __m128i taps[num_taps >> 1];
+  __m128i source[num_taps >> 1];
+  int p = subpixel_x;
+  // Case when width <= 4 is possible.
+  if (filter_index >= 3) {
+    if (filter_index > 3 || width <= 4) {
+      const uint8_t* src_x =
+          &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+      // Only add steps to the 10-bit truncated p to avoid overflow.
+      const __m128i p_fraction = _mm_set1_epi16(p & 1023);
+      const __m128i subpel_indices = _mm_add_epi16(index_steps, p_fraction);
+      PrepareHorizontalTaps<num_taps>(subpel_indices, filter_taps, taps);
+      const __m128i packed_indices = HorizontalScaleIndices(subpel_indices);
+
+      int y = intermediate_height;
+      do {
+        // Load and line up source values with the taps. Width 4 means no need
+        // to load extended source.
+        PrepareSourceVectors<num_taps, /*grade_x=*/1>(src_x, packed_indices,
+                                                      source);
+
+        StoreLo8(intermediate, RightShiftWithRounding_S16(
+                                   SumOnePassTaps<filter_index>(source, taps),
+                                   kInterRoundBitsHorizontal - 1));
+        src_x += src_stride;
+        intermediate += kIntermediateStride;
+      } while (--y != 0);
+      return;
+    }
+  }
+
+  // |width| >= 8
+  int x = 0;
+  do {
+    const uint8_t* src_x =
+        &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
+    int16_t* intermediate_x = intermediate + x;
+    // Only add steps to the 10-bit truncated p to avoid overflow.
+    const __m128i p_fraction = _mm_set1_epi16(p & 1023);
+    const __m128i subpel_indices = _mm_add_epi16(index_steps, p_fraction);
+    PrepareHorizontalTaps<num_taps>(subpel_indices, filter_taps, taps);
+    const __m128i packed_indices = HorizontalScaleIndices(subpel_indices);
+
+    int y = intermediate_height;
+    do {
+      // For each x, a lane of src_k[k] contains src_x[k].
+      PrepareSourceVectors<num_taps, grade_x>(src_x, packed_indices, source);
+
+      // Shift by one less because the taps are halved.
+      StoreAligned16(
+          intermediate_x,
+          RightShiftWithRounding_S16(SumOnePassTaps<filter_index>(source, taps),
+                                     kInterRoundBitsHorizontal - 1));
+      src_x += src_stride;
+      intermediate_x += kIntermediateStride;
+    } while (--y != 0);
+    x += 8;
+    p += step_x8;
+  } while (x < width);
+}
+
+template <int num_taps>
+inline void PrepareVerticalTaps(const int8_t* taps, __m128i* output) {
+  // Avoid overreading the filter due to starting at kernel_offset.
+  // The only danger of overread is in the final filter, which has 4 taps.
+  const __m128i filter =
+      _mm_cvtepi8_epi16((num_taps > 4) ? LoadLo8(taps) : Load4(taps));
+  output[0] = _mm_shuffle_epi32(filter, 0);
+  if (num_taps > 2) {
+    output[1] = _mm_shuffle_epi32(filter, 0x55);
+  }
+  if (num_taps > 4) {
+    output[2] = _mm_shuffle_epi32(filter, 0xAA);
+  }
+  if (num_taps > 6) {
+    output[3] = _mm_shuffle_epi32(filter, 0xFF);
+  }
+}
+
+// Process eight 16 bit inputs and output eight 16 bit values.
+template <int num_taps, bool is_compound>
+inline __m128i Sum2DVerticalTaps(const __m128i* const src,
+                                 const __m128i* taps) {
+  const __m128i src_lo_01 = _mm_unpacklo_epi16(src[0], src[1]);
+  __m128i sum_lo = _mm_madd_epi16(src_lo_01, taps[0]);
+  const __m128i src_hi_01 = _mm_unpackhi_epi16(src[0], src[1]);
+  __m128i sum_hi = _mm_madd_epi16(src_hi_01, taps[0]);
+  if (num_taps > 2) {
+    const __m128i src_lo_23 = _mm_unpacklo_epi16(src[2], src[3]);
+    sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_23, taps[1]));
+    const __m128i src_hi_23 = _mm_unpackhi_epi16(src[2], src[3]);
+    sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_23, taps[1]));
+  }
+  if (num_taps > 4) {
+    const __m128i src_lo_45 = _mm_unpacklo_epi16(src[4], src[5]);
+    sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_45, taps[2]));
+    const __m128i src_hi_45 = _mm_unpackhi_epi16(src[4], src[5]);
+    sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_45, taps[2]));
+  }
+  if (num_taps > 6) {
+    const __m128i src_lo_67 = _mm_unpacklo_epi16(src[6], src[7]);
+    sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_67, taps[3]));
+    const __m128i src_hi_67 = _mm_unpackhi_epi16(src[6], src[7]);
+    sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_67, taps[3]));
+  }
+  if (is_compound) {
+    return _mm_packs_epi32(
+        RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1),
+        RightShiftWithRounding_S32(sum_hi,
+                                   kInterRoundBitsCompoundVertical - 1));
+  }
+  return _mm_packs_epi32(
+      RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1),
+      RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1));
+}
+
+// Bottom half of each src[k] is the source for one filter, and the top half
+// is the source for the other filter, for the next destination row.
+template <int num_taps, bool is_compound>
+__m128i Sum2DVerticalTaps4x2(const __m128i* const src, const __m128i* taps_lo,
+                             const __m128i* taps_hi) {
+  const __m128i src_lo_01 = _mm_unpacklo_epi16(src[0], src[1]);
+  __m128i sum_lo = _mm_madd_epi16(src_lo_01, taps_lo[0]);
+  const __m128i src_hi_01 = _mm_unpackhi_epi16(src[0], src[1]);
+  __m128i sum_hi = _mm_madd_epi16(src_hi_01, taps_hi[0]);
+  if (num_taps > 2) {
+    const __m128i src_lo_23 = _mm_unpacklo_epi16(src[2], src[3]);
+    sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_23, taps_lo[1]));
+    const __m128i src_hi_23 = _mm_unpackhi_epi16(src[2], src[3]);
+    sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_23, taps_hi[1]));
+  }
+  if (num_taps > 4) {
+    const __m128i src_lo_45 = _mm_unpacklo_epi16(src[4], src[5]);
+    sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_45, taps_lo[2]));
+    const __m128i src_hi_45 = _mm_unpackhi_epi16(src[4], src[5]);
+    sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_45, taps_hi[2]));
+  }
+  if (num_taps > 6) {
+    const __m128i src_lo_67 = _mm_unpacklo_epi16(src[6], src[7]);
+    sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_67, taps_lo[3]));
+    const __m128i src_hi_67 = _mm_unpackhi_epi16(src[6], src[7]);
+    sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_67, taps_hi[3]));
+  }
+
+  if (is_compound) {
+    return _mm_packs_epi32(
+        RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1),
+        RightShiftWithRounding_S32(sum_hi,
+                                   kInterRoundBitsCompoundVertical - 1));
+  }
+  return _mm_packs_epi32(
+      RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1),
+      RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1));
+}
+
+// |width_class| is 2, 4, or 8, according to the Store function that should be
+// used.
+template <int num_taps, int width_class, bool is_compound>
+#if LIBGAV1_MSAN
+__attribute__((no_sanitize_memory)) void ConvolveVerticalScale(
+#else
+inline void ConvolveVerticalScale(
+#endif
+    const int16_t* src, const int width, const int subpixel_y,
+    const int filter_index, const int step_y, const int height, void* dest,
+    const ptrdiff_t dest_stride) {
+  constexpr ptrdiff_t src_stride = kIntermediateStride;
+  constexpr int kernel_offset = (8 - num_taps) / 2;
+  const int16_t* src_y = src;
+  // |dest| is 16-bit in compound mode, Pixel otherwise.
+  auto* dest16_y = static_cast<uint16_t*>(dest);
+  auto* dest_y = static_cast<uint8_t*>(dest);
+  __m128i s[num_taps];
+
+  int p = subpixel_y & 1023;
+  int y = height;
+  if (width_class <= 4) {
+    __m128i filter_taps_lo[num_taps >> 1];
+    __m128i filter_taps_hi[num_taps >> 1];
+    do {  // y > 0
+      for (int i = 0; i < num_taps; ++i) {
+        s[i] = LoadLo8(src_y + i * src_stride);
+      }
+      int filter_id = (p >> 6) & kSubPixelMask;
+      const int8_t* filter0 =
+          kHalfSubPixelFilters[filter_index][filter_id] + kernel_offset;
+      PrepareVerticalTaps<num_taps>(filter0, filter_taps_lo);
+      p += step_y;
+      src_y = src + (p >> kScaleSubPixelBits) * src_stride;
+
+      for (int i = 0; i < num_taps; ++i) {
+        s[i] = LoadHi8(s[i], src_y + i * src_stride);
+      }
+      filter_id = (p >> 6) & kSubPixelMask;
+      const int8_t* filter1 =
+          kHalfSubPixelFilters[filter_index][filter_id] + kernel_offset;
+      PrepareVerticalTaps<num_taps>(filter1, filter_taps_hi);
+      p += step_y;
+      src_y = src + (p >> kScaleSubPixelBits) * src_stride;
+
+      const __m128i sums = Sum2DVerticalTaps4x2<num_taps, is_compound>(
+          s, filter_taps_lo, filter_taps_hi);
+      if (is_compound) {
+        assert(width_class > 2);
+        StoreLo8(dest16_y, sums);
+        dest16_y += dest_stride;
+        StoreHi8(dest16_y, sums);
+        dest16_y += dest_stride;
+      } else {
+        const __m128i result = _mm_packus_epi16(sums, sums);
+        if (width_class == 2) {
+          Store2(dest_y, result);
+          dest_y += dest_stride;
+          Store2(dest_y, _mm_srli_si128(result, 4));
+        } else {
+          Store4(dest_y, result);
+          dest_y += dest_stride;
+          Store4(dest_y, _mm_srli_si128(result, 4));
+        }
+        dest_y += dest_stride;
+      }
+      y -= 2;
+    } while (y != 0);
+    return;
+  }
+
+  // |width_class| >= 8
+  __m128i filter_taps[num_taps >> 1];
+  do {  // y > 0
+    src_y = src + (p >> kScaleSubPixelBits) * src_stride;
+    const int filter_id = (p >> 6) & kSubPixelMask;
+    const int8_t* filter =
+        kHalfSubPixelFilters[filter_index][filter_id] + kernel_offset;
+    PrepareVerticalTaps<num_taps>(filter, filter_taps);
+
+    int x = 0;
+    do {  // x < width
+      for (int i = 0; i < num_taps; ++i) {
+        s[i] = LoadUnaligned16(src_y + i * src_stride);
+      }
+
+      const __m128i sums =
+          Sum2DVerticalTaps<num_taps, is_compound>(s, filter_taps);
+      if (is_compound) {
+        StoreUnaligned16(dest16_y + x, sums);
+      } else {
+        StoreLo8(dest_y + x, _mm_packus_epi16(sums, sums));
+      }
+      x += 8;
+      src_y += 8;
+    } while (x < width);
+    p += step_y;
+    dest_y += dest_stride;
+    dest16_y += dest_stride;
+  } while (--y != 0);
+}
+
+template <bool is_compound>
+void ConvolveScale2D_SSE4_1(const void* const reference,
+                            const ptrdiff_t reference_stride,
+                            const int horizontal_filter_index,
+                            const int vertical_filter_index,
+                            const int subpixel_x, const int subpixel_y,
+                            const int step_x, const int step_y, const int width,
+                            const int height, void* prediction,
+                            const ptrdiff_t pred_stride) {
+  const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
+  const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
+  assert(step_x <= 2048);
+  // The output of the horizontal filter, i.e. the intermediate_result, is
+  // guaranteed to fit in int16_t.
+  // TODO(petersonab): Reduce intermediate block stride to width to make smaller
+  // blocks faster.
+  alignas(16) int16_t
+      intermediate_result[kMaxSuperBlockSizeInPixels *
+                          (2 * kMaxSuperBlockSizeInPixels + kSubPixelTaps)];
+  const int num_vert_taps = GetNumTapsInFilter(vert_filter_index);
+  const int intermediate_height =
+      (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
+       kScaleSubPixelBits) +
+      num_vert_taps;
+
+  // Horizontal filter.
+  // Filter types used for width <= 4 are different from those for width > 4.
+  // When width > 4, the valid filter index range is always [0, 3].
+  // When width <= 4, the valid filter index range is always [3, 5].
+  // Similarly for height.
+  int16_t* intermediate = intermediate_result;
+  const ptrdiff_t src_stride = reference_stride;
+  const auto* src = static_cast<const uint8_t*>(reference);
+  const int vert_kernel_offset = (8 - num_vert_taps) / 2;
+  src += vert_kernel_offset * src_stride;
+
+  // Derive the maximum value of |step_x| at which all source values fit in one
+  // 16-byte load. Final index is src_x + |num_taps| - 1 < 16
+  // step_x*7 is the final base sub-pixel index for the shuffle mask for filter
+  // inputs in each iteration on large blocks. When step_x is large, we need a
+  // second register and alignr in order to gather all filter inputs.
+  // |num_taps| - 1 is the offset for the shuffle of inputs to the final tap.
+  const int num_horiz_taps = GetNumTapsInFilter(horiz_filter_index);
+  const int kernel_start_ceiling = 16 - num_horiz_taps;
+  // This truncated quotient |grade_x_threshold| selects |step_x| such that:
+  // (step_x * 7) >> kScaleSubPixelBits < single load limit
+  const int grade_x_threshold =
+      (kernel_start_ceiling << kScaleSubPixelBits) / 7;
+  switch (horiz_filter_index) {
+    case 0:
+      if (step_x > grade_x_threshold) {
+        ConvolveHorizontalScale<2, 0, 6>(src, src_stride, width, subpixel_x,
+                                         step_x, intermediate_height,
+                                         intermediate);
+      } else {
+        ConvolveHorizontalScale<1, 0, 6>(src, src_stride, width, subpixel_x,
+                                         step_x, intermediate_height,
+                                         intermediate);
+      }
+      break;
+    case 1:
+      if (step_x > grade_x_threshold) {
+        ConvolveHorizontalScale<2, 1, 6>(src, src_stride, width, subpixel_x,
+                                         step_x, intermediate_height,
+                                         intermediate);
+
+      } else {
+        ConvolveHorizontalScale<1, 1, 6>(src, src_stride, width, subpixel_x,
+                                         step_x, intermediate_height,
+                                         intermediate);
+      }
+      break;
+    case 2:
+      if (step_x > grade_x_threshold) {
+        ConvolveHorizontalScale<2, 2, 8>(src, src_stride, width, subpixel_x,
+                                         step_x, intermediate_height,
+                                         intermediate);
+      } else {
+        ConvolveHorizontalScale<1, 2, 8>(src, src_stride, width, subpixel_x,
+                                         step_x, intermediate_height,
+                                         intermediate);
+      }
+      break;
+    case 3:
+      if (step_x > grade_x_threshold) {
+        ConvolveHorizontalScale<2, 3, 2>(src, src_stride, width, subpixel_x,
+                                         step_x, intermediate_height,
+                                         intermediate);
+      } else {
+        ConvolveHorizontalScale<1, 3, 2>(src, src_stride, width, subpixel_x,
+                                         step_x, intermediate_height,
+                                         intermediate);
+      }
+      break;
+    case 4:
+      assert(width <= 4);
+      ConvolveHorizontalScale<1, 4, 4>(src, src_stride, width, subpixel_x,
+                                       step_x, intermediate_height,
+                                       intermediate);
+      break;
+    default:
+      assert(horiz_filter_index == 5);
+      assert(width <= 4);
+      ConvolveHorizontalScale<1, 5, 4>(src, src_stride, width, subpixel_x,
+                                       step_x, intermediate_height,
+                                       intermediate);
+  }
+
+  // Vertical filter.
+  intermediate = intermediate_result;
+  switch (vert_filter_index) {
+    case 0:
+    case 1:
+      if (!is_compound && width == 2) {
+        ConvolveVerticalScale<6, 2, is_compound>(
+            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
+            prediction, pred_stride);
+      } else if (width == 4) {
+        ConvolveVerticalScale<6, 4, is_compound>(
+            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
+            prediction, pred_stride);
+      } else {
+        ConvolveVerticalScale<6, 8, is_compound>(
+            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
+            prediction, pred_stride);
+      }
+      break;
+    case 2:
+      if (!is_compound && width == 2) {
+        ConvolveVerticalScale<8, 2, is_compound>(
+            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
+            prediction, pred_stride);
+      } else if (width == 4) {
+        ConvolveVerticalScale<8, 4, is_compound>(
+            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
+            prediction, pred_stride);
+      } else {
+        ConvolveVerticalScale<8, 8, is_compound>(
+            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
+            prediction, pred_stride);
+      }
+      break;
+    case 3:
+      if (!is_compound && width == 2) {
+        ConvolveVerticalScale<2, 2, is_compound>(
+            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
+            prediction, pred_stride);
+      } else if (width == 4) {
+        ConvolveVerticalScale<2, 4, is_compound>(
+            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
+            prediction, pred_stride);
+      } else {
+        ConvolveVerticalScale<2, 8, is_compound>(
+            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
+            prediction, pred_stride);
+      }
+      break;
+    default:
+      assert(vert_filter_index == 4 || vert_filter_index == 5);
+      if (!is_compound && width == 2) {
+        ConvolveVerticalScale<4, 2, is_compound>(
+            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
+            prediction, pred_stride);
+      } else if (width == 4) {
+        ConvolveVerticalScale<4, 4, is_compound>(
+            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
+            prediction, pred_stride);
+      } else {
+        ConvolveVerticalScale<4, 8, is_compound>(
+            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
+            prediction, pred_stride);
+      }
+  }
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+  dsp->convolve[0][0][0][1] = ConvolveHorizontal_SSE4_1;
+  dsp->convolve[0][0][1][0] = ConvolveVertical_SSE4_1;
+  dsp->convolve[0][0][1][1] = Convolve2D_SSE4_1;
+
+  dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_SSE4;
+  dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_SSE4_1;
+  dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_SSE4_1;
+  dsp->convolve[0][1][1][1] = ConvolveCompound2D_SSE4_1;
+
+  dsp->convolve_scale[0] = ConvolveScale2D_SSE4_1<false>;
+  dsp->convolve_scale[1] = ConvolveScale2D_SSE4_1<true>;
+}
+
+}  // namespace
+}  // namespace low_bitdepth
+
+void ConvolveInit_SSE4_1() { low_bitdepth::Init8bpp(); }
 
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_SSE4_1
+#else  // !LIBGAV1_ENABLE_SSE4_1
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/x86/convolve_sse4.h b/libgav1/src/dsp/x86/convolve_sse4.h
index 33b74c9..e449a87 100644
--- a/libgav1/src/dsp/x86/convolve_sse4.h
+++ b/libgav1/src/dsp/x86/convolve_sse4.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_X86_CONVOLVE_SSE4_H_
 #define LIBGAV1_SRC_DSP_X86_CONVOLVE_SSE4_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -34,8 +34,40 @@
 // optimization being enabled, signal the sse4 implementation should be used.
 #if LIBGAV1_ENABLE_SSE4_1
 
+#ifndef LIBGAV1_Dsp8bpp_ConvolveHorizontal
+#define LIBGAV1_Dsp8bpp_ConvolveHorizontal LIBGAV1_CPU_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_ConvolveVertical
+#define LIBGAV1_Dsp8bpp_ConvolveVertical LIBGAV1_CPU_SSE4_1
+#endif
+
 #ifndef LIBGAV1_Dsp8bpp_Convolve2D
-// #define LIBGAV1_Dsp8bpp_Convolve2D LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_Convolve2D LIBGAV1_CPU_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundCopy
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundCopy LIBGAV1_CPU_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundHorizontal LIBGAV1_CPU_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundVertical
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundVertical LIBGAV1_CPU_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_ConvolveCompound2D
+#define LIBGAV1_Dsp8bpp_ConvolveCompound2D LIBGAV1_CPU_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_ConvolveScale2D
+#define LIBGAV1_Dsp8bpp_ConvolveScale2D LIBGAV1_CPU_SSE4_1
+#endif
+
+#ifndef LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D
+#define LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D LIBGAV1_CPU_SSE4_1
 #endif
 
 #endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/distance_weighted_blend_sse4.cc b/libgav1/src/dsp/x86/distance_weighted_blend_sse4.cc
index c8106bb..77517ee 100644
--- a/libgav1/src/dsp/x86/distance_weighted_blend_sse4.cc
+++ b/libgav1/src/dsp/x86/distance_weighted_blend_sse4.cc
@@ -13,7 +13,7 @@
 // limitations under the License.
 
 #include "src/dsp/distance_weighted_blend.h"
-#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
@@ -23,6 +23,8 @@
 #include <cstddef>
 #include <cstdint>
 
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/dsp/x86/common_sse4.h"
 #include "src/utils/common.h"
 
@@ -30,23 +32,19 @@
 namespace dsp {
 namespace {
 
-constexpr int kBitdepth8 = 8;
 constexpr int kInterPostRoundBit = 4;
 
 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
                                        const __m128i& pred1,
                                        const __m128i& weights) {
-  const __m128i compound_round_offset32 =
-      _mm_set1_epi32((16 << (kBitdepth8 + 4)) + (16 << (kBitdepth8 + 3)));
+  // TODO(https://issuetracker.google.com/issues/150325685): Investigate range.
   const __m128i preds_lo = _mm_unpacklo_epi16(pred0, pred1);
-  const __m128i mult_lo =
-      _mm_sub_epi32(_mm_madd_epi16(preds_lo, weights), compound_round_offset32);
+  const __m128i mult_lo = _mm_madd_epi16(preds_lo, weights);
   const __m128i result_lo =
       RightShiftWithRounding_S32(mult_lo, kInterPostRoundBit + 4);
 
   const __m128i preds_hi = _mm_unpackhi_epi16(pred0, pred1);
-  const __m128i mult_hi =
-      _mm_sub_epi32(_mm_madd_epi16(preds_hi, weights), compound_round_offset32);
+  const __m128i mult_hi = _mm_madd_epi16(preds_hi, weights);
   const __m128i result_hi =
       RightShiftWithRounding_S32(mult_hi, kInterPostRoundBit + 4);
 
@@ -55,34 +53,31 @@
 
 template <int height>
 inline void DistanceWeightedBlend4xH_SSE4_1(
-    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
-    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
-    const uint8_t weight_0, const uint8_t weight_1, void* const dest,
-    const ptrdiff_t dest_stride) {
+    const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
+    const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
-  const uint16_t* pred_0 = prediction_0;
-  const uint16_t* pred_1 = prediction_1;
   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
 
   for (int y = 0; y < height; y += 4) {
+    // TODO(b/150326556): Use larger loads.
     const __m128i src_00 = LoadLo8(pred_0);
     const __m128i src_10 = LoadLo8(pred_1);
-    pred_0 += prediction_stride_0;
-    pred_1 += prediction_stride_1;
+    pred_0 += 4;
+    pred_1 += 4;
     __m128i src_0 = LoadHi8(src_00, pred_0);
     __m128i src_1 = LoadHi8(src_10, pred_1);
-    pred_0 += prediction_stride_0;
-    pred_1 += prediction_stride_1;
+    pred_0 += 4;
+    pred_1 += 4;
     const __m128i res0 = ComputeWeightedAverage8(src_0, src_1, weights);
 
     const __m128i src_01 = LoadLo8(pred_0);
     const __m128i src_11 = LoadLo8(pred_1);
-    pred_0 += prediction_stride_0;
-    pred_1 += prediction_stride_1;
+    pred_0 += 4;
+    pred_1 += 4;
     src_0 = LoadHi8(src_01, pred_0);
     src_1 = LoadHi8(src_11, pred_1);
-    pred_0 += prediction_stride_0;
-    pred_1 += prediction_stride_1;
+    pred_0 += 4;
+    pred_1 += 4;
     const __m128i res1 = ComputeWeightedAverage8(src_0, src_1, weights);
 
     const __m128i result_pixels = _mm_packus_epi16(res0, res1);
@@ -102,26 +97,22 @@
 
 template <int height>
 inline void DistanceWeightedBlend8xH_SSE4_1(
-    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
-    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
-    const uint8_t weight_0, const uint8_t weight_1, void* const dest,
-    const ptrdiff_t dest_stride) {
+    const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
+    const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
-  const uint16_t* pred_0 = prediction_0;
-  const uint16_t* pred_1 = prediction_1;
   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
 
   for (int y = 0; y < height; y += 2) {
-    const __m128i src_00 = LoadUnaligned16(pred_0);
-    const __m128i src_10 = LoadUnaligned16(pred_1);
-    pred_0 += prediction_stride_0;
-    pred_1 += prediction_stride_1;
+    const __m128i src_00 = LoadAligned16(pred_0);
+    const __m128i src_10 = LoadAligned16(pred_1);
+    pred_0 += 8;
+    pred_1 += 8;
     const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights);
 
-    const __m128i src_01 = LoadUnaligned16(pred_0);
-    const __m128i src_11 = LoadUnaligned16(pred_1);
-    pred_0 += prediction_stride_0;
-    pred_1 += prediction_stride_1;
+    const __m128i src_01 = LoadAligned16(pred_0);
+    const __m128i src_11 = LoadAligned16(pred_1);
+    pred_0 += 8;
+    pred_1 += 8;
     const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights);
 
     const __m128i result_pixels = _mm_packus_epi16(res0, res1);
@@ -133,26 +124,23 @@
 }
 
 inline void DistanceWeightedBlendLarge_SSE4_1(
-    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
-    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
-    const uint8_t weight_0, const uint8_t weight_1, const int width,
-    const int height, void* const dest, const ptrdiff_t dest_stride) {
+    const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
+    const uint8_t weight_1, const int width, const int height, void* const dest,
+    const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
-  const uint16_t* pred_0 = prediction_0;
-  const uint16_t* pred_1 = prediction_1;
   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
 
   int y = height;
   do {
     int x = 0;
     do {
-      const __m128i src_0_lo = LoadUnaligned16(pred_0 + x);
-      const __m128i src_1_lo = LoadUnaligned16(pred_1 + x);
+      const __m128i src_0_lo = LoadAligned16(pred_0 + x);
+      const __m128i src_1_lo = LoadAligned16(pred_1 + x);
       const __m128i res_lo =
           ComputeWeightedAverage8(src_0_lo, src_1_lo, weights);
 
-      const __m128i src_0_hi = LoadUnaligned16(pred_0 + x + 8);
-      const __m128i src_1_hi = LoadUnaligned16(pred_1 + x + 8);
+      const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
+      const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
       const __m128i res_hi =
           ComputeWeightedAverage8(src_0_hi, src_1_hi, weights);
 
@@ -160,30 +148,30 @@
       x += 16;
     } while (x < width);
     dst += dest_stride;
-    pred_0 += prediction_stride_0;
-    pred_1 += prediction_stride_1;
+    pred_0 += width;
+    pred_1 += width;
   } while (--y != 0);
 }
 
-void DistanceWeightedBlend_SSE4_1(
-    const uint16_t* prediction_0, const ptrdiff_t prediction_stride_0,
-    const uint16_t* prediction_1, const ptrdiff_t prediction_stride_1,
-    const uint8_t weight_0, const uint8_t weight_1, const int width,
-    const int height, void* const dest, const ptrdiff_t dest_stride) {
+void DistanceWeightedBlend_SSE4_1(const void* prediction_0,
+                                  const void* prediction_1,
+                                  const uint8_t weight_0,
+                                  const uint8_t weight_1, const int width,
+                                  const int height, void* const dest,
+                                  const ptrdiff_t dest_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   if (width == 4) {
     if (height == 4) {
-      DistanceWeightedBlend4xH_SSE4_1<4>(prediction_0, prediction_stride_0,
-                                         prediction_1, prediction_stride_1,
-                                         weight_0, weight_1, dest, dest_stride);
+      DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
+                                         dest, dest_stride);
     } else if (height == 8) {
-      DistanceWeightedBlend4xH_SSE4_1<8>(prediction_0, prediction_stride_0,
-                                         prediction_1, prediction_stride_1,
-                                         weight_0, weight_1, dest, dest_stride);
+      DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
+                                         dest, dest_stride);
     } else {
       assert(height == 16);
-      DistanceWeightedBlend4xH_SSE4_1<16>(
-          prediction_0, prediction_stride_0, prediction_1, prediction_stride_1,
-          weight_0, weight_1, dest, dest_stride);
+      DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
+                                          dest, dest_stride);
     }
     return;
   }
@@ -191,37 +179,32 @@
   if (width == 8) {
     switch (height) {
       case 4:
-        DistanceWeightedBlend8xH_SSE4_1<4>(
-            prediction_0, prediction_stride_0, prediction_1,
-            prediction_stride_1, weight_0, weight_1, dest, dest_stride);
+        DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
+                                           dest, dest_stride);
         return;
       case 8:
-        DistanceWeightedBlend8xH_SSE4_1<8>(
-            prediction_0, prediction_stride_0, prediction_1,
-            prediction_stride_1, weight_0, weight_1, dest, dest_stride);
+        DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
+                                           dest, dest_stride);
         return;
       case 16:
-        DistanceWeightedBlend8xH_SSE4_1<16>(
-            prediction_0, prediction_stride_0, prediction_1,
-            prediction_stride_1, weight_0, weight_1, dest, dest_stride);
+        DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
+                                            dest, dest_stride);
         return;
       default:
         assert(height == 32);
-        DistanceWeightedBlend8xH_SSE4_1<32>(
-            prediction_0, prediction_stride_0, prediction_1,
-            prediction_stride_1, weight_0, weight_1, dest, dest_stride);
+        DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
+                                            dest, dest_stride);
 
         return;
     }
   }
 
-  DistanceWeightedBlendLarge_SSE4_1(prediction_0, prediction_stride_0,
-                                    prediction_1, prediction_stride_1, weight_0,
-                                    weight_1, width, height, dest, dest_stride);
+  DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
+                                    height, dest, dest_stride);
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
 #if DSP_ENABLED_8BPP_SSE4_1(DistanceWeightedBlend)
   dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
@@ -235,7 +218,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_SSE4_1
+#else  // !LIBGAV1_ENABLE_SSE4_1
 
 namespace libgav1 {
 namespace dsp {
diff --git a/libgav1/src/dsp/x86/distance_weighted_blend_sse4.h b/libgav1/src/dsp/x86/distance_weighted_blend_sse4.h
index 34ed7a4..2831ded 100644
--- a/libgav1/src/dsp/x86/distance_weighted_blend_sse4.h
+++ b/libgav1/src/dsp/x86/distance_weighted_blend_sse4.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_X86_DISTANCE_WEIGHTED_BLEND_SSE4_H_
 #define LIBGAV1_SRC_DSP_X86_DISTANCE_WEIGHTED_BLEND_SSE4_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -33,7 +33,7 @@
 // optimization being enabled, signal the sse4 implementation should be used.
 #if LIBGAV1_ENABLE_SSE4_1
 #ifndef LIBGAV1_Dsp8bpp_DistanceWeightedBlend
-#define LIBGAV1_Dsp8bpp_DistanceWeightedBlend LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_DistanceWeightedBlend LIBGAV1_CPU_SSE4_1
 #endif
 
 #endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/intra_edge_sse4.cc b/libgav1/src/dsp/x86/intra_edge_sse4.cc
index 62b2bcd..3635ee1 100644
--- a/libgav1/src/dsp/x86/intra_edge_sse4.cc
+++ b/libgav1/src/dsp/x86/intra_edge_sse4.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/intra_edge.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
@@ -25,6 +25,7 @@
 #include <cstring>  // memcpy
 
 #include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/dsp/x86/common_sse4.h"
 #include "src/utils/common.h"
 
@@ -117,7 +118,7 @@
 inline void ComputeKernel3Store8(uint8_t* dest, const uint8_t* source) {
   const __m128i edge_lo = LoadUnaligned16(source);
   const __m128i edge_hi = _mm_srli_si128(edge_lo, 4);
-  // Finish |edge_lo| lifecycle quickly.
+  // Finish |edge_lo| life cycle quickly.
   // Multiply for 2x.
   const __m128i source2_lo = _mm_slli_epi16(_mm_cvtepu8_epi16(edge_lo), 1);
   // Multiply 2x by 2 and align.
@@ -241,7 +242,7 @@
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
 #if DSP_ENABLED_8BPP_SSE4_1(IntraEdgeFilter)
   dsp->intra_edge_filter = IntraEdgeFilter_SSE4_1;
@@ -258,7 +259,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_SSE4_1
+#else  // !LIBGAV1_ENABLE_SSE4_1
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/x86/intra_edge_sse4.h b/libgav1/src/dsp/x86/intra_edge_sse4.h
index f5b4166..d6c926e 100644
--- a/libgav1/src/dsp/x86/intra_edge_sse4.h
+++ b/libgav1/src/dsp/x86/intra_edge_sse4.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_X86_INTRA_EDGE_SSE4_H_
 #define LIBGAV1_SRC_DSP_X86_INTRA_EDGE_SSE4_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -34,11 +34,11 @@
 // optimization being enabled, signal the sse4 implementation should be used.
 #if LIBGAV1_ENABLE_SSE4_1
 #ifndef LIBGAV1_Dsp8bpp_IntraEdgeFilter
-#define LIBGAV1_Dsp8bpp_IntraEdgeFilter LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_IntraEdgeFilter LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_IntraEdgeUpsampler
-#define LIBGAV1_Dsp8bpp_IntraEdgeUpsampler LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_IntraEdgeUpsampler LIBGAV1_CPU_SSE4_1
 #endif
 
 #endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc b/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc
index 6e86b10..ddf3a95 100644
--- a/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc
+++ b/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/intrapred.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
@@ -25,6 +25,7 @@
 #include <cstdint>
 
 #include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/dsp/x86/common_sse4.h"
 #include "src/utils/common.h"
 #include "src/utils/compiler_attributes.h"
@@ -51,7 +52,7 @@
     void* const dest, ptrdiff_t stride,
     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int alpha) {
-  auto* dst = reinterpret_cast<uint8_t*>(dest);
+  auto* dst = static_cast<uint8_t*>(dest);
   const __m128i alpha_sign = _mm_set1_epi16(alpha);
   const __m128i alpha_q12 = _mm_slli_epi16(_mm_abs_epi16(alpha_sign), 9);
   auto* row = reinterpret_cast<const __m128i*>(luma);
@@ -190,7 +191,8 @@
   do {
     __m128i samples0 = LoadLo8(src);
     if (!inside) {
-      const __m128i border0 = _mm_set1_epi8(src[visible_width - 1]);
+      const __m128i border0 =
+          _mm_set1_epi8(static_cast<int8_t>(src[visible_width - 1]));
       samples0 = _mm_blendv_epi8(samples0, border0, blend_mask);
     }
     src += stride;
@@ -202,7 +204,8 @@
 
     samples1 = LoadLo8(src);
     if (!inside) {
-      const __m128i border1 = _mm_set1_epi8(src[visible_width - 1]);
+      const __m128i border1 =
+          _mm_set1_epi8(static_cast<int8_t>(src[visible_width - 1]));
       samples1 = _mm_blendv_epi8(samples1, border1, blend_mask);
     }
     src += stride;
@@ -299,7 +302,7 @@
   __m128i inner_sum_lo, inner_sum_hi;
   int y = 0;
   do {
-#if LIBGAV1_MSAN  // We can load unintialized values here. Even though they are
+#if LIBGAV1_MSAN  // We can load uninitialized values here. Even though they are
                   // then masked off by blendv, MSAN isn't smart enough to
                   // understand that. So we switch to a C implementation here.
     uint16_t c_arr[16];
@@ -314,7 +317,8 @@
     __m128i samples01 = LoadUnaligned16(src);
 
     if (!inside) {
-      const __m128i border16 = _mm_set1_epi8(src[visible_width_16 - 1]);
+      const __m128i border16 =
+          _mm_set1_epi8(static_cast<int8_t>(src[visible_width_16 - 1]));
       samples01 = _mm_blendv_epi8(samples01, border16, blend_mask_16);
     }
     samples0 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples01), 3);
@@ -326,7 +330,7 @@
     __m128i inner_sum = _mm_add_epi16(samples0, samples1);
 
     if (block_width == 32) {
-#if LIBGAV1_MSAN  // We can load unintialized values here. Even though they are
+#if LIBGAV1_MSAN  // We can load uninitialized values here. Even though they are
                   // then masked off by blendv, MSAN isn't smart enough to
                   // understand that. So we switch to a C implementation here.
       uint16_t c_arr[16];
@@ -340,7 +344,8 @@
 #else
       __m128i samples23 = LoadUnaligned16(src + 16);
       if (!inside) {
-        const __m128i border32 = _mm_set1_epi8(src[visible_width_32 - 1]);
+        const __m128i border32 =
+            _mm_set1_epi8(static_cast<int8_t>(src[visible_width_32 - 1]));
         samples23 = _mm_blendv_epi8(samples23, border32, blend_mask_32);
       }
       samples2 = _mm_slli_epi16(_mm_cvtepu8_epi16(samples23), 3);
@@ -781,7 +786,7 @@
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
 #if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_CflSubsampler420)
   dsp->cfl_subsamplers[kTransformSize4x4][kSubsamplingType420] =
diff --git a/libgav1/src/dsp/x86/intrapred_smooth_sse4.cc b/libgav1/src/dsp/x86/intrapred_smooth_sse4.cc
index a21b8df..a761813 100644
--- a/libgav1/src/dsp/x86/intrapred_smooth_sse4.cc
+++ b/libgav1/src/dsp/x86/intrapred_smooth_sse4.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/intrapred.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
@@ -25,6 +25,7 @@
 #include <cstring>  // memcpy
 
 #include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/dsp/x86/common_sse4.h"
 #include "src/utils/common.h"
 
@@ -2408,7 +2409,7 @@
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
 #if DSP_ENABLED_8BPP_SSE4_1(TransformSize4x4_IntraPredictorSmooth)
   dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmooth] =
diff --git a/libgav1/src/dsp/x86/intrapred_sse4.cc b/libgav1/src/dsp/x86/intrapred_sse4.cc
index ec27eb7..11ba9aa 100644
--- a/libgav1/src/dsp/x86/intrapred_sse4.cc
+++ b/libgav1/src/dsp/x86/intrapred_sse4.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/intrapred.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
@@ -26,6 +26,7 @@
 #include <cstring>  // memcpy
 
 #include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/dsp/x86/common_sse4.h"
 #include "src/dsp/x86/transpose_sse4.h"
 #include "src/utils/common.h"
@@ -1022,7 +1023,7 @@
 
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   const __m128i top_lefts16 = _mm_set1_epi16(top_ptr[-1]);
-  const __m128i top_lefts8 = _mm_set1_epi8(top_ptr[-1]);
+  const __m128i top_lefts8 = _mm_set1_epi8(static_cast<int8_t>(top_ptr[-1]));
 
   // Given that the spec defines "base" as top[x] + left[y] - top[-1],
   // pLeft = abs(base - left[y]) = abs(top[x] - top[-1])
@@ -1066,7 +1067,7 @@
   const __m128i top_hi = _mm_cvtepu8_epi16(_mm_srli_si128(top, 8));
 
   const __m128i top_lefts16 = _mm_set1_epi16(top_left);
-  const __m128i top_lefts8 = _mm_set1_epi8(top_left);
+  const __m128i top_lefts8 = _mm_set1_epi8(static_cast<int8_t>(top_left));
 
   // Given that the spec defines "base" as top[x] + left[y] - top_left,
   // pLeft = abs(base - left[y]) = abs(top[x] - top[-1])
@@ -1132,7 +1133,7 @@
   const __m128i top_hi = _mm_cvtepu8_epi16(_mm_srli_si128(top, 8));
 
   const __m128i top_lefts16 = _mm_set1_epi16(top_left);
-  const __m128i top_lefts8 = _mm_set1_epi8(top_left);
+  const __m128i top_lefts8 = _mm_set1_epi8(static_cast<int8_t>(top_left));
 
   // Given that the spec defines "base" as top[x] + left[y] - top[-1],
   // pLeft = abs(base - left[y]) = abs(top[x] - top[-1])
@@ -1994,7 +1995,7 @@
 // time. The values are found by inspection. By coincidence, all angles that
 // satisfy (ystep >> 6) == 2 map to the same value, so it is enough to look up
 // by ystep >> 6. The largest index for this lookup is 1023 >> 6 == 15.
-const int kDirectionalZone2ShuffleInvalidHeight[16] = {
+constexpr int kDirectionalZone2ShuffleInvalidHeight[16] = {
     1024, 1024, 16, 16, 16, 16, 0, 0, 18, 0, 0, 0, 0, 0, 0, 40};
 
 template <bool upsampled>
@@ -2688,7 +2689,6 @@
     dst += 4;
   }
 
-  left = _mm_setzero_si128();
   // Now we handle heights that reference previous blocks rather than top_row.
   for (int y = 4; y < height; y += 4) {
     // Leftmost 4x4 block for this height.
@@ -2740,7 +2740,7 @@
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   static_cast<void>(dsp);
 // These guards check if this version of the function was not superseded by
@@ -3524,7 +3524,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_SSE4_1
+#else  // !LIBGAV1_ENABLE_SSE4_1
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/x86/intrapred_sse4.h b/libgav1/src/dsp/x86/intrapred_sse4.h
index f812ec1..eb3825d 100644
--- a/libgav1/src/dsp/x86/intrapred_sse4.h
+++ b/libgav1/src/dsp/x86/intrapred_sse4.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_X86_INTRAPRED_SSE4_H_
 #define LIBGAV1_SRC_DSP_X86_INTRAPRED_SSE4_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -38,1021 +38,1021 @@
 // optimization being enabled, signal the sse4 implementation should be used.
 #if LIBGAV1_ENABLE_SSE4_1
 #ifndef LIBGAV1_Dsp8bpp_FilterIntraPredictor
-#define LIBGAV1_Dsp8bpp_FilterIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_FilterIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1
-#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone1 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2
-#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone2 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3
-#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_DirectionalIntraPredictorZone3 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcTop
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcTop
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcTop
-#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcTop
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcTop
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcTop
-#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcTop
-#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcTop
-#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcTop
-#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcTop
 #define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcTop \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcTop
 #define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcTop \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcTop
 #define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcTop \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcTop
-#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcTop LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcTop
 #define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcTop \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcTop
 #define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcTop \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcTop
 #define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcTop \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcTop
 #define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcTop \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcTop
 #define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcTop \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcTop
 #define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcTop \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420
-#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420
-#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420
-#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420
-#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler420 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420
-#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420
-#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420
-#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420
-#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler420 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420
-#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler420 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420
-#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler420 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444
-#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444
-#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444
-#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444
-#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflSubsampler444 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444
-#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444
-#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444
-#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444
-#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflSubsampler444 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444
-#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflSubsampler444 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444
-#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflSubsampler444 LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor
-#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor
-#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor
-#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor
-#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_CflIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor
-#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor
-#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor
-#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor
-#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_CflIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor
-#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_CflIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor
-#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_CflIntraPredictor LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcLeft
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcLeft LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDcLeft LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcLeft
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcLeft LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDcLeft LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcLeft
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcLeft LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDcLeft LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcLeft
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcLeft LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDcLeft LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcLeft
 #define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDc
-#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorPaeth
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorPaeth LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorPaeth
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorPaeth LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorPaeth
-#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorPaeth LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorPaeth
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorPaeth LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorPaeth
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorPaeth LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorPaeth
-#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorPaeth LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorPaeth
-#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorPaeth LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorPaeth
-#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorPaeth LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorPaeth
-#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorPaeth LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorPaeth
 #define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorPaeth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorPaeth
 #define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorPaeth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorPaeth
 #define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorPaeth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorPaeth
-#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorPaeth LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorPaeth LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorPaeth
 #define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorPaeth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorPaeth
 #define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorPaeth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorPaeth
 #define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorPaeth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorPaeth
 #define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorPaeth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorPaeth
 #define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorPaeth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorPaeth
 #define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorPaeth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth
-#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth
-#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth
-#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth
-#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmooth LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmooth
 #define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmooth \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothVertical
 #define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize4x4_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize4x8_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize4x16_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize8x4_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize8x8_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize8x16_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize8x32_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize16x4_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize16x8_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize16x16_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize16x32_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize16x64_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize32x8_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize32x16_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize32x32_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize32x64_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize64x16_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize64x32_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothHorizontal
 #define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 //------------------------------------------------------------------------------
 // 10bpp
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcTop
-#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcTop LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcLeft
 #define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDcLeft \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDc
-#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorHorizontal
 #define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/inverse_transform_sse4.cc b/libgav1/src/dsp/x86/inverse_transform_sse4.cc
index 56bace1..30ad436 100644
--- a/libgav1/src/dsp/x86/inverse_transform_sse4.cc
+++ b/libgav1/src/dsp/x86/inverse_transform_sse4.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/inverse_transform.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
@@ -24,6 +24,8 @@
 #include <cstdint>
 #include <cstring>
 
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/dsp/x86/common_sse4.h"
 #include "src/dsp/x86/transpose_sse4.h"
 #include "src/utils/array_2d.h"
@@ -199,9 +201,104 @@
 using ButterflyRotationFunc = void (*)(__m128i* a, __m128i* b, int angle,
                                        bool flip);
 
+LIBGAV1_ALWAYS_INLINE __m128i ShiftResidual(const __m128i residual,
+                                            const __m128i v_row_shift_add,
+                                            const __m128i v_row_shift) {
+  const __m128i k7ffd = _mm_set1_epi16(0x7ffd);
+  // The max row_shift is 2, so int16_t values greater than 0x7ffd may
+  // overflow.  Generate a mask for this case.
+  const __m128i mask = _mm_cmpgt_epi16(residual, k7ffd);
+  const __m128i x = _mm_add_epi16(residual, v_row_shift_add);
+  // Assume int16_t values.
+  const __m128i a = _mm_sra_epi16(x, v_row_shift);
+  // Assume uint16_t values.
+  const __m128i b = _mm_srl_epi16(x, v_row_shift);
+  // Select the correct shifted value.
+  return _mm_blendv_epi8(a, b, mask);
+}
+
 //------------------------------------------------------------------------------
 // Discrete Cosine Transforms (DCT).
 
+template <int width>
+LIBGAV1_ALWAYS_INLINE bool DctDcOnly(void* dest, const void* source,
+                                     int non_zero_coeff_count,
+                                     bool should_round, int row_shift) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  const __m128i v_src_lo = _mm_shufflelo_epi16(_mm_cvtsi32_si128(src[0]), 0);
+  const __m128i v_src =
+      (width == 4) ? v_src_lo : _mm_shuffle_epi32(v_src_lo, 0);
+  const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0);
+  const __m128i v_kTransformRowMultiplier =
+      _mm_set1_epi16(kTransformRowMultiplier << 3);
+  const __m128i v_src_round =
+      _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier);
+  const __m128i s0 = _mm_blendv_epi8(v_src, v_src_round, v_mask);
+  const int16_t cos128 = Cos128(32);
+  const __m128i xy = _mm_mulhrs_epi16(s0, _mm_set1_epi16(cos128 << 3));
+
+  // Expand to 32 bits to prevent int16_t overflows during the shift add.
+  const __m128i v_row_shift_add = _mm_set1_epi32(row_shift);
+  const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add);
+  const __m128i a = _mm_cvtepi16_epi32(xy);
+  const __m128i a1 = _mm_cvtepi16_epi32(_mm_srli_si128(xy, 8));
+  const __m128i b = _mm_add_epi32(a, v_row_shift_add);
+  const __m128i b1 = _mm_add_epi32(a1, v_row_shift_add);
+  const __m128i c = _mm_sra_epi32(b, v_row_shift);
+  const __m128i c1 = _mm_sra_epi32(b1, v_row_shift);
+  const __m128i xy_shifted = _mm_packs_epi32(c, c1);
+
+  if (width == 4) {
+    StoreLo8(dst, xy_shifted);
+  } else {
+    for (int i = 0; i < width; i += 8) {
+      StoreUnaligned16(dst, xy_shifted);
+      dst += 8;
+    }
+  }
+  return true;
+}
+
+template <int height>
+LIBGAV1_ALWAYS_INLINE bool DctDcOnlyColumn(void* dest, const void* source,
+                                           int non_zero_coeff_count,
+                                           int width) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  const int16_t cos128 = Cos128(32);
+
+  // Calculate dc values for first row.
+  if (width == 4) {
+    const __m128i v_src = LoadLo8(src);
+    const __m128i xy = _mm_mulhrs_epi16(v_src, _mm_set1_epi16(cos128 << 3));
+    StoreLo8(dst, xy);
+  } else {
+    int i = 0;
+    do {
+      const __m128i v_src = LoadUnaligned16(&src[i]);
+      const __m128i xy = _mm_mulhrs_epi16(v_src, _mm_set1_epi16(cos128 << 3));
+      StoreUnaligned16(&dst[i], xy);
+      i += 8;
+    } while (i < width);
+  }
+
+  // Copy first row to the rest of the block.
+  for (int y = 1; y < height; ++y) {
+    memcpy(&dst[y * width], &src[(y - 1) * width], width * sizeof(dst[0]));
+  }
+  return true;
+}
+
 template <ButterflyRotationFunc bufferfly_rotation,
           bool is_fast_bufferfly = false>
 LIBGAV1_ALWAYS_INLINE void Dct4Stages(__m128i* s) {
@@ -949,6 +1046,82 @@
   }
 }
 
+constexpr int16_t kAdst4DcOnlyMultiplier[8] = {1321, 0, 2482, 0,
+                                               3344, 0, 2482, 1321};
+
+LIBGAV1_ALWAYS_INLINE bool Adst4DcOnly(void* dest, const void* source,
+                                       int non_zero_coeff_count,
+                                       bool should_round, int row_shift) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  const __m128i v_src =
+      _mm_shuffle_epi32(_mm_shufflelo_epi16(_mm_cvtsi32_si128(src[0]), 0), 0);
+  const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0);
+  const __m128i v_kTransformRowMultiplier =
+      _mm_set1_epi16(kTransformRowMultiplier << 3);
+  const __m128i v_src_round =
+      _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier);
+  const __m128i s0 = _mm_blendv_epi8(v_src, v_src_round, v_mask);
+  const __m128i v_kAdst4DcOnlyMultipliers =
+      LoadUnaligned16(kAdst4DcOnlyMultiplier);
+  // s0*k0 s0*k1 s0*k2 s0*k1
+  // +
+  // s0*0  s0*0  s0*0  s0*k0
+  const __m128i x3 = _mm_madd_epi16(s0, v_kAdst4DcOnlyMultipliers);
+  const __m128i dst_0 = RightShiftWithRounding_S32(x3, 12);
+  const __m128i v_row_shift_add = _mm_set1_epi32(row_shift);
+  const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add);
+  const __m128i a = _mm_add_epi32(dst_0, v_row_shift_add);
+  const __m128i b = _mm_sra_epi32(a, v_row_shift);
+  const __m128i c = _mm_packs_epi32(b, b);
+  StoreLo8(dst, c);
+
+  return true;
+}
+
+LIBGAV1_ALWAYS_INLINE bool Adst4DcOnlyColumn(void* dest, const void* source,
+                                             int non_zero_coeff_count,
+                                             int width) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  int i = 0;
+  do {
+    const __m128i v_src = _mm_cvtepi16_epi32(LoadLo8(&src[i]));
+    const __m128i kAdst4Multiplier_0 = _mm_set1_epi32(kAdst4Multiplier[0]);
+    const __m128i kAdst4Multiplier_1 = _mm_set1_epi32(kAdst4Multiplier[1]);
+    const __m128i kAdst4Multiplier_2 = _mm_set1_epi32(kAdst4Multiplier[2]);
+    const __m128i s0 = _mm_mullo_epi32(kAdst4Multiplier_0, v_src);
+    const __m128i s1 = _mm_mullo_epi32(kAdst4Multiplier_1, v_src);
+    const __m128i s2 = _mm_mullo_epi32(kAdst4Multiplier_2, v_src);
+    const __m128i x0 = s0;
+    const __m128i x1 = s1;
+    const __m128i x2 = s2;
+    const __m128i x3 = _mm_add_epi32(s0, s1);
+    const __m128i dst_0 = RightShiftWithRounding_S32(x0, 12);
+    const __m128i dst_1 = RightShiftWithRounding_S32(x1, 12);
+    const __m128i dst_2 = RightShiftWithRounding_S32(x2, 12);
+    const __m128i dst_3 = RightShiftWithRounding_S32(x3, 12);
+    const __m128i dst_0_1 = _mm_packs_epi32(dst_0, dst_1);
+    const __m128i dst_2_3 = _mm_packs_epi32(dst_2, dst_3);
+    StoreLo8(&dst[i], dst_0_1);
+    StoreHi8(&dst[i + width * 1], dst_0_1);
+    StoreLo8(&dst[i + width * 2], dst_2_3);
+    StoreHi8(&dst[i + width * 3], dst_2_3);
+    i += 4;
+  } while (i < width);
+
+  return true;
+}
+
 template <ButterflyRotationFunc bufferfly_rotation, bool stage_is_rectangular>
 LIBGAV1_ALWAYS_INLINE void Adst8_SSE4_1(void* dest, const void* source,
                                         int32_t step, bool transpose) {
@@ -1040,6 +1213,135 @@
   }
 }
 
+LIBGAV1_ALWAYS_INLINE bool Adst8DcOnly(void* dest, const void* source,
+                                       int non_zero_coeff_count,
+                                       bool should_round, int row_shift) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  __m128i s[8];
+
+  const __m128i v_src = _mm_shufflelo_epi16(_mm_cvtsi32_si128(src[0]), 0);
+  const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0);
+  const __m128i v_kTransformRowMultiplier =
+      _mm_set1_epi16(kTransformRowMultiplier << 3);
+  const __m128i v_src_round =
+      _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier);
+  // stage 1.
+  s[1] = _mm_blendv_epi8(v_src, v_src_round, v_mask);
+
+  // stage 2.
+  ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true);
+
+  // stage 3.
+  s[4] = s[0];
+  s[5] = s[1];
+
+  // stage 4.
+  ButterflyRotation_4(&s[4], &s[5], 48, true);
+
+  // stage 5.
+  s[2] = s[0];
+  s[3] = s[1];
+  s[6] = s[4];
+  s[7] = s[5];
+
+  // stage 6.
+  ButterflyRotation_4(&s[2], &s[3], 32, true);
+  ButterflyRotation_4(&s[6], &s[7], 32, true);
+
+  // stage 7.
+  __m128i x[8];
+  const __m128i v_zero = _mm_setzero_si128();
+  x[0] = s[0];
+  x[1] = _mm_subs_epi16(v_zero, s[4]);
+  x[2] = s[6];
+  x[3] = _mm_subs_epi16(v_zero, s[2]);
+  x[4] = s[3];
+  x[5] = _mm_subs_epi16(v_zero, s[7]);
+  x[6] = s[5];
+  x[7] = _mm_subs_epi16(v_zero, s[1]);
+
+  const __m128i x1_x0 = _mm_unpacklo_epi16(x[0], x[1]);
+  const __m128i x3_x2 = _mm_unpacklo_epi16(x[2], x[3]);
+  const __m128i x5_x4 = _mm_unpacklo_epi16(x[4], x[5]);
+  const __m128i x7_x6 = _mm_unpacklo_epi16(x[6], x[7]);
+  const __m128i x3_x0 = _mm_unpacklo_epi32(x1_x0, x3_x2);
+  const __m128i x7_x4 = _mm_unpacklo_epi32(x5_x4, x7_x6);
+
+  const __m128i v_row_shift_add = _mm_set1_epi32(row_shift);
+  const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add);
+  const __m128i a = _mm_add_epi32(_mm_cvtepi16_epi32(x3_x0), v_row_shift_add);
+  const __m128i a1 = _mm_add_epi32(_mm_cvtepi16_epi32(x7_x4), v_row_shift_add);
+  const __m128i b = _mm_sra_epi32(a, v_row_shift);
+  const __m128i b1 = _mm_sra_epi32(a1, v_row_shift);
+  StoreUnaligned16(dst, _mm_packs_epi32(b, b1));
+
+  return true;
+}
+
+LIBGAV1_ALWAYS_INLINE bool Adst8DcOnlyColumn(void* dest, const void* source,
+                                             int non_zero_coeff_count,
+                                             int width) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  __m128i s[8];
+
+  int i = 0;
+  do {
+    const __m128i v_src = LoadLo8(&src[i]);
+    // stage 1.
+    s[1] = v_src;
+
+    // stage 2.
+    ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true);
+
+    // stage 3.
+    s[4] = s[0];
+    s[5] = s[1];
+
+    // stage 4.
+    ButterflyRotation_4(&s[4], &s[5], 48, true);
+
+    // stage 5.
+    s[2] = s[0];
+    s[3] = s[1];
+    s[6] = s[4];
+    s[7] = s[5];
+
+    // stage 6.
+    ButterflyRotation_4(&s[2], &s[3], 32, true);
+    ButterflyRotation_4(&s[6], &s[7], 32, true);
+
+    // stage 7.
+    __m128i x[8];
+    const __m128i v_zero = _mm_setzero_si128();
+    x[0] = s[0];
+    x[1] = _mm_subs_epi16(v_zero, s[4]);
+    x[2] = s[6];
+    x[3] = _mm_subs_epi16(v_zero, s[2]);
+    x[4] = s[3];
+    x[5] = _mm_subs_epi16(v_zero, s[7]);
+    x[6] = s[5];
+    x[7] = _mm_subs_epi16(v_zero, s[1]);
+
+    for (int j = 0; j < 8; ++j) {
+      StoreLo8(&dst[j * width], x[j]);
+    }
+    i += 4;
+    dst += 4;
+  } while (i < width);
+
+  return true;
+}
+
 template <ButterflyRotationFunc bufferfly_rotation, bool stage_is_rectangular>
 LIBGAV1_ALWAYS_INLINE void Adst16_SSE4_1(void* dest, const void* source,
                                          int32_t step, bool transpose) {
@@ -1187,6 +1489,136 @@
   }
 }
 
+LIBGAV1_ALWAYS_INLINE void Adst16DcOnlyInternal(__m128i* s, __m128i* x) {
+  // stage 2.
+  ButterflyRotation_FirstIsZero(&s[0], &s[1], 62, true);
+
+  // stage 3.
+  s[8] = s[0];
+  s[9] = s[1];
+
+  // stage 4.
+  ButterflyRotation_4(&s[8], &s[9], 56, true);
+
+  // stage 5.
+  s[4] = s[0];
+  s[12] = s[8];
+  s[5] = s[1];
+  s[13] = s[9];
+
+  // stage 6.
+  ButterflyRotation_4(&s[4], &s[5], 48, true);
+  ButterflyRotation_4(&s[12], &s[13], 48, true);
+
+  // stage 7.
+  s[2] = s[0];
+  s[6] = s[4];
+  s[10] = s[8];
+  s[14] = s[12];
+  s[3] = s[1];
+  s[7] = s[5];
+  s[11] = s[9];
+  s[15] = s[13];
+
+  // stage 8.
+  ButterflyRotation_4(&s[2], &s[3], 32, true);
+  ButterflyRotation_4(&s[6], &s[7], 32, true);
+  ButterflyRotation_4(&s[10], &s[11], 32, true);
+  ButterflyRotation_4(&s[14], &s[15], 32, true);
+
+  // stage 9.
+  const __m128i v_zero = _mm_setzero_si128();
+  x[0] = s[0];
+  x[1] = _mm_subs_epi16(v_zero, s[8]);
+  x[2] = s[12];
+  x[3] = _mm_subs_epi16(v_zero, s[4]);
+  x[4] = s[6];
+  x[5] = _mm_subs_epi16(v_zero, s[14]);
+  x[6] = s[10];
+  x[7] = _mm_subs_epi16(v_zero, s[2]);
+  x[8] = s[3];
+  x[9] = _mm_subs_epi16(v_zero, s[11]);
+  x[10] = s[15];
+  x[11] = _mm_subs_epi16(v_zero, s[7]);
+  x[12] = s[5];
+  x[13] = _mm_subs_epi16(v_zero, s[13]);
+  x[14] = s[9];
+  x[15] = _mm_subs_epi16(v_zero, s[1]);
+}
+
+LIBGAV1_ALWAYS_INLINE bool Adst16DcOnly(void* dest, const void* source,
+                                        int non_zero_coeff_count,
+                                        bool should_round, int row_shift) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+  __m128i s[16];
+  __m128i x[16];
+
+  const __m128i v_src = _mm_shufflelo_epi16(_mm_cvtsi32_si128(src[0]), 0);
+  const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0);
+  const __m128i v_kTransformRowMultiplier =
+      _mm_set1_epi16(kTransformRowMultiplier << 3);
+  const __m128i v_src_round =
+      _mm_mulhrs_epi16(v_src, v_kTransformRowMultiplier);
+  // stage 1.
+  s[1] = _mm_blendv_epi8(v_src, v_src_round, v_mask);
+
+  Adst16DcOnlyInternal(s, x);
+
+  for (int i = 0; i < 2; ++i) {
+    const __m128i x1_x0 = _mm_unpacklo_epi16(x[0 + i * 8], x[1 + i * 8]);
+    const __m128i x3_x2 = _mm_unpacklo_epi16(x[2 + i * 8], x[3 + i * 8]);
+    const __m128i x5_x4 = _mm_unpacklo_epi16(x[4 + i * 8], x[5 + i * 8]);
+    const __m128i x7_x6 = _mm_unpacklo_epi16(x[6 + i * 8], x[7 + i * 8]);
+    const __m128i x3_x0 = _mm_unpacklo_epi32(x1_x0, x3_x2);
+    const __m128i x7_x4 = _mm_unpacklo_epi32(x5_x4, x7_x6);
+
+    const __m128i v_row_shift_add = _mm_set1_epi32(row_shift);
+    const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add);
+    const __m128i a = _mm_add_epi32(_mm_cvtepi16_epi32(x3_x0), v_row_shift_add);
+    const __m128i a1 =
+        _mm_add_epi32(_mm_cvtepi16_epi32(x7_x4), v_row_shift_add);
+    const __m128i b = _mm_sra_epi32(a, v_row_shift);
+    const __m128i b1 = _mm_sra_epi32(a1, v_row_shift);
+    StoreUnaligned16(&dst[i * 8], _mm_packs_epi32(b, b1));
+  }
+  return true;
+}
+
+LIBGAV1_ALWAYS_INLINE bool Adst16DcOnlyColumn(void* dest, const void* source,
+                                              int non_zero_coeff_count,
+                                              int width) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  int i = 0;
+  do {
+    __m128i s[16];
+    __m128i x[16];
+    const __m128i v_src = LoadUnaligned16(&src[i]);
+    // stage 1.
+    s[1] = v_src;
+
+    Adst16DcOnlyInternal(s, x);
+
+    for (int j = 0; j < 16; ++j) {
+      StoreLo8(&dst[j * width], x[j]);
+    }
+    i += 4;
+    dst += 4;
+  } while (i < width);
+
+  return true;
+}
+
 //------------------------------------------------------------------------------
 // Identity Transforms.
 
@@ -1223,6 +1655,35 @@
   }
 }
 
+LIBGAV1_ALWAYS_INLINE bool Identity4DcOnly(void* dest, const void* source,
+                                           int non_zero_coeff_count,
+                                           bool should_round, int tx_height) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  const __m128i v_src0 = _mm_cvtsi32_si128(src[0]);
+  const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0);
+  const __m128i v_kTransformRowMultiplier =
+      _mm_set1_epi16(kTransformRowMultiplier << 3);
+  const __m128i v_src_round =
+      _mm_mulhrs_epi16(v_src0, v_kTransformRowMultiplier);
+  const __m128i v_src = _mm_blendv_epi8(v_src0, v_src_round, v_mask);
+
+  const int shift = (tx_height < 16) ? 0 : 1;
+  const __m128i v_dual_round = _mm_set1_epi16((1 + (shift << 1)) << 11);
+  const __m128i v_multiplier_one =
+      _mm_set1_epi32((kIdentity4Multiplier << 16) | 0x0001);
+  const __m128i v_src_round_lo = _mm_unpacklo_epi16(v_dual_round, v_src);
+  const __m128i a = _mm_madd_epi16(v_src_round_lo, v_multiplier_one);
+  const __m128i b = _mm_srai_epi32(a, 12 + shift);
+  dst[0] = _mm_extract_epi16(_mm_packs_epi32(b, b), 0);
+  return true;
+}
+
 LIBGAV1_ALWAYS_INLINE void Identity4ColumnStoreToFrame(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
     const int tx_width, const int tx_height, const int16_t* source) {
@@ -1234,7 +1695,8 @@
   const __m128i v_eight = _mm_set1_epi16(8);
 
   if (tx_width == 4) {
-    for (int i = 0; i < tx_height; ++i) {
+    int i = 0;
+    do {
       const __m128i v_src = LoadLo8(&source[i * tx_width]);
       const __m128i v_src_mult = _mm_mulhrs_epi16(v_src, v_multiplier_fraction);
       const __m128i frame_data = Load4(dst);
@@ -1245,9 +1707,10 @@
       const __m128i d = _mm_adds_epi16(c, b);
       Store4(dst, _mm_packus_epi16(d, d));
       dst += stride;
-    }
+    } while (++i < tx_height);
   } else {
-    for (int i = 0; i < tx_height; ++i) {
+    int i = 0;
+    do {
       const int row = i * tx_width;
       int j = 0;
       do {
@@ -1264,7 +1727,7 @@
         j += 8;
       } while (j < tx_width);
       dst += stride;
-    }
+    } while (++i < tx_height);
   }
 }
 
@@ -1281,7 +1744,8 @@
       _mm_set1_epi16(kTransformRowMultiplier << 3);
 
   if (tx_width == 4) {
-    for (int i = 0; i < tx_height; ++i) {
+    int i = 0;
+    do {
       const __m128i v_src = LoadLo8(&source[i * tx_width]);
       const __m128i v_src_mult = _mm_mulhrs_epi16(v_src, v_multiplier_fraction);
       const __m128i frame_data = Load4(dst);
@@ -1295,9 +1759,10 @@
       const __m128i c = _mm_adds_epi16(frame_data16, b);
       Store4(dst, _mm_packus_epi16(c, c));
       dst += stride;
-    }
+    } while (++i < tx_height);
   } else {
-    for (int i = 0; i < tx_height; ++i) {
+    int i = 0;
+    do {
       const int row = i * tx_width;
       int j = 0;
       do {
@@ -1317,7 +1782,7 @@
         j += 8;
       } while (j < tx_width);
       dst += stride;
-    }
+    } while (++i < tx_height);
   }
 }
 
@@ -1351,6 +1816,33 @@
   }
 }
 
+LIBGAV1_ALWAYS_INLINE bool Identity8DcOnly(void* dest, const void* source,
+                                           int non_zero_coeff_count,
+                                           bool should_round, int row_shift) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  const __m128i v_src0 = _mm_cvtsi32_si128(src[0]);
+  const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0);
+  const __m128i v_kTransformRowMultiplier =
+      _mm_set1_epi16(kTransformRowMultiplier << 3);
+  const __m128i v_src_round =
+      _mm_mulhrs_epi16(v_src0, v_kTransformRowMultiplier);
+  const __m128i v_src =
+      _mm_cvtepi16_epi32(_mm_blendv_epi8(v_src0, v_src_round, v_mask));
+  const __m128i v_srcx2 = _mm_add_epi32(v_src, v_src);
+  const __m128i v_row_shift_add = _mm_set1_epi32(row_shift);
+  const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add);
+  const __m128i a = _mm_add_epi32(v_srcx2, v_row_shift_add);
+  const __m128i b = _mm_sra_epi32(a, v_row_shift);
+  dst[0] = _mm_extract_epi16(_mm_packs_epi32(b, b), 0);
+  return true;
+}
+
 LIBGAV1_ALWAYS_INLINE void Identity8ColumnStoreToFrame_SSE4_1(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
     const int tx_width, const int tx_height, const int16_t* source) {
@@ -1358,7 +1850,8 @@
   uint8_t* dst = frame[start_y] + start_x;
   const __m128i v_eight = _mm_set1_epi16(8);
   if (tx_width == 4) {
-    for (int i = 0; i < tx_height; ++i) {
+    int i = 0;
+    do {
       const int row = i * tx_width;
       const __m128i v_src = LoadLo8(&source[row]);
       const __m128i v_dst_i = _mm_adds_epi16(v_src, v_src);
@@ -1369,9 +1862,10 @@
       const __m128i d = _mm_adds_epi16(c, b);
       Store4(dst, _mm_packus_epi16(d, d));
       dst += stride;
-    }
+    } while (++i < tx_height);
   } else {
-    for (int i = 0; i < tx_height; ++i) {
+    int i = 0;
+    do {
       const int row = i * tx_width;
       int j = 0;
       do {
@@ -1386,7 +1880,7 @@
         j += 8;
       } while (j < tx_width);
       dst += stride;
-    }
+    } while (++i < tx_height);
   }
 }
 
@@ -1420,6 +1914,34 @@
   }
 }
 
+LIBGAV1_ALWAYS_INLINE bool Identity16DcOnly(void* dest, const void* source,
+                                            int non_zero_coeff_count,
+                                            bool should_round, int shift) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  const __m128i v_src0 = _mm_cvtsi32_si128(src[0]);
+  const __m128i v_mask = _mm_set1_epi16(should_round ? 0xffff : 0);
+  const __m128i v_kTransformRowMultiplier =
+      _mm_set1_epi16(kTransformRowMultiplier << 3);
+  const __m128i v_src_round0 =
+      _mm_mulhrs_epi16(v_src0, v_kTransformRowMultiplier);
+  const __m128i v_src = _mm_blendv_epi8(v_src0, v_src_round0, v_mask);
+  const __m128i v_dual_round = _mm_set1_epi16((1 + (shift << 1)) << 11);
+  const __m128i v_multiplier_one =
+      _mm_set1_epi32((kIdentity16Multiplier << 16) | 0x0001);
+  const __m128i v_shift = _mm_set_epi64x(0, 12 + shift);
+  const __m128i v_src_round = _mm_unpacklo_epi16(v_dual_round, v_src);
+  const __m128i a = _mm_madd_epi16(v_src_round, v_multiplier_one);
+  const __m128i b = _mm_sra_epi32(a, v_shift);
+  dst[0] = _mm_extract_epi16(_mm_packs_epi32(b, b), 0);
+  return true;
+}
+
 LIBGAV1_ALWAYS_INLINE void Identity16ColumnStoreToFrame_SSE4_1(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
     const int tx_width, const int tx_height, const int16_t* source) {
@@ -1430,7 +1952,8 @@
       _mm_set1_epi16(static_cast<int16_t>(kIdentity4MultiplierFraction << 4));
 
   if (tx_width == 4) {
-    for (int i = 0; i < tx_height; ++i) {
+    int i = 0;
+    do {
       const __m128i v_src = LoadLo8(&source[i * tx_width]);
       const __m128i v_src_mult = _mm_mulhrs_epi16(v_src, v_multiplier);
       const __m128i frame_data = Load4(dst);
@@ -1442,9 +1965,10 @@
       const __m128i d = _mm_adds_epi16(c, b);
       Store4(dst, _mm_packus_epi16(d, d));
       dst += stride;
-    }
+    } while (++i < tx_height);
   } else {
-    for (int i = 0; i < tx_height; ++i) {
+    int i = 0;
+    do {
       const int row = i * tx_width;
       int j = 0;
       do {
@@ -1461,7 +1985,7 @@
         j += 8;
       } while (j < tx_width);
       dst += stride;
-    }
+    } while (++i < tx_height);
   }
 }
 
@@ -1485,6 +2009,28 @@
   }
 }
 
+LIBGAV1_ALWAYS_INLINE bool Identity32DcOnly(void* dest, const void* source,
+                                            int non_zero_coeff_count) {
+  if (non_zero_coeff_count > 1) {
+    return false;
+  }
+
+  auto* dst = static_cast<int16_t*>(dest);
+  const auto* const src = static_cast<const int16_t*>(source);
+
+  const __m128i v_src0 = _mm_cvtsi32_si128(src[0]);
+  const __m128i v_kTransformRowMultiplier =
+      _mm_set1_epi16(kTransformRowMultiplier << 3);
+  const __m128i v_src = _mm_mulhrs_epi16(v_src0, v_kTransformRowMultiplier);
+
+  // When combining the identity32 multiplier with the row shift, the
+  // calculation for tx_height equal to 16 can be simplified from
+  // ((A * 4) + 1) >> 1) to (A * 2).
+  const __m128i v_dst_0 = _mm_adds_epi16(v_src, v_src);
+  dst[0] = _mm_extract_epi16(v_dst_0, 0);
+  return true;
+}
+
 LIBGAV1_ALWAYS_INLINE void Identity32ColumnStoreToFrame(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
     const int tx_width, const int tx_height, const int16_t* source) {
@@ -1492,7 +2038,8 @@
   uint8_t* dst = frame[start_y] + start_x;
   const __m128i v_two = _mm_set1_epi16(2);
 
-  for (int i = 0; i < tx_height; ++i) {
+  int i = 0;
+  do {
     const int row = i * tx_width;
     int j = 0;
     do {
@@ -1506,7 +2053,7 @@
       j += 8;
     } while (j < tx_width);
     dst += stride;
-  }
+  } while (++i < tx_height);
 }
 
 //------------------------------------------------------------------------------
@@ -1720,36 +2267,26 @@
 template <int tx_width>
 LIBGAV1_ALWAYS_INLINE void RowShift(int16_t* source, int num_rows,
                                     int row_shift) {
-  const __m128i v_row_shift_add = _mm_set1_epi32(row_shift);
-  const __m128i v_row_shift = _mm_cvtepu32_epi64(v_row_shift_add);
+  const __m128i v_row_shift_add = _mm_set1_epi16(row_shift);
+  const __m128i v_row_shift = _mm_cvtepu16_epi64(v_row_shift_add);
   if (tx_width == 4) {
     // Process two rows per iteration.
     int i = 0;
     do {
-      // Expand to 32 bits to prevent int16_t overflows during the shift add.
       const __m128i residual = LoadUnaligned16(&source[i]);
-      const __m128i a = _mm_cvtepi16_epi32(residual);
-      const __m128i a1 = _mm_cvtepi16_epi32(_mm_srli_si128(residual, 8));
-      const __m128i b = _mm_add_epi32(a, v_row_shift_add);
-      const __m128i b1 = _mm_add_epi32(a1, v_row_shift_add);
-      const __m128i c = _mm_sra_epi32(b, v_row_shift);
-      const __m128i c1 = _mm_sra_epi32(b1, v_row_shift);
-      StoreUnaligned16(&source[i], _mm_packs_epi32(c, c1));
+      const __m128i shifted_residual =
+          ShiftResidual(residual, v_row_shift_add, v_row_shift);
+      StoreUnaligned16(&source[i], shifted_residual);
       i += 8;
     } while (i < tx_width * num_rows);
   } else {
     int i = 0;
     do {
       for (int j = 0; j < tx_width; j += 8) {
-        // Expand to 32 bits to prevent int16_t overflows during the shift add.
         const __m128i residual = LoadUnaligned16(&source[i * tx_width + j]);
-        const __m128i a = _mm_cvtepi16_epi32(residual);
-        const __m128i a1 = _mm_cvtepi16_epi32(_mm_srli_si128(residual, 8));
-        const __m128i b = _mm_add_epi32(a, v_row_shift_add);
-        const __m128i b1 = _mm_add_epi32(a1, v_row_shift_add);
-        const __m128i c = _mm_sra_epi32(b, v_row_shift);
-        const __m128i c1 = _mm_sra_epi32(b1, v_row_shift);
-        StoreUnaligned16(&source[i * tx_width + j], _mm_packs_epi32(c, c1));
+        const __m128i shifted_residual =
+            ShiftResidual(residual, v_row_shift_add, v_row_shift);
+        StoreUnaligned16(&source[i * tx_width + j], shifted_residual);
       }
     } while (++i < num_rows);
   }
@@ -1759,14 +2296,22 @@
                               void* src_buffer, int start_x, int start_y,
                               void* dst_frame, bool is_row,
                               int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
-    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
     const bool should_round = (tx_height == 8);
+    const int row_shift = static_cast<int>(tx_height == 16);
+
+    if (DctDcOnly<4>(&src[0], &src[0], non_zero_coeff_count, should_round,
+                     row_shift)) {
+      return;
+    }
+
+    const int num_rows =
+        GetNumRows<4>(tx_type, tx_height, non_zero_coeff_count);
     if (should_round) {
       ApplyRounding<4>(src, num_rows);
     }
@@ -1795,18 +2340,20 @@
     FlipColumns<4>(src, tx_width);
   }
 
-  if (tx_width == 4) {
-    // Process 4 1d dct4 columns in parallel.
-    Dct4_SSE4_1<ButterflyRotation_4, false>(&src[0], &src[0], tx_width,
-                                            /*transpose=*/false);
-  } else {
-    // Process 8 1d dct4 columns in parallel per iteration.
-    int i = 0;
-    do {
-      Dct4_SSE4_1<ButterflyRotation_8, true>(&src[i], &src[i], tx_width,
-                                             /*transpose=*/false);
-      i += 8;
-    } while (i < tx_width);
+  if (!DctDcOnlyColumn<4>(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    if (tx_width == 4) {
+      // Process 4 1d dct4 columns in parallel.
+      Dct4_SSE4_1<ButterflyRotation_4, false>(&src[0], &src[0], tx_width,
+                                              /*transpose=*/false);
+    } else {
+      // Process 8 1d dct4 columns in parallel per iteration.
+      int i = 0;
+      do {
+        Dct4_SSE4_1<ButterflyRotation_8, true>(&src[i], &src[i], tx_width,
+                                               /*transpose=*/false);
+        i += 8;
+      } while (i < tx_width);
+    }
   }
   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 4, src, tx_type);
 }
@@ -1815,14 +2362,23 @@
                               void* src_buffer, int start_x, int start_y,
                               void* dst_frame, bool is_row,
                               int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
-    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
-    if (kShouldRound[tx_size]) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+
+    if (DctDcOnly<8>(&src[0], &src[0], non_zero_coeff_count, should_round,
+                     row_shift)) {
+      return;
+    }
+
+    const int num_rows =
+        GetNumRows<8>(tx_type, tx_height, non_zero_coeff_count);
+    if (should_round) {
       ApplyRounding<8>(src, num_rows);
     }
 
@@ -1839,7 +2395,6 @@
         i += 8;
       } while (i < num_rows);
     }
-    const uint8_t row_shift = kTransformRowShift[tx_size];
     if (row_shift > 0) {
       RowShift<8>(src, num_rows, row_shift);
     }
@@ -1851,18 +2406,20 @@
     FlipColumns<8>(src, tx_width);
   }
 
-  if (tx_width == 4) {
-    // Process 4 1d dct8 columns in parallel.
-    Dct8_SSE4_1<ButterflyRotation_4, true>(&src[0], &src[0], 4,
-                                           /*transpose=*/false);
-  } else {
-    // Process 8 1d dct8 columns in parallel per iteration.
-    int i = 0;
-    do {
-      Dct8_SSE4_1<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
-                                              /*transpose=*/false);
-      i += 8;
-    } while (i < tx_width);
+  if (!DctDcOnlyColumn<8>(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    if (tx_width == 4) {
+      // Process 4 1d dct8 columns in parallel.
+      Dct8_SSE4_1<ButterflyRotation_4, true>(&src[0], &src[0], 4,
+                                             /*transpose=*/false);
+    } else {
+      // Process 8 1d dct8 columns in parallel per iteration.
+      int i = 0;
+      do {
+        Dct8_SSE4_1<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
+                                                /*transpose=*/false);
+        i += 8;
+      } while (i < tx_width);
+    }
   }
   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 8, src, tx_type);
 }
@@ -1871,15 +2428,23 @@
                                void* src_buffer, int start_x, int start_y,
                                void* dst_frame, bool is_row,
                                int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+
+    if (DctDcOnly<16>(&src[0], &src[0], non_zero_coeff_count, should_round,
+                      row_shift)) {
+      return;
+    }
+
     const int num_rows =
-        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
-    if (kShouldRound[tx_size]) {
+        GetNumRows<16>(tx_type, std::min(tx_height, 32), non_zero_coeff_count);
+    if (should_round) {
       ApplyRounding<16>(src, num_rows);
     }
 
@@ -1896,7 +2461,6 @@
         i += 8;
       } while (i < num_rows);
     }
-    const uint8_t row_shift = kTransformRowShift[tx_size];
     // row_shift is always non zero here.
     RowShift<16>(src, num_rows, row_shift);
 
@@ -1908,18 +2472,20 @@
     FlipColumns<16>(src, tx_width);
   }
 
-  if (tx_width == 4) {
-    // Process 4 1d dct16 columns in parallel.
-    Dct16_SSE4_1<ButterflyRotation_4, true>(&src[0], &src[0], 4,
-                                            /*transpose=*/false);
-  } else {
-    int i = 0;
-    do {
-      // Process 8 1d dct16 columns in parallel per iteration.
-      Dct16_SSE4_1<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
-                                               /*transpose=*/false);
-      i += 8;
-    } while (i < tx_width);
+  if (!DctDcOnlyColumn<16>(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    if (tx_width == 4) {
+      // Process 4 1d dct16 columns in parallel.
+      Dct16_SSE4_1<ButterflyRotation_4, true>(&src[0], &src[0], 4,
+                                              /*transpose=*/false);
+    } else {
+      int i = 0;
+      do {
+        // Process 8 1d dct16 columns in parallel per iteration.
+        Dct16_SSE4_1<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
+                                                 /*transpose=*/false);
+        i += 8;
+      } while (i < tx_width);
+    }
   }
   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 16, src, tx_type);
 }
@@ -1928,15 +2494,23 @@
                                void* src_buffer, int start_x, int start_y,
                                void* dst_frame, bool is_row,
                                int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+
+    if (DctDcOnly<32>(&src[0], &src[0], non_zero_coeff_count, should_round,
+                      row_shift)) {
+      return;
+    }
+
     const int num_rows =
-        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
-    if (kShouldRound[tx_size]) {
+        GetNumRows<32>(tx_type, std::min(tx_height, 32), non_zero_coeff_count);
+    if (should_round) {
       ApplyRounding<32>(src, num_rows);
     }
     // Process 8 1d dct32 rows in parallel per iteration.
@@ -1945,7 +2519,6 @@
       Dct32_SSE4_1(&src[i * 32], &src[i * 32], 32, /*transpose=*/true);
       i += 8;
     } while (i < num_rows);
-    const uint8_t row_shift = kTransformRowShift[tx_size];
     // row_shift is always non zero here.
     RowShift<32>(src, num_rows, row_shift);
 
@@ -1953,12 +2526,14 @@
   }
 
   assert(!is_row);
-  // Process 8 1d dct32 columns in parallel per iteration.
-  int i = 0;
-  do {
-    Dct32_SSE4_1(&src[i], &src[i], tx_width, /*transpose=*/false);
-    i += 8;
-  } while (i < tx_width);
+  if (!DctDcOnlyColumn<32>(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    // Process 8 1d dct32 columns in parallel per iteration.
+    int i = 0;
+    do {
+      Dct32_SSE4_1(&src[i], &src[i], tx_width, /*transpose=*/false);
+      i += 8;
+    } while (i < tx_width);
+  }
   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 32, src, tx_type);
 }
 
@@ -1966,15 +2541,23 @@
                                void* src_buffer, int start_x, int start_y,
                                void* dst_frame, bool is_row,
                                int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+
+    if (DctDcOnly<64>(&src[0], &src[0], non_zero_coeff_count, should_round,
+                      row_shift)) {
+      return;
+    }
+
     const int num_rows =
-        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
-    if (kShouldRound[tx_size]) {
+        GetNumRows<32>(tx_type, std::min(tx_height, 32), non_zero_coeff_count);
+    if (should_round) {
       ApplyRounding<64>(src, num_rows);
     }
     // Process 8 1d dct64 rows in parallel per iteration.
@@ -1983,7 +2566,6 @@
       Dct64_SSE4_1(&src[i * 64], &src[i * 64], 64, /*transpose=*/true);
       i += 8;
     } while (i < num_rows);
-    const uint8_t row_shift = kTransformRowShift[tx_size];
     // row_shift is always non zero here.
     RowShift<64>(src, num_rows, row_shift);
 
@@ -1991,12 +2573,14 @@
   }
 
   assert(!is_row);
-  // Process 8 1d dct64 columns in parallel per iteration.
-  int i = 0;
-  do {
-    Dct64_SSE4_1(&src[i], &src[i], tx_width, /*transpose=*/false);
-    i += 8;
-  } while (i < tx_width);
+  if (!DctDcOnlyColumn<64>(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    // Process 8 1d dct64 columns in parallel per iteration.
+    int i = 0;
+    do {
+      Dct64_SSE4_1(&src[i], &src[i], tx_width, /*transpose=*/false);
+      i += 8;
+    } while (i < tx_width);
+  }
   StoreToFrameWithRound(frame, start_x, start_y, tx_width, 64, src, tx_type);
 }
 
@@ -2004,14 +2588,22 @@
                                void* src_buffer, int start_x, int start_y,
                                void* dst_frame, bool is_row,
                                int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
-    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
+    const uint8_t row_shift = static_cast<uint8_t>(tx_height == 16);
     const bool should_round = (tx_height == 8);
+
+    if (Adst4DcOnly(&src[0], &src[0], non_zero_coeff_count, should_round,
+                    row_shift)) {
+      return;
+    }
+
+    const int num_rows =
+        GetNumRows<4>(tx_type, tx_height, non_zero_coeff_count);
     if (should_round) {
       ApplyRounding<4>(src, num_rows);
     }
@@ -2024,7 +2616,7 @@
       i += 4;
     } while (i < num_rows);
 
-    if (tx_height == 16) {
+    if (row_shift != 0u) {
       RowShift<4>(src, num_rows, 1);
     }
     return;
@@ -2035,13 +2627,14 @@
     FlipColumns<4>(src, tx_width);
   }
 
-  // Process 4 1d adst4 columns in parallel per iteration.
-  int i = 0;
-  do {
-    Adst4_SSE4_1<false>(&src[i], &src[i], tx_width, /*transpose=*/false);
-    i += 4;
-  } while (i < tx_width);
-
+  if (!Adst4DcOnlyColumn(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    // Process 4 1d adst4 columns in parallel per iteration.
+    int i = 0;
+    do {
+      Adst4_SSE4_1<false>(&src[i], &src[i], tx_width, /*transpose=*/false);
+      i += 4;
+    } while (i < tx_width);
+  }
   StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
                                                    tx_width, 4, src, tx_type);
 }
@@ -2050,14 +2643,23 @@
                                void* src_buffer, int start_x, int start_y,
                                void* dst_frame, bool is_row,
                                int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
-    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
-    if (kShouldRound[tx_size]) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+
+    if (Adst8DcOnly(&src[0], &src[0], non_zero_coeff_count, should_round,
+                    row_shift)) {
+      return;
+    }
+
+    const int num_rows =
+        GetNumRows<8>(tx_type, tx_height, non_zero_coeff_count);
+    if (should_round) {
       ApplyRounding<8>(src, num_rows);
     }
 
@@ -2075,7 +2677,6 @@
         i += 8;
       } while (i < num_rows);
     }
-    const uint8_t row_shift = kTransformRowShift[tx_size];
     if (row_shift > 0) {
       RowShift<8>(src, num_rows, row_shift);
     }
@@ -2087,18 +2688,20 @@
     FlipColumns<8>(src, tx_width);
   }
 
-  if (tx_width == 4) {
-    // Process 4 1d adst8 columns in parallel.
-    Adst8_SSE4_1<ButterflyRotation_4, true>(&src[0], &src[0], 4,
-                                            /*transpose=*/false);
-  } else {
-    // Process 8 1d adst8 columns in parallel per iteration.
-    int i = 0;
-    do {
-      Adst8_SSE4_1<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
-                                               /*transpose=*/false);
-      i += 8;
-    } while (i < tx_width);
+  if (!Adst8DcOnlyColumn(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    if (tx_width == 4) {
+      // Process 4 1d adst8 columns in parallel.
+      Adst8_SSE4_1<ButterflyRotation_4, true>(&src[0], &src[0], 4,
+                                              /*transpose=*/false);
+    } else {
+      // Process 8 1d adst8 columns in parallel per iteration.
+      int i = 0;
+      do {
+        Adst8_SSE4_1<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
+                                                 /*transpose=*/false);
+        i += 8;
+      } while (i < tx_width);
+    }
   }
   StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
                                                    tx_width, 8, src, tx_type);
@@ -2108,15 +2711,23 @@
                                 void* src_buffer, int start_x, int start_y,
                                 void* dst_frame, bool is_row,
                                 int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+
+    if (Adst16DcOnly(&src[0], &src[0], non_zero_coeff_count, should_round,
+                     row_shift)) {
+      return;
+    }
+
     const int num_rows =
-        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
-    if (kShouldRound[tx_size]) {
+        GetNumRows<16>(tx_type, std::min(tx_height, 32), non_zero_coeff_count);
+    if (should_round) {
       ApplyRounding<16>(src, num_rows);
     }
 
@@ -2133,7 +2744,6 @@
         i += 8;
       } while (i < num_rows);
     }
-    const uint8_t row_shift = kTransformRowShift[tx_size];
     // row_shift is always non zero here.
     RowShift<16>(src, num_rows, row_shift);
 
@@ -2145,18 +2755,20 @@
     FlipColumns<16>(src, tx_width);
   }
 
-  if (tx_width == 4) {
-    // Process 4 1d adst16 columns in parallel.
-    Adst16_SSE4_1<ButterflyRotation_4, true>(&src[0], &src[0], 4,
-                                             /*transpose=*/false);
-  } else {
-    int i = 0;
-    do {
-      // Process 8 1d adst16 columns in parallel per iteration.
-      Adst16_SSE4_1<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
-                                                /*transpose=*/false);
-      i += 8;
-    } while (i < tx_width);
+  if (!Adst16DcOnlyColumn(&src[0], &src[0], non_zero_coeff_count, tx_width)) {
+    if (tx_width == 4) {
+      // Process 4 1d adst16 columns in parallel.
+      Adst16_SSE4_1<ButterflyRotation_4, true>(&src[0], &src[0], 4,
+                                               /*transpose=*/false);
+    } else {
+      int i = 0;
+      do {
+        // Process 8 1d adst16 columns in parallel per iteration.
+        Adst16_SSE4_1<ButterflyRotation_8, false>(&src[i], &src[i], tx_width,
+                                                  /*transpose=*/false);
+        i += 8;
+      } while (i < tx_width);
+    }
   }
   StoreToFrameWithRound</*enable_flip_rows=*/true>(frame, start_x, start_y,
                                                    tx_width, 16, src, tx_type);
@@ -2166,7 +2778,7 @@
                                    void* src_buffer, int start_x, int start_y,
                                    void* dst_frame, bool is_row,
                                    int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
@@ -2179,8 +2791,14 @@
       return;
     }
 
-    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
     const bool should_round = (tx_height == 8);
+    if (Identity4DcOnly(&src[0], &src[0], non_zero_coeff_count, should_round,
+                        tx_height)) {
+      return;
+    }
+
+    const int num_rows =
+        GetNumRows<4>(tx_type, tx_height, non_zero_coeff_count);
     if (should_round) {
       ApplyRounding<4>(src, num_rows);
     }
@@ -2200,10 +2818,12 @@
     return;
   }
   assert(!is_row);
+  const int height = (non_zero_coeff_count == 1) ? 1 : tx_height;
   // Special case: Process row calculations during column transform call.
   if (tx_type == kTransformTypeIdentityIdentity &&
       (tx_size == kTransformSize4x4 || tx_size == kTransformSize8x4)) {
-    Identity4RowColumnStoreToFrame(frame, start_x, start_y, tx_width, 4, src);
+    Identity4RowColumnStoreToFrame(frame, start_x, start_y, tx_width, height,
+                                   src);
     return;
   }
 
@@ -2211,15 +2831,14 @@
     FlipColumns<4>(src, tx_width);
   }
 
-  Identity4ColumnStoreToFrame(frame, start_x, start_y, tx_width,
-                              /*tx_height=*/4, src);
+  Identity4ColumnStoreToFrame(frame, start_x, start_y, tx_width, height, src);
 }
 
 void Identity8TransformLoop_SSE4_1(TransformType tx_type, TransformSize tx_size,
                                    void* src_buffer, int start_x, int start_y,
                                    void* dst_frame, bool is_row,
                                    int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
@@ -2231,8 +2850,17 @@
         tx_size == kTransformSize8x4) {
       return;
     }
-    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
-    if (kShouldRound[tx_size]) {
+
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+    if (Identity8DcOnly(&src[0], &src[0], non_zero_coeff_count, should_round,
+                        row_shift)) {
+      return;
+    }
+
+    const int num_rows =
+        GetNumRows<8>(tx_type, tx_height, non_zero_coeff_count);
+    if (should_round) {
       ApplyRounding<8>(src, num_rows);
     }
 
@@ -2266,23 +2894,31 @@
     FlipColumns<8>(src, tx_width);
   }
 
-  Identity8ColumnStoreToFrame_SSE4_1(frame, start_x, start_y, tx_width,
-                                     /*tx_height=*/8, src);
+  const int height = (non_zero_coeff_count == 1) ? 1 : tx_height;
+  Identity8ColumnStoreToFrame_SSE4_1(frame, start_x, start_y, tx_width, height,
+                                     src);
 }
 
 void Identity16TransformLoop_SSE4_1(TransformType tx_type,
                                     TransformSize tx_size, void* src_buffer,
                                     int start_x, int start_y, void* dst_frame,
                                     bool is_row, int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
+    const bool should_round = kShouldRound[tx_size];
+    const uint8_t row_shift = kTransformRowShift[tx_size];
+    if (Identity16DcOnly(&src[0], &src[0], non_zero_coeff_count, should_round,
+                         row_shift)) {
+      return;
+    }
+
     const int num_rows =
-        (non_zero_coeff_count == 1) ? 1 : std::min(tx_height, 32);
-    if (kShouldRound[tx_size]) {
+        GetNumRows<16>(tx_type, std::min(tx_height, 32), non_zero_coeff_count);
+    if (should_round) {
       ApplyRounding<16>(src, num_rows);
     }
     int i = 0;
@@ -2298,22 +2934,21 @@
   if (kTransformFlipColumnsMask.Contains(tx_type)) {
     FlipColumns<16>(src, tx_width);
   }
-  Identity16ColumnStoreToFrame_SSE4_1(frame, start_x, start_y, tx_width,
-                                      /*tx_height=*/16, src);
+  const int height = (non_zero_coeff_count == 1) ? 1 : tx_height;
+  Identity16ColumnStoreToFrame_SSE4_1(frame, start_x, start_y, tx_width, height,
+                                      src);
 }
 
-void Identity32TransformLoop_SSE4_1(TransformType /*tx_type*/,
+void Identity32TransformLoop_SSE4_1(TransformType tx_type,
                                     TransformSize tx_size, void* src_buffer,
                                     int start_x, int start_y, void* dst_frame,
                                     bool is_row, int non_zero_coeff_count) {
-  auto& frame = *reinterpret_cast<Array2DView<uint8_t>*>(dst_frame);
+  auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
 
   if (is_row) {
-    const int num_rows = (non_zero_coeff_count == 1) ? 1 : tx_height;
-
     // When combining the identity32 multiplier with the row shift, the
     // calculations for tx_height == 8 and tx_height == 32 can be simplified
     // from ((A * 4) + 2) >> 2) to A.
@@ -2321,6 +2956,16 @@
       return;
     }
 
+    // Process kTransformSize32x16. The src is always rounded before the
+    // identity transform and shifted by 1 afterwards.
+
+    if (Identity32DcOnly(&src[0], &src[0], non_zero_coeff_count)) {
+      return;
+    }
+
+    const int num_rows =
+        GetNumRows<32>(tx_type, tx_height, non_zero_coeff_count);
+
     // Process kTransformSize32x16
     assert(tx_size == kTransformSize32x16);
     ApplyRounding<32>(src, num_rows);
@@ -2333,8 +2978,8 @@
   }
 
   assert(!is_row);
-  Identity32ColumnStoreToFrame(frame, start_x, start_y, tx_width,
-                               /*tx_height=*/32, src);
+  const int height = (non_zero_coeff_count == 1) ? 1 : tx_height;
+  Identity32ColumnStoreToFrame(frame, start_x, start_y, tx_width, height, src);
 }
 
 void Wht4TransformLoop_SSE4_1(TransformType tx_type, TransformSize tx_size,
@@ -2397,7 +3042,7 @@
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   InitAll<int16_t, uint8_t>(dsp);
@@ -2464,7 +3109,7 @@
 
 }  // namespace dsp
 }  // namespace libgav1
-#else   // !LIBGAV1_ENABLE_SSE4_1
+#else  // !LIBGAV1_ENABLE_SSE4_1
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/x86/inverse_transform_sse4.h b/libgav1/src/dsp/x86/inverse_transform_sse4.h
index c1eab25..423173b 100644
--- a/libgav1/src/dsp/x86/inverse_transform_sse4.h
+++ b/libgav1/src/dsp/x86/inverse_transform_sse4.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_X86_INVERSE_TRANSFORM_SSE4_H_
 #define LIBGAV1_SRC_DSP_X86_INVERSE_TRANSFORM_SSE4_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -35,55 +35,55 @@
 #if LIBGAV1_ENABLE_SSE4_1
 
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct
-#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct
-#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct
-#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct
-#define LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst
-#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst
-#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity
-#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity
-#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity
-#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht LIBGAV1_CPU_SSE4_1
 #endif
 #endif  // LIBGAV1_ENABLE_SSE4_1
 #endif  // LIBGAV1_SRC_DSP_X86_INVERSE_TRANSFORM_SSE4_H_
diff --git a/libgav1/src/dsp/x86/loop_filter_sse4.cc b/libgav1/src/dsp/x86/loop_filter_sse4.cc
index 60684c5..462b885 100644
--- a/libgav1/src/dsp/x86/loop_filter_sse4.cc
+++ b/libgav1/src/dsp/x86/loop_filter_sse4.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/loop_filter.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
@@ -24,6 +24,8 @@
 #include <cstdint>
 #include <cstring>
 
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/dsp/x86/common_sse4.h"
 
 namespace libgav1 {
@@ -49,7 +51,7 @@
 
 inline __m128i CheckOuterThreshF4(const __m128i& q1q0, const __m128i& p1p0,
                                   const __m128i& outer_thresh) {
-  const __m128i fe = _mm_set1_epi8(0xfe);
+  const __m128i fe = _mm_set1_epi8(static_cast<int8_t>(0xfe));
   //  abs(p0 - q0) * 2 + abs(p1 - q1) / 2 <= outer_thresh;
   const __m128i abs_pmq = AbsDiff(p1p0, q1q0);
   const __m128i a = _mm_adds_epu8(abs_pmq, abs_pmq);
@@ -103,7 +105,7 @@
 
 inline void Filter4(const __m128i& qp1, const __m128i& qp0, __m128i* oqp1,
                     __m128i* oqp0, const __m128i& mask, const __m128i& hev) {
-  const __m128i t80 = _mm_set1_epi8(0x80);
+  const __m128i t80 = _mm_set1_epi8(static_cast<int8_t>(0x80));
   const __m128i t1 = _mm_set1_epi8(0x1);
   const __m128i qp1qp0 = _mm_unpacklo_epi64(qp0, qp1);
   const __m128i qps1qps0 = _mm_xor_si128(qp1qp0, t80);
@@ -1099,7 +1101,7 @@
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   static_cast<void>(dsp);
 #if DSP_ENABLED_8BPP_SSE4_1(LoopFilterSize4_LoopFilterTypeHorizontal)
@@ -1141,7 +1143,7 @@
 struct LoopFilterFuncs_SSE4_1 {
   LoopFilterFuncs_SSE4_1() = delete;
 
-  static const int kThreshShift = bitdepth - 8;
+  static constexpr int kThreshShift = bitdepth - 8;
 
   static void Vertical4(void* dest, ptrdiff_t stride, int outer_thresh,
                         int inner_thresh, int hev_thresh);
@@ -2190,10 +2192,10 @@
   StoreUnaligned16(dst - 8 + 8 + 3 * stride, x3);
 }
 
-using Defs10bpp = LoopFilterFuncs_SSE4_1<10>;
+using Defs10bpp = LoopFilterFuncs_SSE4_1<kBitdepth10>;
 
 void Init10bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
   assert(dsp != nullptr);
   static_cast<void>(dsp);
 #if DSP_ENABLED_10BPP_SSE4_1(LoopFilterSize4_LoopFilterTypeHorizontal)
@@ -2243,7 +2245,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_SSE4_1
+#else  // !LIBGAV1_ENABLE_SSE4_1
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/x86/loop_filter_sse4.h b/libgav1/src/dsp/x86/loop_filter_sse4.h
index 20ccdef..b8c1fe5 100644
--- a/libgav1/src/dsp/x86/loop_filter_sse4.h
+++ b/libgav1/src/dsp/x86/loop_filter_sse4.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_X86_LOOP_FILTER_SSE4_H_
 #define LIBGAV1_SRC_DSP_X86_LOOP_FILTER_SSE4_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -36,82 +36,82 @@
 
 #ifndef LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeHorizontal
 #define LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeHorizontal
 #define LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeHorizontal
 #define LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeHorizontal
 #define LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeVertical
 #define LIBGAV1_Dsp8bpp_LoopFilterSize4_LoopFilterTypeVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeVertical
 #define LIBGAV1_Dsp8bpp_LoopFilterSize6_LoopFilterTypeVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeVertical
 #define LIBGAV1_Dsp8bpp_LoopFilterSize8_LoopFilterTypeVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeVertical
 #define LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_LoopFilterSize4_LoopFilterTypeHorizontal
 #define LIBGAV1_Dsp10bpp_LoopFilterSize4_LoopFilterTypeHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_LoopFilterSize6_LoopFilterTypeHorizontal
 #define LIBGAV1_Dsp10bpp_LoopFilterSize6_LoopFilterTypeHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_LoopFilterSize8_LoopFilterTypeHorizontal
 #define LIBGAV1_Dsp10bpp_LoopFilterSize8_LoopFilterTypeHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_LoopFilterSize14_LoopFilterTypeHorizontal
 #define LIBGAV1_Dsp10bpp_LoopFilterSize14_LoopFilterTypeHorizontal \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_LoopFilterSize4_LoopFilterTypeVertical
 #define LIBGAV1_Dsp10bpp_LoopFilterSize4_LoopFilterTypeVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_LoopFilterSize6_LoopFilterTypeVertical
 #define LIBGAV1_Dsp10bpp_LoopFilterSize6_LoopFilterTypeVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_LoopFilterSize8_LoopFilterTypeVertical
 #define LIBGAV1_Dsp10bpp_LoopFilterSize8_LoopFilterTypeVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp10bpp_LoopFilterSize14_LoopFilterTypeVertical
 #define LIBGAV1_Dsp10bpp_LoopFilterSize14_LoopFilterTypeVertical \
-  LIBGAV1_DSP_SSE4_1
+  LIBGAV1_CPU_SSE4_1
 #endif
 
 #endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/loop_restoration_sse4.cc b/libgav1/src/dsp/x86/loop_restoration_sse4.cc
index 29c97ca..34f4ae8 100644
--- a/libgav1/src/dsp/x86/loop_restoration_sse4.cc
+++ b/libgav1/src/dsp/x86/loop_restoration_sse4.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/loop_restoration.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 #include <smmintrin.h>
@@ -25,9 +25,9 @@
 
 #include "src/dsp/common.h"
 #include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/dsp/x86/common_sse4.h"
 #include "src/utils/common.h"
-#include "src/utils/compiler_attributes.h"
 #include "src/utils/constants.h"
 
 namespace libgav1 {
@@ -35,700 +35,1743 @@
 namespace low_bitdepth {
 namespace {
 
-// Precision of a division table (mtable)
-constexpr int kSgrProjScaleBits = 20;
-constexpr int kSgrProjReciprocalBits = 12;
-// Core selfguided restoration precision bits.
-constexpr int kSgrProjSgrBits = 8;
-// Precision bits of generated values higher than source before projection.
-constexpr int kSgrProjRestoreBits = 4;
-
-// Note: range of wiener filter coefficients.
-// Wiener filter coefficients are symmetric, and their sum is 1 (128).
-// The range of each coefficient:
-// filter[0] = filter[6], 4 bits, min = -5, max = 10.
-// filter[1] = filter[5], 5 bits, min = -23, max = 8.
-// filter[2] = filter[4], 6 bits, min = -17, max = 46.
-// filter[3] = 128 - (filter[0] + filter[1] + filter[2]) * 2.
-// int8_t is used for the sse4 code, so in order to fit in an int8_t, the 128
-// offset must be removed from filter[3].
-// filter[3] = 0 - (filter[0] + filter[1] + filter[2]) * 2.
-// The 128 offset will be added back in the loop.
-inline void PopulateWienerCoefficients(
-    const RestorationUnitInfo& restoration_info, int direction,
-    int8_t* const filter) {
-  filter[3] = 0;
-  for (int i = 0; i < 3; ++i) {
-    const int8_t coeff = restoration_info.wiener_info.filter[direction][i];
-    filter[i] = coeff;
-    filter[6 - i] = coeff;
-    filter[3] -= MultiplyBy2(coeff);
-  }
-
-  // The Wiener filter has only 7 coefficients, but we run it as an 8-tap
-  // filter in SIMD. The 8th coefficient of the filter must be set to 0.
-  filter[7] = 0;
-}
-
-// This function calls LoadUnaligned16() to read 10 bytes from the |source|
-// buffer. Since the LoadUnaligned16() call over-reads 6 bytes, the |source|
-// buffer must be at least (height + kSubPixelTaps - 2) * source_stride + 6
-// bytes long.
-void WienerFilter_SSE4_1(const void* source, void* const dest,
-                         const RestorationUnitInfo& restoration_info,
-                         ptrdiff_t source_stride, ptrdiff_t dest_stride,
-                         int width, int height,
-                         RestorationBuffer* const buffer) {
-  int8_t filter[kSubPixelTaps];
+inline void WienerHorizontalTap7Kernel(const __m128i s[2],
+                                       const __m128i filter[4],
+                                       int16_t* const wiener_buffer) {
   const int limit =
       (1 << (8 + 1 + kWienerFilterBits - kInterRoundBitsHorizontal)) - 1;
-  const auto* src = static_cast<const uint8_t*>(source);
-  auto* dst = static_cast<uint8_t*>(dest);
-  const ptrdiff_t buffer_stride = buffer->wiener_buffer_stride;
-  auto* wiener_buffer = buffer->wiener_buffer;
-  // horizontal filtering.
-  PopulateWienerCoefficients(restoration_info, WienerInfo::kHorizontal, filter);
-  const int center_tap = 3;
-  src -= center_tap * source_stride + center_tap;
-
-  const int horizontal_rounding =
+  const int offset =
       1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
-  const __m128i v_horizontal_rounding =
-      _mm_shufflelo_epi16(_mm_cvtsi32_si128(horizontal_rounding), 0);
-  const __m128i v_limit = _mm_shufflelo_epi16(_mm_cvtsi32_si128(limit), 0);
-  const __m128i v_horizontal_filter = LoadLo8(filter);
-  __m128i v_k1k0 = _mm_shufflelo_epi16(v_horizontal_filter, 0x0);
-  __m128i v_k3k2 = _mm_shufflelo_epi16(v_horizontal_filter, 0x55);
-  __m128i v_k5k4 = _mm_shufflelo_epi16(v_horizontal_filter, 0xaa);
-  __m128i v_k7k6 = _mm_shufflelo_epi16(v_horizontal_filter, 0xff);
-  const __m128i v_round_0 = _mm_shufflelo_epi16(
-      _mm_cvtsi32_si128(1 << (kInterRoundBitsHorizontal - 1)), 0);
-  const __m128i v_round_0_shift = _mm_cvtsi32_si128(kInterRoundBitsHorizontal);
-  const __m128i v_offset_shift =
-      _mm_cvtsi32_si128(7 - kInterRoundBitsHorizontal);
+  const __m128i offsets = _mm_set1_epi16(-offset);
+  const __m128i limits = _mm_set1_epi16(limit - offset);
+  const __m128i round = _mm_set1_epi16(1 << (kInterRoundBitsHorizontal - 1));
+  const auto s01 = _mm_alignr_epi8(s[1], s[0], 1);
+  const auto s23 = _mm_alignr_epi8(s[1], s[0], 5);
+  const auto s45 = _mm_alignr_epi8(s[1], s[0], 9);
+  const auto s67 = _mm_alignr_epi8(s[1], s[0], 13);
+  const __m128i madd01 = _mm_maddubs_epi16(s01, filter[0]);
+  const __m128i madd23 = _mm_maddubs_epi16(s23, filter[1]);
+  const __m128i madd45 = _mm_maddubs_epi16(s45, filter[2]);
+  const __m128i madd67 = _mm_maddubs_epi16(s67, filter[3]);
+  const __m128i madd0123 = _mm_add_epi16(madd01, madd23);
+  const __m128i madd4567 = _mm_add_epi16(madd45, madd67);
+  // The sum range here is [-128 * 255, 90 * 255].
+  const __m128i madd = _mm_add_epi16(madd0123, madd4567);
+  const __m128i sum = _mm_add_epi16(madd, round);
+  const __m128i rounded_sum0 = _mm_srai_epi16(sum, kInterRoundBitsHorizontal);
+  // Calculate scaled down offset correction, and add to sum here to prevent
+  // signed 16 bit outranging.
+  const __m128i s_3x128 =
+      _mm_slli_epi16(_mm_srli_epi16(s23, 8), 7 - kInterRoundBitsHorizontal);
+  const __m128i rounded_sum1 = _mm_add_epi16(rounded_sum0, s_3x128);
+  const __m128i d0 = _mm_max_epi16(rounded_sum1, offsets);
+  const __m128i d1 = _mm_min_epi16(d0, limits);
+  StoreAligned16(wiener_buffer, d1);
+}
 
-  int y = 0;
+inline void WienerHorizontalTap5Kernel(const __m128i s[2],
+                                       const __m128i filter[3],
+                                       int16_t* const wiener_buffer) {
+  const int limit =
+      (1 << (8 + 1 + kWienerFilterBits - kInterRoundBitsHorizontal)) - 1;
+  const int offset =
+      1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
+  const __m128i offsets = _mm_set1_epi16(-offset);
+  const __m128i limits = _mm_set1_epi16(limit - offset);
+  const __m128i round = _mm_set1_epi16(1 << (kInterRoundBitsHorizontal - 1));
+  const auto s01 = _mm_alignr_epi8(s[1], s[0], 1);
+  const auto s23 = _mm_alignr_epi8(s[1], s[0], 5);
+  const auto s45 = _mm_alignr_epi8(s[1], s[0], 9);
+  const __m128i madd01 = _mm_maddubs_epi16(s01, filter[0]);
+  const __m128i madd23 = _mm_maddubs_epi16(s23, filter[1]);
+  const __m128i madd45 = _mm_maddubs_epi16(s45, filter[2]);
+  const __m128i madd0123 = _mm_add_epi16(madd01, madd23);
+  // The sum range here is [-128 * 255, 90 * 255].
+  const __m128i madd = _mm_add_epi16(madd0123, madd45);
+  const __m128i sum = _mm_add_epi16(madd, round);
+  const __m128i rounded_sum0 = _mm_srai_epi16(sum, kInterRoundBitsHorizontal);
+  // Calculate scaled down offset correction, and add to sum here to prevent
+  // signed 16 bit outranging.
+  const __m128i s_3x128 =
+      _mm_srli_epi16(_mm_slli_epi16(s23, 8), kInterRoundBitsHorizontal + 1);
+  const __m128i rounded_sum1 = _mm_add_epi16(rounded_sum0, s_3x128);
+  const __m128i d0 = _mm_max_epi16(rounded_sum1, offsets);
+  const __m128i d1 = _mm_min_epi16(d0, limits);
+  StoreAligned16(wiener_buffer, d1);
+}
+
+inline void WienerHorizontalTap3Kernel(const __m128i s[2],
+                                       const __m128i filter[2],
+                                       int16_t* const wiener_buffer) {
+  const int limit =
+      (1 << (8 + 1 + kWienerFilterBits - kInterRoundBitsHorizontal)) - 1;
+  const int offset =
+      1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
+  const __m128i offsets = _mm_set1_epi16(-offset);
+  const __m128i limits = _mm_set1_epi16(limit - offset);
+  const __m128i round = _mm_set1_epi16(1 << (kInterRoundBitsHorizontal - 1));
+  const auto s01 = _mm_alignr_epi8(s[1], s[0], 1);
+  const auto s23 = _mm_alignr_epi8(s[1], s[0], 5);
+  const __m128i madd01 = _mm_maddubs_epi16(s01, filter[0]);
+  const __m128i madd23 = _mm_maddubs_epi16(s23, filter[1]);
+  // The sum range here is [-128 * 255, 90 * 255].
+  const __m128i madd = _mm_add_epi16(madd01, madd23);
+  const __m128i sum = _mm_add_epi16(madd, round);
+  const __m128i rounded_sum0 = _mm_srai_epi16(sum, kInterRoundBitsHorizontal);
+  // Calculate scaled down offset correction, and add to sum here to prevent
+  // signed 16 bit outranging.
+  const __m128i s_3x128 =
+      _mm_slli_epi16(_mm_srli_epi16(s01, 8), 7 - kInterRoundBitsHorizontal);
+  const __m128i rounded_sum1 = _mm_add_epi16(rounded_sum0, s_3x128);
+  const __m128i d0 = _mm_max_epi16(rounded_sum1, offsets);
+  const __m128i d1 = _mm_min_epi16(d0, limits);
+  StoreAligned16(wiener_buffer, d1);
+}
+
+inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride,
+                                 const ptrdiff_t width, const int height,
+                                 const __m128i coefficients,
+                                 int16_t** const wiener_buffer) {
+  __m128i filter[4];
+  filter[0] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x0200));
+  filter[1] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x0604));
+  filter[2] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x0204));
+  filter[3] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x8000));
+  int y = height;
   do {
-    int x = 0;
+    const __m128i s0 = LoadUnaligned16(src);
+    __m128i ss[4];
+    ss[0] = _mm_unpacklo_epi8(s0, s0);
+    ss[1] = _mm_unpackhi_epi8(s0, s0);
+    ptrdiff_t x = 0;
     do {
-      // Run the Wiener filter on four sets of source samples at a time:
-      //   src[x + 0] ... src[x + 6]
-      //   src[x + 1] ... src[x + 7]
-      //   src[x + 2] ... src[x + 8]
-      //   src[x + 3] ... src[x + 9]
-
-      // Read 10 bytes (from src[x] to src[x + 9]). We over-read 6 bytes but
-      // their results are discarded.
-      const __m128i v_src = LoadUnaligned16(&src[x]);
-      const __m128i v_src_dup_lo = _mm_unpacklo_epi8(v_src, v_src);
-      const __m128i v_src_dup_hi = _mm_unpackhi_epi8(v_src, v_src);
-      const __m128i v_src_10 = _mm_alignr_epi8(v_src_dup_hi, v_src_dup_lo, 1);
-      const __m128i v_src_32 = _mm_alignr_epi8(v_src_dup_hi, v_src_dup_lo, 5);
-      const __m128i v_src_54 = _mm_alignr_epi8(v_src_dup_hi, v_src_dup_lo, 9);
-      // Shift right by 12 bytes instead of 13 bytes so that src[x + 10] is not
-      // shifted into the low 8 bytes of v_src_66.
-      const __m128i v_src_66 = _mm_alignr_epi8(v_src_dup_hi, v_src_dup_lo, 12);
-      const __m128i v_madd_10 = _mm_maddubs_epi16(v_src_10, v_k1k0);
-      const __m128i v_madd_32 = _mm_maddubs_epi16(v_src_32, v_k3k2);
-      const __m128i v_madd_54 = _mm_maddubs_epi16(v_src_54, v_k5k4);
-      const __m128i v_madd_76 = _mm_maddubs_epi16(v_src_66, v_k7k6);
-      const __m128i v_sum_3210 = _mm_add_epi16(v_madd_10, v_madd_32);
-      const __m128i v_sum_7654 = _mm_add_epi16(v_madd_54, v_madd_76);
-      // The sum range here is [-128 * 255, 90 * 255].
-      const __m128i v_sum_76543210 = _mm_add_epi16(v_sum_7654, v_sum_3210);
-      const __m128i v_sum = _mm_add_epi16(v_sum_76543210, v_round_0);
-      const __m128i v_rounded_sum0 = _mm_sra_epi16(v_sum, v_round_0_shift);
-      // Add scaled down horizontal round here to prevent signed 16 bit
-      // outranging
-      const __m128i v_rounded_sum1 =
-          _mm_add_epi16(v_rounded_sum0, v_horizontal_rounding);
-      // Zero out the even bytes, calculate scaled down offset correction, and
-      // add to sum here to prevent signed 16 bit outranging.
-      // (src[3] * 128) >> kInterRoundBitsHorizontal
-      const __m128i v_src_3x128 =
-          _mm_sll_epi16(_mm_srli_epi16(v_src_32, 8), v_offset_shift);
-      const __m128i v_rounded_sum = _mm_add_epi16(v_rounded_sum1, v_src_3x128);
-      const __m128i v_a = _mm_max_epi16(v_rounded_sum, _mm_setzero_si128());
-      const __m128i v_b = _mm_min_epi16(v_a, v_limit);
-      StoreLo8(&wiener_buffer[x], v_b);
-      x += 4;
+      const __m128i s1 = LoadUnaligned16(src + x + 16);
+      ss[2] = _mm_unpacklo_epi8(s1, s1);
+      ss[3] = _mm_unpackhi_epi8(s1, s1);
+      WienerHorizontalTap7Kernel(ss + 0, filter, *wiener_buffer + x + 0);
+      WienerHorizontalTap7Kernel(ss + 1, filter, *wiener_buffer + x + 8);
+      ss[0] = ss[2];
+      ss[1] = ss[3];
+      x += 16;
     } while (x < width);
-    src += source_stride;
-    wiener_buffer += buffer_stride;
-  } while (++y < height + kSubPixelTaps - 2);
+    src += src_stride;
+    *wiener_buffer += width;
+  } while (--y != 0);
+}
 
-  wiener_buffer = buffer->wiener_buffer;
-  // vertical filtering.
-  PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical, filter);
-
-  const int vertical_rounding = -(1 << (8 + kInterRoundBitsVertical - 1));
-  const __m128i v_vertical_rounding =
-      _mm_shuffle_epi32(_mm_cvtsi32_si128(vertical_rounding), 0);
-  const __m128i v_offset_correction = _mm_set_epi16(0, 0, 0, 0, 128, 0, 0, 0);
-  const __m128i v_round_1 = _mm_shuffle_epi32(
-      _mm_cvtsi32_si128(1 << (kInterRoundBitsVertical - 1)), 0);
-  const __m128i v_round_1_shift = _mm_cvtsi32_si128(kInterRoundBitsVertical);
-  const __m128i v_vertical_filter0 = _mm_cvtepi8_epi16(LoadLo8(filter));
-  const __m128i v_vertical_filter =
-      _mm_add_epi16(v_vertical_filter0, v_offset_correction);
-  v_k1k0 = _mm_shuffle_epi32(v_vertical_filter, 0x0);
-  v_k3k2 = _mm_shuffle_epi32(v_vertical_filter, 0x55);
-  v_k5k4 = _mm_shuffle_epi32(v_vertical_filter, 0xaa);
-  v_k7k6 = _mm_shuffle_epi32(v_vertical_filter, 0xff);
-  y = 0;
+inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride,
+                                 const ptrdiff_t width, const int height,
+                                 const __m128i coefficients,
+                                 int16_t** const wiener_buffer) {
+  __m128i filter[3];
+  filter[0] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x0402));
+  filter[1] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x0406));
+  filter[2] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x8002));
+  int y = height;
   do {
-    int x = 0;
+    const __m128i s0 = LoadUnaligned16(src);
+    __m128i ss[4];
+    ss[0] = _mm_unpacklo_epi8(s0, s0);
+    ss[1] = _mm_unpackhi_epi8(s0, s0);
+    ptrdiff_t x = 0;
     do {
-      const __m128i v_wb_0 = LoadLo8(&wiener_buffer[0 * buffer_stride + x]);
-      const __m128i v_wb_1 = LoadLo8(&wiener_buffer[1 * buffer_stride + x]);
-      const __m128i v_wb_2 = LoadLo8(&wiener_buffer[2 * buffer_stride + x]);
-      const __m128i v_wb_3 = LoadLo8(&wiener_buffer[3 * buffer_stride + x]);
-      const __m128i v_wb_4 = LoadLo8(&wiener_buffer[4 * buffer_stride + x]);
-      const __m128i v_wb_5 = LoadLo8(&wiener_buffer[5 * buffer_stride + x]);
-      const __m128i v_wb_6 = LoadLo8(&wiener_buffer[6 * buffer_stride + x]);
-      const __m128i v_wb_10 = _mm_unpacklo_epi16(v_wb_0, v_wb_1);
-      const __m128i v_wb_32 = _mm_unpacklo_epi16(v_wb_2, v_wb_3);
-      const __m128i v_wb_54 = _mm_unpacklo_epi16(v_wb_4, v_wb_5);
-      const __m128i v_wb_76 = _mm_unpacklo_epi16(v_wb_6, _mm_setzero_si128());
-      const __m128i v_madd_10 = _mm_madd_epi16(v_wb_10, v_k1k0);
-      const __m128i v_madd_32 = _mm_madd_epi16(v_wb_32, v_k3k2);
-      const __m128i v_madd_54 = _mm_madd_epi16(v_wb_54, v_k5k4);
-      const __m128i v_madd_76 = _mm_madd_epi16(v_wb_76, v_k7k6);
-      const __m128i v_sum_3210 = _mm_add_epi32(v_madd_10, v_madd_32);
-      const __m128i v_sum_7654 = _mm_add_epi32(v_madd_54, v_madd_76);
-      const __m128i v_sum_76543210 = _mm_add_epi32(v_sum_7654, v_sum_3210);
-      const __m128i v_sum = _mm_add_epi32(v_sum_76543210, v_vertical_rounding);
-      const __m128i v_rounded_sum =
-          _mm_sra_epi32(_mm_add_epi32(v_sum, v_round_1), v_round_1_shift);
-      const __m128i v_a = _mm_packs_epi32(v_rounded_sum, v_rounded_sum);
-      const __m128i v_b = _mm_packus_epi16(v_a, v_a);
-      Store4(&dst[x], v_b);
-      x += 4;
+      const __m128i s1 = LoadUnaligned16(src + x + 16);
+      ss[2] = _mm_unpacklo_epi8(s1, s1);
+      ss[3] = _mm_unpackhi_epi8(s1, s1);
+      WienerHorizontalTap5Kernel(ss + 0, filter, *wiener_buffer + x + 0);
+      WienerHorizontalTap5Kernel(ss + 1, filter, *wiener_buffer + x + 8);
+      ss[0] = ss[2];
+      ss[1] = ss[3];
+      x += 16;
     } while (x < width);
-    dst += dest_stride;
-    wiener_buffer += buffer_stride;
-  } while (++y < height);
+    src += src_stride;
+    *wiener_buffer += width;
+  } while (--y != 0);
 }
 
-// Section 7.17.3.
-// a2: range [1, 256].
-// if (z >= 255)
-//   a2 = 256;
-// else if (z == 0)
-//   a2 = 1;
-// else
-//   a2 = ((z << kSgrProjSgrBits) + (z >> 1)) / (z + 1);
-constexpr int kXByXPlus1[256] = {
-    1,   128, 171, 192, 205, 213, 219, 224, 228, 230, 233, 235, 236, 238, 239,
-    240, 241, 242, 243, 243, 244, 244, 245, 245, 246, 246, 247, 247, 247, 247,
-    248, 248, 248, 248, 249, 249, 249, 249, 249, 250, 250, 250, 250, 250, 250,
-    250, 251, 251, 251, 251, 251, 251, 251, 251, 251, 251, 252, 252, 252, 252,
-    252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 252, 253, 253,
-    253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253,
-    253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 253, 254, 254, 254,
-    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
-    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
-    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
-    254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
-    254, 254, 254, 254, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
-    256};
-
-inline __m128i HorizontalAddVerticalSumsRadius1(const uint32_t* vert_sums) {
-  // Horizontally add vertical sums to get total box sum.
-  const __m128i v_sums_3210 = LoadUnaligned16(&vert_sums[0]);
-  const __m128i v_sums_7654 = LoadUnaligned16(&vert_sums[4]);
-  const __m128i v_sums_4321 = _mm_alignr_epi8(v_sums_7654, v_sums_3210, 4);
-  const __m128i v_sums_5432 = _mm_alignr_epi8(v_sums_7654, v_sums_3210, 8);
-  const __m128i v_s0 = _mm_add_epi32(v_sums_3210, v_sums_4321);
-  const __m128i v_s1 = _mm_add_epi32(v_s0, v_sums_5432);
-  return v_s1;
-}
-
-inline __m128i HorizontalAddVerticalSumsRadius2(const uint32_t* vert_sums) {
-  // Horizontally add vertical sums to get total box sum.
-  const __m128i v_sums_3210 = LoadUnaligned16(&vert_sums[0]);
-  const __m128i v_sums_7654 = LoadUnaligned16(&vert_sums[4]);
-  const __m128i v_sums_4321 = _mm_alignr_epi8(v_sums_7654, v_sums_3210, 4);
-  const __m128i v_sums_5432 = _mm_alignr_epi8(v_sums_7654, v_sums_3210, 8);
-  const __m128i v_sums_6543 = _mm_alignr_epi8(v_sums_7654, v_sums_3210, 12);
-  const __m128i v_s0 = _mm_add_epi32(v_sums_3210, v_sums_4321);
-  const __m128i v_s1 = _mm_add_epi32(v_s0, v_sums_5432);
-  const __m128i v_s2 = _mm_add_epi32(v_s1, v_sums_6543);
-  const __m128i v_s3 = _mm_add_epi32(v_s2, v_sums_7654);
-  return v_s3;
-}
-
-void BoxFilterPreProcessRadius1_SSE4_1(
-    const uint8_t* const src, ptrdiff_t stride, int width, int height,
-    uint32_t s, uint32_t* intermediate_result[2], ptrdiff_t array_stride,
-    uint32_t* vertical_sums, uint32_t* vertical_sum_of_squares) {
-  assert(s != 0);
-  const uint32_t n = 9;
-  const uint32_t one_over_n = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n;
-  const __m128i v_one_over_n =
-      _mm_shuffle_epi32(_mm_cvtsi32_si128(one_over_n), 0);
-  const __m128i v_sgrbits =
-      _mm_shuffle_epi32(_mm_cvtsi32_si128(1 << kSgrProjSgrBits), 0);
-
-#if LIBGAV1_MSAN
-  // Over-reads occur in the x loop, so set to a known value.
-  memset(&vertical_sums[width], 0, 8 * sizeof(vertical_sums[0]));
-  memset(&vertical_sum_of_squares[width], 0,
-         8 * sizeof(vertical_sum_of_squares[0]));
-#endif
-
-  // Calculate intermediate results, including one-pixel border, for example,
-  // if unit size is 64x64, we calculate 66x66 pixels.
-  int y = -1;
+inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride,
+                                 const ptrdiff_t width, const int height,
+                                 const __m128i coefficients,
+                                 int16_t** const wiener_buffer) {
+  __m128i filter[2];
+  filter[0] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x0604));
+  filter[1] = _mm_shuffle_epi8(coefficients, _mm_set1_epi16(0x8004));
+  int y = height;
   do {
-    const uint8_t* top_left = &src[(y - 1) * stride - 2];
-    // Calculate the box vertical sums for each x position.
-    int vsx = -2;
+    const __m128i s0 = LoadUnaligned16(src);
+    __m128i ss[4];
+    ss[0] = _mm_unpacklo_epi8(s0, s0);
+    ss[1] = _mm_unpackhi_epi8(s0, s0);
+    ptrdiff_t x = 0;
     do {
-      const __m128i v_box0 = _mm_cvtepu8_epi32(Load4(top_left));
-      const __m128i v_box1 = _mm_cvtepu8_epi32(Load4(top_left + stride));
-      const __m128i v_box2 = _mm_cvtepu8_epi32(Load4(top_left + stride * 2));
-      const __m128i v_sqr0 = _mm_mullo_epi32(v_box0, v_box0);
-      const __m128i v_sqr1 = _mm_mullo_epi32(v_box1, v_box1);
-      const __m128i v_sqr2 = _mm_mullo_epi32(v_box2, v_box2);
-      const __m128i v_a01 = _mm_add_epi32(v_sqr0, v_sqr1);
-      const __m128i v_a012 = _mm_add_epi32(v_a01, v_sqr2);
-      const __m128i v_b01 = _mm_add_epi32(v_box0, v_box1);
-      const __m128i v_b012 = _mm_add_epi32(v_b01, v_box2);
-      StoreUnaligned16(&vertical_sum_of_squares[vsx], v_a012);
-      StoreUnaligned16(&vertical_sums[vsx], v_b012);
-      top_left += 4;
-      vsx += 4;
-    } while (vsx <= width + 1);
-
-    int x = -1;
-    do {
-      const __m128i v_a =
-          HorizontalAddVerticalSumsRadius1(&vertical_sum_of_squares[x - 1]);
-      const __m128i v_b =
-          HorizontalAddVerticalSumsRadius1(&vertical_sums[x - 1]);
-      // -----------------------
-      // calc p, z, a2
-      // -----------------------
-      const __m128i v_255 = _mm_shuffle_epi32(_mm_cvtsi32_si128(255), 0);
-      const __m128i v_n = _mm_shuffle_epi32(_mm_cvtsi32_si128(n), 0);
-      const __m128i v_s = _mm_shuffle_epi32(_mm_cvtsi32_si128(s), 0);
-      const __m128i v_dxd = _mm_mullo_epi32(v_b, v_b);
-      const __m128i v_axn = _mm_mullo_epi32(v_a, v_n);
-      const __m128i v_p = _mm_sub_epi32(v_axn, v_dxd);
-      const __m128i v_z = _mm_min_epi32(
-          v_255, RightShiftWithRounding_U32(_mm_mullo_epi32(v_p, v_s),
-                                            kSgrProjScaleBits));
-      const __m128i v_a2 = _mm_set_epi32(kXByXPlus1[_mm_extract_epi32(v_z, 3)],
-                                         kXByXPlus1[_mm_extract_epi32(v_z, 2)],
-                                         kXByXPlus1[_mm_extract_epi32(v_z, 1)],
-                                         kXByXPlus1[_mm_extract_epi32(v_z, 0)]);
-      // -----------------------
-      // calc b2 and store
-      // -----------------------
-      const __m128i v_sgrbits_sub_a2 = _mm_sub_epi32(v_sgrbits, v_a2);
-      const __m128i v_b2 =
-          _mm_mullo_epi32(v_sgrbits_sub_a2, _mm_mullo_epi32(v_b, v_one_over_n));
-      StoreUnaligned16(&intermediate_result[0][x], v_a2);
-      StoreUnaligned16(
-          &intermediate_result[1][x],
-          RightShiftWithRounding_U32(v_b2, kSgrProjReciprocalBits));
-      x += 4;
-    } while (x <= width);
-    intermediate_result[0] += array_stride;
-    intermediate_result[1] += array_stride;
-  } while (++y <= height);
+      const __m128i s1 = LoadUnaligned16(src + x + 16);
+      ss[2] = _mm_unpacklo_epi8(s1, s1);
+      ss[3] = _mm_unpackhi_epi8(s1, s1);
+      WienerHorizontalTap3Kernel(ss + 0, filter, *wiener_buffer + x + 0);
+      WienerHorizontalTap3Kernel(ss + 1, filter, *wiener_buffer + x + 8);
+      ss[0] = ss[2];
+      ss[1] = ss[3];
+      x += 16;
+    } while (x < width);
+    src += src_stride;
+    *wiener_buffer += width;
+  } while (--y != 0);
 }
 
-void BoxFilterPreProcessRadius2_SSE4_1(
-    const uint8_t* const src, ptrdiff_t stride, int width, int height,
-    uint32_t s, uint32_t* intermediate_result[2], ptrdiff_t array_stride,
-    uint32_t* vertical_sums, uint32_t* vertical_sum_of_squares) {
-  assert(s != 0);
-  const uint32_t n = 25;
-  const uint32_t one_over_n = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n;
-  const __m128i v_one_over_n =
-      _mm_shuffle_epi32(_mm_cvtsi32_si128(one_over_n), 0);
-  const __m128i v_sgrbits =
-      _mm_shuffle_epi32(_mm_cvtsi32_si128(1 << kSgrProjSgrBits), 0);
-
-  // Calculate intermediate results, including one-pixel border, for example,
-  // if unit size is 64x64, we calculate 66x66 pixels.
-  int y = -1;
+inline void WienerHorizontalTap1(const uint8_t* src, const ptrdiff_t src_stride,
+                                 const ptrdiff_t width, const int height,
+                                 int16_t** const wiener_buffer) {
+  int y = height;
   do {
-    // Calculate the box vertical sums for each x position.
-    const uint8_t* top_left = &src[(y - 2) * stride - 3];
-    int vsx = -3;
+    ptrdiff_t x = 0;
     do {
-      const __m128i v_box0 = _mm_cvtepu8_epi32(Load4(top_left));
-      const __m128i v_box1 = _mm_cvtepu8_epi32(Load4(top_left + stride));
-      const __m128i v_box2 = _mm_cvtepu8_epi32(Load4(top_left + stride * 2));
-      const __m128i v_box3 = _mm_cvtepu8_epi32(Load4(top_left + stride * 3));
-      const __m128i v_box4 = _mm_cvtepu8_epi32(Load4(top_left + stride * 4));
-      const __m128i v_sqr0 = _mm_mullo_epi32(v_box0, v_box0);
-      const __m128i v_sqr1 = _mm_mullo_epi32(v_box1, v_box1);
-      const __m128i v_sqr2 = _mm_mullo_epi32(v_box2, v_box2);
-      const __m128i v_sqr3 = _mm_mullo_epi32(v_box3, v_box3);
-      const __m128i v_sqr4 = _mm_mullo_epi32(v_box4, v_box4);
-      const __m128i v_a01 = _mm_add_epi32(v_sqr0, v_sqr1);
-      const __m128i v_a012 = _mm_add_epi32(v_a01, v_sqr2);
-      const __m128i v_a0123 = _mm_add_epi32(v_a012, v_sqr3);
-      const __m128i v_a01234 = _mm_add_epi32(v_a0123, v_sqr4);
-      const __m128i v_b01 = _mm_add_epi32(v_box0, v_box1);
-      const __m128i v_b012 = _mm_add_epi32(v_b01, v_box2);
-      const __m128i v_b0123 = _mm_add_epi32(v_b012, v_box3);
-      const __m128i v_b01234 = _mm_add_epi32(v_b0123, v_box4);
-      StoreUnaligned16(&vertical_sum_of_squares[vsx], v_a01234);
-      StoreUnaligned16(&vertical_sums[vsx], v_b01234);
-      top_left += 4;
-      vsx += 4;
-    } while (vsx <= width + 2);
-
-    int x = -1;
-    do {
-      const __m128i v_a =
-          HorizontalAddVerticalSumsRadius2(&vertical_sum_of_squares[x - 2]);
-      const __m128i v_b =
-          HorizontalAddVerticalSumsRadius2(&vertical_sums[x - 2]);
-      // -----------------------
-      // calc p, z, a2
-      // -----------------------
-      const __m128i v_255 = _mm_shuffle_epi32(_mm_cvtsi32_si128(255), 0);
-      const __m128i v_n = _mm_shuffle_epi32(_mm_cvtsi32_si128(n), 0);
-      const __m128i v_s = _mm_shuffle_epi32(_mm_cvtsi32_si128(s), 0);
-      const __m128i v_dxd = _mm_mullo_epi32(v_b, v_b);
-      const __m128i v_axn = _mm_mullo_epi32(v_a, v_n);
-      const __m128i v_p = _mm_sub_epi32(v_axn, v_dxd);
-      const __m128i v_z = _mm_min_epi32(
-          v_255, RightShiftWithRounding_U32(_mm_mullo_epi32(v_p, v_s),
-                                            kSgrProjScaleBits));
-      const __m128i v_a2 = _mm_set_epi32(kXByXPlus1[_mm_extract_epi32(v_z, 3)],
-                                         kXByXPlus1[_mm_extract_epi32(v_z, 2)],
-                                         kXByXPlus1[_mm_extract_epi32(v_z, 1)],
-                                         kXByXPlus1[_mm_extract_epi32(v_z, 0)]);
-      // -----------------------
-      // calc b2 and store
-      // -----------------------
-      const __m128i v_sgrbits_sub_a2 = _mm_sub_epi32(v_sgrbits, v_a2);
-      const __m128i v_b2 =
-          _mm_mullo_epi32(v_sgrbits_sub_a2, _mm_mullo_epi32(v_b, v_one_over_n));
-      StoreUnaligned16(&intermediate_result[0][x], v_a2);
-      StoreUnaligned16(
-          &intermediate_result[1][x],
-          RightShiftWithRounding_U32(v_b2, kSgrProjReciprocalBits));
-      x += 4;
-    } while (x <= width);
-    intermediate_result[0] += 2 * array_stride;
-    intermediate_result[1] += 2 * array_stride;
-    y += 2;
-  } while (y <= height);
+      const __m128i s = LoadUnaligned16(src + x);
+      const __m128i s0 = _mm_unpacklo_epi8(s, _mm_setzero_si128());
+      const __m128i s1 = _mm_unpackhi_epi8(s, _mm_setzero_si128());
+      const __m128i d0 = _mm_slli_epi16(s0, 4);
+      const __m128i d1 = _mm_slli_epi16(s1, 4);
+      StoreAligned16(*wiener_buffer + x + 0, d0);
+      StoreAligned16(*wiener_buffer + x + 8, d1);
+      x += 16;
+    } while (x < width);
+    src += src_stride;
+    *wiener_buffer += width;
+  } while (--y != 0);
 }
 
-void BoxFilterPreProcess_SSE4_1(const RestorationUnitInfo& restoration_info,
-                                const uint8_t* const src, ptrdiff_t stride,
-                                int width, int height, int pass,
-                                RestorationBuffer* const buffer) {
-  uint32_t vertical_sums_buf[kRestorationProcessingUnitSize +
-                             2 * kRestorationBorder + kRestorationPadding];
-  uint32_t vertical_sum_of_squares_buf[kRestorationProcessingUnitSize +
-                                       2 * kRestorationBorder +
-                                       kRestorationPadding];
-  uint32_t* vertical_sums = &vertical_sums_buf[4];
-  uint32_t* vertical_sum_of_squares = &vertical_sum_of_squares_buf[4];
-  const ptrdiff_t array_stride = buffer->box_filter_process_intermediate_stride;
-  // The size of the intermediate result buffer is the size of the filter area
-  // plus horizontal (3) and vertical (3) padding. The processing start point
-  // is the filter area start point -1 row and -1 column. Therefore we need to
-  // set offset and use the intermediate_result as the start point for
-  // processing.
-  const ptrdiff_t intermediate_buffer_offset =
-      kRestorationBorder * array_stride + kRestorationBorder;
-  uint32_t* intermediate_result[2] = {
-      buffer->box_filter_process_intermediate[0] + intermediate_buffer_offset -
-          array_stride,
-      buffer->box_filter_process_intermediate[1] + intermediate_buffer_offset -
-          array_stride};
-  const int sgr_proj_index = restoration_info.sgr_proj_info.index;
-  if (pass == 0) {
-    assert(kSgrProjParams[sgr_proj_index][0] == 2);
-    BoxFilterPreProcessRadius2_SSE4_1(src, stride, width, height,
-                                      kSgrScaleParameter[sgr_proj_index][0],
-                                      intermediate_result, array_stride,
-                                      vertical_sums, vertical_sum_of_squares);
-  } else {
-    assert(kSgrProjParams[sgr_proj_index][2] == 1);
-    BoxFilterPreProcessRadius1_SSE4_1(src, stride, width, height,
-                                      kSgrScaleParameter[sgr_proj_index][1],
-                                      intermediate_result, array_stride,
-                                      vertical_sums, vertical_sum_of_squares);
+inline __m128i WienerVertical7(const __m128i a[2], const __m128i filter[2]) {
+  const __m128i round = _mm_set1_epi32(1 << (kInterRoundBitsVertical - 1));
+  const __m128i madd0 = _mm_madd_epi16(a[0], filter[0]);
+  const __m128i madd1 = _mm_madd_epi16(a[1], filter[1]);
+  const __m128i sum0 = _mm_add_epi32(round, madd0);
+  const __m128i sum1 = _mm_add_epi32(sum0, madd1);
+  return _mm_srai_epi32(sum1, kInterRoundBitsVertical);
+}
+
+inline __m128i WienerVertical5(const __m128i a[2], const __m128i filter[2]) {
+  const __m128i madd0 = _mm_madd_epi16(a[0], filter[0]);
+  const __m128i madd1 = _mm_madd_epi16(a[1], filter[1]);
+  const __m128i sum = _mm_add_epi32(madd0, madd1);
+  return _mm_srai_epi32(sum, kInterRoundBitsVertical);
+}
+
+inline __m128i WienerVertical3(const __m128i a, const __m128i filter) {
+  const __m128i round = _mm_set1_epi32(1 << (kInterRoundBitsVertical - 1));
+  const __m128i madd = _mm_madd_epi16(a, filter);
+  const __m128i sum = _mm_add_epi32(round, madd);
+  return _mm_srai_epi32(sum, kInterRoundBitsVertical);
+}
+
+inline __m128i WienerVerticalFilter7(const __m128i a[7],
+                                     const __m128i filter[2]) {
+  __m128i b[2];
+  const __m128i a06 = _mm_add_epi16(a[0], a[6]);
+  const __m128i a15 = _mm_add_epi16(a[1], a[5]);
+  const __m128i a24 = _mm_add_epi16(a[2], a[4]);
+  b[0] = _mm_unpacklo_epi16(a06, a15);
+  b[1] = _mm_unpacklo_epi16(a24, a[3]);
+  const __m128i sum0 = WienerVertical7(b, filter);
+  b[0] = _mm_unpackhi_epi16(a06, a15);
+  b[1] = _mm_unpackhi_epi16(a24, a[3]);
+  const __m128i sum1 = WienerVertical7(b, filter);
+  return _mm_packs_epi32(sum0, sum1);
+}
+
+inline __m128i WienerVerticalFilter5(const __m128i a[5],
+                                     const __m128i filter[2]) {
+  const __m128i round = _mm_set1_epi16(1 << (kInterRoundBitsVertical - 1));
+  __m128i b[2];
+  const __m128i a04 = _mm_add_epi16(a[0], a[4]);
+  const __m128i a13 = _mm_add_epi16(a[1], a[3]);
+  b[0] = _mm_unpacklo_epi16(a04, a13);
+  b[1] = _mm_unpacklo_epi16(a[2], round);
+  const __m128i sum0 = WienerVertical5(b, filter);
+  b[0] = _mm_unpackhi_epi16(a04, a13);
+  b[1] = _mm_unpackhi_epi16(a[2], round);
+  const __m128i sum1 = WienerVertical5(b, filter);
+  return _mm_packs_epi32(sum0, sum1);
+}
+
+inline __m128i WienerVerticalFilter3(const __m128i a[3], const __m128i filter) {
+  __m128i b;
+  const __m128i a02 = _mm_add_epi16(a[0], a[2]);
+  b = _mm_unpacklo_epi16(a02, a[1]);
+  const __m128i sum0 = WienerVertical3(b, filter);
+  b = _mm_unpackhi_epi16(a02, a[1]);
+  const __m128i sum1 = WienerVertical3(b, filter);
+  return _mm_packs_epi32(sum0, sum1);
+}
+
+inline __m128i WienerVerticalTap7Kernel(const int16_t* wiener_buffer,
+                                        const ptrdiff_t wiener_stride,
+                                        const __m128i filter[2], __m128i a[7]) {
+  a[0] = LoadAligned16(wiener_buffer + 0 * wiener_stride);
+  a[1] = LoadAligned16(wiener_buffer + 1 * wiener_stride);
+  a[2] = LoadAligned16(wiener_buffer + 2 * wiener_stride);
+  a[3] = LoadAligned16(wiener_buffer + 3 * wiener_stride);
+  a[4] = LoadAligned16(wiener_buffer + 4 * wiener_stride);
+  a[5] = LoadAligned16(wiener_buffer + 5 * wiener_stride);
+  a[6] = LoadAligned16(wiener_buffer + 6 * wiener_stride);
+  return WienerVerticalFilter7(a, filter);
+}
+
+inline __m128i WienerVerticalTap5Kernel(const int16_t* wiener_buffer,
+                                        const ptrdiff_t wiener_stride,
+                                        const __m128i filter[2], __m128i a[5]) {
+  a[0] = LoadAligned16(wiener_buffer + 0 * wiener_stride);
+  a[1] = LoadAligned16(wiener_buffer + 1 * wiener_stride);
+  a[2] = LoadAligned16(wiener_buffer + 2 * wiener_stride);
+  a[3] = LoadAligned16(wiener_buffer + 3 * wiener_stride);
+  a[4] = LoadAligned16(wiener_buffer + 4 * wiener_stride);
+  return WienerVerticalFilter5(a, filter);
+}
+
+inline __m128i WienerVerticalTap3Kernel(const int16_t* wiener_buffer,
+                                        const ptrdiff_t wiener_stride,
+                                        const __m128i filter, __m128i a[3]) {
+  a[0] = LoadAligned16(wiener_buffer + 0 * wiener_stride);
+  a[1] = LoadAligned16(wiener_buffer + 1 * wiener_stride);
+  a[2] = LoadAligned16(wiener_buffer + 2 * wiener_stride);
+  return WienerVerticalFilter3(a, filter);
+}
+
+inline void WienerVerticalTap7Kernel2(const int16_t* wiener_buffer,
+                                      const ptrdiff_t wiener_stride,
+                                      const __m128i filter[2], __m128i d[2]) {
+  __m128i a[8];
+  d[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a);
+  a[7] = LoadAligned16(wiener_buffer + 7 * wiener_stride);
+  d[1] = WienerVerticalFilter7(a + 1, filter);
+}
+
+inline void WienerVerticalTap5Kernel2(const int16_t* wiener_buffer,
+                                      const ptrdiff_t wiener_stride,
+                                      const __m128i filter[2], __m128i d[2]) {
+  __m128i a[6];
+  d[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a);
+  a[5] = LoadAligned16(wiener_buffer + 5 * wiener_stride);
+  d[1] = WienerVerticalFilter5(a + 1, filter);
+}
+
+inline void WienerVerticalTap3Kernel2(const int16_t* wiener_buffer,
+                                      const ptrdiff_t wiener_stride,
+                                      const __m128i filter, __m128i d[2]) {
+  __m128i a[4];
+  d[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a);
+  a[3] = LoadAligned16(wiener_buffer + 3 * wiener_stride);
+  d[1] = WienerVerticalFilter3(a + 1, filter);
+}
+
+inline void WienerVerticalTap7(const int16_t* wiener_buffer,
+                               const ptrdiff_t width, const int height,
+                               const int16_t coefficients[4], uint8_t* dst,
+                               const ptrdiff_t dst_stride) {
+  const __m128i c = LoadLo8(coefficients);
+  __m128i filter[2];
+  filter[0] = _mm_shuffle_epi32(c, 0x0);
+  filter[1] = _mm_shuffle_epi32(c, 0x55);
+  for (int y = height >> 1; y > 0; --y) {
+    ptrdiff_t x = 0;
+    do {
+      __m128i d[2][2];
+      WienerVerticalTap7Kernel2(wiener_buffer + x + 0, width, filter, d[0]);
+      WienerVerticalTap7Kernel2(wiener_buffer + x + 8, width, filter, d[1]);
+      StoreAligned16(dst + x, _mm_packus_epi16(d[0][0], d[1][0]));
+      StoreAligned16(dst + dst_stride + x, _mm_packus_epi16(d[0][1], d[1][1]));
+      x += 16;
+    } while (x < width);
+    dst += 2 * dst_stride;
+    wiener_buffer += 2 * width;
+  }
+
+  if ((height & 1) != 0) {
+    ptrdiff_t x = 0;
+    do {
+      __m128i a[7];
+      const __m128i d0 =
+          WienerVerticalTap7Kernel(wiener_buffer + x + 0, width, filter, a);
+      const __m128i d1 =
+          WienerVerticalTap7Kernel(wiener_buffer + x + 8, width, filter, a);
+      StoreAligned16(dst + x, _mm_packus_epi16(d0, d1));
+      x += 16;
+    } while (x < width);
   }
 }
 
-inline __m128i Sum565Row(const __m128i v_DBCA, const __m128i v_XXFE) {
-  __m128i v_sum = v_DBCA;
-  const __m128i v_EDCB = _mm_alignr_epi8(v_XXFE, v_DBCA, 4);
-  v_sum = _mm_add_epi32(v_sum, v_EDCB);
-  const __m128i v_FEDC = _mm_alignr_epi8(v_XXFE, v_DBCA, 8);
-  v_sum = _mm_add_epi32(v_sum, v_FEDC);
-  //   D C B A x4
-  // + E D C B x4
-  // + F E D C x4
-  v_sum = _mm_slli_epi32(v_sum, 2);
-  // + D C B A
-  v_sum = _mm_add_epi32(v_sum, v_DBCA);  // 5
-  // + E D C B x2
-  v_sum = _mm_add_epi32(v_sum, _mm_slli_epi32(v_EDCB, 1));  // 6
-  // + F E D C
-  return _mm_add_epi32(v_sum, v_FEDC);  // 5
+inline void WienerVerticalTap5(const int16_t* wiener_buffer,
+                               const ptrdiff_t width, const int height,
+                               const int16_t coefficients[3], uint8_t* dst,
+                               const ptrdiff_t dst_stride) {
+  const __m128i c = Load4(coefficients);
+  __m128i filter[2];
+  filter[0] = _mm_shuffle_epi32(c, 0);
+  filter[1] =
+      _mm_set1_epi32((1 << 16) | static_cast<uint16_t>(coefficients[2]));
+  for (int y = height >> 1; y > 0; --y) {
+    ptrdiff_t x = 0;
+    do {
+      __m128i d[2][2];
+      WienerVerticalTap5Kernel2(wiener_buffer + x + 0, width, filter, d[0]);
+      WienerVerticalTap5Kernel2(wiener_buffer + x + 8, width, filter, d[1]);
+      StoreAligned16(dst + x, _mm_packus_epi16(d[0][0], d[1][0]));
+      StoreAligned16(dst + dst_stride + x, _mm_packus_epi16(d[0][1], d[1][1]));
+      x += 16;
+    } while (x < width);
+    dst += 2 * dst_stride;
+    wiener_buffer += 2 * width;
+  }
+
+  if ((height & 1) != 0) {
+    ptrdiff_t x = 0;
+    do {
+      __m128i a[5];
+      const __m128i d0 =
+          WienerVerticalTap5Kernel(wiener_buffer + x + 0, width, filter, a);
+      const __m128i d1 =
+          WienerVerticalTap5Kernel(wiener_buffer + x + 8, width, filter, a);
+      StoreAligned16(dst + x, _mm_packus_epi16(d0, d1));
+      x += 16;
+    } while (x < width);
+  }
 }
 
-inline __m128i Process3x3Block_565_Odd(const uint32_t* src, ptrdiff_t stride) {
-  // 0 0 0
-  // 5 6 5
-  // 0 0 0
-  const uint32_t* top_left = src - 1;
-  const __m128i v_src1_lo = LoadUnaligned16(top_left + stride);
-  const __m128i v_src1_hi = LoadLo8(top_left + stride + 4);
-  return Sum565Row(v_src1_lo, v_src1_hi);
+inline void WienerVerticalTap3(const int16_t* wiener_buffer,
+                               const ptrdiff_t width, const int height,
+                               const int16_t coefficients[2], uint8_t* dst,
+                               const ptrdiff_t dst_stride) {
+  const __m128i filter =
+      _mm_set1_epi32(*reinterpret_cast<const int32_t*>(coefficients));
+  for (int y = height >> 1; y > 0; --y) {
+    ptrdiff_t x = 0;
+    do {
+      __m128i d[2][2];
+      WienerVerticalTap3Kernel2(wiener_buffer + x + 0, width, filter, d[0]);
+      WienerVerticalTap3Kernel2(wiener_buffer + x + 8, width, filter, d[1]);
+      StoreAligned16(dst + x, _mm_packus_epi16(d[0][0], d[1][0]));
+      StoreAligned16(dst + dst_stride + x, _mm_packus_epi16(d[0][1], d[1][1]));
+      x += 16;
+    } while (x < width);
+    dst += 2 * dst_stride;
+    wiener_buffer += 2 * width;
+  }
+
+  if ((height & 1) != 0) {
+    ptrdiff_t x = 0;
+    do {
+      __m128i a[3];
+      const __m128i d0 =
+          WienerVerticalTap3Kernel(wiener_buffer + x + 0, width, filter, a);
+      const __m128i d1 =
+          WienerVerticalTap3Kernel(wiener_buffer + x + 8, width, filter, a);
+      StoreAligned16(dst + x, _mm_packus_epi16(d0, d1));
+      x += 16;
+    } while (x < width);
+  }
 }
 
-inline __m128i Process3x3Block_565_Even(const uint32_t* src, ptrdiff_t stride) {
-  // 5 6 5
-  // 0 0 0
-  // 5 6 5
-  const uint32_t* top_left = src - 1;
-  const __m128i v_src0_lo = LoadUnaligned16(top_left);
-  const __m128i v_src0_hi = LoadLo8(top_left + 4);
-  const __m128i v_src2_lo = LoadUnaligned16(top_left + stride * 2);
-  const __m128i v_src2_hi = LoadLo8(top_left + stride * 2 + 4);
-  const __m128i v_a0 = Sum565Row(v_src0_lo, v_src0_hi);
-  const __m128i v_a2 = Sum565Row(v_src2_lo, v_src2_hi);
-  return _mm_add_epi32(v_a0, v_a2);
+inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer,
+                                     uint8_t* const dst) {
+  const __m128i a0 = LoadAligned16(wiener_buffer + 0);
+  const __m128i a1 = LoadAligned16(wiener_buffer + 8);
+  const __m128i b0 = _mm_add_epi16(a0, _mm_set1_epi16(8));
+  const __m128i b1 = _mm_add_epi16(a1, _mm_set1_epi16(8));
+  const __m128i c0 = _mm_srai_epi16(b0, 4);
+  const __m128i c1 = _mm_srai_epi16(b1, 4);
+  const __m128i d = _mm_packus_epi16(c0, c1);
+  StoreAligned16(dst, d);
 }
 
-inline __m128i Sum343Row(const __m128i v_DBCA, const __m128i v_XXFE) {
-  __m128i v_sum = v_DBCA;
-  const __m128i v_EDCB = _mm_alignr_epi8(v_XXFE, v_DBCA, 4);
-  v_sum = _mm_add_epi32(v_sum, v_EDCB);
-  const __m128i v_FEDC = _mm_alignr_epi8(v_XXFE, v_DBCA, 8);
-  v_sum = _mm_add_epi32(v_sum, v_FEDC);
-  //   D C B A x4
-  // + E D C B x4
-  // + F E D C x4
-  v_sum = _mm_slli_epi32(v_sum, 2);  // 4
-  // - D C B A
-  v_sum = _mm_sub_epi32(v_sum, v_DBCA);  // 3
-  // - F E D C
-  return _mm_sub_epi32(v_sum, v_FEDC);  // 3
+inline void WienerVerticalTap1(const int16_t* wiener_buffer,
+                               const ptrdiff_t width, const int height,
+                               uint8_t* dst, const ptrdiff_t dst_stride) {
+  for (int y = height >> 1; y > 0; --y) {
+    ptrdiff_t x = 0;
+    do {
+      WienerVerticalTap1Kernel(wiener_buffer + x, dst + x);
+      WienerVerticalTap1Kernel(wiener_buffer + width + x, dst + dst_stride + x);
+      x += 16;
+    } while (x < width);
+    dst += 2 * dst_stride;
+    wiener_buffer += 2 * width;
+  }
+
+  if ((height & 1) != 0) {
+    ptrdiff_t x = 0;
+    do {
+      WienerVerticalTap1Kernel(wiener_buffer + x, dst + x);
+      x += 16;
+    } while (x < width);
+  }
 }
 
-inline __m128i Sum444Row(const __m128i v_DBCA, const __m128i v_XXFE) {
-  __m128i v_sum = v_DBCA;
-  const __m128i v_EDCB = _mm_alignr_epi8(v_XXFE, v_DBCA, 4);
-  v_sum = _mm_add_epi32(v_sum, v_EDCB);
-  const __m128i v_FEDC = _mm_alignr_epi8(v_XXFE, v_DBCA, 8);
-  v_sum = _mm_add_epi32(v_sum, v_FEDC);
-  //   D C B A x4
-  // + E D C B x4
-  // + F E D C x4
-  return _mm_slli_epi32(v_sum, 2);  // 4
+void WienerFilter_SSE4_1(const void* const source, void* const dest,
+                         const RestorationUnitInfo& restoration_info,
+                         const ptrdiff_t source_stride,
+                         const ptrdiff_t dest_stride, const int width,
+                         const int height, RestorationBuffer* const buffer) {
+  constexpr int kCenterTap = kWienerFilterTaps / 2;
+  const int16_t* const number_leading_zero_coefficients =
+      restoration_info.wiener_info.number_leading_zero_coefficients;
+  const int number_rows_to_skip = std::max(
+      static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]),
+      1);
+  const ptrdiff_t wiener_stride = Align(width, 16);
+  int16_t* const wiener_buffer_vertical = buffer->wiener_buffer;
+  // The values are saturated to 13 bits before storing.
+  int16_t* wiener_buffer_horizontal =
+      wiener_buffer_vertical + number_rows_to_skip * wiener_stride;
+
+  // horizontal filtering.
+  // Over-reads up to 15 - |kRestorationHorizontalBorder| values.
+  const int height_horizontal =
+      height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip;
+  const auto* const src = static_cast<const uint8_t*>(source) -
+                          (kCenterTap - number_rows_to_skip) * source_stride;
+  const __m128i c =
+      LoadLo8(restoration_info.wiener_info.filter[WienerInfo::kHorizontal]);
+  // In order to keep the horizontal pass intermediate values within 16 bits we
+  // offset |filter[3]| by 128. The 128 offset will be added back in the loop.
+  const __m128i coefficients_horizontal =
+      _mm_sub_epi16(c, _mm_setr_epi16(0, 0, 0, 128, 0, 0, 0, 0));
+  if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) {
+    WienerHorizontalTap7(src - 3, source_stride, wiener_stride,
+                         height_horizontal, coefficients_horizontal,
+                         &wiener_buffer_horizontal);
+  } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) {
+    WienerHorizontalTap5(src - 2, source_stride, wiener_stride,
+                         height_horizontal, coefficients_horizontal,
+                         &wiener_buffer_horizontal);
+  } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) {
+    // The maximum over-reads happen here.
+    WienerHorizontalTap3(src - 1, source_stride, wiener_stride,
+                         height_horizontal, coefficients_horizontal,
+                         &wiener_buffer_horizontal);
+  } else {
+    assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3);
+    WienerHorizontalTap1(src, source_stride, wiener_stride, height_horizontal,
+                         &wiener_buffer_horizontal);
+  }
+
+  // vertical filtering.
+  // Over-writes up to 15 values.
+  const int16_t* const filter_vertical =
+      restoration_info.wiener_info.filter[WienerInfo::kVertical];
+  auto* dst = static_cast<uint8_t*>(dest);
+  if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) {
+    // Because the top row of |source| is a duplicate of the second row, and the
+    // bottom row of |source| is a duplicate of its above row, we can duplicate
+    // the top and bottom row of |wiener_buffer| accordingly.
+    memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride,
+           sizeof(*wiener_buffer_horizontal) * wiener_stride);
+    memcpy(buffer->wiener_buffer, buffer->wiener_buffer + wiener_stride,
+           sizeof(*buffer->wiener_buffer) * wiener_stride);
+    WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height,
+                       filter_vertical, dst, dest_stride);
+  } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) {
+    WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride,
+                       height, filter_vertical + 1, dst, dest_stride);
+  } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) {
+    WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride,
+                       wiener_stride, height, filter_vertical + 2, dst,
+                       dest_stride);
+  } else {
+    assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3);
+    WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride,
+                       wiener_stride, height, dst, dest_stride);
+  }
 }
 
-inline __m128i Process3x3Block_343(const uint32_t* src, ptrdiff_t stride) {
-  const uint32_t* top_left = src - 1;
-  const __m128i v_ir0_lo = LoadUnaligned16(top_left);
-  const __m128i v_ir0_hi = LoadLo8(top_left + 4);
-  const __m128i v_ir1_lo = LoadUnaligned16(top_left + stride);
-  const __m128i v_ir1_hi = LoadLo8(top_left + stride + 4);
-  const __m128i v_ir2_lo = LoadUnaligned16(top_left + stride * 2);
-  const __m128i v_ir2_hi = LoadLo8(top_left + stride * 2 + 4);
-  const __m128i v_a0 = Sum343Row(v_ir0_lo, v_ir0_hi);
-  const __m128i v_a1 = Sum444Row(v_ir1_lo, v_ir1_hi);
-  const __m128i v_a2 = Sum343Row(v_ir2_lo, v_ir2_hi);
-  return _mm_add_epi32(v_a0, _mm_add_epi32(v_a1, v_a2));
+//------------------------------------------------------------------------------
+// SGR
+
+// Don't use _mm_cvtepu8_epi16() or _mm_cvtepu16_epi32() in the following
+// functions. Some compilers may generate super inefficient code and the whole
+// decoder could be 15% slower.
+
+inline __m128i VaddlLo8(const __m128i src0, const __m128i src1) {
+  const __m128i s0 = _mm_unpacklo_epi8(src0, _mm_setzero_si128());
+  const __m128i s1 = _mm_unpacklo_epi8(src1, _mm_setzero_si128());
+  return _mm_add_epi16(s0, s1);
 }
 
-void BoxFilterProcess_SSE4_1(const RestorationUnitInfo& restoration_info,
-                             const uint8_t* src, ptrdiff_t stride, int width,
-                             int height, RestorationBuffer* const buffer) {
-  const int sgr_proj_index = restoration_info.sgr_proj_info.index;
-  for (int pass = 0; pass < 2; ++pass) {
-    const uint8_t radius = kSgrProjParams[sgr_proj_index][pass * 2];
-    const uint8_t* src_ptr = src;
-    if (radius == 0) continue;
+inline __m128i VaddlHi8(const __m128i src0, const __m128i src1) {
+  const __m128i s0 = _mm_unpackhi_epi8(src0, _mm_setzero_si128());
+  const __m128i s1 = _mm_unpackhi_epi8(src1, _mm_setzero_si128());
+  return _mm_add_epi16(s0, s1);
+}
 
-    BoxFilterPreProcess_SSE4_1(restoration_info, src_ptr, stride, width, height,
-                               pass, buffer);
+inline __m128i VaddlLo16(const __m128i src0, const __m128i src1) {
+  const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128());
+  const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128());
+  return _mm_add_epi32(s0, s1);
+}
 
-    int* filtered_output = buffer->box_filter_process_output[pass];
-    const ptrdiff_t filtered_output_stride =
-        buffer->box_filter_process_output_stride;
-    const ptrdiff_t intermediate_stride =
-        buffer->box_filter_process_intermediate_stride;
-    // Set intermediate buffer start point to the actual start point of
-    // filtering.
-    const ptrdiff_t intermediate_buffer_offset =
-        kRestorationBorder * intermediate_stride + kRestorationBorder;
+inline __m128i VaddlHi16(const __m128i src0, const __m128i src1) {
+  const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128());
+  const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128());
+  return _mm_add_epi32(s0, s1);
+}
 
-    if (pass == 0) {
-      int y = 0;
-      do {
-        const int shift = ((y & 1) != 0) ? 4 : 5;
-        uint32_t* const array_start[2] = {
-            buffer->box_filter_process_intermediate[0] +
-                intermediate_buffer_offset + y * intermediate_stride,
-            buffer->box_filter_process_intermediate[1] +
-                intermediate_buffer_offset + y * intermediate_stride};
-        uint32_t* intermediate_result2[2] = {
-            array_start[0] - intermediate_stride,
-            array_start[1] - intermediate_stride};
-        if ((y & 1) == 0) {  // even row
-          int x = 0;
-          do {
-            // 5 6 5
-            // 0 0 0
-            // 5 6 5
-            const __m128i v_A = Process3x3Block_565_Even(
-                &intermediate_result2[0][x], intermediate_stride);
-            const __m128i v_B = Process3x3Block_565_Even(
-                &intermediate_result2[1][x], intermediate_stride);
-            const __m128i v_src = _mm_cvtepu8_epi32(Load4(src_ptr + x));
-            const __m128i v_v0 = _mm_mullo_epi32(v_A, v_src);
-            const __m128i v_v = _mm_add_epi32(v_v0, v_B);
-            const __m128i v_filtered = RightShiftWithRounding_U32(
-                v_v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
+inline __m128i VaddwLo8(const __m128i src0, const __m128i src1) {
+  const __m128i s1 = _mm_unpacklo_epi8(src1, _mm_setzero_si128());
+  return _mm_add_epi16(src0, s1);
+}
 
-            StoreUnaligned16(&filtered_output[x], v_filtered);
-            x += 4;
-          } while (x < width);
-        } else {
-          int x = 0;
-          do {
-            // 0 0 0
-            // 5 6 5
-            // 0 0 0
-            const __m128i v_A = Process3x3Block_565_Odd(
-                &intermediate_result2[0][x], intermediate_stride);
-            const __m128i v_B = Process3x3Block_565_Odd(
-                &intermediate_result2[1][x], intermediate_stride);
-            const __m128i v_src = _mm_cvtepu8_epi32(Load4(src_ptr + x));
-            const __m128i v_v0 = _mm_mullo_epi32(v_A, v_src);
-            const __m128i v_v = _mm_add_epi32(v_v0, v_B);
-            const __m128i v_filtered = RightShiftWithRounding_U32(
-                v_v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
+inline __m128i VaddwHi8(const __m128i src0, const __m128i src1) {
+  const __m128i s1 = _mm_unpackhi_epi8(src1, _mm_setzero_si128());
+  return _mm_add_epi16(src0, s1);
+}
 
-            StoreUnaligned16(&filtered_output[x], v_filtered);
-            x += 4;
-          } while (x < width);
-        }
-        src_ptr += stride;
-        filtered_output += filtered_output_stride;
-      } while (++y < height);
+inline __m128i VaddwLo16(const __m128i src0, const __m128i src1) {
+  const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128());
+  return _mm_add_epi32(src0, s1);
+}
+
+inline __m128i VaddwHi16(const __m128i src0, const __m128i src1) {
+  const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128());
+  return _mm_add_epi32(src0, s1);
+}
+
+// Using VgetLane16() can save a sign extension instruction.
+template <int n>
+inline int VgetLane16(const __m128i src) {
+  return _mm_extract_epi16(src, n);
+}
+
+inline __m128i VmullLo8(const __m128i src0, const __m128i src1) {
+  const __m128i s0 = _mm_unpacklo_epi8(src0, _mm_setzero_si128());
+  const __m128i s1 = _mm_unpacklo_epi8(src1, _mm_setzero_si128());
+  return _mm_mullo_epi16(s0, s1);
+}
+
+inline __m128i VmullHi8(const __m128i src0, const __m128i src1) {
+  const __m128i s0 = _mm_unpackhi_epi8(src0, _mm_setzero_si128());
+  const __m128i s1 = _mm_unpackhi_epi8(src1, _mm_setzero_si128());
+  return _mm_mullo_epi16(s0, s1);
+}
+
+inline __m128i VmullNLo8(const __m128i src0, const int src1) {
+  const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128());
+  return _mm_madd_epi16(s0, _mm_set1_epi32(src1));
+}
+
+inline __m128i VmullNHi8(const __m128i src0, const int src1) {
+  const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128());
+  return _mm_madd_epi16(s0, _mm_set1_epi32(src1));
+}
+
+inline __m128i VmullLo16(const __m128i src0, const __m128i src1) {
+  const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128());
+  const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128());
+  return _mm_madd_epi16(s0, s1);
+}
+
+inline __m128i VmullHi16(const __m128i src0, const __m128i src1) {
+  const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128());
+  const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128());
+  return _mm_madd_epi16(s0, s1);
+}
+
+inline __m128i VrshrS32(const __m128i src0, const int src1) {
+  const __m128i sum = _mm_add_epi32(src0, _mm_set1_epi32(1 << (src1 - 1)));
+  return _mm_srai_epi32(sum, src1);
+}
+
+inline __m128i VrshrU32(const __m128i src0, const int src1) {
+  const __m128i sum = _mm_add_epi32(src0, _mm_set1_epi32(1 << (src1 - 1)));
+  return _mm_srli_epi32(sum, src1);
+}
+
+template <int n>
+inline __m128i CalcAxN(const __m128i a) {
+  static_assert(n == 9 || n == 25, "");
+  // _mm_mullo_epi32() has high latency. Using shifts and additions instead.
+  // Some compilers could do this for us but we make this explicit.
+  // return _mm_mullo_epi32(a, _mm_set1_epi32(n));
+  const __m128i ax9 = _mm_add_epi32(a, _mm_slli_epi32(a, 3));
+  if (n == 9) return ax9;
+  if (n == 25) return _mm_add_epi32(ax9, _mm_slli_epi32(a, 4));
+}
+
+template <int n>
+inline __m128i CalculateMa(const __m128i sum_sq, const __m128i sum,
+                           const uint32_t s) {
+  // a = |sum_sq|
+  // d = |sum|
+  // p = (a * n < d * d) ? 0 : a * n - d * d;
+  const __m128i dxd = _mm_madd_epi16(sum, sum);
+  const __m128i axn = CalcAxN<n>(sum_sq);
+  const __m128i sub = _mm_sub_epi32(axn, dxd);
+  const __m128i p = _mm_max_epi32(sub, _mm_setzero_si128());
+
+  // z = RightShiftWithRounding(p * s, kSgrProjScaleBits);
+  const __m128i pxs = _mm_mullo_epi32(p, _mm_set1_epi32(s));
+  return VrshrU32(pxs, kSgrProjScaleBits);
+}
+
+// b = ma * b * one_over_n
+// |ma| = [0, 255]
+// |sum| is a box sum with radius 1 or 2.
+// For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
+// For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
+// |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
+// When radius is 2 |n| is 25. |one_over_n| is 164.
+// When radius is 1 |n| is 9. |one_over_n| is 455.
+// |kSgrProjReciprocalBits| is 12.
+// Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
+// Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
+inline __m128i CalculateIntermediate4(const __m128i ma, const __m128i sum,
+                                      const uint32_t one_over_n) {
+  const __m128i maq = _mm_unpacklo_epi8(ma, _mm_setzero_si128());
+  const __m128i s = _mm_unpackhi_epi16(maq, _mm_setzero_si128());
+  const __m128i m = _mm_madd_epi16(s, sum);
+  const __m128i b = _mm_mullo_epi32(m, _mm_set1_epi32(one_over_n));
+  const __m128i truncate_u32 = VrshrU32(b, kSgrProjReciprocalBits);
+  return _mm_packus_epi32(truncate_u32, truncate_u32);
+}
+
+inline __m128i CalculateIntermediate8(const __m128i ma, const __m128i sum,
+                                      const uint32_t one_over_n) {
+  const __m128i maq = _mm_unpackhi_epi8(ma, _mm_setzero_si128());
+  const __m128i m0 = VmullLo16(maq, sum);
+  const __m128i m1 = VmullHi16(maq, sum);
+  const __m128i m2 = _mm_mullo_epi32(m0, _mm_set1_epi32(one_over_n));
+  const __m128i m3 = _mm_mullo_epi32(m1, _mm_set1_epi32(one_over_n));
+  const __m128i b_lo = VrshrU32(m2, kSgrProjReciprocalBits);
+  const __m128i b_hi = VrshrU32(m3, kSgrProjReciprocalBits);
+  return _mm_packus_epi32(b_lo, b_hi);
+}
+
+inline __m128i Sum3_16(const __m128i left, const __m128i middle,
+                       const __m128i right) {
+  const __m128i sum = _mm_add_epi16(left, middle);
+  return _mm_add_epi16(sum, right);
+}
+
+inline __m128i Sum3_32(const __m128i left, const __m128i middle,
+                       const __m128i right) {
+  const __m128i sum = _mm_add_epi32(left, middle);
+  return _mm_add_epi32(sum, right);
+}
+
+inline __m128i Sum3W_16(const __m128i left, const __m128i middle,
+                        const __m128i right) {
+  const __m128i sum = VaddlLo8(left, middle);
+  return VaddwLo8(sum, right);
+}
+
+inline __m128i Sum3WLo_16(const __m128i src[3]) {
+  return Sum3W_16(src[0], src[1], src[2]);
+}
+
+inline __m128i Sum3WHi_16(const __m128i src[3]) {
+  const __m128i sum = VaddlHi8(src[0], src[1]);
+  return VaddwHi8(sum, src[2]);
+}
+
+inline __m128i Sum3WLo_32(const __m128i left, const __m128i middle,
+                          const __m128i right) {
+  const __m128i sum = VaddlLo16(left, middle);
+  return VaddwLo16(sum, right);
+}
+
+inline __m128i Sum3WHi_32(const __m128i left, const __m128i middle,
+                          const __m128i right) {
+  const __m128i sum = VaddlHi16(left, middle);
+  return VaddwHi16(sum, right);
+}
+
+inline __m128i* Sum3W_16x2(const __m128i src[3], __m128i sum[2]) {
+  sum[0] = Sum3WLo_16(src);
+  sum[1] = Sum3WHi_16(src);
+  return sum;
+}
+
+inline __m128i* Sum3W(const __m128i src[3], __m128i sum[2]) {
+  sum[0] = Sum3WLo_32(src[0], src[1], src[2]);
+  sum[1] = Sum3WHi_32(src[0], src[1], src[2]);
+  return sum;
+}
+
+template <int index>
+inline __m128i Sum3WLo(const __m128i src[3][2]) {
+  return Sum3WLo_32(src[0][index], src[1][index], src[2][index]);
+}
+
+inline __m128i Sum3WHi(const __m128i src[3][2]) {
+  return Sum3WHi_32(src[0][0], src[1][0], src[2][0]);
+}
+
+inline __m128i* Sum3W(const __m128i src[3][2], __m128i sum[3]) {
+  sum[0] = Sum3WLo<0>(src);
+  sum[1] = Sum3WHi(src);
+  sum[2] = Sum3WLo<1>(src);
+  return sum;
+}
+
+inline __m128i Sum5_16(const __m128i src[5]) {
+  const __m128i sum01 = _mm_add_epi16(src[0], src[1]);
+  const __m128i sum23 = _mm_add_epi16(src[2], src[3]);
+  const __m128i sum = _mm_add_epi16(sum01, sum23);
+  return _mm_add_epi16(sum, src[4]);
+}
+
+inline __m128i Sum5_32(const __m128i src[5]) {
+  const __m128i sum01 = _mm_add_epi32(src[0], src[1]);
+  const __m128i sum23 = _mm_add_epi32(src[2], src[3]);
+  const __m128i sum = _mm_add_epi32(sum01, sum23);
+  return _mm_add_epi32(sum, src[4]);
+}
+
+inline __m128i Sum5WLo_16(const __m128i src[5]) {
+  const __m128i sum01 = VaddlLo8(src[0], src[1]);
+  const __m128i sum23 = VaddlLo8(src[2], src[3]);
+  const __m128i sum = _mm_add_epi16(sum01, sum23);
+  return VaddwLo8(sum, src[4]);
+}
+
+inline __m128i Sum5WHi_16(const __m128i src[5]) {
+  const __m128i sum01 = VaddlHi8(src[0], src[1]);
+  const __m128i sum23 = VaddlHi8(src[2], src[3]);
+  const __m128i sum = _mm_add_epi16(sum01, sum23);
+  return VaddwHi8(sum, src[4]);
+}
+
+inline __m128i Sum5WLo_32(const __m128i src[5]) {
+  const __m128i sum01 = VaddlLo16(src[0], src[1]);
+  const __m128i sum23 = VaddlLo16(src[2], src[3]);
+  const __m128i sum0123 = _mm_add_epi32(sum01, sum23);
+  return VaddwLo16(sum0123, src[4]);
+}
+
+inline __m128i Sum5WHi_32(const __m128i src[5]) {
+  const __m128i sum01 = VaddlHi16(src[0], src[1]);
+  const __m128i sum23 = VaddlHi16(src[2], src[3]);
+  const __m128i sum0123 = _mm_add_epi32(sum01, sum23);
+  return VaddwHi16(sum0123, src[4]);
+}
+
+inline __m128i* Sum5W_16D(const __m128i src[5], __m128i sum[2]) {
+  sum[0] = Sum5WLo_16(src);
+  sum[1] = Sum5WHi_16(src);
+  return sum;
+}
+
+inline __m128i* Sum5W_32x2(const __m128i src[5], __m128i sum[2]) {
+  sum[0] = Sum5WLo_32(src);
+  sum[1] = Sum5WHi_32(src);
+  return sum;
+}
+
+template <int index>
+inline __m128i Sum5WLo(const __m128i src[5][2]) {
+  __m128i s[5];
+  s[0] = src[0][index];
+  s[1] = src[1][index];
+  s[2] = src[2][index];
+  s[3] = src[3][index];
+  s[4] = src[4][index];
+  return Sum5WLo_32(s);
+}
+
+inline __m128i Sum5WHi(const __m128i src[5][2]) {
+  __m128i s[5];
+  s[0] = src[0][0];
+  s[1] = src[1][0];
+  s[2] = src[2][0];
+  s[3] = src[3][0];
+  s[4] = src[4][0];
+  return Sum5WHi_32(s);
+}
+
+inline __m128i* Sum5W_32x3(const __m128i src[5][2], __m128i sum[3]) {
+  sum[0] = Sum5WLo<0>(src);
+  sum[1] = Sum5WHi(src);
+  sum[2] = Sum5WLo<1>(src);
+  return sum;
+}
+
+inline __m128i Sum3Horizontal(const __m128i src) {
+  const auto left = src;
+  const auto middle = _mm_srli_si128(src, 2);
+  const auto right = _mm_srli_si128(src, 4);
+  return Sum3_16(left, middle, right);
+}
+
+inline __m128i Sum3Horizontal_32(const __m128i src[2]) {
+  const auto left = src[0];
+  const auto middle = _mm_alignr_epi8(src[1], src[0], 4);
+  const auto right = _mm_alignr_epi8(src[1], src[0], 8);
+  return Sum3_32(left, middle, right);
+}
+
+inline __m128i Sum3HorizontalOffset1(const __m128i src) {
+  const auto left = _mm_srli_si128(src, 2);
+  const auto middle = _mm_srli_si128(src, 4);
+  const auto right = _mm_srli_si128(src, 6);
+  return Sum3_16(left, middle, right);
+}
+
+inline __m128i Sum3HorizontalOffset1_16(const __m128i src[2]) {
+  const auto left = _mm_alignr_epi8(src[1], src[0], 2);
+  const auto middle = _mm_alignr_epi8(src[1], src[0], 4);
+  const auto right = _mm_alignr_epi8(src[1], src[0], 6);
+  return Sum3_16(left, middle, right);
+}
+
+inline __m128i Sum3HorizontalOffset1_32(const __m128i src[2]) {
+  const auto left = _mm_alignr_epi8(src[1], src[0], 4);
+  const auto middle = _mm_alignr_epi8(src[1], src[0], 8);
+  const auto right = _mm_alignr_epi8(src[1], src[0], 12);
+  return Sum3_32(left, middle, right);
+}
+
+inline void Sum3HorizontalOffset1_32x2(const __m128i src[3], __m128i sum[2]) {
+  sum[0] = Sum3HorizontalOffset1_32(src + 0);
+  sum[1] = Sum3HorizontalOffset1_32(src + 1);
+}
+
+inline __m128i Sum5Horizontal(const __m128i src) {
+  __m128i s[5];
+  s[0] = src;
+  s[1] = _mm_srli_si128(src, 2);
+  s[2] = _mm_srli_si128(src, 4);
+  s[3] = _mm_srli_si128(src, 6);
+  s[4] = _mm_srli_si128(src, 8);
+  return Sum5_16(s);
+}
+
+inline __m128i Sum5Horizontal_16(const __m128i src[2]) {
+  __m128i s[5];
+  s[0] = src[0];
+  s[1] = _mm_alignr_epi8(src[1], src[0], 2);
+  s[2] = _mm_alignr_epi8(src[1], src[0], 4);
+  s[3] = _mm_alignr_epi8(src[1], src[0], 6);
+  s[4] = _mm_alignr_epi8(src[1], src[0], 8);
+  return Sum5_16(s);
+}
+
+inline __m128i Sum5Horizontal_32(const __m128i src[2]) {
+  __m128i s[5];
+  s[0] = src[0];
+  s[1] = _mm_alignr_epi8(src[1], src[0], 4);
+  s[2] = _mm_alignr_epi8(src[1], src[0], 8);
+  s[3] = _mm_alignr_epi8(src[1], src[0], 12);
+  s[4] = src[1];
+  return Sum5_32(s);
+}
+
+inline __m128i* Sum5Horizontal_32x2(const __m128i src[3], __m128i sum[2]) {
+  __m128i s[5];
+  s[0] = src[0];
+  s[1] = _mm_alignr_epi8(src[1], src[0], 4);
+  s[2] = _mm_alignr_epi8(src[1], src[0], 8);
+  s[3] = _mm_alignr_epi8(src[1], src[0], 12);
+  s[4] = src[1];
+  sum[0] = Sum5_32(s);
+  s[0] = src[1];
+  s[1] = _mm_alignr_epi8(src[2], src[1], 4);
+  s[2] = _mm_alignr_epi8(src[2], src[1], 8);
+  s[3] = _mm_alignr_epi8(src[2], src[1], 12);
+  s[4] = src[2];
+  sum[1] = Sum5_32(s);
+  return sum;
+}
+
+template <int size, int offset>
+inline void BoxFilterPreProcess4(const __m128i* const row,
+                                 const __m128i* const row_sq, const uint32_t s,
+                                 uint16_t* const dst) {
+  static_assert(size == 3 || size == 5, "");
+  static_assert(offset == 0 || offset == 1, "");
+  // Number of elements in the box being summed.
+  constexpr uint32_t n = size * size;
+  constexpr uint32_t one_over_n =
+      ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n;
+  __m128i sum, sum_sq;
+  if (size == 3) {
+    __m128i temp32[2];
+    if (offset == 0) {
+      sum = Sum3Horizontal(Sum3WLo_16(row));
+      sum_sq = Sum3Horizontal_32(Sum3W(row_sq, temp32));
     } else {
-      int y = 0;
-      do {
-        const int shift = 5;
-        uint32_t* const array_start[2] = {
-            buffer->box_filter_process_intermediate[0] +
-                intermediate_buffer_offset + y * intermediate_stride,
-            buffer->box_filter_process_intermediate[1] +
-                intermediate_buffer_offset + y * intermediate_stride};
-        uint32_t* intermediate_result2[2] = {
-            array_start[0] - intermediate_stride,
-            array_start[1] - intermediate_stride};
-        int x = 0;
-        do {
-          const __m128i v_A = Process3x3Block_343(&intermediate_result2[0][x],
-                                                  intermediate_stride);
-          const __m128i v_B = Process3x3Block_343(&intermediate_result2[1][x],
-                                                  intermediate_stride);
-          const __m128i v_src = _mm_cvtepu8_epi32(Load4(src_ptr + x));
-          const __m128i v_v0 = _mm_mullo_epi32(v_A, v_src);
-          const __m128i v_v = _mm_add_epi32(v_v0, v_B);
-          const __m128i v_filtered = RightShiftWithRounding_U32(
-              v_v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
-
-          StoreUnaligned16(&filtered_output[x], v_filtered);
-          x += 4;
-        } while (x < width);
-        src_ptr += stride;
-        filtered_output += filtered_output_stride;
-      } while (++y < height);
+      sum = Sum3HorizontalOffset1(Sum3WLo_16(row));
+      sum_sq = Sum3HorizontalOffset1_32(Sum3W(row_sq, temp32));
     }
   }
+  if (size == 5) {
+    __m128i temp[2];
+    sum = Sum5Horizontal(Sum5WLo_16(row));
+    sum_sq = Sum5Horizontal_32(Sum5W_32x2(row_sq, temp));
+  }
+  const __m128i sum_32 = _mm_unpacklo_epi16(sum, _mm_setzero_si128());
+  const __m128i z0 = CalculateMa<n>(sum_sq, sum_32, s);
+  const __m128i z1 = _mm_packus_epi32(z0, z0);
+  const __m128i z = _mm_min_epu16(z1, _mm_set1_epi16(255));
+  __m128i ma = _mm_setzero_si128();
+  ma = _mm_insert_epi8(ma, kSgrMaLookup[VgetLane16<0>(z)], 4);
+  ma = _mm_insert_epi8(ma, kSgrMaLookup[VgetLane16<1>(z)], 5);
+  ma = _mm_insert_epi8(ma, kSgrMaLookup[VgetLane16<2>(z)], 6);
+  ma = _mm_insert_epi8(ma, kSgrMaLookup[VgetLane16<3>(z)], 7);
+  const __m128i b = CalculateIntermediate4(ma, sum_32, one_over_n);
+  const __m128i ma_b = _mm_unpacklo_epi64(ma, b);
+  StoreAligned16(dst, ma_b);
 }
 
-void SelfGuidedFilter_SSE4_1(const void* source, void* dest,
+template <int size>
+inline void BoxFilterPreProcess8(const __m128i* const row,
+                                 const __m128i row_sq[][2], const uint32_t s,
+                                 __m128i* const ma, __m128i* const b,
+                                 uint16_t* const dst) {
+  static_assert(size == 3 || size == 5, "");
+  // Number of elements in the box being summed.
+  constexpr uint32_t n = size * size;
+  constexpr uint32_t one_over_n =
+      ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n;
+  __m128i sum, sum_sq[2];
+  if (size == 3) {
+    __m128i temp16[2], temp32[3];
+    sum = Sum3HorizontalOffset1_16(Sum3W_16x2(row, temp16));
+    Sum3HorizontalOffset1_32x2(Sum3W(row_sq, temp32), sum_sq);
+  }
+  if (size == 5) {
+    __m128i temp16[2], temp32[3];
+    sum = Sum5Horizontal_16(Sum5W_16D(row, temp16));
+    Sum5Horizontal_32x2(Sum5W_32x3(row_sq, temp32), sum_sq);
+  }
+  const __m128i sum_lo = _mm_unpacklo_epi16(sum, _mm_setzero_si128());
+  const __m128i z0 = CalculateMa<n>(sum_sq[0], sum_lo, s);
+  const __m128i sum_hi = _mm_unpackhi_epi16(sum, _mm_setzero_si128());
+  const __m128i z1 = CalculateMa<n>(sum_sq[1], sum_hi, s);
+  const __m128i z01 = _mm_packus_epi32(z0, z1);
+  const __m128i z = _mm_min_epu16(z01, _mm_set1_epi16(255));
+  *ma = _mm_insert_epi8(*ma, kSgrMaLookup[VgetLane16<0>(z)], 8);
+  *ma = _mm_insert_epi8(*ma, kSgrMaLookup[VgetLane16<1>(z)], 9);
+  *ma = _mm_insert_epi8(*ma, kSgrMaLookup[VgetLane16<2>(z)], 10);
+  *ma = _mm_insert_epi8(*ma, kSgrMaLookup[VgetLane16<3>(z)], 11);
+  *ma = _mm_insert_epi8(*ma, kSgrMaLookup[VgetLane16<4>(z)], 12);
+  *ma = _mm_insert_epi8(*ma, kSgrMaLookup[VgetLane16<5>(z)], 13);
+  *ma = _mm_insert_epi8(*ma, kSgrMaLookup[VgetLane16<6>(z)], 14);
+  *ma = _mm_insert_epi8(*ma, kSgrMaLookup[VgetLane16<7>(z)], 15);
+  *b = CalculateIntermediate8(*ma, sum, one_over_n);
+  const __m128i ma_b = _mm_unpackhi_epi64(*ma, *b);
+  StoreAligned16(dst, ma_b);
+}
+
+inline void Prepare3_8(const __m128i src, __m128i* const left,
+                       __m128i* const middle, __m128i* const right) {
+  *left = _mm_srli_si128(src, 5);
+  *middle = _mm_srli_si128(src, 6);
+  *right = _mm_srli_si128(src, 7);
+}
+
+inline void Prepare3_16(const __m128i src[2], __m128i* const left,
+                        __m128i* const middle, __m128i* const right) {
+  *left = _mm_alignr_epi8(src[1], src[0], 10);
+  *middle = _mm_alignr_epi8(src[1], src[0], 12);
+  *right = _mm_alignr_epi8(src[1], src[0], 14);
+}
+
+inline __m128i Sum343(const __m128i src) {
+  __m128i left, middle, right;
+  Prepare3_8(src, &left, &middle, &right);
+  const auto sum = Sum3W_16(left, middle, right);
+  const auto sum3 = Sum3_16(sum, sum, sum);
+  return VaddwLo8(sum3, middle);
+}
+
+inline void Sum343_444(const __m128i src, __m128i* const sum343,
+                       __m128i* const sum444) {
+  __m128i left, middle, right;
+  Prepare3_8(src, &left, &middle, &right);
+  const auto sum111 = Sum3W_16(left, middle, right);
+  *sum444 = _mm_slli_epi16(sum111, 2);
+  const __m128i sum333 = _mm_sub_epi16(*sum444, sum111);
+  *sum343 = VaddwLo8(sum333, middle);
+}
+
+inline __m128i* Sum343W(const __m128i src[2], __m128i d[2]) {
+  __m128i left, middle, right;
+  Prepare3_16(src, &left, &middle, &right);
+  d[0] = Sum3WLo_32(left, middle, right);
+  d[1] = Sum3WHi_32(left, middle, right);
+  d[0] = Sum3_32(d[0], d[0], d[0]);
+  d[1] = Sum3_32(d[1], d[1], d[1]);
+  d[0] = VaddwLo16(d[0], middle);
+  d[1] = VaddwHi16(d[1], middle);
+  return d;
+}
+
+inline void Sum343_444W(const __m128i src[2], __m128i sum343[2],
+                        __m128i sum444[2]) {
+  __m128i left, middle, right, sum111[2];
+  Prepare3_16(src, &left, &middle, &right);
+  sum111[0] = Sum3WLo_32(left, middle, right);
+  sum111[1] = Sum3WHi_32(left, middle, right);
+  sum444[0] = _mm_slli_epi32(sum111[0], 2);
+  sum444[1] = _mm_slli_epi32(sum111[1], 2);
+  sum343[0] = _mm_sub_epi32(sum444[0], sum111[0]);
+  sum343[1] = _mm_sub_epi32(sum444[1], sum111[1]);
+  sum343[0] = VaddwLo16(sum343[0], middle);
+  sum343[1] = VaddwHi16(sum343[1], middle);
+}
+
+inline __m128i Sum565(const __m128i src) {
+  __m128i left, middle, right;
+  Prepare3_8(src, &left, &middle, &right);
+  const auto sum = Sum3W_16(left, middle, right);
+  const auto sum4 = _mm_slli_epi16(sum, 2);
+  const auto sum5 = _mm_add_epi16(sum4, sum);
+  return VaddwLo8(sum5, middle);
+}
+
+inline __m128i Sum565W(const __m128i src) {
+  const auto left = _mm_srli_si128(src, 2);
+  const auto middle = _mm_srli_si128(src, 4);
+  const auto right = _mm_srli_si128(src, 6);
+  const auto sum = Sum3WLo_32(left, middle, right);
+  const auto sum4 = _mm_slli_epi32(sum, 2);
+  const auto sum5 = _mm_add_epi32(sum4, sum);
+  return VaddwLo16(sum5, middle);
+}
+
+template <int shift>
+inline __m128i FilterOutput(const __m128i ma_x_src, const __m128i b) {
+  // ma: 255 * 32 = 8160 (13 bits)
+  // b: 65088 * 32 = 2082816 (21 bits)
+  // v: b - ma * 255 (22 bits)
+  const __m128i v = _mm_sub_epi32(b, ma_x_src);
+  // kSgrProjSgrBits = 8
+  // kSgrProjRestoreBits = 4
+  // shift = 4 or 5
+  // v >> 8 or 9 (13 bits)
+  return VrshrS32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
+}
+
+template <int shift>
+inline __m128i CalculateFilteredOutput(const __m128i src, const __m128i a,
+                                       const __m128i b[2]) {
+  const __m128i src_u16 = _mm_unpacklo_epi8(src, _mm_setzero_si128());
+  const __m128i ma_x_src_lo = VmullLo16(a, src_u16);
+  const __m128i ma_x_src_hi = VmullHi16(a, src_u16);
+  const __m128i dst_lo = FilterOutput<shift>(ma_x_src_lo, b[0]);
+  const __m128i dst_hi = FilterOutput<shift>(ma_x_src_hi, b[1]);
+  return _mm_packs_epi32(dst_lo, dst_hi);  // 13 bits
+}
+
+inline __m128i BoxFilterPass1(const __m128i src_u8, const __m128i ma,
+                              const __m128i b[2], __m128i ma565[2],
+                              __m128i b565[2][2]) {
+  __m128i b_sum[2];
+  ma565[1] = Sum565(ma);
+  b565[1][0] = Sum565W(_mm_alignr_epi8(b[1], b[0], 8));
+  b565[1][1] = Sum565W(b[1]);
+  __m128i ma_sum = _mm_add_epi16(ma565[0], ma565[1]);
+  b_sum[0] = _mm_add_epi32(b565[0][0], b565[1][0]);
+  b_sum[1] = _mm_add_epi32(b565[0][1], b565[1][1]);
+  return CalculateFilteredOutput<5>(src_u8, ma_sum, b_sum);  // 13 bits
+}
+
+inline __m128i BoxFilterPass2(const __m128i src_u8, const __m128i ma,
+                              const __m128i b[2], __m128i ma343[4],
+                              __m128i ma444[3], __m128i b343[4][2],
+                              __m128i b444[3][2]) {
+  __m128i b_sum[2];
+  Sum343_444(ma, &ma343[2], &ma444[1]);
+  __m128i ma_sum = Sum3_16(ma343[0], ma444[0], ma343[2]);
+  Sum343_444W(b, b343[2], b444[1]);
+  b_sum[0] = Sum3_32(b343[0][0], b444[0][0], b343[2][0]);
+  b_sum[1] = Sum3_32(b343[0][1], b444[0][1], b343[2][1]);
+  return CalculateFilteredOutput<5>(src_u8, ma_sum, b_sum);  // 13 bits
+}
+
+inline void SelfGuidedFinal(const __m128i src, const __m128i v[2],
+                            uint8_t* const dst) {
+  const __m128i v_lo =
+      VrshrS32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits);
+  const __m128i v_hi =
+      VrshrS32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits);
+  const __m128i vv = _mm_packs_epi32(v_lo, v_hi);
+  const __m128i s = _mm_unpacklo_epi8(src, _mm_setzero_si128());
+  const __m128i d = _mm_add_epi16(s, vv);
+  StoreLo8(dst, _mm_packus_epi16(d, d));
+}
+
+inline void SelfGuidedDoubleMultiplier(const __m128i src,
+                                       const __m128i filter[2], const int w0,
+                                       const int w2, uint8_t* const dst) {
+  __m128i v[2];
+  const __m128i w0_w2 = _mm_set1_epi32((w2 << 16) | static_cast<uint16_t>(w0));
+  const __m128i f_lo = _mm_unpacklo_epi16(filter[0], filter[1]);
+  const __m128i f_hi = _mm_unpackhi_epi16(filter[0], filter[1]);
+  v[0] = _mm_madd_epi16(w0_w2, f_lo);
+  v[1] = _mm_madd_epi16(w0_w2, f_hi);
+  SelfGuidedFinal(src, v, dst);
+}
+
+inline void SelfGuidedSingleMultiplier(const __m128i src, const __m128i filter,
+                                       const int w0, uint8_t* const dst) {
+  // weight: -96 to 96 (Sgrproj_Xqd_Min/Max)
+  __m128i v[2];
+  v[0] = VmullNLo8(filter, w0);
+  v[1] = VmullNHi8(filter, w0);
+  SelfGuidedFinal(src, v, dst);
+}
+
+inline void BoxFilterProcess(const uint8_t* const src,
+                             const ptrdiff_t src_stride,
                              const RestorationUnitInfo& restoration_info,
-                             ptrdiff_t source_stride, ptrdiff_t dest_stride,
-                             int width, int height,
-                             RestorationBuffer* const buffer) {
-  const auto* src = static_cast<const uint8_t*>(source);
-  auto* dst = static_cast<uint8_t*>(dest);
+                             const int width, const int height,
+                             const uint16_t scale[2], uint16_t* const temp,
+                             uint8_t* const dst, const ptrdiff_t dst_stride) {
+  // We have combined PreProcess and Process for the first pass by storing
+  // intermediate values in the |ma| region. The values stored are one
+  // vertical column of interleaved |ma| and |b| values and consume 8 *
+  // |height| values. This is |height| and not |height| * 2 because PreProcess
+  // only generates output for every other row. When processing the next column
+  // we write the new scratch values right after reading the previously saved
+  // ones.
+
+  // The PreProcess phase calculates a 5x5 box sum for every other row
+  //
+  // PreProcess and Process have been combined into the same step. We need 12
+  // input values to generate 8 output values for PreProcess:
+  // 0 1 2 3 4 5 6 7 8 9 10 11
+  // 2 = 0 + 1 + 2 +  3 +  4
+  // 3 = 1 + 2 + 3 +  4 +  5
+  // 4 = 2 + 3 + 4 +  5 +  6
+  // 5 = 3 + 4 + 5 +  6 +  7
+  // 6 = 4 + 5 + 6 +  7 +  8
+  // 7 = 5 + 6 + 7 +  8 +  9
+  // 8 = 6 + 7 + 8 +  9 + 10
+  // 9 = 7 + 8 + 9 + 10 + 11
+  //
+  // and then we need 10 input values to generate 8 output values for Process:
+  // 0 1 2 3 4 5 6 7 8 9
+  // 1 = 0 + 1 + 2
+  // 2 = 1 + 2 + 3
+  // 3 = 2 + 3 + 4
+  // 4 = 3 + 4 + 5
+  // 5 = 4 + 5 + 6
+  // 6 = 5 + 6 + 7
+  // 7 = 6 + 7 + 8
+  // 8 = 7 + 8 + 9
+  //
+  // To avoid re-calculating PreProcess values over and over again we will do a
+  // single column of 8 output values and store the second half of them
+  // interleaved in |temp|. The first half is not stored, since it is used
+  // immediately and becomes useless for the next column. Next we will start the
+  // second column. When 2 rows have been calculated we can calculate Process
+  // and output the results.
+
+  // Calculate and store a single column. Scope so we can re-use the variable
+  // names for the next step.
+  uint16_t* ab_ptr = temp;
+  const uint8_t* const src_pre_process = src - 2 * src_stride;
+  // Calculate intermediate results, including two-pixel border, for example, if
+  // unit size is 64x64, we calculate 68x68 pixels.
+  {
+    const uint8_t* column = src_pre_process - 4;
+    __m128i row[5], row_sq[5];
+    row[0] = row[1] = LoadLo8(column);
+    column += src_stride;
+    row[2] = LoadLo8(column);
+    row_sq[0] = row_sq[1] = VmullLo8(row[1], row[1]);
+    row_sq[2] = VmullLo8(row[2], row[2]);
+
+    int y = (height + 2) >> 1;
+    do {
+      column += src_stride;
+      row[3] = LoadLo8(column);
+      column += src_stride;
+      row[4] = LoadLo8(column);
+      row_sq[3] = VmullLo8(row[3], row[3]);
+      row_sq[4] = VmullLo8(row[4], row[4]);
+      BoxFilterPreProcess4<5, 1>(row + 0, row_sq + 0, scale[0], ab_ptr + 0);
+      BoxFilterPreProcess4<3, 1>(row + 1, row_sq + 1, scale[1], ab_ptr + 8);
+      BoxFilterPreProcess4<3, 1>(row + 2, row_sq + 2, scale[1], ab_ptr + 16);
+      row[0] = row[2];
+      row[1] = row[3];
+      row[2] = row[4];
+      row_sq[0] = row_sq[2];
+      row_sq[1] = row_sq[3];
+      row_sq[2] = row_sq[4];
+      ab_ptr += 24;
+    } while (--y != 0);
+
+    if ((height & 1) != 0) {
+      column += src_stride;
+      row[3] = row[4] = LoadLo8(column);
+      row_sq[3] = row_sq[4] = VmullLo8(row[3], row[3]);
+      BoxFilterPreProcess4<5, 1>(row + 0, row_sq + 0, scale[0], ab_ptr + 0);
+      BoxFilterPreProcess4<3, 1>(row + 1, row_sq + 1, scale[1], ab_ptr + 8);
+    }
+  }
+
   const int w0 = restoration_info.sgr_proj_info.multiplier[0];
   const int w1 = restoration_info.sgr_proj_info.multiplier[1];
   const int w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
-  const int index = restoration_info.sgr_proj_info.index;
-  const uint8_t r0 = kSgrProjParams[index][0];
-  const uint8_t r1 = kSgrProjParams[index][2];
-  const ptrdiff_t array_stride = buffer->box_filter_process_output_stride;
-  int* box_filter_process_output[2] = {buffer->box_filter_process_output[0],
-                                       buffer->box_filter_process_output[1]};
-
-  BoxFilterProcess_SSE4_1(restoration_info, src, source_stride, width, height,
-                          buffer);
-
-  const __m128i v_w0 = _mm_shuffle_epi32(_mm_cvtsi32_si128(w0), 0);
-  const __m128i v_w1 = _mm_shuffle_epi32(_mm_cvtsi32_si128(w1), 0);
-  const __m128i v_w2 = _mm_shuffle_epi32(_mm_cvtsi32_si128(w2), 0);
-  const __m128i v_r0 = _mm_shuffle_epi32(_mm_cvtsi32_si128(r0), 0);
-  const __m128i v_r1 = _mm_shuffle_epi32(_mm_cvtsi32_si128(r1), 0);
-  const __m128i zero = _mm_setzero_si128();
-  // Create masks used to select between src and box_filter_process_output.
-  const __m128i v_r0_mask = _mm_cmpeq_epi32(v_r0, zero);
-  const __m128i v_r1_mask = _mm_cmpeq_epi32(v_r1, zero);
-
-  int y = 0;
+  int x = 0;
   do {
-    int x = 0;
-    do {
-      const __m128i v_src = _mm_cvtepu8_epi32(Load4(src + x));
-      const __m128i v_u = _mm_slli_epi32(v_src, kSgrProjRestoreBits);
-      const __m128i v_v_a = _mm_mullo_epi32(v_w1, v_u);
-      const __m128i v_bfp_out0 =
-          LoadUnaligned16(&box_filter_process_output[0][x]);
-      // Select u or box_filter_process_output[0][x].
-      const __m128i v_r0_mult = _mm_blendv_epi8(v_bfp_out0, v_u, v_r0_mask);
-      const __m128i v_v_b = _mm_mullo_epi32(v_w0, v_r0_mult);
-      const __m128i v_v_c = _mm_add_epi32(v_v_a, v_v_b);
-      const __m128i v_bfp_out1 =
-          LoadUnaligned16(&box_filter_process_output[1][x]);
-      // Select u or box_filter_process_output[1][x].
-      const __m128i v_r1_mult = _mm_blendv_epi8(v_bfp_out1, v_u, v_r1_mask);
-      const __m128i v_v_d = _mm_mullo_epi32(v_w2, v_r1_mult);
-      const __m128i v_v_e = _mm_add_epi32(v_v_c, v_v_d);
-      __m128i v_s = RightShiftWithRounding_S32(
-          v_v_e, kSgrProjRestoreBits + kSgrProjPrecisionBits);
-      v_s = _mm_packs_epi32(v_s, v_s);
-      v_s = _mm_packus_epi16(v_s, v_s);
-      Store4(&dst[x], v_s);
-      x += 4;
-    } while (x < width);
+    // |src_pre_process| is X but we already processed the first column of 4
+    // values so we want to start at Y and increment from there.
+    // X s s s Y s s
+    // s s s s s s s
+    // s s i i i i i
+    // s s i o o o o
+    // s s i o o o o
 
-    src += source_stride;
-    dst += dest_stride;
-    box_filter_process_output[0] += array_stride;
-    box_filter_process_output[1] += array_stride;
-  } while (++y < height);
+    // Seed the loop with one line of output. Then, inside the loop, for each
+    // iteration we can output one even row and one odd row and carry the new
+    // line to the next iteration. In the diagram below 'i' values are
+    // intermediary values from the first step and '-' values are empty.
+    // iiii
+    // ---- > even row
+    // iiii - odd row
+    // ---- > even row
+    // iiii
+    __m128i ma[2], b[2][2], ma565[2], ma343[4], ma444[3];
+    __m128i b565[2][2], b343[4][2], b444[3][2];
+    ab_ptr = temp;
+    ma[0] = b[0][0] = LoadAligned16(ab_ptr);
+    ma[1] = b[1][0] = LoadAligned16(ab_ptr + 8);
+    const uint8_t* column = src_pre_process + x;
+    __m128i row[5], row_sq[5][2];
+    // Need |width| + 3 pixels, but we read max(|x|) + 16 pixels.
+    // Mask max(|x|) + 13 - |width| extra pixels.
+    row[0] = row[1] = LoadUnaligned16Msan(column, x + 13 - width);
+    column += src_stride;
+    row[2] = LoadUnaligned16Msan(column, x + 13 - width);
+    column += src_stride;
+    row[3] = LoadUnaligned16Msan(column, x + 13 - width);
+    column += src_stride;
+    row[4] = LoadUnaligned16Msan(column, x + 13 - width);
+    row_sq[0][0] = row_sq[1][0] = VmullLo8(row[1], row[1]);
+    row_sq[0][1] = row_sq[1][1] = VmullHi8(row[1], row[1]);
+    row_sq[2][0] = VmullLo8(row[2], row[2]);
+    row_sq[2][1] = VmullHi8(row[2], row[2]);
+    row_sq[3][0] = VmullLo8(row[3], row[3]);
+    row_sq[3][1] = VmullHi8(row[3], row[3]);
+    row_sq[4][0] = VmullLo8(row[4], row[4]);
+    row_sq[4][1] = VmullHi8(row[4], row[4]);
+    BoxFilterPreProcess8<5>(row, row_sq, scale[0], &ma[0], &b[0][1], ab_ptr);
+    BoxFilterPreProcess8<3>(row + 1, row_sq + 1, scale[1], &ma[1], &b[1][1],
+                            ab_ptr + 8);
+
+    // Pass 1 Process. These are the only values we need to propagate between
+    // rows.
+    ma565[0] = Sum565(ma[0]);
+    b565[0][0] = Sum565W(_mm_alignr_epi8(b[0][1], b[0][0], 8));
+    b565[0][1] = Sum565W(b[0][1]);
+    ma343[0] = Sum343(ma[1]);
+    Sum343W(b[1], b343[0]);
+    ma[1] = b[1][0] = LoadAligned16(ab_ptr + 16);
+    BoxFilterPreProcess8<3>(row + 2, row_sq + 2, scale[1], &ma[1], &b[1][1],
+                            ab_ptr + 16);
+    Sum343_444(ma[1], &ma343[1], &ma444[0]);
+    Sum343_444W(b[1], b343[1], b444[0]);
+
+    uint8_t* dst_ptr = dst + x;
+    // Calculate one output line. Add in the line from the previous pass and
+    // output one even row. Sum the new line and output the odd row. Carry the
+    // new row into the next pass.
+    for (int y = height >> 1; y != 0; --y) {
+      ab_ptr += 24;
+      ma[0] = b[0][0] = LoadAligned16(ab_ptr);
+      ma[1] = b[1][0] = LoadAligned16(ab_ptr + 8);
+      row[0] = row[2];
+      row[1] = row[3];
+      row[2] = row[4];
+      row_sq[0][0] = row_sq[2][0], row_sq[0][1] = row_sq[2][1];
+      row_sq[1][0] = row_sq[3][0], row_sq[1][1] = row_sq[3][1];
+      row_sq[2][0] = row_sq[4][0], row_sq[2][1] = row_sq[4][1];
+      column += src_stride;
+      row[3] = LoadUnaligned16Msan(column, x + 13 - width);
+      column += src_stride;
+      row[4] = LoadUnaligned16Msan(column, x + 13 - width);
+      row_sq[3][0] = VmullLo8(row[3], row[3]);
+      row_sq[3][1] = VmullHi8(row[3], row[3]);
+      row_sq[4][0] = VmullLo8(row[4], row[4]);
+      row_sq[4][1] = VmullHi8(row[4], row[4]);
+      BoxFilterPreProcess8<5>(row, row_sq, scale[0], &ma[0], &b[0][1], ab_ptr);
+      BoxFilterPreProcess8<3>(row + 1, row_sq + 1, scale[1], &ma[1], &b[1][1],
+                              ab_ptr + 8);
+      __m128i p[2];
+      p[0] = BoxFilterPass1(row[1], ma[0], b[0], ma565, b565);
+      p[1] = BoxFilterPass2(row[1], ma[1], b[1], ma343, ma444, b343, b444);
+      SelfGuidedDoubleMultiplier(row[1], p, w0, w2, dst_ptr);
+      dst_ptr += dst_stride;
+      p[0] = CalculateFilteredOutput<4>(row[2], ma565[1], b565[1]);
+      ma[1] = b[1][0] = LoadAligned16(ab_ptr + 16);
+      BoxFilterPreProcess8<3>(row + 2, row_sq + 2, scale[1], &ma[1], &b[1][1],
+                              ab_ptr + 16);
+      p[1] = BoxFilterPass2(row[2], ma[1], b[1], ma343 + 1, ma444 + 1, b343 + 1,
+                            b444 + 1);
+      SelfGuidedDoubleMultiplier(row[2], p, w0, w2, dst_ptr);
+      dst_ptr += dst_stride;
+      ma565[0] = ma565[1];
+      b565[0][0] = b565[1][0], b565[0][1] = b565[1][1];
+      ma343[0] = ma343[2];
+      ma343[1] = ma343[3];
+      ma444[0] = ma444[2];
+      b343[0][0] = b343[2][0], b343[0][1] = b343[2][1];
+      b343[1][0] = b343[3][0], b343[1][1] = b343[3][1];
+      b444[0][0] = b444[2][0], b444[0][1] = b444[2][1];
+    }
+
+    if ((height & 1) != 0) {
+      ab_ptr += 24;
+      ma[0] = b[0][0] = LoadAligned16(ab_ptr);
+      ma[1] = b[1][0] = LoadAligned16(ab_ptr + 8);
+      row[0] = row[2];
+      row[1] = row[3];
+      row[2] = row[4];
+      row_sq[0][0] = row_sq[2][0], row_sq[0][1] = row_sq[2][1];
+      row_sq[1][0] = row_sq[3][0], row_sq[1][1] = row_sq[3][1];
+      row_sq[2][0] = row_sq[4][0], row_sq[2][1] = row_sq[4][1];
+      column += src_stride;
+      row[3] = row[4] = LoadUnaligned16Msan(column, x + 13 - width);
+      row_sq[3][0] = row_sq[4][0] = VmullLo8(row[3], row[3]);
+      row_sq[3][1] = row_sq[4][1] = VmullHi8(row[3], row[3]);
+      BoxFilterPreProcess8<5>(row, row_sq, scale[0], &ma[0], &b[0][1], ab_ptr);
+      BoxFilterPreProcess8<3>(row + 1, row_sq + 1, scale[1], &ma[1], &b[1][1],
+                              ab_ptr + 8);
+      __m128i p[2];
+      p[0] = BoxFilterPass1(row[1], ma[0], b[0], ma565, b565);
+      p[1] = BoxFilterPass2(row[1], ma[1], b[1], ma343, ma444, b343, b444);
+      SelfGuidedDoubleMultiplier(row[1], p, w0, w2, dst_ptr);
+    }
+    x += 8;
+  } while (x < width);
+}
+
+inline void BoxFilterProcessPass1(const uint8_t* const src,
+                                  const ptrdiff_t src_stride,
+                                  const RestorationUnitInfo& restoration_info,
+                                  const int width, const int height,
+                                  const uint32_t scale, uint16_t* const temp,
+                                  uint8_t* const dst,
+                                  const ptrdiff_t dst_stride) {
+  // We have combined PreProcess and Process for the first pass by storing
+  // intermediate values in the |ma| region. The values stored are one
+  // vertical column of interleaved |ma| and |b| values and consume 8 *
+  // |height| values. This is |height| and not |height| * 2 because PreProcess
+  // only generates output for every other row. When processing the next column
+  // we write the new scratch values right after reading the previously saved
+  // ones.
+
+  // The PreProcess phase calculates a 5x5 box sum for every other row
+  //
+  // PreProcess and Process have been combined into the same step. We need 12
+  // input values to generate 8 output values for PreProcess:
+  // 0 1 2 3 4 5 6 7 8 9 10 11
+  // 2 = 0 + 1 + 2 +  3 +  4
+  // 3 = 1 + 2 + 3 +  4 +  5
+  // 4 = 2 + 3 + 4 +  5 +  6
+  // 5 = 3 + 4 + 5 +  6 +  7
+  // 6 = 4 + 5 + 6 +  7 +  8
+  // 7 = 5 + 6 + 7 +  8 +  9
+  // 8 = 6 + 7 + 8 +  9 + 10
+  // 9 = 7 + 8 + 9 + 10 + 11
+  //
+  // and then we need 10 input values to generate 8 output values for Process:
+  // 0 1 2 3 4 5 6 7 8 9
+  // 1 = 0 + 1 + 2
+  // 2 = 1 + 2 + 3
+  // 3 = 2 + 3 + 4
+  // 4 = 3 + 4 + 5
+  // 5 = 4 + 5 + 6
+  // 6 = 5 + 6 + 7
+  // 7 = 6 + 7 + 8
+  // 8 = 7 + 8 + 9
+  //
+  // To avoid re-calculating PreProcess values over and over again we will do a
+  // single column of 8 output values and store the second half of them
+  // interleaved in |temp|. The first half is not stored, since it is used
+  // immediately and becomes useless for the next column. Next we will start the
+  // second column. When 2 rows have been calculated we can calculate Process
+  // and output the results.
+
+  // Calculate and store a single column. Scope so we can re-use the variable
+  // names for the next step.
+  uint16_t* ab_ptr = temp;
+  const uint8_t* const src_pre_process = src - 2 * src_stride;
+  // Calculate intermediate results, including two-pixel border, for example, if
+  // unit size is 64x64, we calculate 68x68 pixels.
+  {
+    const uint8_t* column = src_pre_process - 4;
+    __m128i row[5], row_sq[5];
+    row[0] = row[1] = LoadLo8(column);
+    column += src_stride;
+    row[2] = LoadLo8(column);
+    row_sq[0] = row_sq[1] = VmullLo8(row[1], row[1]);
+    row_sq[2] = VmullLo8(row[2], row[2]);
+
+    int y = (height + 2) >> 1;
+    do {
+      column += src_stride;
+      row[3] = LoadLo8(column);
+      column += src_stride;
+      row[4] = LoadLo8(column);
+      row_sq[3] = VmullLo8(row[3], row[3]);
+      row_sq[4] = VmullLo8(row[4], row[4]);
+      BoxFilterPreProcess4<5, 1>(row, row_sq, scale, ab_ptr);
+      row[0] = row[2];
+      row[1] = row[3];
+      row[2] = row[4];
+      row_sq[0] = row_sq[2];
+      row_sq[1] = row_sq[3];
+      row_sq[2] = row_sq[4];
+      ab_ptr += 8;
+    } while (--y != 0);
+
+    if ((height & 1) != 0) {
+      column += src_stride;
+      row[3] = row[4] = LoadLo8(column);
+      row_sq[3] = row_sq[4] = VmullLo8(row[3], row[3]);
+      BoxFilterPreProcess4<5, 1>(row, row_sq, scale, ab_ptr);
+    }
+  }
+
+  const int w0 = restoration_info.sgr_proj_info.multiplier[0];
+  int x = 0;
+  do {
+    // |src_pre_process| is X but we already processed the first column of 4
+    // values so we want to start at Y and increment from there.
+    // X s s s Y s s
+    // s s s s s s s
+    // s s i i i i i
+    // s s i o o o o
+    // s s i o o o o
+
+    // Seed the loop with one line of output. Then, inside the loop, for each
+    // iteration we can output one even row and one odd row and carry the new
+    // line to the next iteration. In the diagram below 'i' values are
+    // intermediary values from the first step and '-' values are empty.
+    // iiii
+    // ---- > even row
+    // iiii - odd row
+    // ---- > even row
+    // iiii
+    __m128i ma[2], b[2], ma565[2], b565[2][2];
+    ab_ptr = temp;
+    ma[0] = b[0] = LoadAligned16(ab_ptr);
+    const uint8_t* column = src_pre_process + x;
+    __m128i row[5], row_sq[5][2];
+    // Need |width| + 3 pixels, but we read max(|x|) + 16 pixels.
+    // Mask max(|x|) + 13 - |width| extra pixels.
+    row[0] = row[1] = LoadUnaligned16Msan(column, x + 13 - width);
+    column += src_stride;
+    row[2] = LoadUnaligned16Msan(column, x + 13 - width);
+    column += src_stride;
+    row[3] = LoadUnaligned16Msan(column, x + 13 - width);
+    column += src_stride;
+    row[4] = LoadUnaligned16Msan(column, x + 13 - width);
+    row_sq[0][0] = row_sq[1][0] = VmullLo8(row[1], row[1]);
+    row_sq[0][1] = row_sq[1][1] = VmullHi8(row[1], row[1]);
+    row_sq[2][0] = VmullLo8(row[2], row[2]);
+    row_sq[2][1] = VmullHi8(row[2], row[2]);
+    row_sq[3][0] = VmullLo8(row[3], row[3]);
+    row_sq[3][1] = VmullHi8(row[3], row[3]);
+    row_sq[4][0] = VmullLo8(row[4], row[4]);
+    row_sq[4][1] = VmullHi8(row[4], row[4]);
+    BoxFilterPreProcess8<5>(row, row_sq, scale, &ma[0], &b[1], ab_ptr);
+
+    // Pass 1 Process. These are the only values we need to propagate between
+    // rows.
+    ma565[0] = Sum565(ma[0]);
+    b565[0][0] = Sum565W(_mm_alignr_epi8(b[1], b[0], 8));
+    b565[0][1] = Sum565W(b[1]);
+    uint8_t* dst_ptr = dst + x;
+    // Calculate one output line. Add in the line from the previous pass and
+    // output one even row. Sum the new line and output the odd row. Carry the
+    // new row into the next pass.
+    for (int y = height >> 1; y != 0; --y) {
+      ab_ptr += 8;
+      ma[0] = b[0] = LoadAligned16(ab_ptr);
+      row[0] = row[2];
+      row[1] = row[3];
+      row[2] = row[4];
+      row_sq[0][0] = row_sq[2][0], row_sq[0][1] = row_sq[2][1];
+      row_sq[1][0] = row_sq[3][0], row_sq[1][1] = row_sq[3][1];
+      row_sq[2][0] = row_sq[4][0], row_sq[2][1] = row_sq[4][1];
+      column += src_stride;
+      row[3] = LoadUnaligned16Msan(column, x + 13 - width);
+      column += src_stride;
+      row[4] = LoadUnaligned16Msan(column, x + 13 - width);
+      row_sq[3][0] = VmullLo8(row[3], row[3]);
+      row_sq[3][1] = VmullHi8(row[3], row[3]);
+      row_sq[4][0] = VmullLo8(row[4], row[4]);
+      row_sq[4][1] = VmullHi8(row[4], row[4]);
+      BoxFilterPreProcess8<5>(row, row_sq, scale, &ma[0], &b[1], ab_ptr);
+      const __m128i p0 = BoxFilterPass1(row[1], ma[0], b, ma565, b565);
+      SelfGuidedSingleMultiplier(row[1], p0, w0, dst_ptr);
+      dst_ptr += dst_stride;
+      const __m128i p1 = CalculateFilteredOutput<4>(row[2], ma565[1], b565[1]);
+      SelfGuidedSingleMultiplier(row[2], p1, w0, dst_ptr);
+      dst_ptr += dst_stride;
+      ma565[0] = ma565[1];
+      b565[0][0] = b565[1][0], b565[0][1] = b565[1][1];
+    }
+
+    if ((height & 1) != 0) {
+      ab_ptr += 8;
+      ma[0] = b[0] = LoadAligned16(ab_ptr);
+      row[0] = row[2];
+      row[1] = row[3];
+      row[2] = row[4];
+      row_sq[0][0] = row_sq[2][0], row_sq[0][1] = row_sq[2][1];
+      row_sq[1][0] = row_sq[3][0], row_sq[1][1] = row_sq[3][1];
+      row_sq[2][0] = row_sq[4][0], row_sq[2][1] = row_sq[4][1];
+      column += src_stride;
+      row[3] = row[4] = LoadUnaligned16Msan(column, x + 13 - width);
+      row_sq[3][0] = row_sq[4][0] = VmullLo8(row[3], row[3]);
+      row_sq[3][1] = row_sq[4][1] = VmullHi8(row[3], row[3]);
+      BoxFilterPreProcess8<5>(row, row_sq, scale, &ma[0], &b[1], ab_ptr);
+      const __m128i p0 = BoxFilterPass1(row[1], ma[0], b, ma565, b565);
+      SelfGuidedSingleMultiplier(row[1], p0, w0, dst_ptr);
+    }
+    x += 8;
+  } while (x < width);
+}
+
+inline void BoxFilterProcessPass2(const uint8_t* src,
+                                  const ptrdiff_t src_stride,
+                                  const RestorationUnitInfo& restoration_info,
+                                  const int width, const int height,
+                                  const uint32_t scale, uint16_t* const temp,
+                                  uint8_t* const dst,
+                                  const ptrdiff_t dst_stride) {
+  // Calculate intermediate results, including one-pixel border, for example, if
+  // unit size is 64x64, we calculate 66x66 pixels.
+  // Because of the vectors this calculates start in blocks of 4 so we actually
+  // get 68 values.
+  uint16_t* ab_ptr = temp;
+  const uint8_t* const src_pre_process = src - 2 * src_stride;
+  {
+    const uint8_t* column = src_pre_process - 3;
+    __m128i row[3], row_sq[3];
+    row[0] = LoadLo8(column);
+    column += src_stride;
+    row[1] = LoadLo8(column);
+    row_sq[0] = VmullLo8(row[0], row[0]);
+    row_sq[1] = VmullLo8(row[1], row[1]);
+    int y = height + 2;
+    do {
+      column += src_stride;
+      row[2] = LoadLo8(column);
+      row_sq[2] = VmullLo8(row[2], row[2]);
+      BoxFilterPreProcess4<3, 0>(row, row_sq, scale, ab_ptr);
+      row[0] = row[1];
+      row[1] = row[2];
+      row_sq[0] = row_sq[1];
+      row_sq[1] = row_sq[2];
+      ab_ptr += 8;
+    } while (--y != 0);
+  }
+
+  assert(restoration_info.sgr_proj_info.multiplier[0] == 0);
+  const int w1 = restoration_info.sgr_proj_info.multiplier[1];
+  const int w0 = (1 << kSgrProjPrecisionBits) - w1;
+  int x = 0;
+  do {
+    ab_ptr = temp;
+    __m128i ma, b[2], ma343[3], ma444[2], b343[3][2], b444[2][2];
+    ma = b[0] = LoadAligned16(ab_ptr);
+    const uint8_t* column = src_pre_process + x;
+    __m128i row[3], row_sq[3][2];
+    // Need |width| + 2 pixels, but we read max(|x|) + 16 pixels.
+    // Mask max(|x|) + 14 - |width| extra pixels.
+    row[0] = LoadUnaligned16Msan(column, x + 14 - width);
+    column += src_stride;
+    row[1] = LoadUnaligned16Msan(column, x + 14 - width);
+    column += src_stride;
+    row[2] = LoadUnaligned16Msan(column, x + 14 - width);
+    row_sq[0][0] = VmullLo8(row[0], row[0]);
+    row_sq[0][1] = VmullHi8(row[0], row[0]);
+    row_sq[1][0] = VmullLo8(row[1], row[1]);
+    row_sq[1][1] = VmullHi8(row[1], row[1]);
+    row_sq[2][0] = VmullLo8(row[2], row[2]);
+    row_sq[2][1] = VmullHi8(row[2], row[2]);
+    BoxFilterPreProcess8<3>(row, row_sq, scale, &ma, &b[1], ab_ptr);
+    ma343[0] = Sum343(ma);
+    Sum343W(b, b343[0]);
+    ab_ptr += 8;
+    ma = b[0] = LoadAligned16(ab_ptr);
+    row[0] = row[1];
+    row[1] = row[2];
+    row_sq[0][0] = row_sq[1][0], row_sq[0][1] = row_sq[1][1];
+    row_sq[1][0] = row_sq[2][0], row_sq[1][1] = row_sq[2][1];
+    column += src_stride;
+    row[2] = LoadUnaligned16Msan(column, x + 14 - width);
+    row_sq[2][0] = VmullLo8(row[2], row[2]);
+    row_sq[2][1] = VmullHi8(row[2], row[2]);
+    BoxFilterPreProcess8<3>(row, row_sq, scale, &ma, &b[1], ab_ptr);
+    Sum343_444(ma, &ma343[1], &ma444[0]);
+    Sum343_444W(b, b343[1], b444[0]);
+
+    uint8_t* dst_ptr = dst + x;
+    int y = height;
+    do {
+      ab_ptr += 8;
+      ma = b[0] = LoadAligned16(ab_ptr);
+      row[0] = row[1];
+      row[1] = row[2];
+      row_sq[0][0] = row_sq[1][0], row_sq[0][1] = row_sq[1][1];
+      row_sq[1][0] = row_sq[2][0], row_sq[1][1] = row_sq[2][1];
+      column += src_stride;
+      row[2] = LoadUnaligned16Msan(column, x + 14 - width);
+      row_sq[2][0] = VmullLo8(row[2], row[2]);
+      row_sq[2][1] = VmullHi8(row[2], row[2]);
+      BoxFilterPreProcess8<3>(row, row_sq, scale, &ma, &b[1], ab_ptr);
+      const __m128i p = BoxFilterPass2(row[0], ma, b, ma343, ma444, b343, b444);
+      SelfGuidedSingleMultiplier(row[0], p, w0, dst_ptr);
+      ma343[0] = ma343[1];
+      ma343[1] = ma343[2];
+      ma444[0] = ma444[1];
+      b343[0][0] = b343[1][0], b343[0][1] = b343[1][1];
+      b343[1][0] = b343[2][0], b343[1][1] = b343[2][1];
+      b444[0][0] = b444[1][0], b444[0][1] = b444[1][1];
+      dst_ptr += dst_stride;
+    } while (--y != 0);
+    x += 8;
+  } while (x < width);
+}
+
+// If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in
+// the end of each row. It is safe to overwrite the output as it will not be
+// part of the visible frame.
+void SelfGuidedFilter_SSE4_1(const void* const source, void* const dest,
+                             const RestorationUnitInfo& restoration_info,
+                             const ptrdiff_t source_stride,
+                             const ptrdiff_t dest_stride, const int width,
+                             const int height,
+                             RestorationBuffer* const buffer) {
+  const int index = restoration_info.sgr_proj_info.index;
+  const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
+  const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
+  const auto* src = static_cast<const uint8_t*>(source);
+  auto* dst = static_cast<uint8_t*>(dest);
+  if (radius_pass_1 == 0) {
+    // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the
+    // following assertion.
+    assert(radius_pass_0 != 0);
+    BoxFilterProcessPass1(src, source_stride, restoration_info, width, height,
+                          kSgrScaleParameter[index][0],
+                          buffer->sgr_buffer.temp_buffer, dst, dest_stride);
+  } else if (radius_pass_0 == 0) {
+    BoxFilterProcessPass2(src, source_stride, restoration_info, width, height,
+                          kSgrScaleParameter[index][1],
+                          buffer->sgr_buffer.temp_buffer, dst, dest_stride);
+  } else {
+    BoxFilterProcess(src, source_stride, restoration_info, width, height,
+                     kSgrScaleParameter[index], buffer->sgr_buffer.temp_buffer,
+                     dst, dest_stride);
+  }
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
 #if DSP_ENABLED_8BPP_SSE4_1(WienerFilter)
   dsp->loop_restorations[0] = WienerFilter_SSE4_1;
@@ -746,7 +1789,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_SSE4_1
+#else  // !LIBGAV1_ENABLE_SSE4_1
 namespace libgav1 {
 namespace dsp {
 
diff --git a/libgav1/src/dsp/x86/loop_restoration_sse4.h b/libgav1/src/dsp/x86/loop_restoration_sse4.h
index 38094c7..e11f35a 100644
--- a/libgav1/src/dsp/x86/loop_restoration_sse4.h
+++ b/libgav1/src/dsp/x86/loop_restoration_sse4.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_X86_LOOP_RESTORATION_SSE4_H_
 #define LIBGAV1_SRC_DSP_X86_LOOP_RESTORATION_SSE4_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -35,11 +35,11 @@
 #if LIBGAV1_ENABLE_SSE4_1
 
 #ifndef LIBGAV1_Dsp8bpp_WienerFilter
-#define LIBGAV1_Dsp8bpp_WienerFilter LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_WienerFilter LIBGAV1_CPU_SSE4_1
 #endif
 
 #ifndef LIBGAV1_Dsp8bpp_SelfGuidedFilter
-#define LIBGAV1_Dsp8bpp_SelfGuidedFilter LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_SelfGuidedFilter LIBGAV1_CPU_SSE4_1
 #endif
 
 #endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/mask_blend_sse4.cc b/libgav1/src/dsp/x86/mask_blend_sse4.cc
new file mode 100644
index 0000000..76d3811
--- /dev/null
+++ b/libgav1/src/dsp/x86/mask_blend_sse4.cc
@@ -0,0 +1,450 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/mask_blend.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_SSE4_1
+
+#include <smmintrin.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/x86/common_sse4.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+// Width can only be 4 when it is subsampled from a block of width 8, hence
+// subsampling_x is always 1 when this function is called.
+template <int subsampling_x, int subsampling_y>
+inline __m128i GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) {
+  if (subsampling_x == 1) {
+    const __m128i mask_val_0 = _mm_cvtepu8_epi16(LoadLo8(mask));
+    const __m128i mask_val_1 =
+        _mm_cvtepu8_epi16(LoadLo8(mask + (mask_stride << subsampling_y)));
+    __m128i subsampled_mask = _mm_hadd_epi16(mask_val_0, mask_val_1);
+    if (subsampling_y == 1) {
+      const __m128i next_mask_val_0 =
+          _mm_cvtepu8_epi16(LoadLo8(mask + mask_stride));
+      const __m128i next_mask_val_1 =
+          _mm_cvtepu8_epi16(LoadLo8(mask + mask_stride * 3));
+      subsampled_mask = _mm_add_epi16(
+          subsampled_mask, _mm_hadd_epi16(next_mask_val_0, next_mask_val_1));
+    }
+    return RightShiftWithRounding_U16(subsampled_mask, 1 + subsampling_y);
+  }
+  const __m128i mask_val_0 = Load4(mask);
+  const __m128i mask_val_1 = Load4(mask + mask_stride);
+  return _mm_cvtepu8_epi16(
+      _mm_or_si128(mask_val_0, _mm_slli_si128(mask_val_1, 4)));
+}
+
+// This function returns a 16-bit packed mask to fit in _mm_madd_epi16.
+// 16-bit is also the lowest packing for hadd, but without subsampling there is
+// an unfortunate conversion required.
+template <int subsampling_x, int subsampling_y>
+inline __m128i GetMask8(const uint8_t* mask, ptrdiff_t stride) {
+  if (subsampling_x == 1) {
+    const __m128i row_vals = LoadUnaligned16(mask);
+
+    const __m128i mask_val_0 = _mm_cvtepu8_epi16(row_vals);
+    const __m128i mask_val_1 = _mm_cvtepu8_epi16(_mm_srli_si128(row_vals, 8));
+    __m128i subsampled_mask = _mm_hadd_epi16(mask_val_0, mask_val_1);
+
+    if (subsampling_y == 1) {
+      const __m128i next_row_vals = LoadUnaligned16(mask + stride);
+      const __m128i next_mask_val_0 = _mm_cvtepu8_epi16(next_row_vals);
+      const __m128i next_mask_val_1 =
+          _mm_cvtepu8_epi16(_mm_srli_si128(next_row_vals, 8));
+      subsampled_mask = _mm_add_epi16(
+          subsampled_mask, _mm_hadd_epi16(next_mask_val_0, next_mask_val_1));
+    }
+    return RightShiftWithRounding_U16(subsampled_mask, 1 + subsampling_y);
+  }
+  assert(subsampling_y == 0 && subsampling_x == 0);
+  const __m128i mask_val = LoadLo8(mask);
+  return _mm_cvtepu8_epi16(mask_val);
+}
+
+// This version returns 8-bit packed values to fit in _mm_maddubs_epi16 because,
+// when is_inter_intra is true, the prediction values are brought to 8-bit
+// packing as well.
+template <int subsampling_x, int subsampling_y>
+inline __m128i GetInterIntraMask8(const uint8_t* mask, ptrdiff_t stride) {
+  if (subsampling_x == 1) {
+    const __m128i row_vals = LoadUnaligned16(mask);
+
+    const __m128i mask_val_0 = _mm_cvtepu8_epi16(row_vals);
+    const __m128i mask_val_1 = _mm_cvtepu8_epi16(_mm_srli_si128(row_vals, 8));
+    __m128i subsampled_mask = _mm_hadd_epi16(mask_val_0, mask_val_1);
+
+    if (subsampling_y == 1) {
+      const __m128i next_row_vals = LoadUnaligned16(mask + stride);
+      const __m128i next_mask_val_0 = _mm_cvtepu8_epi16(next_row_vals);
+      const __m128i next_mask_val_1 =
+          _mm_cvtepu8_epi16(_mm_srli_si128(next_row_vals, 8));
+      subsampled_mask = _mm_add_epi16(
+          subsampled_mask, _mm_hadd_epi16(next_mask_val_0, next_mask_val_1));
+    }
+    const __m128i ret =
+        RightShiftWithRounding_U16(subsampled_mask, 1 + subsampling_y);
+    return _mm_packus_epi16(ret, ret);
+  }
+  assert(subsampling_y == 0 && subsampling_x == 0);
+  // Unfortunately there is no shift operation for 8-bit packing, or else we
+  // could return everything with 8-bit packing.
+  const __m128i mask_val = LoadLo8(mask);
+  return mask_val;
+}
+
+inline void WriteMaskBlendLine4x2(const int16_t* const pred_0,
+                                  const int16_t* const pred_1,
+                                  const __m128i pred_mask_0,
+                                  const __m128i pred_mask_1, uint8_t* dst,
+                                  const ptrdiff_t dst_stride) {
+  const __m128i pred_val_0_lo = LoadLo8(pred_0);
+  const __m128i pred_val_0 = LoadHi8(pred_val_0_lo, pred_0 + 4);
+  const __m128i pred_val_1_lo = LoadLo8(pred_1);
+  const __m128i pred_val_1 = LoadHi8(pred_val_1_lo, pred_1 + 4);
+  const __m128i mask_lo = _mm_unpacklo_epi16(pred_mask_0, pred_mask_1);
+  const __m128i mask_hi = _mm_unpackhi_epi16(pred_mask_0, pred_mask_1);
+  const __m128i pred_lo = _mm_unpacklo_epi16(pred_val_0, pred_val_1);
+  const __m128i pred_hi = _mm_unpackhi_epi16(pred_val_0, pred_val_1);
+
+  // int res = (mask_value * prediction_0[x] +
+  //      (64 - mask_value) * prediction_1[x]) >> 6;
+  const __m128i compound_pred_lo = _mm_madd_epi16(pred_lo, mask_lo);
+  const __m128i compound_pred_hi = _mm_madd_epi16(pred_hi, mask_hi);
+  const __m128i compound_pred = _mm_packus_epi32(
+      _mm_srli_epi32(compound_pred_lo, 6), _mm_srli_epi32(compound_pred_hi, 6));
+
+  // dst[x] = static_cast<Pixel>(
+  //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
+  //           (1 << kBitdepth8) - 1));
+  const __m128i result = RightShiftWithRounding_S16(compound_pred, 4);
+  const __m128i res = _mm_packus_epi16(result, result);
+  Store4(dst, res);
+  Store4(dst + dst_stride, _mm_srli_si128(res, 4));
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void MaskBlending4x4_SSE4(const int16_t* pred_0, const int16_t* pred_1,
+                                 const uint8_t* mask,
+                                 const ptrdiff_t mask_stride, uint8_t* dst,
+                                 const ptrdiff_t dst_stride) {
+  const __m128i mask_inverter = _mm_set1_epi16(64);
+  __m128i pred_mask_0 =
+      GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+  __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
+  WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+                        dst_stride);
+  pred_0 += 4 << 1;
+  pred_1 += 4 << 1;
+  mask += mask_stride << (1 + subsampling_y);
+  dst += dst_stride << 1;
+
+  pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+  pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
+  WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+                        dst_stride);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void MaskBlending4xH_SSE4(const int16_t* pred_0, const int16_t* pred_1,
+                                 const uint8_t* const mask_ptr,
+                                 const ptrdiff_t mask_stride, const int height,
+                                 uint8_t* dst, const ptrdiff_t dst_stride) {
+  const uint8_t* mask = mask_ptr;
+  if (height == 4) {
+    MaskBlending4x4_SSE4<subsampling_x, subsampling_y>(
+        pred_0, pred_1, mask, mask_stride, dst, dst_stride);
+    return;
+  }
+  const __m128i mask_inverter = _mm_set1_epi16(64);
+  int y = 0;
+  do {
+    __m128i pred_mask_0 =
+        GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+    __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
+
+    WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+                          dst_stride);
+    pred_0 += 4 << 1;
+    pred_1 += 4 << 1;
+    mask += mask_stride << (1 + subsampling_y);
+    dst += dst_stride << 1;
+
+    pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+    pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
+    WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+                          dst_stride);
+    pred_0 += 4 << 1;
+    pred_1 += 4 << 1;
+    mask += mask_stride << (1 + subsampling_y);
+    dst += dst_stride << 1;
+
+    pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+    pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
+    WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+                          dst_stride);
+    pred_0 += 4 << 1;
+    pred_1 += 4 << 1;
+    mask += mask_stride << (1 + subsampling_y);
+    dst += dst_stride << 1;
+
+    pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+    pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
+    WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
+                          dst_stride);
+    pred_0 += 4 << 1;
+    pred_1 += 4 << 1;
+    mask += mask_stride << (1 + subsampling_y);
+    dst += dst_stride << 1;
+    y += 8;
+  } while (y < height);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void MaskBlend_SSE4(const void* prediction_0, const void* prediction_1,
+                           const ptrdiff_t /*prediction_stride_1*/,
+                           const uint8_t* const mask_ptr,
+                           const ptrdiff_t mask_stride, const int width,
+                           const int height, void* dest,
+                           const ptrdiff_t dst_stride) {
+  auto* dst = static_cast<uint8_t*>(dest);
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  const ptrdiff_t pred_stride_0 = width;
+  const ptrdiff_t pred_stride_1 = width;
+  if (width == 4) {
+    MaskBlending4xH_SSE4<subsampling_x, subsampling_y>(
+        pred_0, pred_1, mask_ptr, mask_stride, height, dst, dst_stride);
+    return;
+  }
+  const uint8_t* mask = mask_ptr;
+  const __m128i mask_inverter = _mm_set1_epi16(64);
+  int y = 0;
+  do {
+    int x = 0;
+    do {
+      const __m128i pred_mask_0 = GetMask8<subsampling_x, subsampling_y>(
+          mask + (x << subsampling_x), mask_stride);
+      // 64 - mask
+      const __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
+      const __m128i mask_lo = _mm_unpacklo_epi16(pred_mask_0, pred_mask_1);
+      const __m128i mask_hi = _mm_unpackhi_epi16(pred_mask_0, pred_mask_1);
+
+      const __m128i pred_val_0 = LoadAligned16(pred_0 + x);
+      const __m128i pred_val_1 = LoadAligned16(pred_1 + x);
+      const __m128i pred_lo = _mm_unpacklo_epi16(pred_val_0, pred_val_1);
+      const __m128i pred_hi = _mm_unpackhi_epi16(pred_val_0, pred_val_1);
+      // int res = (mask_value * prediction_0[x] +
+      //      (64 - mask_value) * prediction_1[x]) >> 6;
+      const __m128i compound_pred_lo = _mm_madd_epi16(pred_lo, mask_lo);
+      const __m128i compound_pred_hi = _mm_madd_epi16(pred_hi, mask_hi);
+
+      const __m128i res = _mm_packus_epi32(_mm_srli_epi32(compound_pred_lo, 6),
+                                           _mm_srli_epi32(compound_pred_hi, 6));
+      // dst[x] = static_cast<Pixel>(
+      //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
+      //           (1 << kBitdepth8) - 1));
+      const __m128i result = RightShiftWithRounding_S16(res, 4);
+      StoreLo8(dst + x, _mm_packus_epi16(result, result));
+
+      x += 8;
+    } while (x < width);
+    dst += dst_stride;
+    pred_0 += pred_stride_0;
+    pred_1 += pred_stride_1;
+    mask += mask_stride << subsampling_y;
+  } while (++y < height);
+}
+
+inline void InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t* const pred_0,
+                                                uint8_t* const pred_1,
+                                                const ptrdiff_t pred_stride_1,
+                                                const __m128i pred_mask_0,
+                                                const __m128i pred_mask_1) {
+  const __m128i pred_mask = _mm_unpacklo_epi8(pred_mask_0, pred_mask_1);
+
+  __m128i pred_val_0 = Load4(pred_0);
+  pred_val_0 = _mm_or_si128(_mm_slli_si128(Load4(pred_0 + 4), 4), pred_val_0);
+  // TODO(b/150326556): One load.
+  __m128i pred_val_1 = Load4(pred_1);
+  pred_val_1 = _mm_or_si128(_mm_slli_si128(Load4(pred_1 + pred_stride_1), 4),
+                            pred_val_1);
+  const __m128i pred = _mm_unpacklo_epi8(pred_val_0, pred_val_1);
+  // int res = (mask_value * prediction_1[x] +
+  //      (64 - mask_value) * prediction_0[x]) >> 6;
+  const __m128i compound_pred = _mm_maddubs_epi16(pred, pred_mask);
+  const __m128i result = RightShiftWithRounding_U16(compound_pred, 6);
+  const __m128i res = _mm_packus_epi16(result, result);
+
+  Store4(pred_1, res);
+  Store4(pred_1 + pred_stride_1, _mm_srli_si128(res, 4));
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void InterIntraMaskBlending8bpp4x4_SSE4(const uint8_t* pred_0,
+                                               uint8_t* pred_1,
+                                               const ptrdiff_t pred_stride_1,
+                                               const uint8_t* mask,
+                                               const ptrdiff_t mask_stride) {
+  const __m128i mask_inverter = _mm_set1_epi8(64);
+  const __m128i pred_mask_u16_first =
+      GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+  mask += mask_stride << (1 + subsampling_y);
+  const __m128i pred_mask_u16_second =
+      GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+  mask += mask_stride << (1 + subsampling_y);
+  __m128i pred_mask_1 =
+      _mm_packus_epi16(pred_mask_u16_first, pred_mask_u16_second);
+  __m128i pred_mask_0 = _mm_sub_epi8(mask_inverter, pred_mask_1);
+  InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1,
+                                      pred_mask_0, pred_mask_1);
+  pred_0 += 4 << 1;
+  pred_1 += pred_stride_1 << 1;
+
+  pred_mask_1 = _mm_srli_si128(pred_mask_1, 8);
+  pred_mask_0 = _mm_sub_epi8(mask_inverter, pred_mask_1);
+  InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1,
+                                      pred_mask_0, pred_mask_1);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline void InterIntraMaskBlending8bpp4xH_SSE4(const uint8_t* pred_0,
+                                               uint8_t* pred_1,
+                                               const ptrdiff_t pred_stride_1,
+                                               const uint8_t* const mask_ptr,
+                                               const ptrdiff_t mask_stride,
+                                               const int height) {
+  const uint8_t* mask = mask_ptr;
+  if (height == 4) {
+    InterIntraMaskBlending8bpp4x4_SSE4<subsampling_x, subsampling_y>(
+        pred_0, pred_1, pred_stride_1, mask, mask_stride);
+    return;
+  }
+  int y = 0;
+  do {
+    InterIntraMaskBlending8bpp4x4_SSE4<subsampling_x, subsampling_y>(
+        pred_0, pred_1, pred_stride_1, mask, mask_stride);
+    pred_0 += 4 << 2;
+    pred_1 += pred_stride_1 << 2;
+    mask += mask_stride << (2 + subsampling_y);
+
+    InterIntraMaskBlending8bpp4x4_SSE4<subsampling_x, subsampling_y>(
+        pred_0, pred_1, pred_stride_1, mask, mask_stride);
+    pred_0 += 4 << 2;
+    pred_1 += pred_stride_1 << 2;
+    mask += mask_stride << (2 + subsampling_y);
+    y += 8;
+  } while (y < height);
+}
+
+template <int subsampling_x, int subsampling_y>
+void InterIntraMaskBlend8bpp_SSE4(const uint8_t* prediction_0,
+                                  uint8_t* prediction_1,
+                                  const ptrdiff_t prediction_stride_1,
+                                  const uint8_t* const mask_ptr,
+                                  const ptrdiff_t mask_stride, const int width,
+                                  const int height) {
+  if (width == 4) {
+    InterIntraMaskBlending8bpp4xH_SSE4<subsampling_x, subsampling_y>(
+        prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride,
+        height);
+    return;
+  }
+  const uint8_t* mask = mask_ptr;
+  const __m128i mask_inverter = _mm_set1_epi8(64);
+  int y = 0;
+  do {
+    int x = 0;
+    do {
+      const __m128i pred_mask_1 =
+          GetInterIntraMask8<subsampling_x, subsampling_y>(
+              mask + (x << subsampling_x), mask_stride);
+      // 64 - mask
+      const __m128i pred_mask_0 = _mm_sub_epi8(mask_inverter, pred_mask_1);
+      const __m128i pred_mask = _mm_unpacklo_epi8(pred_mask_0, pred_mask_1);
+
+      const __m128i pred_val_0 = LoadLo8(prediction_0 + x);
+      const __m128i pred_val_1 = LoadLo8(prediction_1 + x);
+      const __m128i pred = _mm_unpacklo_epi8(pred_val_0, pred_val_1);
+      // int res = (mask_value * prediction_1[x] +
+      //      (64 - mask_value) * prediction_0[x]) >> 6;
+      const __m128i compound_pred = _mm_maddubs_epi16(pred, pred_mask);
+      const __m128i result = RightShiftWithRounding_U16(compound_pred, 6);
+      const __m128i res = _mm_packus_epi16(result, result);
+
+      StoreLo8(prediction_1 + x, res);
+
+      x += 8;
+    } while (x < width);
+    prediction_0 += width;
+    prediction_1 += prediction_stride_1;
+    mask += mask_stride << subsampling_y;
+  } while (++y < height);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+#if DSP_ENABLED_8BPP_SSE4_1(MaskBlend444)
+  dsp->mask_blend[0][0] = MaskBlend_SSE4<0, 0>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(MaskBlend422)
+  dsp->mask_blend[1][0] = MaskBlend_SSE4<1, 0>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(MaskBlend420)
+  dsp->mask_blend[2][0] = MaskBlend_SSE4<1, 1>;
+#endif
+  // The is_inter_intra index of mask_blend[][] is replaced by
+  // inter_intra_mask_blend_8bpp[] in 8-bit.
+#if DSP_ENABLED_8BPP_SSE4_1(InterIntraMaskBlend8bpp444)
+  dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_SSE4<0, 0>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(InterIntraMaskBlend8bpp422)
+  dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_SSE4<1, 0>;
+#endif
+#if DSP_ENABLED_8BPP_SSE4_1(InterIntraMaskBlend8bpp420)
+  dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_SSE4<1, 1>;
+#endif
+}
+
+}  // namespace
+}  // namespace low_bitdepth
+
+void MaskBlendInit_SSE4_1() { low_bitdepth::Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else  // !LIBGAV1_ENABLE_SSE4_1
+
+namespace libgav1 {
+namespace dsp {
+
+void MaskBlendInit_SSE4_1() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/mask_blend_sse4.h b/libgav1/src/dsp/x86/mask_blend_sse4.h
new file mode 100644
index 0000000..cfd5e9a
--- /dev/null
+++ b/libgav1/src/dsp/x86/mask_blend_sse4.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_X86_MASK_BLEND_SSE4_H_
+#define LIBGAV1_SRC_DSP_X86_MASK_BLEND_SSE4_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::mask_blend. This function is not thread-safe.
+void MaskBlendInit_SSE4_1();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_SSE4_1
+#define LIBGAV1_Dsp8bpp_MaskBlend444 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_MaskBlend422 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_MaskBlend420 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp444 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp422 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp420 LIBGAV1_CPU_SSE4_1
+#endif  // LIBGAV1_ENABLE_SSE4_1
+
+#endif  // LIBGAV1_SRC_DSP_X86_MASK_BLEND_SSE4_H_
diff --git a/libgav1/src/dsp/x86/motion_field_projection_sse4.cc b/libgav1/src/dsp/x86/motion_field_projection_sse4.cc
new file mode 100644
index 0000000..1875198
--- /dev/null
+++ b/libgav1/src/dsp/x86/motion_field_projection_sse4.cc
@@ -0,0 +1,397 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/motion_field_projection.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_SSE4_1
+
+#include <smmintrin.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/x86/common_sse4.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+#include "src/utils/types.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+inline __m128i LoadDivision(const __m128i division_table,
+                            const __m128i reference_offset) {
+  const __m128i kOne = _mm_set1_epi16(0x0100);
+  const __m128i t = _mm_add_epi8(reference_offset, reference_offset);
+  const __m128i tt = _mm_unpacklo_epi8(t, t);
+  const __m128i idx = _mm_add_epi8(tt, kOne);
+  return _mm_shuffle_epi8(division_table, idx);
+}
+
+inline __m128i MvProjection(const __m128i mv, const __m128i denominator,
+                            const int numerator) {
+  const __m128i m0 = _mm_madd_epi16(mv, denominator);
+  const __m128i m = _mm_mullo_epi32(m0, _mm_set1_epi32(numerator));
+  // Add the sign (0 or -1) to round towards zero.
+  const __m128i sign = _mm_srai_epi32(m, 31);
+  const __m128i add_sign = _mm_add_epi32(m, sign);
+  const __m128i sum = _mm_add_epi32(add_sign, _mm_set1_epi32(1 << 13));
+  return _mm_srai_epi32(sum, 14);
+}
+
+inline __m128i MvProjectionClip(const __m128i mv, const __m128i denominator,
+                                const int numerator) {
+  const __m128i mv0 = _mm_unpacklo_epi16(mv, _mm_setzero_si128());
+  const __m128i mv1 = _mm_unpackhi_epi16(mv, _mm_setzero_si128());
+  const __m128i denorm0 = _mm_unpacklo_epi16(denominator, _mm_setzero_si128());
+  const __m128i denorm1 = _mm_unpackhi_epi16(denominator, _mm_setzero_si128());
+  const __m128i s0 = MvProjection(mv0, denorm0, numerator);
+  const __m128i s1 = MvProjection(mv1, denorm1, numerator);
+  const __m128i projection = _mm_packs_epi32(s0, s1);
+  const __m128i projection_mv_clamp = _mm_set1_epi16(kProjectionMvClamp);
+  const __m128i projection_mv_clamp_negative =
+      _mm_set1_epi16(-kProjectionMvClamp);
+  const __m128i clamp = _mm_min_epi16(projection, projection_mv_clamp);
+  return _mm_max_epi16(clamp, projection_mv_clamp_negative);
+}
+
+inline __m128i Project_SSE4_1(const __m128i delta, const __m128i dst_sign) {
+  // Add 63 to negative delta so that it shifts towards zero.
+  const __m128i delta_sign = _mm_srai_epi16(delta, 15);
+  const __m128i delta_sign_63 = _mm_srli_epi16(delta_sign, 10);
+  const __m128i delta_adjust = _mm_add_epi16(delta, delta_sign_63);
+  const __m128i offset0 = _mm_srai_epi16(delta_adjust, 6);
+  const __m128i offset1 = _mm_xor_si128(offset0, dst_sign);
+  return _mm_sub_epi16(offset1, dst_sign);
+}
+
+inline void GetPosition(
+    const __m128i division_table, const MotionVector* const mv,
+    const int numerator, const int x8_start, const int x8_end, const int x8,
+    const __m128i& r_offsets, const __m128i& source_reference_type8,
+    const __m128i& skip_r, const __m128i& y8_floor8, const __m128i& y8_ceiling8,
+    const __m128i& d_sign, const int delta, __m128i* const r,
+    __m128i* const position_xy, int64_t* const skip_64, __m128i mvs[2]) {
+  const auto* const mv_int = reinterpret_cast<const int32_t*>(mv + x8);
+  *r = _mm_shuffle_epi8(r_offsets, source_reference_type8);
+  const __m128i denorm = LoadDivision(division_table, source_reference_type8);
+  __m128i projection_mv[2];
+  mvs[0] = LoadUnaligned16(mv_int + 0);
+  mvs[1] = LoadUnaligned16(mv_int + 4);
+  // Deinterlace x and y components
+  const __m128i kShuffle =
+      _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15);
+  const __m128i mv0 = _mm_shuffle_epi8(mvs[0], kShuffle);
+  const __m128i mv1 = _mm_shuffle_epi8(mvs[1], kShuffle);
+  const __m128i mv_y = _mm_unpacklo_epi64(mv0, mv1);
+  const __m128i mv_x = _mm_unpackhi_epi64(mv0, mv1);
+  // numerator could be 0.
+  projection_mv[0] = MvProjectionClip(mv_y, denorm, numerator);
+  projection_mv[1] = MvProjectionClip(mv_x, denorm, numerator);
+  // Do not update the motion vector if the block position is not valid or
+  // if position_x8 is outside the current range of x8_start and x8_end.
+  // Note that position_y8 will always be within the range of y8_start and
+  // y8_end.
+  // After subtracting the base, valid projections are within 8-bit.
+  const __m128i position_y = Project_SSE4_1(projection_mv[0], d_sign);
+  const __m128i position_x = Project_SSE4_1(projection_mv[1], d_sign);
+  const __m128i positions = _mm_packs_epi16(position_x, position_y);
+  const __m128i k01234567 =
+      _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0);
+  *position_xy = _mm_add_epi8(positions, k01234567);
+  const int x8_floor = std::max(
+      x8_start - x8, delta - kProjectionMvMaxHorizontalOffset);  // [-8, 8]
+  const int x8_ceiling =
+      std::min(x8_end - x8, delta + 8 + kProjectionMvMaxHorizontalOffset) -
+      1;  // [-1, 15]
+  const __m128i x8_floor8 = _mm_set1_epi8(x8_floor);
+  const __m128i x8_ceiling8 = _mm_set1_epi8(x8_ceiling);
+  const __m128i floor_xy = _mm_unpacklo_epi64(x8_floor8, y8_floor8);
+  const __m128i ceiling_xy = _mm_unpacklo_epi64(x8_ceiling8, y8_ceiling8);
+  const __m128i underflow = _mm_cmplt_epi8(*position_xy, floor_xy);
+  const __m128i overflow = _mm_cmpgt_epi8(*position_xy, ceiling_xy);
+  const __m128i out = _mm_or_si128(underflow, overflow);
+  const __m128i skip_low = _mm_or_si128(skip_r, out);
+  const __m128i skip = _mm_or_si128(skip_low, _mm_srli_si128(out, 8));
+  StoreLo8(skip_64, skip);
+}
+
+template <int idx>
+inline void Store(const __m128i position, const __m128i reference_offset,
+                  const __m128i mv, int8_t* dst_reference_offset,
+                  MotionVector* dst_mv) {
+  const ptrdiff_t offset =
+      static_cast<int16_t>(_mm_extract_epi16(position, idx));
+  if ((idx & 3) == 0) {
+    dst_mv[offset].mv32 = _mm_cvtsi128_si32(mv);
+  } else {
+    dst_mv[offset].mv32 = _mm_extract_epi32(mv, idx & 3);
+  }
+  dst_reference_offset[offset] = _mm_extract_epi8(reference_offset, idx);
+}
+
+template <int idx>
+inline void CheckStore(const int8_t* skips, const __m128i position,
+                       const __m128i reference_offset, const __m128i mv,
+                       int8_t* dst_reference_offset, MotionVector* dst_mv) {
+  if (skips[idx] == 0) {
+    Store<idx>(position, reference_offset, mv, dst_reference_offset, dst_mv);
+  }
+}
+
+// 7.9.2.
+void MotionFieldProjectionKernel_SSE4_1(
+    const ReferenceInfo& reference_info,
+    const int reference_to_current_with_sign, const int dst_sign,
+    const int y8_start, const int y8_end, const int x8_start, const int x8_end,
+    TemporalMotionField* const motion_field) {
+  const ptrdiff_t stride = motion_field->mv.columns();
+  // The column range has to be offset by kProjectionMvMaxHorizontalOffset since
+  // coordinates in that range could end up being position_x8 because of
+  // projection.
+  const int adjusted_x8_start =
+      std::max(x8_start - kProjectionMvMaxHorizontalOffset, 0);
+  const int adjusted_x8_end = std::min(
+      x8_end + kProjectionMvMaxHorizontalOffset, static_cast<int>(stride));
+  const int adjusted_x8_end8 = adjusted_x8_end & ~7;
+  const int leftover = adjusted_x8_end - adjusted_x8_end8;
+  const int8_t* const reference_offsets =
+      reference_info.relative_distance_to.data();
+  const bool* const skip_references = reference_info.skip_references.data();
+  const int16_t* const projection_divisions =
+      reference_info.projection_divisions.data();
+  const ReferenceFrameType* source_reference_types =
+      &reference_info.motion_field_reference_frame[y8_start][0];
+  const MotionVector* mv = &reference_info.motion_field_mv[y8_start][0];
+  int8_t* dst_reference_offset = motion_field->reference_offset[y8_start];
+  MotionVector* dst_mv = motion_field->mv[y8_start];
+  const __m128i d_sign = _mm_set1_epi16(dst_sign);
+
+  static_assert(sizeof(int8_t) == sizeof(bool), "");
+  static_assert(sizeof(int8_t) == sizeof(ReferenceFrameType), "");
+  static_assert(sizeof(int32_t) == sizeof(MotionVector), "");
+  assert(dst_sign == 0 || dst_sign == -1);
+  assert(stride == motion_field->reference_offset.columns());
+  assert((y8_start & 7) == 0);
+  assert((adjusted_x8_start & 7) == 0);
+  // The final position calculation is represented with int16_t. Valid
+  // position_y8 from its base is at most 7. After considering the horizontal
+  // offset which is at most |stride - 1|, we have the following assertion,
+  // which means this optimization works for frame width up to 32K (each
+  // position is a 8x8 block).
+  assert(8 * stride <= 32768);
+  const __m128i skip_reference = LoadLo8(skip_references);
+  const __m128i r_offsets = LoadLo8(reference_offsets);
+  const __m128i division_table = LoadUnaligned16(projection_divisions);
+
+  int y8 = y8_start;
+  do {
+    const int y8_floor = (y8 & ~7) - y8;                             // [-7, 0]
+    const int y8_ceiling = std::min(y8_end - y8, y8_floor + 8) - 1;  // [0, 7]
+    const __m128i y8_floor8 = _mm_set1_epi8(y8_floor);
+    const __m128i y8_ceiling8 = _mm_set1_epi8(y8_ceiling);
+    int x8;
+
+    for (x8 = adjusted_x8_start; x8 < adjusted_x8_end8; x8 += 8) {
+      const __m128i source_reference_type8 =
+          LoadLo8(source_reference_types + x8);
+      const __m128i skip_r =
+          _mm_shuffle_epi8(skip_reference, source_reference_type8);
+      int64_t early_skip;
+      StoreLo8(&early_skip, skip_r);
+      // Early termination #1 if all are skips. Chance is typically ~30-40%.
+      if (early_skip == -1) continue;
+      int64_t skip_64;
+      __m128i r, position_xy, mvs[2];
+      GetPosition(division_table, mv, reference_to_current_with_sign, x8_start,
+                  x8_end, x8, r_offsets, source_reference_type8, skip_r,
+                  y8_floor8, y8_ceiling8, d_sign, 0, &r, &position_xy, &skip_64,
+                  mvs);
+      // Early termination #2 if all are skips.
+      // Chance is typically ~15-25% after Early termination #1.
+      if (skip_64 == -1) continue;
+      const __m128i p_y = _mm_cvtepi8_epi16(_mm_srli_si128(position_xy, 8));
+      const __m128i p_x = _mm_cvtepi8_epi16(position_xy);
+      const __m128i p_y_offset = _mm_mullo_epi16(p_y, _mm_set1_epi16(stride));
+      const __m128i pos = _mm_add_epi16(p_y_offset, p_x);
+      const __m128i position = _mm_add_epi16(pos, _mm_set1_epi16(x8));
+      if (skip_64 == 0) {
+        // Store all. Chance is typically ~70-85% after Early termination #2.
+        Store<0>(position, r, mvs[0], dst_reference_offset, dst_mv);
+        Store<1>(position, r, mvs[0], dst_reference_offset, dst_mv);
+        Store<2>(position, r, mvs[0], dst_reference_offset, dst_mv);
+        Store<3>(position, r, mvs[0], dst_reference_offset, dst_mv);
+        Store<4>(position, r, mvs[1], dst_reference_offset, dst_mv);
+        Store<5>(position, r, mvs[1], dst_reference_offset, dst_mv);
+        Store<6>(position, r, mvs[1], dst_reference_offset, dst_mv);
+        Store<7>(position, r, mvs[1], dst_reference_offset, dst_mv);
+      } else {
+        // Check and store each.
+        // Chance is typically ~15-30% after Early termination #2.
+        // The compiler is smart enough to not create the local buffer skips[].
+        int8_t skips[8];
+        memcpy(skips, &skip_64, sizeof(skips));
+        CheckStore<0>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
+        CheckStore<1>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
+        CheckStore<2>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
+        CheckStore<3>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
+        CheckStore<4>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
+        CheckStore<5>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
+        CheckStore<6>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
+        CheckStore<7>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
+      }
+    }
+
+    // The following leftover processing cannot be moved out of the do...while
+    // loop. Doing so may change the result storing orders of the same position.
+    if (leftover > 0) {
+      // Use SIMD only when leftover is at least 4, and there are at least 8
+      // elements in a row.
+      if (leftover >= 4 && adjusted_x8_start < adjusted_x8_end8) {
+        // Process the last 8 elements to avoid loading invalid memory. Some
+        // elements may have been processed in the above loop, which is OK.
+        const int delta = 8 - leftover;
+        x8 = adjusted_x8_end - 8;
+        const __m128i source_reference_type8 =
+            LoadLo8(source_reference_types + x8);
+        const __m128i skip_r =
+            _mm_shuffle_epi8(skip_reference, source_reference_type8);
+        int64_t early_skip;
+        StoreLo8(&early_skip, skip_r);
+        // Early termination #1 if all are skips.
+        if (early_skip != -1) {
+          int64_t skip_64;
+          __m128i r, position_xy, mvs[2];
+          GetPosition(division_table, mv, reference_to_current_with_sign,
+                      x8_start, x8_end, x8, r_offsets, source_reference_type8,
+                      skip_r, y8_floor8, y8_ceiling8, d_sign, delta, &r,
+                      &position_xy, &skip_64, mvs);
+          // Early termination #2 if all are skips.
+          if (skip_64 != -1) {
+            const __m128i p_y =
+                _mm_cvtepi8_epi16(_mm_srli_si128(position_xy, 8));
+            const __m128i p_x = _mm_cvtepi8_epi16(position_xy);
+            const __m128i p_y_offset =
+                _mm_mullo_epi16(p_y, _mm_set1_epi16(stride));
+            const __m128i pos = _mm_add_epi16(p_y_offset, p_x);
+            const __m128i position = _mm_add_epi16(pos, _mm_set1_epi16(x8));
+            // Store up to 7 elements since leftover is at most 7.
+            if (skip_64 == 0) {
+              // Store all.
+              Store<1>(position, r, mvs[0], dst_reference_offset, dst_mv);
+              Store<2>(position, r, mvs[0], dst_reference_offset, dst_mv);
+              Store<3>(position, r, mvs[0], dst_reference_offset, dst_mv);
+              Store<4>(position, r, mvs[1], dst_reference_offset, dst_mv);
+              Store<5>(position, r, mvs[1], dst_reference_offset, dst_mv);
+              Store<6>(position, r, mvs[1], dst_reference_offset, dst_mv);
+              Store<7>(position, r, mvs[1], dst_reference_offset, dst_mv);
+            } else {
+              // Check and store each.
+              // The compiler is smart enough to not create the local buffer
+              // skips[].
+              int8_t skips[8];
+              memcpy(skips, &skip_64, sizeof(skips));
+              CheckStore<1>(skips, position, r, mvs[0], dst_reference_offset,
+                            dst_mv);
+              CheckStore<2>(skips, position, r, mvs[0], dst_reference_offset,
+                            dst_mv);
+              CheckStore<3>(skips, position, r, mvs[0], dst_reference_offset,
+                            dst_mv);
+              CheckStore<4>(skips, position, r, mvs[1], dst_reference_offset,
+                            dst_mv);
+              CheckStore<5>(skips, position, r, mvs[1], dst_reference_offset,
+                            dst_mv);
+              CheckStore<6>(skips, position, r, mvs[1], dst_reference_offset,
+                            dst_mv);
+              CheckStore<7>(skips, position, r, mvs[1], dst_reference_offset,
+                            dst_mv);
+            }
+          }
+        }
+      } else {
+        for (; x8 < adjusted_x8_end; ++x8) {
+          const int source_reference_type = source_reference_types[x8];
+          if (skip_references[source_reference_type]) continue;
+          MotionVector projection_mv;
+          // reference_to_current_with_sign could be 0.
+          GetMvProjection(mv[x8], reference_to_current_with_sign,
+                          projection_divisions[source_reference_type],
+                          &projection_mv);
+          // Do not update the motion vector if the block position is not valid
+          // or if position_x8 is outside the current range of x8_start and
+          // x8_end. Note that position_y8 will always be within the range of
+          // y8_start and y8_end.
+          const int position_y8 = Project(0, projection_mv.mv[0], dst_sign);
+          if (position_y8 < y8_floor || position_y8 > y8_ceiling) continue;
+          const int x8_base = x8 & ~7;
+          const int x8_floor =
+              std::max(x8_start, x8_base - kProjectionMvMaxHorizontalOffset);
+          const int x8_ceiling =
+              std::min(x8_end, x8_base + 8 + kProjectionMvMaxHorizontalOffset);
+          const int position_x8 = Project(x8, projection_mv.mv[1], dst_sign);
+          if (position_x8 < x8_floor || position_x8 >= x8_ceiling) continue;
+          dst_mv[position_y8 * stride + position_x8] = mv[x8];
+          dst_reference_offset[position_y8 * stride + position_x8] =
+              reference_offsets[source_reference_type];
+        }
+      }
+    }
+
+    source_reference_types += stride;
+    mv += stride;
+    dst_reference_offset += stride;
+    dst_mv += stride;
+  } while (++y8 < y8_end);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_SSE4_1;
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_SSE4_1;
+}
+#endif
+
+}  // namespace
+
+void MotionFieldProjectionInit_SSE4_1() {
+  Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  Init10bpp();
+#endif
+}
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else  // !LIBGAV1_ENABLE_SSE4_1
+namespace libgav1 {
+namespace dsp {
+
+void MotionFieldProjectionInit_SSE4_1() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/motion_field_projection_sse4.h b/libgav1/src/dsp/x86/motion_field_projection_sse4.h
new file mode 100644
index 0000000..7828de5
--- /dev/null
+++ b/libgav1/src/dsp/x86/motion_field_projection_sse4.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_X86_MOTION_FIELD_PROJECTION_SSE4_H_
+#define LIBGAV1_SRC_DSP_X86_MOTION_FIELD_PROJECTION_SSE4_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::motion_field_projection_kernel. This function is not
+// thread-safe.
+void MotionFieldProjectionInit_SSE4_1();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_SSE4_1
+#define LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel LIBGAV1_CPU_SSE4_1
+#endif  // LIBGAV1_ENABLE_SSE4_1
+
+#endif  // LIBGAV1_SRC_DSP_X86_MOTION_FIELD_PROJECTION_SSE4_H_
diff --git a/libgav1/src/dsp/x86/motion_vector_search_sse4.cc b/libgav1/src/dsp/x86/motion_vector_search_sse4.cc
new file mode 100644
index 0000000..e49be12
--- /dev/null
+++ b/libgav1/src/dsp/x86/motion_vector_search_sse4.cc
@@ -0,0 +1,262 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/motion_vector_search.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_SSE4_1
+
+#include <smmintrin.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/x86/common_sse4.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+#include "src/utils/types.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+constexpr int kProjectionMvDivisionLookup_32bit[kMaxFrameDistance + 1] = {
+    0,    16384, 8192, 5461, 4096, 3276, 2730, 2340, 2048, 1820, 1638,
+    1489, 1365,  1260, 1170, 1092, 1024, 963,  910,  862,  819,  780,
+    744,  712,   682,  655,  630,  606,  585,  564,  546,  528};
+
+inline __m128i MvProjection(const __m128i mv, const __m128i denominator,
+                            const __m128i numerator) {
+  const __m128i m0 = _mm_madd_epi16(mv, denominator);
+  const __m128i m = _mm_mullo_epi32(m0, numerator);
+  // Add the sign (0 or -1) to round towards zero.
+  const __m128i sign = _mm_srai_epi32(m, 31);
+  const __m128i add_sign = _mm_add_epi32(m, sign);
+  const __m128i sum = _mm_add_epi32(add_sign, _mm_set1_epi32(1 << 13));
+  return _mm_srai_epi32(sum, 14);
+}
+
+inline __m128i MvProjectionClip(const __m128i mvs[2],
+                                const __m128i denominators[2],
+                                const __m128i numerator) {
+  const __m128i s0 = MvProjection(mvs[0], denominators[0], numerator);
+  const __m128i s1 = MvProjection(mvs[1], denominators[1], numerator);
+  const __m128i mv = _mm_packs_epi32(s0, s1);
+  const __m128i projection_mv_clamp = _mm_set1_epi16(kProjectionMvClamp);
+  const __m128i projection_mv_clamp_negative =
+      _mm_set1_epi16(-kProjectionMvClamp);
+  const __m128i clamp = _mm_min_epi16(mv, projection_mv_clamp);
+  return _mm_max_epi16(clamp, projection_mv_clamp_negative);
+}
+
+inline __m128i MvProjectionCompoundClip(
+    const MotionVector* const temporal_mvs,
+    const int8_t temporal_reference_offsets[2],
+    const int reference_offsets[2]) {
+  const auto* const tmvs = reinterpret_cast<const int32_t*>(temporal_mvs);
+  const __m128i temporal_mv = LoadLo8(tmvs);
+  const __m128i temporal_mv_0 = _mm_cvtepu16_epi32(temporal_mv);
+  __m128i mvs[2], denominators[2];
+  mvs[0] = _mm_unpacklo_epi64(temporal_mv_0, temporal_mv_0);
+  mvs[1] = _mm_unpackhi_epi64(temporal_mv_0, temporal_mv_0);
+  denominators[0] = _mm_set1_epi32(
+      kProjectionMvDivisionLookup[temporal_reference_offsets[0]]);
+  denominators[1] = _mm_set1_epi32(
+      kProjectionMvDivisionLookup[temporal_reference_offsets[1]]);
+  const __m128i offsets = LoadLo8(reference_offsets);
+  const __m128i numerator = _mm_unpacklo_epi32(offsets, offsets);
+  return MvProjectionClip(mvs, denominators, numerator);
+}
+
+inline __m128i MvProjectionSingleClip(
+    const MotionVector* const temporal_mvs,
+    const int8_t* const temporal_reference_offsets,
+    const int reference_offset) {
+  const auto* const tmvs = reinterpret_cast<const int16_t*>(temporal_mvs);
+  const __m128i temporal_mv = LoadAligned16(tmvs);
+  __m128i lookup = _mm_cvtsi32_si128(
+      kProjectionMvDivisionLookup_32bit[temporal_reference_offsets[0]]);
+  lookup = _mm_insert_epi32(
+      lookup, kProjectionMvDivisionLookup_32bit[temporal_reference_offsets[1]],
+      1);
+  lookup = _mm_insert_epi32(
+      lookup, kProjectionMvDivisionLookup_32bit[temporal_reference_offsets[2]],
+      2);
+  lookup = _mm_insert_epi32(
+      lookup, kProjectionMvDivisionLookup_32bit[temporal_reference_offsets[3]],
+      3);
+  __m128i mvs[2], denominators[2];
+  mvs[0] = _mm_unpacklo_epi16(temporal_mv, _mm_setzero_si128());
+  mvs[1] = _mm_unpackhi_epi16(temporal_mv, _mm_setzero_si128());
+  denominators[0] = _mm_unpacklo_epi32(lookup, lookup);
+  denominators[1] = _mm_unpackhi_epi32(lookup, lookup);
+  const __m128i numerator = _mm_set1_epi32(reference_offset);
+  return MvProjectionClip(mvs, denominators, numerator);
+}
+
+inline void LowPrecision(const __m128i mv, void* const candidate_mvs) {
+  const __m128i kRoundDownMask = _mm_set1_epi16(~1);
+  const __m128i sign = _mm_srai_epi16(mv, 15);
+  const __m128i sub_sign = _mm_sub_epi16(mv, sign);
+  const __m128i d = _mm_and_si128(sub_sign, kRoundDownMask);
+  StoreAligned16(candidate_mvs, d);
+}
+
+inline void ForceInteger(const __m128i mv, void* const candidate_mvs) {
+  const __m128i kRoundDownMask = _mm_set1_epi16(~7);
+  const __m128i sign = _mm_srai_epi16(mv, 15);
+  const __m128i mv1 = _mm_add_epi16(mv, _mm_set1_epi16(3));
+  const __m128i mv2 = _mm_sub_epi16(mv1, sign);
+  const __m128i mv3 = _mm_and_si128(mv2, kRoundDownMask);
+  StoreAligned16(candidate_mvs, mv3);
+}
+
+void MvProjectionCompoundLowPrecision_SSE4_1(
+    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const int reference_offsets[2], const int count,
+    CompoundMotionVector* candidate_mvs) {
+  // |reference_offsets| non-zero check usually equals true and is ignored.
+  // To facilitate the compilers, make a local copy of |reference_offsets|.
+  const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
+  // One more element could be calculated.
+  int i = 0;
+  do {
+    const __m128i mv = MvProjectionCompoundClip(
+        temporal_mvs + i, temporal_reference_offsets + i, offsets);
+    LowPrecision(mv, candidate_mvs + i);
+    i += 2;
+  } while (i < count);
+}
+
+void MvProjectionCompoundForceInteger_SSE4_1(
+    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const int reference_offsets[2], const int count,
+    CompoundMotionVector* candidate_mvs) {
+  // |reference_offsets| non-zero check usually equals true and is ignored.
+  // To facilitate the compilers, make a local copy of |reference_offsets|.
+  const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
+  // One more element could be calculated.
+  int i = 0;
+  do {
+    const __m128i mv = MvProjectionCompoundClip(
+        temporal_mvs + i, temporal_reference_offsets + i, offsets);
+    ForceInteger(mv, candidate_mvs + i);
+    i += 2;
+  } while (i < count);
+}
+
+void MvProjectionCompoundHighPrecision_SSE4_1(
+    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const int reference_offsets[2], const int count,
+    CompoundMotionVector* candidate_mvs) {
+  // |reference_offsets| non-zero check usually equals true and is ignored.
+  // To facilitate the compilers, make a local copy of |reference_offsets|.
+  const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
+  // One more element could be calculated.
+  int i = 0;
+  do {
+    const __m128i mv = MvProjectionCompoundClip(
+        temporal_mvs + i, temporal_reference_offsets + i, offsets);
+    StoreAligned16(candidate_mvs + i, mv);
+    i += 2;
+  } while (i < count);
+}
+
+void MvProjectionSingleLowPrecision_SSE4_1(
+    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const int reference_offset, const int count, MotionVector* candidate_mvs) {
+  // Up to three more elements could be calculated.
+  int i = 0;
+  do {
+    const __m128i mv = MvProjectionSingleClip(
+        temporal_mvs + i, temporal_reference_offsets + i, reference_offset);
+    LowPrecision(mv, candidate_mvs + i);
+    i += 4;
+  } while (i < count);
+}
+
+void MvProjectionSingleForceInteger_SSE4_1(
+    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const int reference_offset, const int count, MotionVector* candidate_mvs) {
+  // Up to three more elements could be calculated.
+  int i = 0;
+  do {
+    const __m128i mv = MvProjectionSingleClip(
+        temporal_mvs + i, temporal_reference_offsets + i, reference_offset);
+    ForceInteger(mv, candidate_mvs + i);
+    i += 4;
+  } while (i < count);
+}
+
+void MvProjectionSingleHighPrecision_SSE4_1(
+    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const int reference_offset, const int count, MotionVector* candidate_mvs) {
+  // Up to three more elements could be calculated.
+  int i = 0;
+  do {
+    const __m128i mv = MvProjectionSingleClip(
+        temporal_mvs + i, temporal_reference_offsets + i, reference_offset);
+    StoreAligned16(candidate_mvs + i, mv);
+    i += 4;
+  } while (i < count);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+  dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_SSE4_1;
+  dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_SSE4_1;
+  dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_SSE4_1;
+  dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_SSE4_1;
+  dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_SSE4_1;
+  dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_SSE4_1;
+}
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+  dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_SSE4_1;
+  dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_SSE4_1;
+  dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_SSE4_1;
+  dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_SSE4_1;
+  dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_SSE4_1;
+  dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_SSE4_1;
+}
+#endif
+
+}  // namespace
+
+void MotionVectorSearchInit_SSE4_1() {
+  Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  Init10bpp();
+#endif
+}
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else  // !LIBGAV1_ENABLE_SSE4_1
+namespace libgav1 {
+namespace dsp {
+
+void MotionVectorSearchInit_SSE4_1() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/motion_vector_search_sse4.h b/libgav1/src/dsp/x86/motion_vector_search_sse4.h
new file mode 100644
index 0000000..b8b0412
--- /dev/null
+++ b/libgav1/src/dsp/x86/motion_vector_search_sse4.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_X86_MOTION_VECTOR_SEARCH_SSE4_H_
+#define LIBGAV1_SRC_DSP_X86_MOTION_VECTOR_SEARCH_SSE4_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::mv_projection_compound and Dsp::mv_projection_single. This
+// function is not thread-safe.
+void MotionVectorSearchInit_SSE4_1();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_SSE4_1
+#define LIBGAV1_Dsp8bpp_MotionVectorSearch LIBGAV1_CPU_SSE4_1
+#endif  // LIBGAV1_ENABLE_SSE4_1
+
+#endif  // LIBGAV1_SRC_DSP_X86_MOTION_VECTOR_SEARCH_SSE4_H_
diff --git a/libgav1/src/dsp/x86/obmc_sse4.cc b/libgav1/src/dsp/x86/obmc_sse4.cc
index 5d1ca7a..a1be5ef 100644
--- a/libgav1/src/dsp/x86/obmc_sse4.cc
+++ b/libgav1/src/dsp/x86/obmc_sse4.cc
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/dsp.h"
 #include "src/dsp/obmc.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 
@@ -23,6 +23,8 @@
 #include <cstddef>
 #include <cstdint>
 
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
 #include "src/dsp/x86/common_sse4.h"
 #include "src/utils/common.h"
 #include "src/utils/constants.h"
@@ -298,7 +300,7 @@
 }
 
 void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
 #if DSP_ENABLED_8BPP_SSE4_1(ObmcVertical)
   dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendFromTop_SSE4_1;
@@ -315,7 +317,7 @@
 }  // namespace dsp
 }  // namespace libgav1
 
-#else   // !LIBGAV1_ENABLE_SSE4_1
+#else  // !LIBGAV1_ENABLE_SSE4_1
 
 namespace libgav1 {
 namespace dsp {
diff --git a/libgav1/src/dsp/x86/obmc_sse4.h b/libgav1/src/dsp/x86/obmc_sse4.h
index bc01cf3..03669ad 100644
--- a/libgav1/src/dsp/x86/obmc_sse4.h
+++ b/libgav1/src/dsp/x86/obmc_sse4.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_X86_OBMC_SSE4_H_
 #define LIBGAV1_SRC_DSP_X86_OBMC_SSE4_H_
 
-#include "src/dsp/cpu.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -33,10 +33,10 @@
 // optimization being enabled, signal the sse4 implementation should be used.
 #if LIBGAV1_ENABLE_SSE4_1
 #ifndef LIBGAV1_Dsp8bpp_ObmcVertical
-#define LIBGAV1_Dsp8bpp_ObmcVertical LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_ObmcVertical LIBGAV1_CPU_SSE4_1
 #endif
 #ifndef LIBGAV1_Dsp8bpp_ObmcHorizontal
-#define LIBGAV1_Dsp8bpp_ObmcHorizontal LIBGAV1_DSP_SSE4_1
+#define LIBGAV1_Dsp8bpp_ObmcHorizontal LIBGAV1_CPU_SSE4_1
 #endif
 #endif  // LIBGAV1_ENABLE_SSE4_1
 
diff --git a/libgav1/src/dsp/x86/super_res_sse4.cc b/libgav1/src/dsp/x86/super_res_sse4.cc
new file mode 100644
index 0000000..050bcc4
--- /dev/null
+++ b/libgav1/src/dsp/x86/super_res_sse4.cc
@@ -0,0 +1,156 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/super_res.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_SSE4_1
+
+#include <smmintrin.h>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/x86/common_sse4.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+// Upscale_Filter as defined in AV1 Section 7.16
+alignas(16) const int16_t
+    kUpscaleFilter[kSuperResFilterShifts][kSuperResFilterTaps] = {
+        {-0, 0, -0, 128, 0, -0, 0, -0},    {-0, 0, -1, 128, 2, -1, 0, -0},
+        {-0, 1, -3, 127, 4, -2, 1, -0},    {-0, 1, -4, 127, 6, -3, 1, -0},
+        {-0, 2, -6, 126, 8, -3, 1, -0},    {-0, 2, -7, 125, 11, -4, 1, -0},
+        {-1, 2, -8, 125, 13, -5, 2, -0},   {-1, 3, -9, 124, 15, -6, 2, -0},
+        {-1, 3, -10, 123, 18, -6, 2, -1},  {-1, 3, -11, 122, 20, -7, 3, -1},
+        {-1, 4, -12, 121, 22, -8, 3, -1},  {-1, 4, -13, 120, 25, -9, 3, -1},
+        {-1, 4, -14, 118, 28, -9, 3, -1},  {-1, 4, -15, 117, 30, -10, 4, -1},
+        {-1, 5, -16, 116, 32, -11, 4, -1}, {-1, 5, -16, 114, 35, -12, 4, -1},
+        {-1, 5, -17, 112, 38, -12, 4, -1}, {-1, 5, -18, 111, 40, -13, 5, -1},
+        {-1, 5, -18, 109, 43, -14, 5, -1}, {-1, 6, -19, 107, 45, -14, 5, -1},
+        {-1, 6, -19, 105, 48, -15, 5, -1}, {-1, 6, -19, 103, 51, -16, 5, -1},
+        {-1, 6, -20, 101, 53, -16, 6, -1}, {-1, 6, -20, 99, 56, -17, 6, -1},
+        {-1, 6, -20, 97, 58, -17, 6, -1},  {-1, 6, -20, 95, 61, -18, 6, -1},
+        {-2, 7, -20, 93, 64, -18, 6, -2},  {-2, 7, -20, 91, 66, -19, 6, -1},
+        {-2, 7, -20, 88, 69, -19, 6, -1},  {-2, 7, -20, 86, 71, -19, 6, -1},
+        {-2, 7, -20, 84, 74, -20, 7, -2},  {-2, 7, -20, 81, 76, -20, 7, -1},
+        {-2, 7, -20, 79, 79, -20, 7, -2},  {-1, 7, -20, 76, 81, -20, 7, -2},
+        {-2, 7, -20, 74, 84, -20, 7, -2},  {-1, 6, -19, 71, 86, -20, 7, -2},
+        {-1, 6, -19, 69, 88, -20, 7, -2},  {-1, 6, -19, 66, 91, -20, 7, -2},
+        {-2, 6, -18, 64, 93, -20, 7, -2},  {-1, 6, -18, 61, 95, -20, 6, -1},
+        {-1, 6, -17, 58, 97, -20, 6, -1},  {-1, 6, -17, 56, 99, -20, 6, -1},
+        {-1, 6, -16, 53, 101, -20, 6, -1}, {-1, 5, -16, 51, 103, -19, 6, -1},
+        {-1, 5, -15, 48, 105, -19, 6, -1}, {-1, 5, -14, 45, 107, -19, 6, -1},
+        {-1, 5, -14, 43, 109, -18, 5, -1}, {-1, 5, -13, 40, 111, -18, 5, -1},
+        {-1, 4, -12, 38, 112, -17, 5, -1}, {-1, 4, -12, 35, 114, -16, 5, -1},
+        {-1, 4, -11, 32, 116, -16, 5, -1}, {-1, 4, -10, 30, 117, -15, 4, -1},
+        {-1, 3, -9, 28, 118, -14, 4, -1},  {-1, 3, -9, 25, 120, -13, 4, -1},
+        {-1, 3, -8, 22, 121, -12, 4, -1},  {-1, 3, -7, 20, 122, -11, 3, -1},
+        {-1, 2, -6, 18, 123, -10, 3, -1},  {-0, 2, -6, 15, 124, -9, 3, -1},
+        {-0, 2, -5, 13, 125, -8, 2, -1},   {-0, 1, -4, 11, 125, -7, 2, -0},
+        {-0, 1, -3, 8, 126, -6, 2, -0},    {-0, 1, -3, 6, 127, -4, 1, -0},
+        {-0, 1, -2, 4, 127, -3, 1, -0},    {-0, 0, -1, 2, 128, -1, 0, -0},
+};
+
+inline void ComputeSuperRes4(const uint8_t* src, uint8_t* dst_x, int step,
+                             int* p) {
+  __m128i weighted_src[4];
+  for (int i = 0; i < 4; ++i, *p += step) {
+    const __m128i src_x = LoadLo8(&src[*p >> kSuperResScaleBits]);
+    const int remainder = *p & kSuperResScaleMask;
+    const __m128i filter =
+        LoadUnaligned16(kUpscaleFilter[remainder >> kSuperResExtraBits]);
+    weighted_src[i] = _mm_madd_epi16(_mm_cvtepu8_epi16(src_x), filter);
+  }
+
+  // Pairwise add is chosen in favor of transpose and add because of the
+  // ability to take advantage of madd.
+  const __m128i res0 = _mm_hadd_epi32(weighted_src[0], weighted_src[1]);
+  const __m128i res1 = _mm_hadd_epi32(weighted_src[2], weighted_src[3]);
+  const __m128i result0 = _mm_hadd_epi32(res0, res1);
+  const __m128i result = _mm_packus_epi32(
+      RightShiftWithRounding_S32(result0, kFilterBits), result0);
+  Store4(dst_x, _mm_packus_epi16(result, result));
+}
+
+inline void ComputeSuperRes8(const uint8_t* src, uint8_t* dst_x, int step,
+                             int* p) {
+  __m128i weighted_src[8];
+  for (int i = 0; i < 8; ++i, *p += step) {
+    const __m128i src_x = LoadLo8(&src[*p >> kSuperResScaleBits]);
+    const int remainder = *p & kSuperResScaleMask;
+    const __m128i filter =
+        LoadUnaligned16(kUpscaleFilter[remainder >> kSuperResExtraBits]);
+    weighted_src[i] = _mm_madd_epi16(_mm_cvtepu8_epi16(src_x), filter);
+  }
+
+  // Pairwise add is chosen in favor of transpose and add because of the
+  // ability to take advantage of madd.
+  const __m128i res0 = _mm_hadd_epi32(weighted_src[0], weighted_src[1]);
+  const __m128i res1 = _mm_hadd_epi32(weighted_src[2], weighted_src[3]);
+  const __m128i res2 = _mm_hadd_epi32(weighted_src[4], weighted_src[5]);
+  const __m128i res3 = _mm_hadd_epi32(weighted_src[6], weighted_src[7]);
+  const __m128i result0 = _mm_hadd_epi32(res0, res1);
+  const __m128i result1 = _mm_hadd_epi32(res2, res3);
+  const __m128i result =
+      _mm_packus_epi32(RightShiftWithRounding_S32(result0, kFilterBits),
+                       RightShiftWithRounding_S32(result1, kFilterBits));
+  StoreLo8(dst_x, _mm_packus_epi16(result, result));
+}
+
+void ComputeSuperRes_SSE4_1(const void* source, const int upscaled_width,
+                            const int initial_subpixel_x, const int step,
+                            void* const dest) {
+  const auto* src = static_cast<const uint8_t*>(source);
+  auto* dst = static_cast<uint8_t*>(dest);
+  src -= kSuperResFilterTaps >> 1;
+
+  int p = initial_subpixel_x;
+  int x = 0;
+  for (; x < (upscaled_width & ~7); x += 8) {
+    ComputeSuperRes8(src, &dst[x], step, &p);
+  }
+  // The below code can overwrite at most 3 bytes and overread at most 7.
+  // kSuperResHorizontalBorder accounts for this.
+  for (; x < upscaled_width; x += 4) {
+    ComputeSuperRes4(src, &dst[x], step, &p);
+  }
+}
+
+void Init8bpp() {
+  Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  dsp->super_res_row = ComputeSuperRes_SSE4_1;
+}
+
+}  // namespace
+}  // namespace low_bitdepth
+
+void SuperResInit_SSE4_1() { low_bitdepth::Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else  // !LIBGAV1_ENABLE_SSE4_1
+
+namespace libgav1 {
+namespace dsp {
+
+void SuperResInit_SSE4_1() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/super_res_sse4.h b/libgav1/src/dsp/x86/super_res_sse4.h
new file mode 100644
index 0000000..5673ca5
--- /dev/null
+++ b/libgav1/src/dsp/x86/super_res_sse4.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_X86_SUPER_RES_SSE4_H_
+#define LIBGAV1_SRC_DSP_X86_SUPER_RES_SSE4_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::super_res_row. This function is not thread-safe.
+void SuperResInit_SSE4_1();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_SSE4_1
+#define LIBGAV1_Dsp8bpp_SuperRes LIBGAV1_CPU_SSE4_1
+#endif  // LIBGAV1_ENABLE_SSE4_1
+
+#endif  // LIBGAV1_SRC_DSP_X86_SUPER_RES_SSE4_H_
diff --git a/libgav1/src/dsp/x86/transpose_sse4.h b/libgav1/src/dsp/x86/transpose_sse4.h
index 993b42f..cd61c92 100644
--- a/libgav1/src/dsp/x86/transpose_sse4.h
+++ b/libgav1/src/dsp/x86/transpose_sse4.h
@@ -17,8 +17,8 @@
 #ifndef LIBGAV1_SRC_DSP_X86_TRANSPOSE_SSE4_H_
 #define LIBGAV1_SRC_DSP_X86_TRANSPOSE_SSE4_H_
 
-#include "src/dsp/dsp.h"
 #include "src/utils/compiler_attributes.h"
+#include "src/utils/cpu.h"
 
 #if LIBGAV1_ENABLE_SSE4_1
 #include <emmintrin.h>
@@ -27,7 +27,7 @@
 namespace dsp {
 
 LIBGAV1_ALWAYS_INLINE __m128i Transpose4x4_U8(const __m128i* const in) {
-  // Unpack 16 bit elements. Goes from:
+  // Unpack 8 bit elements. Goes from:
   // in[0]: 00 01 02 03
   // in[1]: 10 11 12 13
   // in[2]: 20 21 22 23
@@ -43,6 +43,46 @@
   return _mm_unpacklo_epi16(a0, a1);
 }
 
+LIBGAV1_ALWAYS_INLINE void Transpose8x8To4x16_U8(const __m128i* const in,
+                                                 __m128i* out) {
+  // Unpack 8 bit elements. Goes from:
+  // in[0]:  00 01 02 03 04 05 06 07
+  // in[1]:  10 11 12 13 14 15 16 17
+  // in[2]:  20 21 22 23 24 25 26 27
+  // in[3]:  30 31 32 33 34 35 36 37
+  // in[4]:  40 41 42 43 44 45 46 47
+  // in[5]:  50 51 52 53 54 55 56 57
+  // in[6]:  60 61 62 63 64 65 66 67
+  // in[7]:  70 71 72 73 74 75 76 77
+  // to:
+  // a0:     00 10 01 11  02 12 03 13  04 14 05 15  06 16 07 17
+  // a1:     20 30 21 31  22 32 23 33  24 34 25 35  26 36 27 37
+  // a2:     40 50 41 51  42 52 43 53  44 54 45 55  46 56 47 57
+  // a3:     60 70 61 71  62 72 63 73  64 74 65 75  66 76 67 77
+  const __m128i a0 = _mm_unpacklo_epi8(in[0], in[1]);
+  const __m128i a1 = _mm_unpacklo_epi8(in[2], in[3]);
+  const __m128i a2 = _mm_unpacklo_epi8(in[4], in[5]);
+  const __m128i a3 = _mm_unpacklo_epi8(in[6], in[7]);
+
+  // b0:     00 10 20 30  01 11 21 31  02 12 22 32  03 13 23 33
+  // b1:     40 50 60 70  41 51 61 71  42 52 62 72  43 53 63 73
+  // b2:     04 14 24 34  05 15 25 35  06 16 26 36  07 17 27 37
+  // b3:     44 54 64 74  45 55 65 75  46 56 66 76  47 57 67 77
+  const __m128i b0 = _mm_unpacklo_epi16(a0, a1);
+  const __m128i b1 = _mm_unpacklo_epi16(a2, a3);
+  const __m128i b2 = _mm_unpackhi_epi16(a0, a1);
+  const __m128i b3 = _mm_unpackhi_epi16(a2, a3);
+
+  // out[0]: 00 10 20 30  40 50 60 70  01 11 21 31  41 51 61 71
+  // out[1]: 02 12 22 32  42 52 62 72  03 13 23 33  43 53 63 73
+  // out[2]: 04 14 24 34  44 54 64 74  05 15 25 35  45 55 65 75
+  // out[3]: 06 16 26 36  46 56 66 76  07 17 27 37  47 57 67 77
+  out[0] = _mm_unpacklo_epi32(b0, b1);
+  out[1] = _mm_unpackhi_epi32(b0, b1);
+  out[2] = _mm_unpacklo_epi32(b2, b3);
+  out[3] = _mm_unpackhi_epi32(b2, b3);
+}
+
 LIBGAV1_ALWAYS_INLINE void Transpose4x4_U16(const __m128i* in, __m128i* out) {
   // Unpack 16 bit elements. Goes from:
   // in[0]: 00 01 02 03  XX XX XX XX
diff --git a/libgav1/src/dsp/x86/warp_sse4.cc b/libgav1/src/dsp/x86/warp_sse4.cc
new file mode 100644
index 0000000..4c9e716
--- /dev/null
+++ b/libgav1/src/dsp/x86/warp_sse4.cc
@@ -0,0 +1,525 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/warp.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_SSE4_1
+
+#include <smmintrin.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+#include <type_traits>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/x86/common_sse4.h"
+#include "src/dsp/x86/transpose_sse4.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+// Number of extra bits of precision in warped filtering.
+constexpr int kWarpedDiffPrecisionBits = 10;
+
+// This assumes the two filters contain filter[x] and filter[x+2].
+inline __m128i AccumulateFilter(const __m128i sum, const __m128i filter_0,
+                                const __m128i filter_1,
+                                const __m128i& src_window) {
+  const __m128i filter_taps = _mm_unpacklo_epi8(filter_0, filter_1);
+  const __m128i src =
+      _mm_unpacklo_epi8(src_window, _mm_srli_si128(src_window, 2));
+  return _mm_add_epi16(sum, _mm_maddubs_epi16(src, filter_taps));
+}
+
+constexpr int kFirstPassOffset = 1 << 14;
+constexpr int kOffsetRemoval =
+    (kFirstPassOffset >> kInterRoundBitsHorizontal) * 128;
+
+// Applies the horizontal filter to one source row and stores the result in
+// |intermediate_result_row|. |intermediate_result_row| is a row in the 15x8
+// |intermediate_result| two-dimensional array.
+inline void HorizontalFilter(const int sx4, const int16_t alpha,
+                             const __m128i src_row,
+                             int16_t intermediate_result_row[8]) {
+  int sx = sx4 - MultiplyBy4(alpha);
+  __m128i filter[8];
+  for (__m128i& f : filter) {
+    const int offset = RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) +
+                       kWarpedPixelPrecisionShifts;
+    f = LoadLo8(kWarpedFilters8[offset]);
+    sx += alpha;
+  }
+  Transpose8x8To4x16_U8(filter, filter);
+  // |filter| now contains two filters per register.
+  // Staggered combinations allow us to take advantage of _mm_maddubs_epi16
+  // without overflowing the sign bit. The sign bit is hit only where two taps
+  // paired in a single madd add up to more than 128. This is only possible with
+  // two adjacent "inner" taps. Therefore, pairing odd with odd and even with
+  // even guarantees safety. |sum| is given a negative offset to allow for large
+  // intermediate values.
+  // k = 0, 2.
+  __m128i src_row_window = src_row;
+  __m128i sum = _mm_set1_epi16(-kFirstPassOffset);
+  sum = AccumulateFilter(sum, filter[0], filter[1], src_row_window);
+
+  // k = 1, 3.
+  src_row_window = _mm_srli_si128(src_row_window, 1);
+  sum = AccumulateFilter(sum, _mm_srli_si128(filter[0], 8),
+                         _mm_srli_si128(filter[1], 8), src_row_window);
+  // k = 4, 6.
+  src_row_window = _mm_srli_si128(src_row_window, 3);
+  sum = AccumulateFilter(sum, filter[2], filter[3], src_row_window);
+
+  // k = 5, 7.
+  src_row_window = _mm_srli_si128(src_row_window, 1);
+  sum = AccumulateFilter(sum, _mm_srli_si128(filter[2], 8),
+                         _mm_srli_si128(filter[3], 8), src_row_window);
+
+  sum = RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal);
+  StoreUnaligned16(intermediate_result_row, sum);
+}
+
+template <bool is_compound>
+inline void WriteVerticalFilter(const __m128i filter[8],
+                                const int16_t intermediate_result[15][8], int y,
+                                void* dst_row) {
+  constexpr int kRoundBitsVertical =
+      is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical;
+  __m128i sum_low = _mm_set1_epi32(kOffsetRemoval);
+  __m128i sum_high = sum_low;
+  for (int k = 0; k < 8; k += 2) {
+    const __m128i filters_low = _mm_unpacklo_epi16(filter[k], filter[k + 1]);
+    const __m128i filters_high = _mm_unpackhi_epi16(filter[k], filter[k + 1]);
+    const __m128i intermediate_0 = LoadUnaligned16(intermediate_result[y + k]);
+    const __m128i intermediate_1 =
+        LoadUnaligned16(intermediate_result[y + k + 1]);
+    const __m128i intermediate_low =
+        _mm_unpacklo_epi16(intermediate_0, intermediate_1);
+    const __m128i intermediate_high =
+        _mm_unpackhi_epi16(intermediate_0, intermediate_1);
+
+    const __m128i product_low = _mm_madd_epi16(filters_low, intermediate_low);
+    const __m128i product_high =
+        _mm_madd_epi16(filters_high, intermediate_high);
+    sum_low = _mm_add_epi32(sum_low, product_low);
+    sum_high = _mm_add_epi32(sum_high, product_high);
+  }
+  sum_low = RightShiftWithRounding_S32(sum_low, kRoundBitsVertical);
+  sum_high = RightShiftWithRounding_S32(sum_high, kRoundBitsVertical);
+  if (is_compound) {
+    const __m128i sum = _mm_packs_epi32(sum_low, sum_high);
+    StoreUnaligned16(static_cast<int16_t*>(dst_row), sum);
+  } else {
+    const __m128i sum = _mm_packus_epi32(sum_low, sum_high);
+    StoreLo8(static_cast<uint8_t*>(dst_row), _mm_packus_epi16(sum, sum));
+  }
+}
+
+template <bool is_compound>
+inline void WriteVerticalFilter(const __m128i filter[8],
+                                const int16_t* intermediate_result_column,
+                                void* dst_row) {
+  constexpr int kRoundBitsVertical =
+      is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical;
+  __m128i sum_low = _mm_setzero_si128();
+  __m128i sum_high = _mm_setzero_si128();
+  for (int k = 0; k < 8; k += 2) {
+    const __m128i filters_low = _mm_unpacklo_epi16(filter[k], filter[k + 1]);
+    const __m128i filters_high = _mm_unpackhi_epi16(filter[k], filter[k + 1]);
+    // Equivalent to unpacking two vectors made by duplicating int16_t values.
+    const __m128i intermediate =
+        _mm_set1_epi32((intermediate_result_column[k + 1] << 16) |
+                       intermediate_result_column[k]);
+    const __m128i product_low = _mm_madd_epi16(filters_low, intermediate);
+    const __m128i product_high = _mm_madd_epi16(filters_high, intermediate);
+    sum_low = _mm_add_epi32(sum_low, product_low);
+    sum_high = _mm_add_epi32(sum_high, product_high);
+  }
+  sum_low = RightShiftWithRounding_S32(sum_low, kRoundBitsVertical);
+  sum_high = RightShiftWithRounding_S32(sum_high, kRoundBitsVertical);
+  if (is_compound) {
+    const __m128i sum = _mm_packs_epi32(sum_low, sum_high);
+    StoreUnaligned16(static_cast<int16_t*>(dst_row), sum);
+  } else {
+    const __m128i sum = _mm_packus_epi32(sum_low, sum_high);
+    StoreLo8(static_cast<uint8_t*>(dst_row), _mm_packus_epi16(sum, sum));
+  }
+}
+
+template <bool is_compound, typename DestType>
+inline void VerticalFilter(const int16_t source[15][8], int y4, int gamma,
+                           int delta, DestType* dest_row,
+                           ptrdiff_t dest_stride) {
+  int sy4 = (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta);
+  for (int y = 0; y < 8; ++y) {
+    int sy = sy4 - MultiplyBy4(gamma);
+    __m128i filter[8];
+    for (__m128i& f : filter) {
+      const int offset = RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
+                         kWarpedPixelPrecisionShifts;
+      f = LoadUnaligned16(kWarpedFilters[offset]);
+      sy += gamma;
+    }
+    Transpose8x8_U16(filter, filter);
+    WriteVerticalFilter<is_compound>(filter, source, y, dest_row);
+    dest_row += dest_stride;
+    sy4 += delta;
+  }
+}
+
+template <bool is_compound, typename DestType>
+inline void VerticalFilter(const int16_t* source_cols, int y4, int gamma,
+                           int delta, DestType* dest_row,
+                           ptrdiff_t dest_stride) {
+  int sy4 = (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta);
+  for (int y = 0; y < 8; ++y) {
+    int sy = sy4 - MultiplyBy4(gamma);
+    __m128i filter[8];
+    for (__m128i& f : filter) {
+      const int offset = RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
+                         kWarpedPixelPrecisionShifts;
+      f = LoadUnaligned16(kWarpedFilters[offset]);
+      sy += gamma;
+    }
+    Transpose8x8_U16(filter, filter);
+    WriteVerticalFilter<is_compound>(filter, &source_cols[y], dest_row);
+    dest_row += dest_stride;
+    sy4 += delta;
+  }
+}
+
+template <bool is_compound, typename DestType>
+inline void WarpRegion1(const uint8_t* src, ptrdiff_t source_stride,
+                        int source_width, int source_height, int ix4, int iy4,
+                        DestType* dst_row, ptrdiff_t dest_stride) {
+  // Region 1
+  // Points to the left or right border of the first row of |src|.
+  const uint8_t* first_row_border =
+      (ix4 + 7 <= 0) ? src : src + source_width - 1;
+  // In general, for y in [-7, 8), the row number iy4 + y is clipped:
+  //   const int row = Clip3(iy4 + y, 0, source_height - 1);
+  // In two special cases, iy4 + y is clipped to either 0 or
+  // source_height - 1 for all y. In the rest of the cases, iy4 + y is
+  // bounded and we can avoid clipping iy4 + y by relying on a reference
+  // frame's boundary extension on the top and bottom.
+  // Region 1.
+  // Every sample used to calculate the prediction block has the same
+  // value. So the whole prediction block has the same value.
+  const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1;
+  const uint8_t row_border_pixel = first_row_border[row * source_stride];
+
+  if (is_compound) {
+    const __m128i sum =
+        _mm_set1_epi16(row_border_pixel << (kInterRoundBitsVertical -
+                                            kInterRoundBitsCompoundVertical));
+    StoreUnaligned16(dst_row, sum);
+  } else {
+    memset(dst_row, row_border_pixel, 8);
+  }
+  const DestType* const first_dst_row = dst_row;
+  dst_row += dest_stride;
+  for (int y = 1; y < 8; ++y) {
+    memcpy(dst_row, first_dst_row, 8 * sizeof(*dst_row));
+    dst_row += dest_stride;
+  }
+}
+
+template <bool is_compound, typename DestType>
+inline void WarpRegion2(const uint8_t* src, ptrdiff_t source_stride,
+                        int source_width, int y4, int ix4, int iy4, int gamma,
+                        int delta, int16_t intermediate_result_column[15],
+                        DestType* dst_row, ptrdiff_t dest_stride) {
+  // Region 2.
+  // Points to the left or right border of the first row of |src|.
+  const uint8_t* first_row_border =
+      (ix4 + 7 <= 0) ? src : src + source_width - 1;
+  // In general, for y in [-7, 8), the row number iy4 + y is clipped:
+  //   const int row = Clip3(iy4 + y, 0, source_height - 1);
+  // In two special cases, iy4 + y is clipped to either 0 or
+  // source_height - 1 for all y. In the rest of the cases, iy4 + y is
+  // bounded and we can avoid clipping iy4 + y by relying on a reference
+  // frame's boundary extension on the top and bottom.
+
+  // Region 2.
+  // Horizontal filter.
+  // The input values in this region are generated by extending the border
+  // which makes them identical in the horizontal direction. This
+  // computation could be inlined in the vertical pass but most
+  // implementations will need a transpose of some sort.
+  // It is not necessary to use the offset values here because the
+  // horizontal pass is a simple shift and the vertical pass will always
+  // require using 32 bits.
+  for (int y = -7; y < 8; ++y) {
+    // We may over-read up to 13 pixels above the top source row, or up
+    // to 13 pixels below the bottom source row. This is proved in
+    // warp.cc.
+    const int row = iy4 + y;
+    int sum = first_row_border[row * source_stride];
+    sum <<= (kFilterBits - kInterRoundBitsHorizontal);
+    intermediate_result_column[y + 7] = sum;
+  }
+  // Region 2 vertical filter.
+  VerticalFilter<is_compound, DestType>(intermediate_result_column, y4, gamma,
+                                        delta, dst_row, dest_stride);
+}
+
+template <bool is_compound, typename DestType>
+inline void WarpRegion3(const uint8_t* src, ptrdiff_t source_stride,
+                        int source_height, int alpha, int beta, int x4, int ix4,
+                        int iy4, int16_t intermediate_result[15][8]) {
+  // Region 3
+  // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0.
+
+  // In general, for y in [-7, 8), the row number iy4 + y is clipped:
+  //   const int row = Clip3(iy4 + y, 0, source_height - 1);
+  // In two special cases, iy4 + y is clipped to either 0 or
+  // source_height - 1 for all y. In the rest of the cases, iy4 + y is
+  // bounded and we can avoid clipping iy4 + y by relying on a reference
+  // frame's boundary extension on the top and bottom.
+  // Horizontal filter.
+  const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1;
+  const uint8_t* const src_row = src + row * source_stride;
+  // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also
+  // read but is ignored.
+  //
+  // NOTE: This may read up to 13 bytes before src_row[0] or up to 14
+  // bytes after src_row[source_width - 1]. We assume the source frame
+  // has left and right borders of at least 13 bytes that extend the
+  // frame boundary pixels. We also assume there is at least one extra
+  // padding byte after the right border of the last source row.
+  const __m128i src_row_v = LoadUnaligned16(&src_row[ix4 - 7]);
+  int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7;
+  for (int y = -7; y < 8; ++y) {
+    HorizontalFilter(sx4, alpha, src_row_v, intermediate_result[y + 7]);
+    sx4 += beta;
+  }
+}
+
+template <bool is_compound, typename DestType>
+inline void WarpRegion4(const uint8_t* src, ptrdiff_t source_stride, int alpha,
+                        int beta, int x4, int ix4, int iy4,
+                        int16_t intermediate_result[15][8]) {
+  // Region 4.
+  // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0.
+
+  // In general, for y in [-7, 8), the row number iy4 + y is clipped:
+  //   const int row = Clip3(iy4 + y, 0, source_height - 1);
+  // In two special cases, iy4 + y is clipped to either 0 or
+  // source_height - 1 for all y. In the rest of the cases, iy4 + y is
+  // bounded and we can avoid clipping iy4 + y by relying on a reference
+  // frame's boundary extension on the top and bottom.
+  // Horizontal filter.
+  int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7;
+  for (int y = -7; y < 8; ++y) {
+    // We may over-read up to 13 pixels above the top source row, or up
+    // to 13 pixels below the bottom source row. This is proved in
+    // warp.cc.
+    const int row = iy4 + y;
+    const uint8_t* const src_row = src + row * source_stride;
+    // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also
+    // read but is ignored.
+    //
+    // NOTE: This may read up to 13 bytes before src_row[0] or up to 14
+    // bytes after src_row[source_width - 1]. We assume the source frame
+    // has left and right borders of at least 13 bytes that extend the
+    // frame boundary pixels. We also assume there is at least one extra
+    // padding byte after the right border of the last source row.
+    const __m128i src_row_v = LoadUnaligned16(&src_row[ix4 - 7]);
+    // Convert src_row_v to int8 (subtract 128).
+    HorizontalFilter(sx4, alpha, src_row_v, intermediate_result[y + 7]);
+    sx4 += beta;
+  }
+}
+
+template <bool is_compound, typename DestType>
+inline void HandleWarpBlock(const uint8_t* src, ptrdiff_t source_stride,
+                            int source_width, int source_height,
+                            const int* warp_params, int subsampling_x,
+                            int subsampling_y, int src_x, int src_y,
+                            int16_t alpha, int16_t beta, int16_t gamma,
+                            int16_t delta, DestType* dst_row,
+                            ptrdiff_t dest_stride) {
+  union {
+    // Intermediate_result is the output of the horizontal filtering and
+    // rounding. The range is within 13 (= bitdepth + kFilterBits + 1 -
+    // kInterRoundBitsHorizontal) bits (unsigned). We use the signed int16_t
+    // type so that we can start with a negative offset and restore it on the
+    // final filter sum.
+    int16_t intermediate_result[15][8];  // 15 rows, 8 columns.
+    // In the simple special cases where the samples in each row are all the
+    // same, store one sample per row in a column vector.
+    int16_t intermediate_result_column[15];
+  };
+
+  const int dst_x =
+      src_x * warp_params[2] + src_y * warp_params[3] + warp_params[0];
+  const int dst_y =
+      src_x * warp_params[4] + src_y * warp_params[5] + warp_params[1];
+  const int x4 = dst_x >> subsampling_x;
+  const int y4 = dst_y >> subsampling_y;
+  const int ix4 = x4 >> kWarpedModelPrecisionBits;
+  const int iy4 = y4 >> kWarpedModelPrecisionBits;
+  // A prediction block may fall outside the frame's boundaries. If a
+  // prediction block is calculated using only samples outside the frame's
+  // boundary, the filtering can be simplified. We can divide the plane
+  // into several regions and handle them differently.
+  //
+  //                |           |
+  //            1   |     3     |   1
+  //                |           |
+  //         -------+-----------+-------
+  //                |***********|
+  //            2   |*****4*****|   2
+  //                |***********|
+  //         -------+-----------+-------
+  //                |           |
+  //            1   |     3     |   1
+  //                |           |
+  //
+  // At the center, region 4 represents the frame and is the general case.
+  //
+  // In regions 1 and 2, the prediction block is outside the frame's
+  // boundary horizontally. Therefore the horizontal filtering can be
+  // simplified. Furthermore, in the region 1 (at the four corners), the
+  // prediction is outside the frame's boundary both horizontally and
+  // vertically, so we get a constant prediction block.
+  //
+  // In region 3, the prediction block is outside the frame's boundary
+  // vertically. Unfortunately because we apply the horizontal filters
+  // first, by the time we apply the vertical filters, they no longer see
+  // simple inputs. So the only simplification is that all the rows are
+  // the same, but we still need to apply all the horizontal and vertical
+  // filters.
+
+  // Check for two simple special cases, where the horizontal filter can
+  // be significantly simplified.
+  //
+  // In general, for each row, the horizontal filter is calculated as
+  // follows:
+  //   for (int x = -4; x < 4; ++x) {
+  //     const int offset = ...;
+  //     int sum = first_pass_offset;
+  //     for (int k = 0; k < 8; ++k) {
+  //       const int column = Clip3(ix4 + x + k - 3, 0, source_width - 1);
+  //       sum += kWarpedFilters[offset][k] * src_row[column];
+  //     }
+  //     ...
+  //   }
+  // The column index before clipping, ix4 + x + k - 3, varies in the range
+  // ix4 - 7 <= ix4 + x + k - 3 <= ix4 + 7. If ix4 - 7 >= source_width - 1
+  // or ix4 + 7 <= 0, then all the column indexes are clipped to the same
+  // border index (source_width - 1 or 0, respectively). Then for each x,
+  // the inner for loop of the horizontal filter is reduced to multiplying
+  // the border pixel by the sum of the filter coefficients.
+  if (ix4 - 7 >= source_width - 1 || ix4 + 7 <= 0) {
+    if ((iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0)) {
+      // Outside the frame in both directions. One repeated value.
+      WarpRegion1<is_compound, DestType>(src, source_stride, source_width,
+                                         source_height, ix4, iy4, dst_row,
+                                         dest_stride);
+      return;
+    }
+    // Outside the frame horizontally. Rows repeated.
+    WarpRegion2<is_compound, DestType>(
+        src, source_stride, source_width, y4, ix4, iy4, gamma, delta,
+        intermediate_result_column, dst_row, dest_stride);
+    return;
+  }
+
+  if ((iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0)) {
+    // Outside the frame vertically.
+    WarpRegion3<is_compound, DestType>(src, source_stride, source_height, alpha,
+                                       beta, x4, ix4, iy4, intermediate_result);
+  } else {
+    // Inside the frame.
+    WarpRegion4<is_compound, DestType>(src, source_stride, alpha, beta, x4, ix4,
+                                       iy4, intermediate_result);
+  }
+  // Region 3 and 4 vertical filter.
+  VerticalFilter<is_compound, DestType>(intermediate_result, y4, gamma, delta,
+                                        dst_row, dest_stride);
+}
+
+template <bool is_compound>
+void Warp_SSE4_1(const void* source, ptrdiff_t source_stride, int source_width,
+                 int source_height, const int* warp_params, int subsampling_x,
+                 int subsampling_y, int block_start_x, int block_start_y,
+                 int block_width, int block_height, int16_t alpha, int16_t beta,
+                 int16_t gamma, int16_t delta, void* dest,
+                 ptrdiff_t dest_stride) {
+  const auto* const src = static_cast<const uint8_t*>(source);
+  using DestType =
+      typename std::conditional<is_compound, int16_t, uint8_t>::type;
+  auto* dst = static_cast<DestType*>(dest);
+
+  // Warp process applies for each 8x8 block.
+  assert(block_width >= 8);
+  assert(block_height >= 8);
+  const int block_end_x = block_start_x + block_width;
+  const int block_end_y = block_start_y + block_height;
+
+  const int start_x = block_start_x;
+  const int start_y = block_start_y;
+  int src_x = (start_x + 4) << subsampling_x;
+  int src_y = (start_y + 4) << subsampling_y;
+  const int end_x = (block_end_x + 4) << subsampling_x;
+  const int end_y = (block_end_y + 4) << subsampling_y;
+  do {
+    DestType* dst_row = dst;
+    src_x = (start_x + 4) << subsampling_x;
+    do {
+      HandleWarpBlock<is_compound, DestType>(
+          src, source_stride, source_width, source_height, warp_params,
+          subsampling_x, subsampling_y, src_x, src_y, alpha, beta, gamma, delta,
+          dst_row, dest_stride);
+      src_x += (8 << subsampling_x);
+      dst_row += 8;
+    } while (src_x < end_x);
+    dst += 8 * dest_stride;
+    src_y += (8 << subsampling_y);
+  } while (src_y < end_y);
+}
+
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+  dsp->warp = Warp_SSE4_1</*is_compound=*/false>;
+  dsp->warp_compound = Warp_SSE4_1</*is_compound=*/true>;
+}
+
+}  // namespace
+}  // namespace low_bitdepth
+
+void WarpInit_SSE4_1() { low_bitdepth::Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+#else  // !LIBGAV1_ENABLE_SSE4_1
+
+namespace libgav1 {
+namespace dsp {
+
+void WarpInit_SSE4_1() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/warp_sse4.h b/libgav1/src/dsp/x86/warp_sse4.h
new file mode 100644
index 0000000..51fbf43
--- /dev/null
+++ b/libgav1/src/dsp/x86/warp_sse4.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_X86_WARP_SSE4_H_
+#define LIBGAV1_SRC_DSP_X86_WARP_SSE4_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::warp. This function is not thread-safe.
+void WarpInit_SSE4_1();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_SSE4_1
+#define LIBGAV1_Dsp8bpp_Warp LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WarpCompound LIBGAV1_CPU_SSE4_1
+#endif  // LIBGAV1_ENABLE_SSE4_1
+
+#endif  // LIBGAV1_SRC_DSP_X86_WARP_SSE4_H_
diff --git a/libgav1/src/dsp/x86/weight_mask_sse4.cc b/libgav1/src/dsp/x86/weight_mask_sse4.cc
new file mode 100644
index 0000000..9d9d9c4
--- /dev/null
+++ b/libgav1/src/dsp/x86/weight_mask_sse4.cc
@@ -0,0 +1,464 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/x86/weight_mask_sse4.h"
+
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_SSE4_1
+
+#include <smmintrin.h>
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/x86/common_sse4.h"
+#include "src/utils/common.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace low_bitdepth {
+namespace {
+
+constexpr int kRoundingBits8bpp = 4;
+
+template <bool mask_is_inverse>
+inline void WeightMask8_SSE4(const int16_t* prediction_0,
+                             const int16_t* prediction_1, uint8_t* mask) {
+  const __m128i pred_0 = LoadAligned16(prediction_0);
+  const __m128i pred_1 = LoadAligned16(prediction_1);
+  const __m128i difference = RightShiftWithRounding_U16(
+      _mm_abs_epi16(_mm_sub_epi16(pred_0, pred_1)), kRoundingBits8bpp);
+  const __m128i scaled_difference = _mm_srli_epi16(difference, 4);
+  const __m128i difference_offset = _mm_set1_epi8(38);
+  const __m128i adjusted_difference =
+      _mm_adds_epu8(_mm_packus_epi16(scaled_difference, scaled_difference),
+                    difference_offset);
+  const __m128i mask_ceiling = _mm_set1_epi8(64);
+  const __m128i mask_value = _mm_min_epi8(adjusted_difference, mask_ceiling);
+  if (mask_is_inverse) {
+    const __m128i inverted_mask_value = _mm_sub_epi8(mask_ceiling, mask_value);
+    StoreLo8(mask, inverted_mask_value);
+  } else {
+    StoreLo8(mask, mask_value);
+  }
+}
+
+#define WEIGHT8_WITHOUT_STRIDE \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0, pred_1, mask)
+
+#define WEIGHT8_AND_STRIDE \
+  WEIGHT8_WITHOUT_STRIDE;  \
+  pred_0 += 8;             \
+  pred_1 += 8;             \
+  mask += mask_stride
+
+template <bool mask_is_inverse>
+void WeightMask8x8_SSE4(const void* prediction_0, const void* prediction_1,
+                        uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y = 0;
+  do {
+    WEIGHT8_AND_STRIDE;
+  } while (++y < 7);
+  WEIGHT8_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask8x16_SSE4(const void* prediction_0, const void* prediction_1,
+                         uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT8_AND_STRIDE;
+    WEIGHT8_AND_STRIDE;
+    WEIGHT8_AND_STRIDE;
+  } while (++y3 < 5);
+  WEIGHT8_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask8x32_SSE4(const void* prediction_0, const void* prediction_1,
+                         uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y5 = 0;
+  do {
+    WEIGHT8_AND_STRIDE;
+    WEIGHT8_AND_STRIDE;
+    WEIGHT8_AND_STRIDE;
+    WEIGHT8_AND_STRIDE;
+    WEIGHT8_AND_STRIDE;
+  } while (++y5 < 6);
+  WEIGHT8_AND_STRIDE;
+  WEIGHT8_WITHOUT_STRIDE;
+}
+
+#define WEIGHT16_WITHOUT_STRIDE                            \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0, pred_1, mask); \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8)
+
+#define WEIGHT16_AND_STRIDE \
+  WEIGHT16_WITHOUT_STRIDE;  \
+  pred_0 += 16;             \
+  pred_1 += 16;             \
+  mask += mask_stride
+
+template <bool mask_is_inverse>
+void WeightMask16x8_SSE4(const void* prediction_0, const void* prediction_1,
+                         uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y = 0;
+  do {
+    WEIGHT16_AND_STRIDE;
+  } while (++y < 7);
+  WEIGHT16_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask16x16_SSE4(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+  } while (++y3 < 5);
+  WEIGHT16_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask16x32_SSE4(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y5 = 0;
+  do {
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+  } while (++y5 < 6);
+  WEIGHT16_AND_STRIDE;
+  WEIGHT16_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask16x64_SSE4(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+    WEIGHT16_AND_STRIDE;
+  } while (++y3 < 21);
+  WEIGHT16_WITHOUT_STRIDE;
+}
+
+#define WEIGHT32_WITHOUT_STRIDE                                           \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0, pred_1, mask);                \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8);    \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24)
+
+#define WEIGHT32_AND_STRIDE \
+  WEIGHT32_WITHOUT_STRIDE;  \
+  pred_0 += 32;             \
+  pred_1 += 32;             \
+  mask += mask_stride
+
+template <bool mask_is_inverse>
+void WeightMask32x8_SSE4(const void* prediction_0, const void* prediction_1,
+                         uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask32x16_SSE4(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+  } while (++y3 < 5);
+  WEIGHT32_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask32x32_SSE4(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y5 = 0;
+  do {
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+  } while (++y5 < 6);
+  WEIGHT32_AND_STRIDE;
+  WEIGHT32_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask32x64_SSE4(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+    WEIGHT32_AND_STRIDE;
+  } while (++y3 < 21);
+  WEIGHT32_WITHOUT_STRIDE;
+}
+
+#define WEIGHT64_WITHOUT_STRIDE                                           \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0, pred_1, mask);                \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8);    \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24); \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0 + 32, pred_1 + 32, mask + 32); \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0 + 40, pred_1 + 40, mask + 40); \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0 + 48, pred_1 + 48, mask + 48); \
+  WeightMask8_SSE4<mask_is_inverse>(pred_0 + 56, pred_1 + 56, mask + 56)
+
+#define WEIGHT64_AND_STRIDE \
+  WEIGHT64_WITHOUT_STRIDE;  \
+  pred_0 += 64;             \
+  pred_1 += 64;             \
+  mask += mask_stride
+
+template <bool mask_is_inverse>
+void WeightMask64x16_SSE4(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+  } while (++y3 < 5);
+  WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask64x32_SSE4(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y5 = 0;
+  do {
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+  } while (++y5 < 6);
+  WEIGHT64_AND_STRIDE;
+  WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask64x64_SSE4(const void* prediction_0, const void* prediction_1,
+                          uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+  } while (++y3 < 21);
+  WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask64x128_SSE4(const void* prediction_0, const void* prediction_1,
+                           uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  do {
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+    WEIGHT64_AND_STRIDE;
+  } while (++y3 < 42);
+  WEIGHT64_AND_STRIDE;
+  WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask128x64_SSE4(const void* prediction_0, const void* prediction_1,
+                           uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
+  do {
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += 64;
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += adjusted_mask_stride;
+
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += 64;
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += adjusted_mask_stride;
+
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += 64;
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += adjusted_mask_stride;
+  } while (++y3 < 21);
+  WEIGHT64_WITHOUT_STRIDE;
+  pred_0 += 64;
+  pred_1 += 64;
+  mask += 64;
+  WEIGHT64_WITHOUT_STRIDE;
+}
+
+template <bool mask_is_inverse>
+void WeightMask128x128_SSE4(const void* prediction_0, const void* prediction_1,
+                            uint8_t* mask, ptrdiff_t mask_stride) {
+  const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
+  int y3 = 0;
+  const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
+  do {
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += 64;
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += adjusted_mask_stride;
+
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += 64;
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += adjusted_mask_stride;
+
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += 64;
+    WEIGHT64_WITHOUT_STRIDE;
+    pred_0 += 64;
+    pred_1 += 64;
+    mask += adjusted_mask_stride;
+  } while (++y3 < 42);
+  WEIGHT64_WITHOUT_STRIDE;
+  pred_0 += 64;
+  pred_1 += 64;
+  mask += 64;
+  WEIGHT64_WITHOUT_STRIDE;
+  pred_0 += 64;
+  pred_1 += 64;
+  mask += adjusted_mask_stride;
+
+  WEIGHT64_WITHOUT_STRIDE;
+  pred_0 += 64;
+  pred_1 += 64;
+  mask += 64;
+  WEIGHT64_WITHOUT_STRIDE;
+}
+
+#define INIT_WEIGHT_MASK_8BPP(width, height, w_index, h_index) \
+  dsp->weight_mask[w_index][h_index][0] =                      \
+      WeightMask##width##x##height##_SSE4<0>;                  \
+  dsp->weight_mask[w_index][h_index][1] = WeightMask##width##x##height##_SSE4<1>
+void Init8bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+  INIT_WEIGHT_MASK_8BPP(8, 8, 0, 0);
+  INIT_WEIGHT_MASK_8BPP(8, 16, 0, 1);
+  INIT_WEIGHT_MASK_8BPP(8, 32, 0, 2);
+  INIT_WEIGHT_MASK_8BPP(16, 8, 1, 0);
+  INIT_WEIGHT_MASK_8BPP(16, 16, 1, 1);
+  INIT_WEIGHT_MASK_8BPP(16, 32, 1, 2);
+  INIT_WEIGHT_MASK_8BPP(16, 64, 1, 3);
+  INIT_WEIGHT_MASK_8BPP(32, 8, 2, 0);
+  INIT_WEIGHT_MASK_8BPP(32, 16, 2, 1);
+  INIT_WEIGHT_MASK_8BPP(32, 32, 2, 2);
+  INIT_WEIGHT_MASK_8BPP(32, 64, 2, 3);
+  INIT_WEIGHT_MASK_8BPP(64, 16, 3, 1);
+  INIT_WEIGHT_MASK_8BPP(64, 32, 3, 2);
+  INIT_WEIGHT_MASK_8BPP(64, 64, 3, 3);
+  INIT_WEIGHT_MASK_8BPP(64, 128, 3, 4);
+  INIT_WEIGHT_MASK_8BPP(128, 64, 4, 3);
+  INIT_WEIGHT_MASK_8BPP(128, 128, 4, 4);
+}
+
+}  // namespace
+}  // namespace low_bitdepth
+
+void WeightMaskInit_SSE4_1() { low_bitdepth::Init8bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else  // !LIBGAV1_ENABLE_SSE4_1
+
+namespace libgav1 {
+namespace dsp {
+
+void WeightMaskInit_SSE4_1() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_SSE4_1
diff --git a/libgav1/src/dsp/x86/weight_mask_sse4.h b/libgav1/src/dsp/x86/weight_mask_sse4.h
new file mode 100644
index 0000000..841dd5a
--- /dev/null
+++ b/libgav1/src/dsp/x86/weight_mask_sse4.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_SSE4_H_
+#define LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_SSE4_H_
+
+#include "src/dsp/dsp.h"
+#include "src/utils/cpu.h"
+
+namespace libgav1 {
+namespace dsp {
+
+// Initializes Dsp::weight_mask. This function is not thread-safe.
+void WeightMaskInit_SSE4_1();
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#if LIBGAV1_ENABLE_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_8x8 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_8x16 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_8x32 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_16x8 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_16x16 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_16x32 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_16x64 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_32x8 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_32x16 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_32x32 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_32x64 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_64x16 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_64x32 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_64x64 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_64x128 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_128x64 LIBGAV1_CPU_SSE4_1
+#define LIBGAV1_Dsp8bpp_WeightMask_128x128 LIBGAV1_CPU_SSE4_1
+#endif  // LIBGAV1_ENABLE_SSE4_1
+
+#endif  // LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_SSE4_H_
diff --git a/libgav1/src/film_grain.cc b/libgav1/src/film_grain.cc
new file mode 100644
index 0000000..15ae956
--- /dev/null
+++ b/libgav1/src/film_grain.cc
@@ -0,0 +1,819 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/film_grain.h"
+
+#include <algorithm>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+#include <new>
+
+#include "src/dsp/common.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/array_2d.h"
+#include "src/utils/blocking_counter.h"
+#include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
+#include "src/utils/constants.h"
+#include "src/utils/logging.h"
+#include "src/utils/threadpool.h"
+
+namespace libgav1 {
+
+namespace {
+
+// The kGaussianSequence array contains random samples from a Gaussian
+// distribution with zero mean and standard deviation of about 512 clipped to
+// the range of [-2048, 2047] (representable by a signed integer using 12 bits
+// of precision) and rounded to the nearest multiple of 4.
+//
+// Note: It is important that every element in the kGaussianSequence array be
+// less than 2040, so that RightShiftWithRounding(kGaussianSequence[i], 4) is
+// less than 128 for bitdepth=8 (GrainType=int8_t).
+constexpr int16_t kGaussianSequence[/*2048*/] = {
+    56,    568,   -180,  172,   124,   -84,   172,   -64,   -900,  24,   820,
+    224,   1248,  996,   272,   -8,    -916,  -388,  -732,  -104,  -188, 800,
+    112,   -652,  -320,  -376,  140,   -252,  492,   -168,  44,    -788, 588,
+    -584,  500,   -228,  12,    680,   272,   -476,  972,   -100,  652,  368,
+    432,   -196,  -720,  -192,  1000,  -332,  652,   -136,  -552,  -604, -4,
+    192,   -220,  -136,  1000,  -52,   372,   -96,   -624,  124,   -24,  396,
+    540,   -12,   -104,  640,   464,   244,   -208,  -84,   368,   -528, -740,
+    248,   -968,  -848,  608,   376,   -60,   -292,  -40,   -156,  252,  -292,
+    248,   224,   -280,  400,   -244,  244,   -60,   76,    -80,   212,  532,
+    340,   128,   -36,   824,   -352,  -60,   -264,  -96,   -612,  416,  -704,
+    220,   -204,  640,   -160,  1220,  -408,  900,   336,   20,    -336, -96,
+    -792,  304,   48,    -28,   -1232, -1172, -448,  104,   -292,  -520, 244,
+    60,    -948,  0,     -708,  268,   108,   356,   -548,  488,   -344, -136,
+    488,   -196,  -224,  656,   -236,  -1128, 60,    4,     140,   276,  -676,
+    -376,  168,   -108,  464,   8,     564,   64,    240,   308,   -300, -400,
+    -456,  -136,  56,    120,   -408,  -116,  436,   504,   -232,  328,  844,
+    -164,  -84,   784,   -168,  232,   -224,  348,   -376,  128,   568,  96,
+    -1244, -288,  276,   848,   832,   -360,  656,   464,   -384,  -332, -356,
+    728,   -388,  160,   -192,  468,   296,   224,   140,   -776,  -100, 280,
+    4,     196,   44,    -36,   -648,  932,   16,    1428,  28,    528,  808,
+    772,   20,    268,   88,    -332,  -284,  124,   -384,  -448,  208,  -228,
+    -1044, -328,  660,   380,   -148,  -300,  588,   240,   540,   28,   136,
+    -88,   -436,  256,   296,   -1000, 1400,  0,     -48,   1056,  -136, 264,
+    -528,  -1108, 632,   -484,  -592,  -344,  796,   124,   -668,  -768, 388,
+    1296,  -232,  -188,  -200,  -288,  -4,    308,   100,   -168,  256,  -500,
+    204,   -508,  648,   -136,  372,   -272,  -120,  -1004, -552,  -548, -384,
+    548,   -296,  428,   -108,  -8,    -912,  -324,  -224,  -88,   -112, -220,
+    -100,  996,   -796,  548,   360,   -216,  180,   428,   -200,  -212, 148,
+    96,    148,   284,   216,   -412,  -320,  120,   -300,  -384,  -604, -572,
+    -332,  -8,    -180,  -176,  696,   116,   -88,   628,   76,    44,   -516,
+    240,   -208,  -40,   100,   -592,  344,   -308,  -452,  -228,  20,   916,
+    -1752, -136,  -340,  -804,  140,   40,    512,   340,   248,   184,  -492,
+    896,   -156,  932,   -628,  328,   -688,  -448,  -616,  -752,  -100, 560,
+    -1020, 180,   -800,  -64,   76,    576,   1068,  396,   660,   552,  -108,
+    -28,   320,   -628,  312,   -92,   -92,   -472,  268,   16,    560,  516,
+    -672,  -52,   492,   -100,  260,   384,   284,   292,   304,   -148, 88,
+    -152,  1012,  1064,  -228,  164,   -376,  -684,  592,   -392,  156,  196,
+    -524,  -64,   -884,  160,   -176,  636,   648,   404,   -396,  -436, 864,
+    424,   -728,  988,   -604,  904,   -592,  296,   -224,  536,   -176, -920,
+    436,   -48,   1176,  -884,  416,   -776,  -824,  -884,  524,   -548, -564,
+    -68,   -164,  -96,   692,   364,   -692,  -1012, -68,   260,   -480, 876,
+    -1116, 452,   -332,  -352,  892,   -1088, 1220,  -676,  12,    -292, 244,
+    496,   372,   -32,   280,   200,   112,   -440,  -96,   24,    -644, -184,
+    56,    -432,  224,   -980,  272,   -260,  144,   -436,  420,   356,  364,
+    -528,  76,    172,   -744,  -368,  404,   -752,  -416,  684,   -688, 72,
+    540,   416,   92,    444,   480,   -72,   -1416, 164,   -1172, -68,  24,
+    424,   264,   1040,  128,   -912,  -524,  -356,  64,    876,   -12,  4,
+    -88,   532,   272,   -524,  320,   276,   -508,  940,   24,    -400, -120,
+    756,   60,    236,   -412,  100,   376,   -484,  400,   -100,  -740, -108,
+    -260,  328,   -268,  224,   -200,  -416,  184,   -604,  -564,  -20,  296,
+    60,    892,   -888,  60,    164,   68,    -760,  216,   -296,  904,  -336,
+    -28,   404,   -356,  -568,  -208,  -1480, -512,  296,   328,   -360, -164,
+    -1560, -776,  1156,  -428,  164,   -504,  -112,  120,   -216,  -148, -264,
+    308,   32,    64,    -72,   72,    116,   176,   -64,   -272,  460,  -536,
+    -784,  -280,  348,   108,   -752,  -132,  524,   -540,  -776,  116,  -296,
+    -1196, -288,  -560,  1040,  -472,  116,   -848,  -1116, 116,   636,  696,
+    284,   -176,  1016,  204,   -864,  -648,  -248,  356,   972,   -584, -204,
+    264,   880,   528,   -24,   -184,  116,   448,   -144,  828,   524,  212,
+    -212,  52,    12,    200,   268,   -488,  -404,  -880,  824,   -672, -40,
+    908,   -248,  500,   716,   -576,  492,   -576,  16,    720,   -108, 384,
+    124,   344,   280,   576,   -500,  252,   104,   -308,  196,   -188, -8,
+    1268,  296,   1032,  -1196, 436,   316,   372,   -432,  -200,  -660, 704,
+    -224,  596,   -132,  268,   32,    -452,  884,   104,   -1008, 424,  -1348,
+    -280,  4,     -1168, 368,   476,   696,   300,   -8,    24,    180,  -592,
+    -196,  388,   304,   500,   724,   -160,  244,   -84,   272,   -256, -420,
+    320,   208,   -144,  -156,  156,   364,   452,   28,    540,   316,  220,
+    -644,  -248,  464,   72,    360,   32,    -388,  496,   -680,  -48,  208,
+    -116,  -408,  60,    -604,  -392,  548,   -840,  784,   -460,  656,  -544,
+    -388,  -264,  908,   -800,  -628,  -612,  -568,  572,   -220,  164,  288,
+    -16,   -308,  308,   -112,  -636,  -760,  280,   -668,  432,   364,  240,
+    -196,  604,   340,   384,   196,   592,   -44,   -500,  432,   -580, -132,
+    636,   -76,   392,   4,     -412,  540,   508,   328,   -356,  -36,  16,
+    -220,  -64,   -248,  -60,   24,    -192,  368,   1040,  92,    -24,  -1044,
+    -32,   40,    104,   148,   192,   -136,  -520,  56,    -816,  -224, 732,
+    392,   356,   212,   -80,   -424,  -1008, -324,  588,   -1496, 576,  460,
+    -816,  -848,  56,    -580,  -92,   -1372, -112,  -496,  200,   364,  52,
+    -140,  48,    -48,   -60,   84,    72,    40,    132,   -356,  -268, -104,
+    -284,  -404,  732,   -520,  164,   -304,  -540,  120,   328,   -76,  -460,
+    756,   388,   588,   236,   -436,  -72,   -176,  -404,  -316,  -148, 716,
+    -604,  404,   -72,   -88,   -888,  -68,   944,   88,    -220,  -344, 960,
+    472,   460,   -232,  704,   120,   832,   -228,  692,   -508,  132,  -476,
+    844,   -748,  -364,  -44,   1116,  -1104, -1056, 76,    428,   552,  -692,
+    60,    356,   96,    -384,  -188,  -612,  -576,  736,   508,   892,  352,
+    -1132, 504,   -24,   -352,  324,   332,   -600,  -312,  292,   508,  -144,
+    -8,    484,   48,    284,   -260,  -240,  256,   -100,  -292,  -204, -44,
+    472,   -204,  908,   -188,  -1000, -256,  92,    1164,  -392,  564,  356,
+    652,   -28,   -884,  256,   484,   -192,  760,   -176,  376,   -524, -452,
+    -436,  860,   -736,  212,   124,   504,   -476,  468,   76,    -472, 552,
+    -692,  -944,  -620,  740,   -240,  400,   132,   20,    192,   -196, 264,
+    -668,  -1012, -60,   296,   -316,  -828,  76,    -156,  284,   -768, -448,
+    -832,  148,   248,   652,   616,   1236,  288,   -328,  -400,  -124, 588,
+    220,   520,   -696,  1032,  768,   -740,  -92,   -272,  296,   448,  -464,
+    412,   -200,  392,   440,   -200,  264,   -152,  -260,  320,   1032, 216,
+    320,   -8,    -64,   156,   -1016, 1084,  1172,  536,   484,   -432, 132,
+    372,   -52,   -256,  84,    116,   -352,  48,    116,   304,   -384, 412,
+    924,   -300,  528,   628,   180,   648,   44,    -980,  -220,  1320, 48,
+    332,   748,   524,   -268,  -720,  540,   -276,  564,   -344,  -208, -196,
+    436,   896,   88,    -392,  132,   80,    -964,  -288,  568,   56,   -48,
+    -456,  888,   8,     552,   -156,  -292,  948,   288,   128,   -716, -292,
+    1192,  -152,  876,   352,   -600,  -260,  -812,  -468,  -28,   -120, -32,
+    -44,   1284,  496,   192,   464,   312,   -76,   -516,  -380,  -456, -1012,
+    -48,   308,   -156,  36,    492,   -156,  -808,  188,   1652,  68,   -120,
+    -116,  316,   160,   -140,  352,   808,   -416,  592,   316,   -480, 56,
+    528,   -204,  -568,  372,   -232,  752,   -344,  744,   -4,    324,  -416,
+    -600,  768,   268,   -248,  -88,   -132,  -420,  -432,  80,    -288, 404,
+    -316,  -1216, -588,  520,   -108,  92,    -320,  368,   -480,  -216, -92,
+    1688,  -300,  180,   1020,  -176,  820,   -68,   -228,  -260,  436,  -904,
+    20,    40,    -508,  440,   -736,  312,   332,   204,   760,   -372, 728,
+    96,    -20,   -632,  -520,  -560,  336,   1076,  -64,   -532,  776,  584,
+    192,   396,   -728,  -520,  276,   -188,  80,    -52,   -612,  -252, -48,
+    648,   212,   -688,  228,   -52,   -260,  428,   -412,  -272,  -404, 180,
+    816,   -796,  48,    152,   484,   -88,   -216,  988,   696,   188,  -528,
+    648,   -116,  -180,  316,   476,   12,    -564,  96,    476,   -252, -364,
+    -376,  -392,  556,   -256,  -576,  260,   -352,  120,   -16,   -136, -260,
+    -492,  72,    556,   660,   580,   616,   772,   436,   424,   -32,  -324,
+    -1268, 416,   -324,  -80,   920,   160,   228,   724,   32,    -516, 64,
+    384,   68,    -128,  136,   240,   248,   -204,  -68,   252,   -932, -120,
+    -480,  -628,  -84,   192,   852,   -404,  -288,  -132,  204,   100,  168,
+    -68,   -196,  -868,  460,   1080,  380,   -80,   244,   0,     484,  -888,
+    64,    184,   352,   600,   460,   164,   604,   -196,  320,   -64,  588,
+    -184,  228,   12,    372,   48,    -848,  -344,  224,   208,   -200, 484,
+    128,   -20,   272,   -468,  -840,  384,   256,   -720,  -520,  -464, -580,
+    112,   -120,  644,   -356,  -208,  -608,  -528,  704,   560,   -424, 392,
+    828,   40,    84,    200,   -152,  0,     -144,  584,   280,   -120, 80,
+    -556,  -972,  -196,  -472,  724,   80,    168,   -32,   88,    160,  -688,
+    0,     160,   356,   372,   -776,  740,   -128,  676,   -248,  -480, 4,
+    -364,  96,    544,   232,   -1032, 956,   236,   356,   20,    -40,  300,
+    24,    -676,  -596,  132,   1120,  -104,  532,   -1096, 568,   648,  444,
+    508,   380,   188,   -376,  -604,  1488,  424,   24,    756,   -220, -192,
+    716,   120,   920,   688,   168,   44,    -460,  568,   284,   1144, 1160,
+    600,   424,   888,   656,   -356,  -320,  220,   316,   -176,  -724, -188,
+    -816,  -628,  -348,  -228,  -380,  1012,  -452,  -660,  736,   928,  404,
+    -696,  -72,   -268,  -892,  128,   184,   -344,  -780,  360,   336,  400,
+    344,   428,   548,   -112,  136,   -228,  -216,  -820,  -516,  340,  92,
+    -136,  116,   -300,  376,   -244,  100,   -316,  -520,  -284,  -12,  824,
+    164,   -548,  -180,  -128,  116,   -924,  -828,  268,   -368,  -580, 620,
+    192,   160,   0,     -1676, 1068,  424,   -56,   -360,  468,   -156, 720,
+    288,   -528,  556,   -364,  548,   -148,  504,   316,   152,   -648, -620,
+    -684,  -24,   -376,  -384,  -108,  -920,  -1032, 768,   180,   -264, -508,
+    -1268, -260,  -60,   300,   -240,  988,   724,   -376,  -576,  -212, -736,
+    556,   192,   1092,  -620,  -880,  376,   -56,   -4,    -216,  -32,  836,
+    268,   396,   1332,  864,   -600,  100,   56,    -412,  -92,   356,  180,
+    884,   -468,  -436,  292,   -388,  -804,  -704,  -840,  368,   -348, 140,
+    -724,  1536,  940,   372,   112,   -372,  436,   -480,  1136,  296,  -32,
+    -228,  132,   -48,   -220,  868,   -1016, -60,   -1044, -464,  328,  916,
+    244,   12,    -736,  -296,  360,   468,   -376,  -108,  -92,   788,  368,
+    -56,   544,   400,   -672,  -420,  728,   16,    320,   44,    -284, -380,
+    -796,  488,   132,   204,   -596,  -372,  88,    -152,  -908,  -636, -572,
+    -624,  -116,  -692,  -200,  -56,   276,   -88,   484,   -324,  948,  864,
+    1000,  -456,  -184,  -276,  292,   -296,  156,   676,   320,   160,  908,
+    -84,   -1236, -288,  -116,  260,   -372,  -644,  732,   -756,  -96,  84,
+    344,   -520,  348,   -688,  240,   -84,   216,   -1044, -136,  -676, -396,
+    -1500, 960,   -40,   176,   168,   1516,  420,   -504,  -344,  -364, -360,
+    1216,  -940,  -380,  -212,  252,   -660,  -708,  484,   -444,  -152, 928,
+    -120,  1112,  476,   -260,  560,   -148,  -344,  108,   -196,  228,  -288,
+    504,   560,   -328,  -88,   288,   -1008, 460,   -228,  468,   -836, -196,
+    76,    388,   232,   412,   -1168, -716,  -644,  756,   -172,  -356, -504,
+    116,   432,   528,   48,    476,   -168,  -608,  448,   160,   -532, -272,
+    28,    -676,  -12,   828,   980,   456,   520,   104,   -104,  256,  -344,
+    -4,    -28,   -368,  -52,   -524,  -572,  -556,  -200,  768,   1124, -208,
+    -512,  176,   232,   248,   -148,  -888,  604,   -600,  -304,  804,  -156,
+    -212,  488,   -192,  -804,  -256,  368,   -360,  -916,  -328,  228,  -240,
+    -448,  -472,  856,   -556,  -364,  572,   -12,   -156,  -368,  -340, 432,
+    252,   -752,  -152,  288,   268,   -580,  -848,  -592,  108,   -76,  244,
+    312,   -716,  592,   -80,   436,   360,   4,     -248,  160,   516,  584,
+    732,   44,    -468,  -280,  -292,  -156,  -588,  28,    308,   912,  24,
+    124,   156,   180,   -252,  944,   -924,  -772,  -520,  -428,  -624, 300,
+    -212,  -1144, 32,    -724,  800,   -1128, -212,  -1288, -848,  180,  -416,
+    440,   192,   -576,  -792,  -76,   -1080, 80,    -532,  -352,  -132, 380,
+    -820,  148,   1112,  128,   164,   456,   700,   -924,  144,   -668, -384,
+    648,   -832,  508,   552,   -52,   -100,  -656,  208,   -568,  748,  -88,
+    680,   232,   300,   192,   -408,  -1012, -152,  -252,  -268,  272,  -876,
+    -664,  -648,  -332,  -136,  16,    12,    1152,  -28,   332,   -536, 320,
+    -672,  -460,  -316,  532,   -260,  228,   -40,   1052,  -816,  180,  88,
+    -496,  -556,  -672,  -368,  428,   92,    356,   404,   -408,  252,  196,
+    -176,  -556,  792,   268,   32,    372,   40,    96,    -332,  328,  120,
+    372,   -900,  -40,   472,   -264,  -592,  952,   128,   656,   112,  664,
+    -232,  420,   4,     -344,  -464,  556,   244,   -416,  -32,   252,  0,
+    -412,  188,   -696,  508,   -476,  324,   -1096, 656,   -312,  560,  264,
+    -136,  304,   160,   -64,   -580,  248,   336,   -720,  560,   -348, -288,
+    -276,  -196,  -500,  852,   -544,  -236,  -1128, -992,  -776,  116,  56,
+    52,    860,   884,   212,   -12,   168,   1020,  512,   -552,  924,  -148,
+    716,   188,   164,   -340,  -520,  -184,  880,   -152,  -680,  -208, -1156,
+    -300,  -528,  -472,  364,   100,   -744,  -1056, -32,   540,   280,  144,
+    -676,  -32,   -232,  -280,  -224,  96,    568,   -76,   172,   148,  148,
+    104,   32,    -296,  -32,   788,   -80,   32,    -16,   280,   288,  944,
+    428,   -484};
+static_assert(sizeof(kGaussianSequence) / sizeof(kGaussianSequence[0]) == 2048,
+              "");
+
+// The number of rows in a contiguous group computed by a single worker thread
+// before checking for the next available group.
+constexpr int kFrameChunkHeight = 8;
+
+// |width| and |height| refer to the plane, not the frame, meaning any
+// subsampling should be applied by the caller.
+template <typename Pixel>
+inline void CopyImagePlane(const uint8_t* source_plane, ptrdiff_t source_stride,
+                           int width, int height, uint8_t* dest_plane,
+                           ptrdiff_t dest_stride) {
+  // If it's the same buffer there's nothing to do.
+  if (source_plane == dest_plane) return;
+
+  int y = 0;
+  do {
+    memcpy(dest_plane, source_plane, width * sizeof(Pixel));
+    source_plane += source_stride;
+    dest_plane += dest_stride;
+  } while (++y < height);
+}
+
+}  // namespace
+
+template <int bitdepth>
+FilmGrain<bitdepth>::FilmGrain(const FilmGrainParams& params,
+                               bool is_monochrome,
+                               bool color_matrix_is_identity, int subsampling_x,
+                               int subsampling_y, int width, int height,
+                               ThreadPool* thread_pool)
+    : params_(params),
+      is_monochrome_(is_monochrome),
+      color_matrix_is_identity_(color_matrix_is_identity),
+      subsampling_x_(subsampling_x),
+      subsampling_y_(subsampling_y),
+      width_(width),
+      height_(height),
+      template_uv_width_((subsampling_x != 0) ? kMinChromaWidth
+                                              : kMaxChromaWidth),
+      template_uv_height_((subsampling_y != 0) ? kMinChromaHeight
+                                               : kMaxChromaHeight),
+      thread_pool_(thread_pool) {}
+
+template <int bitdepth>
+bool FilmGrain<bitdepth>::Init() {
+  // Section 7.18.3.3. Generate grain process.
+  const dsp::Dsp& dsp = *dsp::GetDspTable(bitdepth);
+  // If params_.num_y_points is 0, luma_grain_ will never be read, so we don't
+  // need to generate it.
+  const bool use_luma = params_.num_y_points > 0;
+  if (use_luma) {
+    GenerateLumaGrain(params_, luma_grain_);
+    // If params_.auto_regression_coeff_lag is 0, the filter is the identity
+    // filter and therefore can be skipped.
+    if (params_.auto_regression_coeff_lag > 0) {
+      dsp.film_grain
+          .luma_auto_regression[params_.auto_regression_coeff_lag - 1](
+              params_, luma_grain_);
+    }
+  } else {
+    // Have AddressSanitizer warn if luma_grain_ is used.
+    ASAN_POISON_MEMORY_REGION(luma_grain_, sizeof(luma_grain_));
+  }
+  if (!is_monochrome_) {
+    GenerateChromaGrains(params_, template_uv_width_, template_uv_height_,
+                         u_grain_, v_grain_);
+    if (params_.auto_regression_coeff_lag > 0 || use_luma) {
+      dsp.film_grain.chroma_auto_regression[static_cast<int>(
+          use_luma)][params_.auto_regression_coeff_lag](
+          params_, luma_grain_, subsampling_x_, subsampling_y_, u_grain_,
+          v_grain_);
+    }
+  }
+
+  // Section 7.18.3.4. Scaling lookup initialization process.
+
+  // Initialize scaling_lut_y_. If params_.num_y_points > 0, scaling_lut_y_
+  // is used for the Y plane. If params_.chroma_scaling_from_luma is true,
+  // scaling_lut_u_ and scaling_lut_v_ are the same as scaling_lut_y_ and are
+  // set up as aliases. So we need to initialize scaling_lut_y_ under these
+  // two conditions.
+  //
+  // Note: Although it does not seem to make sense, there are test vectors
+  // with chroma_scaling_from_luma=true and params_.num_y_points=0.
+  if (use_luma || params_.chroma_scaling_from_luma) {
+    dsp.film_grain.initialize_scaling_lut(
+        params_.num_y_points, params_.point_y_value, params_.point_y_scaling,
+        scaling_lut_y_);
+  } else {
+    ASAN_POISON_MEMORY_REGION(scaling_lut_y_, sizeof(scaling_lut_y_));
+  }
+  if (!is_monochrome_) {
+    if (params_.chroma_scaling_from_luma) {
+      scaling_lut_u_ = scaling_lut_y_;
+      scaling_lut_v_ = scaling_lut_y_;
+    } else if (params_.num_u_points > 0 || params_.num_v_points > 0) {
+      const size_t buffer_size =
+          (kScalingLookupTableSize + kScalingLookupTablePadding) *
+          (static_cast<int>(params_.num_u_points > 0) +
+           static_cast<int>(params_.num_v_points > 0));
+      scaling_lut_chroma_buffer_.reset(new (std::nothrow) uint8_t[buffer_size]);
+      if (scaling_lut_chroma_buffer_ == nullptr) return false;
+
+      uint8_t* buffer = scaling_lut_chroma_buffer_.get();
+      if (params_.num_u_points > 0) {
+        scaling_lut_u_ = buffer;
+        dsp.film_grain.initialize_scaling_lut(
+            params_.num_u_points, params_.point_u_value,
+            params_.point_u_scaling, scaling_lut_u_);
+        buffer += kScalingLookupTableSize + kScalingLookupTablePadding;
+      }
+      if (params_.num_v_points > 0) {
+        scaling_lut_v_ = buffer;
+        dsp.film_grain.initialize_scaling_lut(
+            params_.num_v_points, params_.point_v_value,
+            params_.point_v_scaling, scaling_lut_v_);
+      }
+    }
+  }
+  return true;
+}
+
+template <int bitdepth>
+void FilmGrain<bitdepth>::GenerateLumaGrain(const FilmGrainParams& params,
+                                            GrainType* luma_grain) {
+  // If params.num_y_points is equal to 0, Section 7.18.3.3 specifies we set
+  // the luma_grain array to all zeros. But the Note at the end of Section
+  // 7.18.3.3 says luma_grain "will never be read in this case". So we don't
+  // call GenerateLumaGrain if params.num_y_points is equal to 0.
+  assert(params.num_y_points > 0);
+  const int shift = 12 - bitdepth + params.grain_scale_shift;
+  uint16_t seed = params.grain_seed;
+  GrainType* luma_grain_row = luma_grain;
+  for (int y = 0; y < kLumaHeight; ++y) {
+    for (int x = 0; x < kLumaWidth; ++x) {
+      luma_grain_row[x] = RightShiftWithRounding(
+          kGaussianSequence[GetFilmGrainRandomNumber(11, &seed)], shift);
+    }
+    luma_grain_row += kLumaWidth;
+  }
+}
+
+template <int bitdepth>
+void FilmGrain<bitdepth>::GenerateChromaGrains(const FilmGrainParams& params,
+                                               int chroma_width,
+                                               int chroma_height,
+                                               GrainType* u_grain,
+                                               GrainType* v_grain) {
+  const int shift = 12 - bitdepth + params.grain_scale_shift;
+  if (params.num_u_points == 0 && !params.chroma_scaling_from_luma) {
+    memset(u_grain, 0, chroma_height * chroma_width * sizeof(*u_grain));
+  } else {
+    uint16_t seed = params.grain_seed ^ 0xb524;
+    GrainType* u_grain_row = u_grain;
+    assert(chroma_width > 0);
+    assert(chroma_height > 0);
+    int y = 0;
+    do {
+      int x = 0;
+      do {
+        u_grain_row[x] = RightShiftWithRounding(
+            kGaussianSequence[GetFilmGrainRandomNumber(11, &seed)], shift);
+      } while (++x < chroma_width);
+
+      u_grain_row += chroma_width;
+    } while (++y < chroma_height);
+  }
+  if (params.num_v_points == 0 && !params.chroma_scaling_from_luma) {
+    memset(v_grain, 0, chroma_height * chroma_width * sizeof(*v_grain));
+  } else {
+    GrainType* v_grain_row = v_grain;
+    uint16_t seed = params.grain_seed ^ 0x49d8;
+    int y = 0;
+    do {
+      int x = 0;
+      do {
+        v_grain_row[x] = RightShiftWithRounding(
+            kGaussianSequence[GetFilmGrainRandomNumber(11, &seed)], shift);
+      } while (++x < chroma_width);
+
+      v_grain_row += chroma_width;
+    } while (++y < chroma_height);
+  }
+}
+
+template <int bitdepth>
+bool FilmGrain<bitdepth>::AllocateNoiseStripes() {
+  const int half_height = DivideBy2(height_ + 1);
+  assert(half_height > 0);
+  // ceil(half_height / 16.0)
+  const int max_luma_num = DivideBy16(half_height + 15);
+  constexpr int kNoiseStripeHeight = 34;
+  size_t noise_buffer_size = kNoiseStripePadding;
+  if (params_.num_y_points > 0) {
+    noise_buffer_size += max_luma_num * kNoiseStripeHeight * width_;
+  }
+  if (!is_monochrome_) {
+    noise_buffer_size += 2 * max_luma_num *
+                         (kNoiseStripeHeight >> subsampling_y_) *
+                         RightShiftWithRounding(width_, subsampling_x_);
+  }
+  noise_buffer_.reset(new (std::nothrow) GrainType[noise_buffer_size]);
+  if (noise_buffer_ == nullptr) return false;
+  GrainType* noise_buffer = noise_buffer_.get();
+  if (params_.num_y_points > 0) {
+    noise_stripes_[kPlaneY].Reset(max_luma_num, kNoiseStripeHeight * width_,
+                                  noise_buffer);
+    noise_buffer += max_luma_num * kNoiseStripeHeight * width_;
+  }
+  if (!is_monochrome_) {
+    noise_stripes_[kPlaneU].Reset(
+        max_luma_num,
+        (kNoiseStripeHeight >> subsampling_y_) *
+            RightShiftWithRounding(width_, subsampling_x_),
+        noise_buffer);
+    noise_buffer += max_luma_num * (kNoiseStripeHeight >> subsampling_y_) *
+                    RightShiftWithRounding(width_, subsampling_x_);
+    noise_stripes_[kPlaneV].Reset(
+        max_luma_num,
+        (kNoiseStripeHeight >> subsampling_y_) *
+            RightShiftWithRounding(width_, subsampling_x_),
+        noise_buffer);
+  }
+  return true;
+}
+
+template <int bitdepth>
+bool FilmGrain<bitdepth>::AllocateNoiseImage() {
+  if (params_.num_y_points > 0 &&
+      !noise_image_[kPlaneY].Reset(height_, width_ + kNoiseImagePadding,
+                                   /*zero_initialize=*/false)) {
+    return false;
+  }
+  if (!is_monochrome_) {
+    if (!noise_image_[kPlaneU].Reset(
+            (height_ + subsampling_y_) >> subsampling_y_,
+            ((width_ + subsampling_x_) >> subsampling_x_) + kNoiseImagePadding,
+            /*zero_initialize=*/false)) {
+      return false;
+    }
+    if (!noise_image_[kPlaneV].Reset(
+            (height_ + subsampling_y_) >> subsampling_y_,
+            ((width_ + subsampling_x_) >> subsampling_x_) + kNoiseImagePadding,
+            /*zero_initialize=*/false)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+// Uses |overlap_flag| to skip rows that are covered by the overlap computation.
+template <int bitdepth>
+void FilmGrain<bitdepth>::ConstructNoiseImage(
+    const Array2DView<GrainType>* noise_stripes, int width, int height,
+    int subsampling_x, int subsampling_y, int stripe_start_offset,
+    Array2D<GrainType>* noise_image) {
+  const int plane_width = (width + subsampling_x) >> subsampling_x;
+  const int plane_height = (height + subsampling_y) >> subsampling_y;
+  const int stripe_height = 32 >> subsampling_y;
+  const int stripe_mask = stripe_height - 1;
+  int y = 0;
+  // |luma_num| = y >> (5 - |subsampling_y|). Hence |luma_num| == 0 for all y up
+  // to either 16 or 32.
+  const GrainType* first_noise_stripe = (*noise_stripes)[0];
+  do {
+    memcpy((*noise_image)[y], first_noise_stripe + y * plane_width,
+           plane_width * sizeof(first_noise_stripe[0]));
+  } while (++y < std::min(stripe_height, plane_height));
+  // End special iterations for luma_num == 0.
+
+  int luma_num = 1;
+  for (; y < (plane_height & ~stripe_mask); ++luma_num, y += stripe_height) {
+    const GrainType* noise_stripe = (*noise_stripes)[luma_num];
+    int i = stripe_start_offset;
+    do {
+      memcpy((*noise_image)[y + i], noise_stripe + i * plane_width,
+             plane_width * sizeof(noise_stripe[0]));
+    } while (++i < stripe_height);
+  }
+
+  // If there is a partial stripe, copy any rows beyond the overlap rows.
+  const int remaining_height = plane_height - y;
+  if (remaining_height > stripe_start_offset) {
+    assert(luma_num < noise_stripes->rows());
+    const GrainType* noise_stripe = (*noise_stripes)[luma_num];
+    int i = stripe_start_offset;
+    do {
+      memcpy((*noise_image)[y + i], noise_stripe + i * plane_width,
+             plane_width * sizeof(noise_stripe[0]));
+    } while (++i < remaining_height);
+  }
+}
+
+template <int bitdepth>
+void FilmGrain<bitdepth>::BlendNoiseChromaWorker(
+    const dsp::Dsp& dsp, const Plane* planes, int num_planes,
+    std::atomic<int>* job_counter, int min_value, int max_chroma,
+    const uint8_t* source_plane_y, ptrdiff_t source_stride_y,
+    const uint8_t* source_plane_u, const uint8_t* source_plane_v,
+    ptrdiff_t source_stride_uv, uint8_t* dest_plane_u, uint8_t* dest_plane_v,
+    ptrdiff_t dest_stride_uv) {
+  assert(num_planes > 0);
+  const int full_jobs_per_plane = height_ / kFrameChunkHeight;
+  const int remainder_job_height = height_ & (kFrameChunkHeight - 1);
+  const int total_full_jobs = full_jobs_per_plane * num_planes;
+  // If the frame height is not a multiple of kFrameChunkHeight, one job with
+  // a smaller number of rows is necessary at the end of each plane.
+  const int total_jobs =
+      total_full_jobs + ((remainder_job_height == 0) ? 0 : num_planes);
+  int job_index;
+  // Each job corresponds to a slice of kFrameChunkHeight rows in the luma
+  // plane. dsp->blend_noise_chroma handles subsampling.
+  // This loop body handles a slice of one plane or the other, depending on
+  // which are active. That way, threads working on consecutive jobs will keep
+  // the same region of luma source in working memory.
+  while ((job_index = job_counter->fetch_add(1, std::memory_order_relaxed)) <
+         total_jobs) {
+    const Plane plane = planes[job_index % num_planes];
+    const int slice_index = job_index / num_planes;
+    const int start_height = slice_index * kFrameChunkHeight;
+    const int job_height = std::min(height_ - start_height, kFrameChunkHeight);
+
+    const auto* source_cursor_y = reinterpret_cast<const Pixel*>(
+        source_plane_y + start_height * source_stride_y);
+    const uint8_t* scaling_lut_uv;
+    const uint8_t* source_plane_uv;
+    uint8_t* dest_plane_uv;
+
+    if (plane == kPlaneU) {
+      scaling_lut_uv = scaling_lut_u_;
+      source_plane_uv = source_plane_u;
+      dest_plane_uv = dest_plane_u;
+    } else {
+      assert(plane == kPlaneV);
+      scaling_lut_uv = scaling_lut_v_;
+      source_plane_uv = source_plane_v;
+      dest_plane_uv = dest_plane_v;
+    }
+    const auto* source_cursor_uv = reinterpret_cast<const Pixel*>(
+        source_plane_uv + (start_height >> subsampling_y_) * source_stride_uv);
+    auto* dest_cursor_uv = reinterpret_cast<Pixel*>(
+        dest_plane_uv + (start_height >> subsampling_y_) * dest_stride_uv);
+    dsp.film_grain.blend_noise_chroma[params_.chroma_scaling_from_luma](
+        plane, params_, noise_image_, min_value, max_chroma, width_, job_height,
+        start_height, subsampling_x_, subsampling_y_, scaling_lut_uv,
+        source_cursor_y, source_stride_y, source_cursor_uv, source_stride_uv,
+        dest_cursor_uv, dest_stride_uv);
+  }
+}
+
+template <int bitdepth>
+void FilmGrain<bitdepth>::BlendNoiseLumaWorker(
+    const dsp::Dsp& dsp, std::atomic<int>* job_counter, int min_value,
+    int max_luma, const uint8_t* source_plane_y, ptrdiff_t source_stride_y,
+    uint8_t* dest_plane_y, ptrdiff_t dest_stride_y) {
+  const int total_full_jobs = height_ / kFrameChunkHeight;
+  const int remainder_job_height = height_ & (kFrameChunkHeight - 1);
+  const int total_jobs =
+      total_full_jobs + static_cast<int>(remainder_job_height > 0);
+  int job_index;
+  // Each job is some number of rows in a plane.
+  while ((job_index = job_counter->fetch_add(1, std::memory_order_relaxed)) <
+         total_jobs) {
+    const int start_height = job_index * kFrameChunkHeight;
+    const int job_height = std::min(height_ - start_height, kFrameChunkHeight);
+
+    const auto* source_cursor_y = reinterpret_cast<const Pixel*>(
+        source_plane_y + start_height * source_stride_y);
+    auto* dest_cursor_y =
+        reinterpret_cast<Pixel*>(dest_plane_y + start_height * dest_stride_y);
+    dsp.film_grain.blend_noise_luma(
+        noise_image_, min_value, max_luma, params_.chroma_scaling, width_,
+        job_height, start_height, scaling_lut_y_, source_cursor_y,
+        source_stride_y, dest_cursor_y, dest_stride_y);
+  }
+}
+
+template <int bitdepth>
+bool FilmGrain<bitdepth>::AddNoise(
+    const uint8_t* source_plane_y, ptrdiff_t source_stride_y,
+    const uint8_t* source_plane_u, const uint8_t* source_plane_v,
+    ptrdiff_t source_stride_uv, uint8_t* dest_plane_y, ptrdiff_t dest_stride_y,
+    uint8_t* dest_plane_u, uint8_t* dest_plane_v, ptrdiff_t dest_stride_uv) {
+  if (!Init()) {
+    LIBGAV1_DLOG(ERROR, "Init() failed.");
+    return false;
+  }
+  if (!AllocateNoiseStripes()) {
+    LIBGAV1_DLOG(ERROR, "AllocateNoiseStripes() failed.");
+    return false;
+  }
+
+  const dsp::Dsp& dsp = *dsp::GetDspTable(bitdepth);
+  const bool use_luma = params_.num_y_points > 0;
+
+  // Construct noise stripes.
+  if (use_luma) {
+    // The luma plane is never subsampled.
+    dsp.film_grain
+        .construct_noise_stripes[static_cast<int>(params_.overlap_flag)](
+            luma_grain_, params_.grain_seed, width_, height_,
+            /*subsampling_x=*/0, /*subsampling_y=*/0, &noise_stripes_[kPlaneY]);
+  }
+  if (!is_monochrome_) {
+    dsp.film_grain
+        .construct_noise_stripes[static_cast<int>(params_.overlap_flag)](
+            u_grain_, params_.grain_seed, width_, height_, subsampling_x_,
+            subsampling_y_, &noise_stripes_[kPlaneU]);
+    dsp.film_grain
+        .construct_noise_stripes[static_cast<int>(params_.overlap_flag)](
+            v_grain_, params_.grain_seed, width_, height_, subsampling_x_,
+            subsampling_y_, &noise_stripes_[kPlaneV]);
+  }
+
+  if (!AllocateNoiseImage()) {
+    LIBGAV1_DLOG(ERROR, "AllocateNoiseImage() failed.");
+    return false;
+  }
+
+  // Construct noise image.
+  if (use_luma) {
+    ConstructNoiseImage(
+        &noise_stripes_[kPlaneY], width_, height_, /*subsampling_x=*/0,
+        /*subsampling_y=*/0, static_cast<int>(params_.overlap_flag) << 1,
+        &noise_image_[kPlaneY]);
+    if (params_.overlap_flag) {
+      dsp.film_grain.construct_noise_image_overlap(
+          &noise_stripes_[kPlaneY], width_, height_, /*subsampling_x=*/0,
+          /*subsampling_y=*/0, &noise_image_[kPlaneY]);
+    }
+  }
+  if (!is_monochrome_) {
+    ConstructNoiseImage(&noise_stripes_[kPlaneU], width_, height_,
+                        subsampling_x_, subsampling_y_,
+                        static_cast<int>(params_.overlap_flag)
+                            << (1 - subsampling_y_),
+                        &noise_image_[kPlaneU]);
+    ConstructNoiseImage(&noise_stripes_[kPlaneV], width_, height_,
+                        subsampling_x_, subsampling_y_,
+                        static_cast<int>(params_.overlap_flag)
+                            << (1 - subsampling_y_),
+                        &noise_image_[kPlaneV]);
+    if (params_.overlap_flag) {
+      dsp.film_grain.construct_noise_image_overlap(
+          &noise_stripes_[kPlaneU], width_, height_, subsampling_x_,
+          subsampling_y_, &noise_image_[kPlaneU]);
+      dsp.film_grain.construct_noise_image_overlap(
+          &noise_stripes_[kPlaneV], width_, height_, subsampling_x_,
+          subsampling_y_, &noise_image_[kPlaneV]);
+    }
+  }
+
+  // Blend noise image.
+  int min_value;
+  int max_luma;
+  int max_chroma;
+  if (params_.clip_to_restricted_range) {
+    min_value = 16 << (bitdepth - 8);
+    max_luma = 235 << (bitdepth - 8);
+    if (color_matrix_is_identity_) {
+      max_chroma = max_luma;
+    } else {
+      max_chroma = 240 << (bitdepth - 8);
+    }
+  } else {
+    min_value = 0;
+    max_luma = (256 << (bitdepth - 8)) - 1;
+    max_chroma = max_luma;
+  }
+
+  // Handle all chroma planes first because luma source may be altered in place.
+  if (!is_monochrome_) {
+    // This is done in a strange way but Vector can't be passed by copy to the
+    // lambda capture that spawns the thread.
+    Plane planes_to_blend[2];
+    int num_planes = 0;
+    if (params_.chroma_scaling_from_luma) {
+      // Both noise planes are computed from the luma scaling lookup table.
+      planes_to_blend[num_planes++] = kPlaneU;
+      planes_to_blend[num_planes++] = kPlaneV;
+    } else {
+      const int height_uv = RightShiftWithRounding(height_, subsampling_y_);
+      const int width_uv = RightShiftWithRounding(width_, subsampling_x_);
+
+      // Noise is applied according to a lookup table defined by pieceiwse
+      // linear "points." If the lookup table is empty, that corresponds to
+      // outputting zero noise.
+      if (params_.num_u_points == 0) {
+        CopyImagePlane<Pixel>(source_plane_u, source_stride_uv, width_uv,
+                              height_uv, dest_plane_u, dest_stride_uv);
+      } else {
+        planes_to_blend[num_planes++] = kPlaneU;
+      }
+      if (params_.num_v_points == 0) {
+        CopyImagePlane<Pixel>(source_plane_v, source_stride_uv, width_uv,
+                              height_uv, dest_plane_v, dest_stride_uv);
+      } else {
+        planes_to_blend[num_planes++] = kPlaneV;
+      }
+    }
+    if (thread_pool_ != nullptr && num_planes > 0) {
+      const int num_workers = thread_pool_->num_threads();
+      BlockingCounter pending_workers(num_workers);
+      std::atomic<int> job_counter(0);
+      for (int i = 0; i < num_workers; ++i) {
+        thread_pool_->Schedule([this, dsp, &pending_workers, &planes_to_blend,
+                                num_planes, &job_counter, min_value, max_chroma,
+                                source_plane_y, source_stride_y, source_plane_u,
+                                source_plane_v, source_stride_uv, dest_plane_u,
+                                dest_plane_v, dest_stride_uv]() {
+          BlendNoiseChromaWorker(dsp, planes_to_blend, num_planes, &job_counter,
+                                 min_value, max_chroma, source_plane_y,
+                                 source_stride_y, source_plane_u,
+                                 source_plane_v, source_stride_uv, dest_plane_u,
+                                 dest_plane_v, dest_stride_uv);
+          pending_workers.Decrement();
+        });
+      }
+      BlendNoiseChromaWorker(
+          dsp, planes_to_blend, num_planes, &job_counter, min_value, max_chroma,
+          source_plane_y, source_stride_y, source_plane_u, source_plane_v,
+          source_stride_uv, dest_plane_u, dest_plane_v, dest_stride_uv);
+
+      pending_workers.Wait();
+    } else {
+      // Single threaded.
+      if (params_.num_u_points > 0 || params_.chroma_scaling_from_luma) {
+        dsp.film_grain.blend_noise_chroma[params_.chroma_scaling_from_luma](
+            kPlaneU, params_, noise_image_, min_value, max_chroma, width_,
+            height_, /*start_height=*/0, subsampling_x_, subsampling_y_,
+            scaling_lut_u_, source_plane_y, source_stride_y, source_plane_u,
+            source_stride_uv, dest_plane_u, dest_stride_uv);
+      }
+      if (params_.num_v_points > 0 || params_.chroma_scaling_from_luma) {
+        dsp.film_grain.blend_noise_chroma[params_.chroma_scaling_from_luma](
+            kPlaneV, params_, noise_image_, min_value, max_chroma, width_,
+            height_, /*start_height=*/0, subsampling_x_, subsampling_y_,
+            scaling_lut_v_, source_plane_y, source_stride_y, source_plane_v,
+            source_stride_uv, dest_plane_v, dest_stride_uv);
+      }
+    }
+  }
+  if (use_luma) {
+    if (thread_pool_ != nullptr) {
+      const int num_workers = thread_pool_->num_threads();
+      BlockingCounter pending_workers(num_workers);
+      std::atomic<int> job_counter(0);
+      for (int i = 0; i < num_workers; ++i) {
+        thread_pool_->Schedule(
+            [this, dsp, &pending_workers, &job_counter, min_value, max_luma,
+             source_plane_y, source_stride_y, dest_plane_y, dest_stride_y]() {
+              BlendNoiseLumaWorker(dsp, &job_counter, min_value, max_luma,
+                                   source_plane_y, source_stride_y,
+                                   dest_plane_y, dest_stride_y);
+              pending_workers.Decrement();
+            });
+      }
+
+      BlendNoiseLumaWorker(dsp, &job_counter, min_value, max_luma,
+                           source_plane_y, source_stride_y, dest_plane_y,
+                           dest_stride_y);
+      pending_workers.Wait();
+    } else {
+      dsp.film_grain.blend_noise_luma(
+          noise_image_, min_value, max_luma, params_.chroma_scaling, width_,
+          height_, /*start_height=*/0, scaling_lut_y_, source_plane_y,
+          source_stride_y, dest_plane_y, dest_stride_y);
+    }
+  } else {
+    CopyImagePlane<Pixel>(source_plane_y, source_stride_y, width_, height_,
+                          dest_plane_y, dest_stride_y);
+  }
+
+  return true;
+}
+
+// Explicit instantiations.
+template class FilmGrain<8>;
+#if LIBGAV1_MAX_BITDEPTH >= 10
+template class FilmGrain<10>;
+#endif
+
+}  // namespace libgav1
diff --git a/libgav1/src/film_grain.h b/libgav1/src/film_grain.h
new file mode 100644
index 0000000..6757214
--- /dev/null
+++ b/libgav1/src/film_grain.h
@@ -0,0 +1,193 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_FILM_GRAIN_H_
+#define LIBGAV1_SRC_FILM_GRAIN_H_
+
+#include <atomic>
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <type_traits>
+
+#include "src/dsp/common.h"
+#include "src/dsp/dsp.h"
+#include "src/dsp/film_grain_common.h"
+#include "src/utils/array_2d.h"
+#include "src/utils/constants.h"
+#include "src/utils/cpu.h"
+#include "src/utils/threadpool.h"
+#include "src/utils/types.h"
+#include "src/utils/vector.h"
+
+namespace libgav1 {
+
+// Film grain synthesis function signature. Section 7.18.3.
+// This function generates film grain noise and blends the noise with the
+// decoded frame.
+// |source_plane_y|, |source_plane_u|, and |source_plane_v| are the plane
+// buffers of the decoded frame. They are blended with the film grain noise and
+// written to |dest_plane_y|, |dest_plane_u|, and |dest_plane_v| as final
+// output for display. |source_plane_p| and |dest_plane_p| (where p is y, u, or
+// v) may point to the same buffer, in which case the film grain noise is added
+// in place.
+// |film_grain_params| are parameters read from frame header.
+// |is_monochrome| is true indicates only Y plane needs to be processed.
+// |color_matrix_is_identity| is true if the matrix_coefficients field in the
+// sequence header's color config is is MC_IDENTITY.
+// |width| is the upscaled width of the frame.
+// |height| is the frame height.
+// |subsampling_x| and |subsampling_y| are subsamplings for UV planes, not used
+// if |is_monochrome| is true.
+// Returns true on success, or false on failure (e.g., out of memory).
+using FilmGrainSynthesisFunc = bool (*)(
+    const void* source_plane_y, ptrdiff_t source_stride_y,
+    const void* source_plane_u, ptrdiff_t source_stride_u,
+    const void* source_plane_v, ptrdiff_t source_stride_v,
+    const FilmGrainParams& film_grain_params, bool is_monochrome,
+    bool color_matrix_is_identity, int width, int height, int subsampling_x,
+    int subsampling_y, void* dest_plane_y, ptrdiff_t dest_stride_y,
+    void* dest_plane_u, ptrdiff_t dest_stride_u, void* dest_plane_v,
+    ptrdiff_t dest_stride_v);
+
+// Section 7.18.3.5. Add noise synthesis process.
+template <int bitdepth>
+class FilmGrain {
+ public:
+  using GrainType =
+      typename std::conditional<bitdepth == 8, int8_t, int16_t>::type;
+
+  FilmGrain(const FilmGrainParams& params, bool is_monochrome,
+            bool color_matrix_is_identity, int subsampling_x, int subsampling_y,
+            int width, int height, ThreadPool* thread_pool);
+
+  // Note: These static methods are declared public so that the unit tests can
+  // call them.
+
+  static void GenerateLumaGrain(const FilmGrainParams& params,
+                                GrainType* luma_grain);
+
+  // Generates white noise arrays u_grain and v_grain chroma_width samples wide
+  // and chroma_height samples high.
+  static void GenerateChromaGrains(const FilmGrainParams& params,
+                                   int chroma_width, int chroma_height,
+                                   GrainType* u_grain, GrainType* v_grain);
+
+  // Copies rows from |noise_stripes| to |noise_image|, skipping rows that are
+  // subject to overlap.
+  static void ConstructNoiseImage(const Array2DView<GrainType>* noise_stripes,
+                                  int width, int height, int subsampling_x,
+                                  int subsampling_y, int stripe_start_offset,
+                                  Array2D<GrainType>* noise_image);
+
+  // Combines the film grain with the image data.
+  bool AddNoise(const uint8_t* source_plane_y, ptrdiff_t source_stride_y,
+                const uint8_t* source_plane_u, const uint8_t* source_plane_v,
+                ptrdiff_t source_stride_uv, uint8_t* dest_plane_y,
+                ptrdiff_t dest_stride_y, uint8_t* dest_plane_u,
+                uint8_t* dest_plane_v, ptrdiff_t dest_stride_uv);
+
+ private:
+  using Pixel =
+      typename std::conditional<bitdepth == 8, uint8_t, uint16_t>::type;
+
+  bool Init();
+
+  // Allocates noise_stripes_.
+  bool AllocateNoiseStripes();
+
+  bool AllocateNoiseImage();
+
+  void BlendNoiseChromaWorker(const dsp::Dsp& dsp, const Plane* planes,
+                              int num_planes, std::atomic<int>* job_counter,
+                              int min_value, int max_chroma,
+                              const uint8_t* source_plane_y,
+                              ptrdiff_t source_stride_y,
+                              const uint8_t* source_plane_u,
+                              const uint8_t* source_plane_v,
+                              ptrdiff_t source_stride_uv, uint8_t* dest_plane_u,
+                              uint8_t* dest_plane_v, ptrdiff_t dest_stride_uv);
+
+  void BlendNoiseLumaWorker(const dsp::Dsp& dsp, std::atomic<int>* job_counter,
+                            int min_value, int max_luma,
+                            const uint8_t* source_plane_y,
+                            ptrdiff_t source_stride_y, uint8_t* dest_plane_y,
+                            ptrdiff_t dest_stride_y);
+
+  const FilmGrainParams& params_;
+  const bool is_monochrome_;
+  const bool color_matrix_is_identity_;
+  const int subsampling_x_;
+  const int subsampling_y_;
+  // Frame width and height.
+  const int width_;
+  const int height_;
+  // Section 7.18.3.3, Dimensions of the noise templates for chroma, which are
+  // known as CbGrain and CrGrain.
+  // These templates are used to construct the noise image for each plane by
+  // copying 32x32 blocks with pseudorandom offsets, into "noise stripes."
+  // The noise template known as LumaGrain array is an 82x73 block.
+  // The height and width of the templates for chroma become 44 and 38 under
+  // subsampling, respectively.
+  //  For more details see:
+  // A. Norkin and N. Birkbeck, "Film Grain Synthesis for AV1 Video Codec," 2018
+  // Data Compression Conference, Snowbird, UT, 2018, pp. 3-12.
+  const int template_uv_width_;
+  const int template_uv_height_;
+  // LumaGrain. The luma_grain array contains white noise generated for luma.
+  // The array size is fixed but subject to further optimization for SIMD.
+  GrainType luma_grain_[kLumaHeight * kLumaWidth];
+  // CbGrain and CrGrain. The maximum size of the u_grain and v_grain arrays is
+  // kMaxChromaHeight * kMaxChromaWidth. The actual size is
+  // template_uv_height_ * template_uv_width_.
+  GrainType u_grain_[kMaxChromaHeight * kMaxChromaWidth];
+  GrainType v_grain_[kMaxChromaHeight * kMaxChromaWidth];
+  // Scaling lookup tables.
+  uint8_t scaling_lut_y_[kScalingLookupTableSize + kScalingLookupTablePadding];
+  uint8_t* scaling_lut_u_ = nullptr;
+  uint8_t* scaling_lut_v_ = nullptr;
+  // If allocated, this buffer is 256 * 2 bytes long and scaling_lut_u_ and
+  // scaling_lut_v_ point into this buffer. Otherwise, scaling_lut_u_ and
+  // scaling_lut_v_ point to scaling_lut_y_.
+  std::unique_ptr<uint8_t[]> scaling_lut_chroma_buffer_;
+
+  // A two-dimensional array of noise data for each plane. Generated for each 32
+  // luma sample high stripe of the image. The first dimension is called
+  // luma_num. The second dimension is the size of one noise stripe.
+  //
+  // Each row of the Array2DView noise_stripes_[plane] is a conceptually
+  // two-dimensional array of |GrainType|s. The two-dimensional array of
+  // |GrainType|s is flattened into a one-dimensional buffer in this
+  // implementation.
+  //
+  // noise_stripes_[kPlaneY][luma_num] is an array that has 34 rows and
+  // |width_| columns and contains noise for the luma component.
+  //
+  // noise_stripes_[kPlaneU][luma_num] or noise_stripes_[kPlaneV][luma_num]
+  // is an array that has (34 >> subsampling_y_) rows and
+  // RightShiftWithRounding(width_, subsampling_x_) columns and contains noise
+  // for the chroma components.
+  Array2DView<GrainType> noise_stripes_[kMaxPlanes];
+  // Owns the memory that the elements of noise_stripes_ point to.
+  std::unique_ptr<GrainType[]> noise_buffer_;
+
+  Array2D<GrainType> noise_image_[kMaxPlanes];
+  ThreadPool* const thread_pool_;
+};
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_FILM_GRAIN_H_
diff --git a/libgav1/src/frame_buffer.cc b/libgav1/src/frame_buffer.cc
new file mode 100644
index 0000000..50c7756
--- /dev/null
+++ b/libgav1/src/frame_buffer.cc
@@ -0,0 +1,151 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/gav1/frame_buffer.h"
+
+#include <cstdint>
+
+#include "src/frame_buffer_utils.h"
+#include "src/utils/common.h"
+
+extern "C" {
+
+Libgav1StatusCode Libgav1ComputeFrameBufferInfo(
+    int bitdepth, Libgav1ImageFormat image_format, int width, int height,
+    int left_border, int right_border, int top_border, int bottom_border,
+    int stride_alignment, Libgav1FrameBufferInfo* info) {
+  switch (bitdepth) {
+    case 8:
+#if LIBGAV1_MAX_BITDEPTH >= 10
+    case 10:
+#endif
+#if LIBGAV1_MAX_BITDEPTH == 12
+    case 12:
+#endif
+      break;
+    default:
+      return kLibgav1StatusInvalidArgument;
+  }
+  switch (image_format) {
+    case kLibgav1ImageFormatYuv420:
+    case kLibgav1ImageFormatYuv422:
+    case kLibgav1ImageFormatYuv444:
+    case kLibgav1ImageFormatMonochrome400:
+      break;
+    default:
+      return kLibgav1StatusInvalidArgument;
+  }
+  // All int arguments must be nonnegative. Borders must be a multiple of 2.
+  // |stride_alignment| must be a power of 2.
+  if ((width | height | left_border | right_border | top_border |
+       bottom_border | stride_alignment) < 0 ||
+      ((left_border | right_border | top_border | bottom_border) & 1) != 0 ||
+      (stride_alignment & (stride_alignment - 1)) != 0 || info == nullptr) {
+    return kLibgav1StatusInvalidArgument;
+  }
+
+  bool is_monochrome;
+  int8_t subsampling_x;
+  int8_t subsampling_y;
+  libgav1::DecomposeImageFormat(image_format, &is_monochrome, &subsampling_x,
+                                &subsampling_y);
+
+  // Calculate y_stride (in bytes). It is padded to a multiple of
+  // |stride_alignment| bytes.
+  int y_stride = width + left_border + right_border;
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  if (bitdepth > 8) y_stride *= sizeof(uint16_t);
+#endif
+  y_stride = libgav1::Align(y_stride, stride_alignment);
+  // Size of the Y buffer in bytes.
+  const uint64_t y_buffer_size =
+      (height + top_border + bottom_border) * static_cast<uint64_t>(y_stride) +
+      (stride_alignment - 1);
+
+  const int uv_width =
+      is_monochrome ? 0 : libgav1::SubsampledValue(width, subsampling_x);
+  const int uv_height =
+      is_monochrome ? 0 : libgav1::SubsampledValue(height, subsampling_y);
+  const int uv_left_border = is_monochrome ? 0 : left_border >> subsampling_x;
+  const int uv_right_border = is_monochrome ? 0 : right_border >> subsampling_x;
+  const int uv_top_border = is_monochrome ? 0 : top_border >> subsampling_y;
+  const int uv_bottom_border =
+      is_monochrome ? 0 : bottom_border >> subsampling_y;
+
+  // Calculate uv_stride (in bytes). It is padded to a multiple of
+  // |stride_alignment| bytes.
+  int uv_stride = uv_width + uv_left_border + uv_right_border;
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  if (bitdepth > 8) uv_stride *= sizeof(uint16_t);
+#endif
+  uv_stride = libgav1::Align(uv_stride, stride_alignment);
+  // Size of the U or V buffer in bytes.
+  const uint64_t uv_buffer_size =
+      is_monochrome ? 0
+                    : (uv_height + uv_top_border + uv_bottom_border) *
+                              static_cast<uint64_t>(uv_stride) +
+                          (stride_alignment - 1);
+
+  // Check if it is safe to cast y_buffer_size and uv_buffer_size to size_t.
+  if (y_buffer_size > SIZE_MAX || uv_buffer_size > SIZE_MAX) {
+    return kLibgav1StatusInvalidArgument;
+  }
+
+  int left_border_bytes = left_border;
+  int uv_left_border_bytes = uv_left_border;
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  if (bitdepth > 8) {
+    left_border_bytes *= sizeof(uint16_t);
+    uv_left_border_bytes *= sizeof(uint16_t);
+  }
+#endif
+
+  info->y_stride = y_stride;
+  info->uv_stride = uv_stride;
+  info->y_buffer_size = static_cast<size_t>(y_buffer_size);
+  info->uv_buffer_size = static_cast<size_t>(uv_buffer_size);
+  info->y_plane_offset = top_border * y_stride + left_border_bytes;
+  info->uv_plane_offset = uv_top_border * uv_stride + uv_left_border_bytes;
+  info->stride_alignment = stride_alignment;
+  return kLibgav1StatusOk;
+}
+
+Libgav1StatusCode Libgav1SetFrameBuffer(const Libgav1FrameBufferInfo* info,
+                                        uint8_t* y_buffer, uint8_t* u_buffer,
+                                        uint8_t* v_buffer,
+                                        void* buffer_private_data,
+                                        Libgav1FrameBuffer* frame_buffer) {
+  if (info == nullptr ||
+      (info->uv_buffer_size == 0 &&
+       (u_buffer != nullptr || v_buffer != nullptr)) ||
+      frame_buffer == nullptr) {
+    return kLibgav1StatusInvalidArgument;
+  }
+  if (y_buffer == nullptr || (info->uv_buffer_size != 0 &&
+                              (u_buffer == nullptr || v_buffer == nullptr))) {
+    return kLibgav1StatusOutOfMemory;
+  }
+  frame_buffer->plane[0] = libgav1::AlignAddr(y_buffer + info->y_plane_offset,
+                                              info->stride_alignment);
+  frame_buffer->plane[1] = libgav1::AlignAddr(u_buffer + info->uv_plane_offset,
+                                              info->stride_alignment);
+  frame_buffer->plane[2] = libgav1::AlignAddr(v_buffer + info->uv_plane_offset,
+                                              info->stride_alignment);
+  frame_buffer->stride[0] = info->y_stride;
+  frame_buffer->stride[1] = frame_buffer->stride[2] = info->uv_stride;
+  frame_buffer->private_data = buffer_private_data;
+  return kLibgav1StatusOk;
+}
+
+}  // extern "C"
diff --git a/libgav1/src/frame_buffer.h b/libgav1/src/frame_buffer.h
deleted file mode 100644
index 8eadccb..0000000
--- a/libgav1/src/frame_buffer.h
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Copyright 2019 The libgav1 Authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBGAV1_SRC_FRAME_BUFFER_H_
-#define LIBGAV1_SRC_FRAME_BUFFER_H_
-
-// All the declarations in this file are part of the public ABI. This file may
-// be included by both C and C++ files.
-
-#include <stddef.h>
-#include <stdint.h>
-
-// The callback functions use the C linkage conventions.
-#if defined(__cplusplus)
-extern "C" {
-#endif
-
-// This structure represents an allocated frame buffer.
-struct Libgav1FrameBuffer {
-  // In the |data| and |size| arrays, the elements at indexes 0, 1, and 2 are
-  // for the Y, U, and V planes, respectively.
-  uint8_t* data[3];    // Pointers to the data buffers.
-  size_t size[3];      // Sizes of the data buffers in bytes.
-  void* private_data;  // Frame buffer's private data. Available for use by the
-                       // release frame buffer callback. Also copied to the
-                       // |buffer_private_data| field of DecoderBuffer for use
-                       // by the consumer of a DecoderBuffer.
-};
-
-// This callback is invoked by the decoder to allocate a frame buffer, which
-// consists of three data buffers, for the Y, U, and V planes, respectively.
-// |y_plane_min_size| specifies the minimum size in bytes of the Y plane data
-// buffer, and |uv_plane_min_size| specifies the minimum size in bytes of the
-// U and V plane data buffers.
-//
-// The callback must set |frame_buffer->data[i]| to point to the data buffers,
-// and set |frame_buffer->size[i]| to the actual sizes of the data buffers. The
-// callback may set |frame_buffer->private_data| to a value that will be useful
-// to the release frame buffer callback and the consumer of a DecoderBuffer.
-//
-// Returns 0 on success, -1 on failure.
-typedef int (*Libgav1GetFrameBufferCallback)(void* private_data,
-                                             size_t y_plane_min_size,
-                                             size_t uv_plane_min_size,
-                                             Libgav1FrameBuffer* frame_buffer);
-
-// This callback is invoked by the decoder to release a frame buffer.
-//
-// Returns 0 on success, -1 on failure.
-typedef int (*Libgav1ReleaseFrameBufferCallback)(
-    void* private_data, Libgav1FrameBuffer* frame_buffer);
-
-#if defined(__cplusplus)
-}  // extern "C"
-
-// Declare type aliases for C++.
-namespace libgav1 {
-
-using FrameBuffer = Libgav1FrameBuffer;
-using GetFrameBufferCallback = Libgav1GetFrameBufferCallback;
-using ReleaseFrameBufferCallback = Libgav1ReleaseFrameBufferCallback;
-
-}  // namespace libgav1
-#endif  // defined(__cplusplus)
-
-#endif  // LIBGAV1_SRC_FRAME_BUFFER_H_
diff --git a/libgav1/src/frame_buffer_utils.h b/libgav1/src/frame_buffer_utils.h
new file mode 100644
index 0000000..d41437e
--- /dev/null
+++ b/libgav1/src/frame_buffer_utils.h
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_FRAME_BUFFER_UTILS_H_
+#define LIBGAV1_SRC_FRAME_BUFFER_UTILS_H_
+
+#include <cassert>
+#include <cstdint>
+
+#include "src/gav1/decoder_buffer.h"
+
+namespace libgav1 {
+
+// The following table is from Section 6.4.2 of the spec.
+//
+// subsampling_x  subsampling_y  mono_chrome  Description
+// -----------------------------------------------------------
+// 0              0              0            YUV 4:4:4
+// 1              0              0            YUV 4:2:2
+// 1              1              0            YUV 4:2:0
+// 1              1              1            Monochrome 4:0:0
+
+inline Libgav1ImageFormat ComposeImageFormat(bool is_monochrome,
+                                             int8_t subsampling_x,
+                                             int8_t subsampling_y) {
+  Libgav1ImageFormat image_format;
+  if (subsampling_x == 0) {
+    assert(subsampling_y == 0 && !is_monochrome);
+    image_format = kLibgav1ImageFormatYuv444;
+  } else if (subsampling_y == 0) {
+    assert(!is_monochrome);
+    image_format = kLibgav1ImageFormatYuv422;
+  } else if (!is_monochrome) {
+    image_format = kLibgav1ImageFormatYuv420;
+  } else {
+    image_format = kLibgav1ImageFormatMonochrome400;
+  }
+  return image_format;
+}
+
+inline void DecomposeImageFormat(Libgav1ImageFormat image_format,
+                                 bool* is_monochrome, int8_t* subsampling_x,
+                                 int8_t* subsampling_y) {
+  *is_monochrome = false;
+  *subsampling_x = 1;
+  *subsampling_y = 1;
+  switch (image_format) {
+    case kLibgav1ImageFormatYuv420:
+      break;
+    case kLibgav1ImageFormatYuv422:
+      *subsampling_y = 0;
+      break;
+    case kLibgav1ImageFormatYuv444:
+      *subsampling_x = *subsampling_y = 0;
+      break;
+    default:
+      assert(image_format == kLibgav1ImageFormatMonochrome400);
+      *is_monochrome = true;
+      break;
+  }
+}
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_FRAME_BUFFER_UTILS_H_
diff --git a/libgav1/src/frame_scratch_buffer.h b/libgav1/src/frame_scratch_buffer.h
new file mode 100644
index 0000000..1d6a1f4
--- /dev/null
+++ b/libgav1/src/frame_scratch_buffer.h
@@ -0,0 +1,115 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_FRAME_SCRATCH_BUFFER_H_
+#define LIBGAV1_SRC_FRAME_SCRATCH_BUFFER_H_
+
+#include <condition_variable>  // NOLINT (unapproved c++11 header)
+#include <cstdint>
+#include <memory>
+#include <mutex>  // NOLINT (unapproved c++11 header)
+
+#include "src/loop_restoration_info.h"
+#include "src/residual_buffer_pool.h"
+#include "src/symbol_decoder_context.h"
+#include "src/threading_strategy.h"
+#include "src/tile_scratch_buffer.h"
+#include "src/utils/array_2d.h"
+#include "src/utils/block_parameters_holder.h"
+#include "src/utils/compiler_attributes.h"
+#include "src/utils/constants.h"
+#include "src/utils/dynamic_buffer.h"
+#include "src/utils/memory.h"
+#include "src/utils/stack.h"
+#include "src/utils/types.h"
+#include "src/yuv_buffer.h"
+
+namespace libgav1 {
+
+// Buffer used to store the unfiltered pixels that are necessary for decoding
+// the next superblock row (for the intra prediction process).
+using IntraPredictionBuffer =
+    std::array<AlignedDynamicBuffer<uint8_t, kMaxAlignment>, kMaxPlanes>;
+
+// Buffer to facilitate decoding a frame. This struct is used only within
+// DecoderImpl::DecodeTiles().
+struct FrameScratchBuffer {
+  LoopRestorationInfo loop_restoration_info;
+  Array2D<int16_t> cdef_index;
+  Array2D<TransformSize> inter_transform_sizes;
+  BlockParametersHolder block_parameters_holder;
+  TemporalMotionField motion_field;
+  SymbolDecoderContext symbol_decoder_context;
+  std::unique_ptr<ResidualBufferPool> residual_buffer_pool;
+  // threaded_window_buffer will be subdivided by PostFilter into windows of
+  // width 512 pixels. Each row in the window is filtered by a worker thread.
+  // To avoid false sharing, each 512-pixel row processed by one thread should
+  // not share a cache line with a row processed by another thread. So we align
+  // threaded_window_buffer to the cache line size. In addition, it is faster to
+  // memcpy from an aligned buffer.
+  AlignedDynamicBuffer<uint8_t, kCacheLineSize> threaded_window_buffer;
+  // Buffer used to temporarily store the input row for applying SuperRes.
+  AlignedDynamicBuffer<uint8_t, 16> superres_line_buffer;
+  // Buffer used to store the deblocked pixels that are necessary for loop
+  // restoration. This buffer will store 4 rows for every 64x64 block (4 rows
+  // for every 32x32 for chroma with subsampling). The indices of the rows that
+  // are stored are specified in |kDeblockedRowsForLoopRestoration|.
+  YuvBuffer deblock_buffer;
+  // The size of this dynamic buffer is |tile_rows|.
+  DynamicBuffer<IntraPredictionBuffer> intra_prediction_buffers;
+  TileScratchBufferPool tile_scratch_buffer_pool;
+  ThreadingStrategy threading_strategy;
+  std::mutex superblock_row_mutex;
+  // The size of this buffer is the number of superblock rows.
+  // |superblock_row_progress[i]| is incremented whenever a tile finishes
+  // decoding superblock row at index i. If the count reaches tile_columns, then
+  // |superblock_row_progress_condvar[i]| is notified.
+  DynamicBuffer<int> superblock_row_progress
+      LIBGAV1_GUARDED_BY(superblock_row_mutex);
+  // The size of this buffer is the number of superblock rows. Used to wait for
+  // |superblock_row_progress[i]| to reach tile_columns.
+  DynamicBuffer<std::condition_variable> superblock_row_progress_condvar;
+  // Used to signal tile decoding failure in the combined multithreading mode.
+  bool tile_decoding_failed LIBGAV1_GUARDED_BY(superblock_row_mutex);
+};
+
+class FrameScratchBufferPool {
+ public:
+  std::unique_ptr<FrameScratchBuffer> Get() {
+    std::unique_lock<std::mutex> lock(mutex_);
+    if (!buffers_.Empty()) {
+      return buffers_.Pop();
+    }
+    lock.unlock();
+    std::unique_ptr<FrameScratchBuffer> scratch_buffer(new (std::nothrow)
+                                                           FrameScratchBuffer);
+    return scratch_buffer;
+  }
+
+  void Release(std::unique_ptr<FrameScratchBuffer> scratch_buffer) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    buffers_.Push(std::move(scratch_buffer));
+  }
+
+ private:
+  std::mutex mutex_;
+  Stack<std::unique_ptr<FrameScratchBuffer>, kMaxThreads> buffers_
+      LIBGAV1_GUARDED_BY(mutex_);
+};
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_FRAME_SCRATCH_BUFFER_H_
diff --git a/libgav1/src/gav1/decoder.h b/libgav1/src/gav1/decoder.h
new file mode 100644
index 0000000..da08da9
--- /dev/null
+++ b/libgav1/src/gav1/decoder.h
@@ -0,0 +1,148 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_GAV1_DECODER_H_
+#define LIBGAV1_SRC_GAV1_DECODER_H_
+
+#if defined(__cplusplus)
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#else
+#include <stddef.h>
+#include <stdint.h>
+#endif  // defined(__cplusplus)
+
+// IWYU pragma: begin_exports
+#include "gav1/decoder_buffer.h"
+#include "gav1/decoder_settings.h"
+#include "gav1/frame_buffer.h"
+#include "gav1/status_code.h"
+#include "gav1/symbol_visibility.h"
+#include "gav1/version.h"
+// IWYU pragma: end_exports
+
+#if defined(__cplusplus)
+extern "C" {
+#endif
+
+struct Libgav1Decoder;
+typedef struct Libgav1Decoder Libgav1Decoder;
+
+LIBGAV1_PUBLIC Libgav1StatusCode Libgav1DecoderCreate(
+    const Libgav1DecoderSettings* settings, Libgav1Decoder** decoder_out);
+
+LIBGAV1_PUBLIC void Libgav1DecoderDestroy(Libgav1Decoder* decoder);
+
+LIBGAV1_PUBLIC Libgav1StatusCode Libgav1DecoderEnqueueFrame(
+    Libgav1Decoder* decoder, const uint8_t* data, size_t size,
+    int64_t user_private_data, void* buffer_private_data);
+
+LIBGAV1_PUBLIC Libgav1StatusCode Libgav1DecoderDequeueFrame(
+    Libgav1Decoder* decoder, const Libgav1DecoderBuffer** out_ptr);
+
+LIBGAV1_PUBLIC Libgav1StatusCode
+Libgav1DecoderSignalEOS(Libgav1Decoder* decoder);
+
+LIBGAV1_PUBLIC int Libgav1DecoderGetMaxBitdepth(void);
+
+#if defined(__cplusplus)
+}  // extern "C"
+
+namespace libgav1 {
+
+// Forward declaration.
+class DecoderImpl;
+
+class LIBGAV1_PUBLIC Decoder {
+ public:
+  Decoder();
+  ~Decoder();
+
+  // Init must be called exactly once per instance. Subsequent calls will do
+  // nothing. If |settings| is nullptr, the decoder will be initialized with
+  // default settings. Returns kStatusOk on success, an error status otherwise.
+  StatusCode Init(const DecoderSettings* settings);
+
+  // Enqueues a compressed frame to be decoded.
+  //
+  // This function returns:
+  //   * kStatusOk on success
+  //   * kStatusTryAgain if the decoder queue is full
+  //   * an error status otherwise.
+  //
+  // |user_private_data| may be used to associate application specific private
+  // data with the compressed frame. It will be copied to the user_private_data
+  // field of the DecoderBuffer returned by the corresponding |DequeueFrame()|
+  // call.
+  //
+  // NOTE: |EnqueueFrame()| does not copy the data. Therefore, after a
+  // successful |EnqueueFrame()| call, the caller must keep the |data| buffer
+  // alive until:
+  // 1) If |settings_.release_input_buffer| is not nullptr, then |data| buffer
+  // must be kept alive until release_input_buffer is called with the
+  // |buffer_private_data| passed into this EnqueueFrame call.
+  // 2) If |settings_.release_input_buffer| is nullptr, then |data| buffer must
+  // be kept alive until the corresponding DequeueFrame() call is completed.
+  //
+  // If the call to |EnqueueFrame()| is not successful, then libgav1 will not
+  // hold any references to the |data| buffer. |settings_.release_input_buffer|
+  // callback will not be called in that case.
+  StatusCode EnqueueFrame(const uint8_t* data, size_t size,
+                          int64_t user_private_data, void* buffer_private_data);
+
+  // Dequeues a decompressed frame. If there are enqueued compressed frames,
+  // decodes one and sets |*out_ptr| to the last displayable frame in the
+  // compressed frame. If there are no displayable frames available, sets
+  // |*out_ptr| to nullptr.
+  //
+  // Returns kStatusOk on success. Returns kStatusNothingToDequeue if there are
+  // no enqueued frames (in this case out_ptr will always be set to nullptr).
+  // Returns one of the other error statuses if there is an error.
+  //
+  // If |settings_.blocking_dequeue| is false and the decoder is operating in
+  // frame parallel mode (|settings_.frame_parallel| is true and the video
+  // stream passes the decoder's heuristics for enabling frame parallel mode),
+  // then this call will return kStatusTryAgain if an enqueued frame is not yet
+  // decoded (it is a non blocking call in this case). In all other cases, this
+  // call will block until an enqueued frame has been decoded.
+  StatusCode DequeueFrame(const DecoderBuffer** out_ptr);
+
+  // Signals the end of stream.
+  //
+  // In non-frame-parallel mode, this function will release all the frames held
+  // by the decoder. If the frame buffers were allocated by libgav1, then the
+  // pointer obtained by the prior DequeueFrame call will no longer be valid. If
+  // the frame buffers were allocated by the application, then any references
+  // that libgav1 is holding on to will be released.
+  //
+  // Once this function returns successfully, the decoder state will be reset
+  // and the decoder is ready to start decoding a new coded video sequence.
+  StatusCode SignalEOS();
+
+  // Returns the maximum bitdepth that is supported by this decoder.
+  static int GetMaxBitdepth();
+
+ private:
+  DecoderSettings settings_;
+  // The object is initialized if and only if impl_ != nullptr.
+  std::unique_ptr<DecoderImpl> impl_;
+};
+
+}  // namespace libgav1
+#endif  // defined(__cplusplus)
+
+#endif  // LIBGAV1_SRC_GAV1_DECODER_H_
diff --git a/libgav1/src/gav1/decoder_buffer.h b/libgav1/src/gav1/decoder_buffer.h
new file mode 100644
index 0000000..37bcb29
--- /dev/null
+++ b/libgav1/src/gav1/decoder_buffer.h
@@ -0,0 +1,279 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_GAV1_DECODER_BUFFER_H_
+#define LIBGAV1_SRC_GAV1_DECODER_BUFFER_H_
+
+#if defined(__cplusplus)
+#include <cstdint>
+#else
+#include <stdint.h>
+#endif  // defined(__cplusplus)
+
+#include "gav1/symbol_visibility.h"
+
+// All the declarations in this file are part of the public ABI.
+
+// The documentation for the enum values in this file can be found in Section
+// 6.4.2 of the AV1 spec.
+
+typedef enum Libgav1ChromaSamplePosition {
+  kLibgav1ChromaSamplePositionUnknown,
+  kLibgav1ChromaSamplePositionVertical,
+  kLibgav1ChromaSamplePositionColocated,
+  kLibgav1ChromaSamplePositionReserved
+} Libgav1ChromaSamplePosition;
+
+typedef enum Libgav1ImageFormat {
+  kLibgav1ImageFormatYuv420,
+  kLibgav1ImageFormatYuv422,
+  kLibgav1ImageFormatYuv444,
+  kLibgav1ImageFormatMonochrome400
+} Libgav1ImageFormat;
+
+typedef enum Libgav1ColorPrimary {
+  // 0 is reserved.
+  kLibgav1ColorPrimaryBt709 = 1,
+  kLibgav1ColorPrimaryUnspecified,
+  // 3 is reserved.
+  kLibgav1ColorPrimaryBt470M = 4,
+  kLibgav1ColorPrimaryBt470Bg,
+  kLibgav1ColorPrimaryBt601,
+  kLibgav1ColorPrimarySmpte240,
+  kLibgav1ColorPrimaryGenericFilm,
+  kLibgav1ColorPrimaryBt2020,
+  kLibgav1ColorPrimaryXyz,
+  kLibgav1ColorPrimarySmpte431,
+  kLibgav1ColorPrimarySmpte432,
+  // 13-21 are reserved.
+  kLibgav1ColorPrimaryEbu3213 = 22,
+  // 23-254 are reserved.
+  kLibgav1MaxColorPrimaries = 255
+} Libgav1ColorPrimary;
+
+typedef enum Libgav1TransferCharacteristics {
+  // 0 is reserved.
+  kLibgav1TransferCharacteristicsBt709 = 1,
+  kLibgav1TransferCharacteristicsUnspecified,
+  // 3 is reserved.
+  kLibgav1TransferCharacteristicsBt470M = 4,
+  kLibgav1TransferCharacteristicsBt470Bg,
+  kLibgav1TransferCharacteristicsBt601,
+  kLibgav1TransferCharacteristicsSmpte240,
+  kLibgav1TransferCharacteristicsLinear,
+  kLibgav1TransferCharacteristicsLog100,
+  kLibgav1TransferCharacteristicsLog100Sqrt10,
+  kLibgav1TransferCharacteristicsIec61966,
+  kLibgav1TransferCharacteristicsBt1361,
+  kLibgav1TransferCharacteristicsSrgb,
+  kLibgav1TransferCharacteristicsBt2020TenBit,
+  kLibgav1TransferCharacteristicsBt2020TwelveBit,
+  kLibgav1TransferCharacteristicsSmpte2084,
+  kLibgav1TransferCharacteristicsSmpte428,
+  kLibgav1TransferCharacteristicsHlg,
+  // 19-254 are reserved.
+  kLibgav1MaxTransferCharacteristics = 255
+} Libgav1TransferCharacteristics;
+
+typedef enum Libgav1MatrixCoefficients {
+  kLibgav1MatrixCoefficientsIdentity,
+  kLibgav1MatrixCoefficientsBt709,
+  kLibgav1MatrixCoefficientsUnspecified,
+  // 3 is reserved.
+  kLibgav1MatrixCoefficientsFcc = 4,
+  kLibgav1MatrixCoefficientsBt470BG,
+  kLibgav1MatrixCoefficientsBt601,
+  kLibgav1MatrixCoefficientsSmpte240,
+  kLibgav1MatrixCoefficientsSmpteYcgco,
+  kLibgav1MatrixCoefficientsBt2020Ncl,
+  kLibgav1MatrixCoefficientsBt2020Cl,
+  kLibgav1MatrixCoefficientsSmpte2085,
+  kLibgav1MatrixCoefficientsChromatNcl,
+  kLibgav1MatrixCoefficientsChromatCl,
+  kLibgav1MatrixCoefficientsIctcp,
+  // 15-254 are reserved.
+  kLibgav1MaxMatrixCoefficients = 255
+} Libgav1MatrixCoefficients;
+
+typedef enum Libgav1ColorRange {
+  // The color ranges are scaled by value << (bitdepth - 8) for 10 and 12bit
+  // streams.
+  kLibgav1ColorRangeStudio,  // Y [16..235], UV [16..240]
+  kLibgav1ColorRangeFull     // YUV/RGB [0..255]
+} Libgav1ColorRange;
+
+typedef struct Libgav1DecoderBuffer {
+#if defined(__cplusplus)
+  LIBGAV1_PUBLIC int NumPlanes() const {
+    return (image_format == kLibgav1ImageFormatMonochrome400) ? 1 : 3;
+  }
+#endif  // defined(__cplusplus)
+
+  Libgav1ChromaSamplePosition chroma_sample_position;
+  Libgav1ImageFormat image_format;
+  Libgav1ColorRange color_range;
+  Libgav1ColorPrimary color_primary;
+  Libgav1TransferCharacteristics transfer_characteristics;
+  Libgav1MatrixCoefficients matrix_coefficients;
+
+  // Image storage dimensions.
+  // NOTE: These fields are named w and h in vpx_image_t and aom_image_t.
+  // uint32_t width;  // Stored image width.
+  // uint32_t height;  // Stored image height.
+  int bitdepth;  // Stored image bitdepth.
+
+  // Image display dimensions.
+  // NOTES:
+  // 1. These fields are named d_w and d_h in vpx_image_t and aom_image_t.
+  // 2. libvpx and libaom clients use d_w and d_h much more often than w and h.
+  // 3. These fields can just be stored for the Y plane and the clients can
+  //    calculate the values for the U and V planes if the image format or
+  //    subsampling is exposed.
+  int displayed_width[3];   // Displayed image width.
+  int displayed_height[3];  // Displayed image height.
+
+  int stride[3];
+  uint8_t* plane[3];
+
+  // Spatial id of this frame.
+  int spatial_id;
+  // Temporal id of this frame.
+  int temporal_id;
+
+  // The |user_private_data| argument passed to Decoder::EnqueueFrame().
+  int64_t user_private_data;
+  // The |private_data| field of FrameBuffer. Set by the get frame buffer
+  // callback when it allocates a frame buffer.
+  void* buffer_private_data;
+} Libgav1DecoderBuffer;
+
+#if defined(__cplusplus)
+namespace libgav1 {
+
+using ChromaSamplePosition = Libgav1ChromaSamplePosition;
+constexpr ChromaSamplePosition kChromaSamplePositionUnknown =
+    kLibgav1ChromaSamplePositionUnknown;
+constexpr ChromaSamplePosition kChromaSamplePositionVertical =
+    kLibgav1ChromaSamplePositionVertical;
+constexpr ChromaSamplePosition kChromaSamplePositionColocated =
+    kLibgav1ChromaSamplePositionColocated;
+constexpr ChromaSamplePosition kChromaSamplePositionReserved =
+    kLibgav1ChromaSamplePositionReserved;
+
+using ImageFormat = Libgav1ImageFormat;
+constexpr ImageFormat kImageFormatYuv420 = kLibgav1ImageFormatYuv420;
+constexpr ImageFormat kImageFormatYuv422 = kLibgav1ImageFormatYuv422;
+constexpr ImageFormat kImageFormatYuv444 = kLibgav1ImageFormatYuv444;
+constexpr ImageFormat kImageFormatMonochrome400 =
+    kLibgav1ImageFormatMonochrome400;
+
+using ColorPrimary = Libgav1ColorPrimary;
+constexpr ColorPrimary kColorPrimaryBt709 = kLibgav1ColorPrimaryBt709;
+constexpr ColorPrimary kColorPrimaryUnspecified =
+    kLibgav1ColorPrimaryUnspecified;
+constexpr ColorPrimary kColorPrimaryBt470M = kLibgav1ColorPrimaryBt470M;
+constexpr ColorPrimary kColorPrimaryBt470Bg = kLibgav1ColorPrimaryBt470Bg;
+constexpr ColorPrimary kColorPrimaryBt601 = kLibgav1ColorPrimaryBt601;
+constexpr ColorPrimary kColorPrimarySmpte240 = kLibgav1ColorPrimarySmpte240;
+constexpr ColorPrimary kColorPrimaryGenericFilm =
+    kLibgav1ColorPrimaryGenericFilm;
+constexpr ColorPrimary kColorPrimaryBt2020 = kLibgav1ColorPrimaryBt2020;
+constexpr ColorPrimary kColorPrimaryXyz = kLibgav1ColorPrimaryXyz;
+constexpr ColorPrimary kColorPrimarySmpte431 = kLibgav1ColorPrimarySmpte431;
+constexpr ColorPrimary kColorPrimarySmpte432 = kLibgav1ColorPrimarySmpte432;
+constexpr ColorPrimary kColorPrimaryEbu3213 = kLibgav1ColorPrimaryEbu3213;
+constexpr ColorPrimary kMaxColorPrimaries = kLibgav1MaxColorPrimaries;
+
+using TransferCharacteristics = Libgav1TransferCharacteristics;
+constexpr TransferCharacteristics kTransferCharacteristicsBt709 =
+    kLibgav1TransferCharacteristicsBt709;
+constexpr TransferCharacteristics kTransferCharacteristicsUnspecified =
+    kLibgav1TransferCharacteristicsUnspecified;
+constexpr TransferCharacteristics kTransferCharacteristicsBt470M =
+    kLibgav1TransferCharacteristicsBt470M;
+constexpr TransferCharacteristics kTransferCharacteristicsBt470Bg =
+    kLibgav1TransferCharacteristicsBt470Bg;
+constexpr TransferCharacteristics kTransferCharacteristicsBt601 =
+    kLibgav1TransferCharacteristicsBt601;
+constexpr TransferCharacteristics kTransferCharacteristicsSmpte240 =
+    kLibgav1TransferCharacteristicsSmpte240;
+constexpr TransferCharacteristics kTransferCharacteristicsLinear =
+    kLibgav1TransferCharacteristicsLinear;
+constexpr TransferCharacteristics kTransferCharacteristicsLog100 =
+    kLibgav1TransferCharacteristicsLog100;
+constexpr TransferCharacteristics kTransferCharacteristicsLog100Sqrt10 =
+    kLibgav1TransferCharacteristicsLog100Sqrt10;
+constexpr TransferCharacteristics kTransferCharacteristicsIec61966 =
+    kLibgav1TransferCharacteristicsIec61966;
+constexpr TransferCharacteristics kTransferCharacteristicsBt1361 =
+    kLibgav1TransferCharacteristicsBt1361;
+constexpr TransferCharacteristics kTransferCharacteristicsSrgb =
+    kLibgav1TransferCharacteristicsSrgb;
+constexpr TransferCharacteristics kTransferCharacteristicsBt2020TenBit =
+    kLibgav1TransferCharacteristicsBt2020TenBit;
+constexpr TransferCharacteristics kTransferCharacteristicsBt2020TwelveBit =
+    kLibgav1TransferCharacteristicsBt2020TwelveBit;
+constexpr TransferCharacteristics kTransferCharacteristicsSmpte2084 =
+    kLibgav1TransferCharacteristicsSmpte2084;
+constexpr TransferCharacteristics kTransferCharacteristicsSmpte428 =
+    kLibgav1TransferCharacteristicsSmpte428;
+constexpr TransferCharacteristics kTransferCharacteristicsHlg =
+    kLibgav1TransferCharacteristicsHlg;
+constexpr TransferCharacteristics kMaxTransferCharacteristics =
+    kLibgav1MaxTransferCharacteristics;
+
+using MatrixCoefficients = Libgav1MatrixCoefficients;
+constexpr MatrixCoefficients kMatrixCoefficientsIdentity =
+    kLibgav1MatrixCoefficientsIdentity;
+constexpr MatrixCoefficients kMatrixCoefficientsBt709 =
+    kLibgav1MatrixCoefficientsBt709;
+constexpr MatrixCoefficients kMatrixCoefficientsUnspecified =
+    kLibgav1MatrixCoefficientsUnspecified;
+constexpr MatrixCoefficients kMatrixCoefficientsFcc =
+    kLibgav1MatrixCoefficientsFcc;
+constexpr MatrixCoefficients kMatrixCoefficientsBt470BG =
+    kLibgav1MatrixCoefficientsBt470BG;
+constexpr MatrixCoefficients kMatrixCoefficientsBt601 =
+    kLibgav1MatrixCoefficientsBt601;
+constexpr MatrixCoefficients kMatrixCoefficientsSmpte240 =
+    kLibgav1MatrixCoefficientsSmpte240;
+constexpr MatrixCoefficients kMatrixCoefficientsSmpteYcgco =
+    kLibgav1MatrixCoefficientsSmpteYcgco;
+constexpr MatrixCoefficients kMatrixCoefficientsBt2020Ncl =
+    kLibgav1MatrixCoefficientsBt2020Ncl;
+constexpr MatrixCoefficients kMatrixCoefficientsBt2020Cl =
+    kLibgav1MatrixCoefficientsBt2020Cl;
+constexpr MatrixCoefficients kMatrixCoefficientsSmpte2085 =
+    kLibgav1MatrixCoefficientsSmpte2085;
+constexpr MatrixCoefficients kMatrixCoefficientsChromatNcl =
+    kLibgav1MatrixCoefficientsChromatNcl;
+constexpr MatrixCoefficients kMatrixCoefficientsChromatCl =
+    kLibgav1MatrixCoefficientsChromatCl;
+constexpr MatrixCoefficients kMatrixCoefficientsIctcp =
+    kLibgav1MatrixCoefficientsIctcp;
+constexpr MatrixCoefficients kMaxMatrixCoefficients =
+    kLibgav1MaxMatrixCoefficients;
+
+using ColorRange = Libgav1ColorRange;
+constexpr ColorRange kColorRangeStudio = kLibgav1ColorRangeStudio;
+constexpr ColorRange kColorRangeFull = kLibgav1ColorRangeFull;
+
+using DecoderBuffer = Libgav1DecoderBuffer;
+
+}  // namespace libgav1
+#endif  // defined(__cplusplus)
+
+#endif  // LIBGAV1_SRC_GAV1_DECODER_BUFFER_H_
diff --git a/libgav1/src/gav1/decoder_settings.h b/libgav1/src/gav1/decoder_settings.h
new file mode 100644
index 0000000..ab22a4d
--- /dev/null
+++ b/libgav1/src/gav1/decoder_settings.h
@@ -0,0 +1,144 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_GAV1_DECODER_SETTINGS_H_
+#define LIBGAV1_SRC_GAV1_DECODER_SETTINGS_H_
+
+#if defined(__cplusplus)
+#include <cstdint>
+#else
+#include <stdint.h>
+#endif  // defined(__cplusplus)
+
+#include "gav1/frame_buffer.h"
+#include "gav1/symbol_visibility.h"
+
+// All the declarations in this file are part of the public ABI.
+
+#if defined(__cplusplus)
+extern "C" {
+#endif
+
+// This callback is invoked by the decoder when it is done using an input frame
+// buffer. When frame_parallel is set to true, this callback must not be
+// nullptr. Otherwise, this callback is optional.
+//
+// |buffer_private_data| is the value passed in the EnqueueFrame() call.
+typedef void (*Libgav1ReleaseInputBufferCallback)(void* callback_private_data,
+                                                  void* buffer_private_data);
+
+typedef struct Libgav1DecoderSettings {
+  // Number of threads to use when decoding. Must be greater than 0. The library
+  // will create at most |threads| new threads. Defaults to 1 (no new threads
+  // will be created).
+  int threads;
+  // A boolean. Indicate to the decoder that frame parallel decoding is allowed.
+  // Note that this is just a request and the decoder will decide the number of
+  // frames to be decoded in parallel based on the video stream being decoded.
+  int frame_parallel;
+  // A boolean. In frame parallel mode, should Libgav1DecoderDequeueFrame wait
+  // until a enqueued frame is available for dequeueing.
+  //
+  // If frame_parallel is 0, this setting is ignored.
+  int blocking_dequeue;
+  // Called when the first sequence header or a sequence header with a
+  // different frame size (which includes bitdepth, monochrome, subsampling_x,
+  // subsampling_y, maximum frame width, or maximum frame height) is received.
+  Libgav1FrameBufferSizeChangedCallback on_frame_buffer_size_changed;
+  // Get frame buffer callback.
+  Libgav1GetFrameBufferCallback get_frame_buffer;
+  // Release frame buffer callback.
+  Libgav1ReleaseFrameBufferCallback release_frame_buffer;
+  // Release input frame buffer callback.
+  Libgav1ReleaseInputBufferCallback release_input_buffer;
+  // Passed as the private_data argument to the callbacks.
+  void* callback_private_data;
+  // A boolean. If set to 1, the decoder will output all the spatial and
+  // temporal layers.
+  int output_all_layers;
+  // Index of the operating point to decode.
+  int operating_point;
+  // Mask indicating the post processing filters that need to be applied to the
+  // reconstructed frame. Note this is an advanced setting and does not
+  // typically need to be changed.
+  // From LSB:
+  //   Bit 0: Loop filter (deblocking filter).
+  //   Bit 1: Cdef.
+  //   Bit 2: SuperRes.
+  //   Bit 3: Loop restoration.
+  //   Bit 4: Film grain synthesis.
+  //   All the bits other than the last 5 are ignored.
+  uint8_t post_filter_mask;
+} Libgav1DecoderSettings;
+
+LIBGAV1_PUBLIC void Libgav1DecoderSettingsInitDefault(
+    Libgav1DecoderSettings* settings);
+
+#if defined(__cplusplus)
+}  // extern "C"
+
+namespace libgav1 {
+
+using ReleaseInputBufferCallback = Libgav1ReleaseInputBufferCallback;
+
+// Applications must populate this structure before creating a decoder instance.
+struct DecoderSettings {
+  // Number of threads to use when decoding. Must be greater than 0. The library
+  // will create at most |threads| new threads. Defaults to 1 (no new threads
+  // will be created).
+  int threads = 1;
+  // Indicate to the decoder that frame parallel decoding is allowed. Note that
+  // this is just a request and the decoder will decide the number of frames to
+  // be decoded in parallel based on the video stream being decoded.
+  bool frame_parallel = false;
+  // In frame parallel mode, should DequeueFrame wait until a enqueued frame is
+  // available for dequeueing.
+  //
+  // If frame_parallel is false, this setting is ignored.
+  bool blocking_dequeue = false;
+  // Called when the first sequence header or a sequence header with a
+  // different frame size (which includes bitdepth, monochrome, subsampling_x,
+  // subsampling_y, maximum frame width, or maximum frame height) is received.
+  FrameBufferSizeChangedCallback on_frame_buffer_size_changed = nullptr;
+  // Get frame buffer callback.
+  GetFrameBufferCallback get_frame_buffer = nullptr;
+  // Release frame buffer callback.
+  ReleaseFrameBufferCallback release_frame_buffer = nullptr;
+  // Release input frame buffer callback.
+  ReleaseInputBufferCallback release_input_buffer = nullptr;
+  // Passed as the private_data argument to the callbacks.
+  void* callback_private_data = nullptr;
+  // If set to true, the decoder will output all the spatial and temporal
+  // layers.
+  bool output_all_layers = false;
+  // Index of the operating point to decode.
+  int operating_point = 0;
+  // Mask indicating the post processing filters that need to be applied to the
+  // reconstructed frame. Note this is an advanced setting and does not
+  // typically need to be changed.
+  // From LSB:
+  //   Bit 0: Loop filter (deblocking filter).
+  //   Bit 1: Cdef.
+  //   Bit 2: SuperRes.
+  //   Bit 3: Loop restoration.
+  //   Bit 4: Film grain synthesis.
+  //   All the bits other than the last 5 are ignored.
+  uint8_t post_filter_mask = 0x1f;
+};
+
+}  // namespace libgav1
+#endif  // defined(__cplusplus)
+#endif  // LIBGAV1_SRC_GAV1_DECODER_SETTINGS_H_
diff --git a/libgav1/src/gav1/frame_buffer.h b/libgav1/src/gav1/frame_buffer.h
new file mode 100644
index 0000000..8132b61
--- /dev/null
+++ b/libgav1/src/gav1/frame_buffer.h
@@ -0,0 +1,177 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_GAV1_FRAME_BUFFER_H_
+#define LIBGAV1_SRC_GAV1_FRAME_BUFFER_H_
+
+// All the declarations in this file are part of the public ABI. This file may
+// be included by both C and C++ files.
+
+#if defined(__cplusplus)
+#include <cstddef>
+#include <cstdint>
+#else
+#include <stddef.h>
+#include <stdint.h>
+#endif  // defined(__cplusplus)
+
+#include "gav1/decoder_buffer.h"
+#include "gav1/status_code.h"
+#include "gav1/symbol_visibility.h"
+
+// The callback functions use the C linkage conventions.
+#if defined(__cplusplus)
+extern "C" {
+#endif
+
+// This structure represents an allocated frame buffer.
+typedef struct Libgav1FrameBuffer {
+  // In the |plane| and |stride| arrays, the elements at indexes 0, 1, and 2
+  // are for the Y, U, and V planes, respectively.
+  uint8_t* plane[3];   // Pointers to the frame (excluding the borders) in the
+                       // data buffers.
+  int stride[3];       // Row strides in bytes.
+  void* private_data;  // Frame buffer's private data. Available for use by the
+                       // release frame buffer callback. Also copied to the
+                       // |buffer_private_data| field of DecoderBuffer for use
+                       // by the consumer of a DecoderBuffer.
+} Libgav1FrameBuffer;
+
+// This callback is invoked by the decoder to provide information on the
+// subsequent frames in the video, until the next invocation of this callback
+// or the end of the video.
+//
+// |width| and |height| are the maximum frame width and height in pixels.
+// |left_border|, |right_border|, |top_border|, and |bottom_border| are the
+// maximum left, right, top, and bottom border sizes in pixels.
+// |stride_alignment| specifies the alignment of the row stride in bytes.
+//
+// Returns kLibgav1StatusOk on success, an error status on failure.
+//
+// NOTE: This callback may be omitted if the information is not useful to the
+// application.
+typedef Libgav1StatusCode (*Libgav1FrameBufferSizeChangedCallback)(
+    void* callback_private_data, int bitdepth, Libgav1ImageFormat image_format,
+    int width, int height, int left_border, int right_border, int top_border,
+    int bottom_border, int stride_alignment);
+
+// This callback is invoked by the decoder to allocate a frame buffer, which
+// consists of three data buffers, for the Y, U, and V planes, respectively.
+//
+// The callback must set |frame_buffer->plane[i]| to point to the data buffers
+// of the planes, and set |frame_buffer->stride[i]| to the row strides of the
+// planes. If |image_format| is kLibgav1ImageFormatMonochrome400, the callback
+// should set |frame_buffer->plane[1]| and |frame_buffer->plane[2]| to a null
+// pointer and set |frame_buffer->stride[1]| and |frame_buffer->stride[2]| to
+// 0. The callback may set |frame_buffer->private_data| to a value that will
+// be useful to the release frame buffer callback and the consumer of a
+// DecoderBuffer.
+//
+// Returns kLibgav1StatusOk on success, an error status on failure.
+
+// |width| and |height| are the frame width and height in pixels.
+// |left_border|, |right_border|, |top_border|, and |bottom_border| are the
+// left, right, top, and bottom border sizes in pixels. |stride_alignment|
+// specifies the alignment of the row stride in bytes.
+typedef Libgav1StatusCode (*Libgav1GetFrameBufferCallback)(
+    void* callback_private_data, int bitdepth, Libgav1ImageFormat image_format,
+    int width, int height, int left_border, int right_border, int top_border,
+    int bottom_border, int stride_alignment, Libgav1FrameBuffer* frame_buffer);
+
+// After a frame buffer is allocated, the decoder starts to write decoded video
+// to the frame buffer. When the frame buffer is ready for consumption, it is
+// made available to the application in a Decoder::DequeueFrame() call.
+// Afterwards, the decoder may continue to use the frame buffer in read-only
+// mode. When the decoder is finished using the frame buffer, it notifies the
+// application by calling the Libgav1ReleaseFrameBufferCallback.
+
+// This callback is invoked by the decoder to release a frame buffer.
+typedef void (*Libgav1ReleaseFrameBufferCallback)(void* callback_private_data,
+                                                  void* buffer_private_data);
+
+// Libgav1ComputeFrameBufferInfo() and Libgav1SetFrameBuffer() are intended to
+// help clients implement frame buffer callbacks using memory buffers. First,
+// call Libgav1ComputeFrameBufferInfo(). If it succeeds, allocate y_buffer of
+// size info.y_buffer_size and allocate u_buffer and v_buffer, both of size
+// info.uv_buffer_size. Finally, pass y_buffer, u_buffer, v_buffer, and
+// buffer_private_data to Libgav1SetFrameBuffer().
+
+// This structure contains information useful for allocating memory for a frame
+// buffer.
+typedef struct Libgav1FrameBufferInfo {
+  size_t y_buffer_size;   // Size in bytes of the Y buffer.
+  size_t uv_buffer_size;  // Size in bytes of the U or V buffer.
+
+  // The following fields are consumed by Libgav1SetFrameBuffer(). Do not use
+  // them directly.
+  int y_stride;            // Row stride in bytes of the Y buffer.
+  int uv_stride;           // Row stride in bytes of the U or V buffer.
+  size_t y_plane_offset;   // Offset in bytes of the frame (excluding the
+                           // borders) in the Y buffer.
+  size_t uv_plane_offset;  // Offset in bytes of the frame (excluding the
+                           // borders) in the U or V buffer.
+  int stride_alignment;    // The stride_alignment argument passed to
+                           // Libgav1ComputeFrameBufferInfo().
+} Libgav1FrameBufferInfo;
+
+// Computes the information useful for allocating memory for a frame buffer.
+// On success, stores the output in |info|.
+LIBGAV1_PUBLIC Libgav1StatusCode Libgav1ComputeFrameBufferInfo(
+    int bitdepth, Libgav1ImageFormat image_format, int width, int height,
+    int left_border, int right_border, int top_border, int bottom_border,
+    int stride_alignment, Libgav1FrameBufferInfo* info);
+
+// Sets the |frame_buffer| struct.
+LIBGAV1_PUBLIC Libgav1StatusCode Libgav1SetFrameBuffer(
+    const Libgav1FrameBufferInfo* info, uint8_t* y_buffer, uint8_t* u_buffer,
+    uint8_t* v_buffer, void* buffer_private_data,
+    Libgav1FrameBuffer* frame_buffer);
+
+#if defined(__cplusplus)
+}  // extern "C"
+
+// Declare type aliases for C++.
+namespace libgav1 {
+
+using FrameBuffer = Libgav1FrameBuffer;
+using FrameBufferSizeChangedCallback = Libgav1FrameBufferSizeChangedCallback;
+using GetFrameBufferCallback = Libgav1GetFrameBufferCallback;
+using ReleaseFrameBufferCallback = Libgav1ReleaseFrameBufferCallback;
+using FrameBufferInfo = Libgav1FrameBufferInfo;
+
+inline StatusCode ComputeFrameBufferInfo(int bitdepth, ImageFormat image_format,
+                                         int width, int height, int left_border,
+                                         int right_border, int top_border,
+                                         int bottom_border,
+                                         int stride_alignment,
+                                         FrameBufferInfo* info) {
+  return Libgav1ComputeFrameBufferInfo(bitdepth, image_format, width, height,
+                                       left_border, right_border, top_border,
+                                       bottom_border, stride_alignment, info);
+}
+
+inline StatusCode SetFrameBuffer(const FrameBufferInfo* info, uint8_t* y_buffer,
+                                 uint8_t* u_buffer, uint8_t* v_buffer,
+                                 void* buffer_private_data,
+                                 FrameBuffer* frame_buffer) {
+  return Libgav1SetFrameBuffer(info, y_buffer, u_buffer, v_buffer,
+                               buffer_private_data, frame_buffer);
+}
+
+}  // namespace libgav1
+#endif  // defined(__cplusplus)
+
+#endif  // LIBGAV1_SRC_GAV1_FRAME_BUFFER_H_
diff --git a/libgav1/src/gav1/status_code.h b/libgav1/src/gav1/status_code.h
new file mode 100644
index 0000000..d7476ca
--- /dev/null
+++ b/libgav1/src/gav1/status_code.h
@@ -0,0 +1,118 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_GAV1_STATUS_CODE_H_
+#define LIBGAV1_SRC_GAV1_STATUS_CODE_H_
+
+#include "gav1/symbol_visibility.h"
+
+// All the declarations in this file are part of the public ABI. This file may
+// be included by both C and C++ files.
+
+// The Libgav1StatusCode enum type: A libgav1 function may return
+// Libgav1StatusCode to indicate success or the reason for failure.
+typedef enum {
+  // Success.
+  kLibgav1StatusOk = 0,
+
+  // An unknown error. Used as the default error status if error detail is not
+  // available.
+  kLibgav1StatusUnknownError = -1,
+
+  // An invalid function argument.
+  kLibgav1StatusInvalidArgument = -2,
+
+  // Memory allocation failure.
+  kLibgav1StatusOutOfMemory = -3,
+
+  // Ran out of a resource (other than memory).
+  kLibgav1StatusResourceExhausted = -4,
+
+  // The object is not initialized.
+  kLibgav1StatusNotInitialized = -5,
+
+  // An operation that can only be performed once has already been performed.
+  kLibgav1StatusAlready = -6,
+
+  // Not implemented, or not supported.
+  kLibgav1StatusUnimplemented = -7,
+
+  // An internal error in libgav1. Usually this indicates a programming error.
+  kLibgav1StatusInternalError = -8,
+
+  // The bitstream is not encoded correctly or violates a bitstream conformance
+  // requirement.
+  kLibgav1StatusBitstreamError = -9,
+
+  // The operation is not allowed at the moment. This is not a fatal error. Try
+  // again later.
+  kLibgav1StatusTryAgain = -10,
+
+  // Used only by DequeueFrame(). There are no enqueued frames, so there is
+  // nothing to dequeue. This is not a fatal error. Try enqueuing a frame before
+  // trying to dequeue again.
+  kLibgav1StatusNothingToDequeue = -11,
+
+  // An extra enumerator to prevent people from writing code that fails to
+  // compile when a new status code is added.
+  //
+  // Do not reference this enumerator. In particular, if you write code that
+  // switches on Libgav1StatusCode, add a default: case instead of a case that
+  // mentions this enumerator.
+  //
+  // Do not depend on the value (currently -1000) listed here. It may change in
+  // the future.
+  kLibgav1StatusReservedForFutureExpansionUseDefaultInSwitchInstead_ = -1000
+} Libgav1StatusCode;
+
+#if defined(__cplusplus)
+extern "C" {
+#endif
+
+// Returns a human readable error string in en-US for the status code |status|.
+// Always returns a valid (non-NULL) string.
+LIBGAV1_PUBLIC const char* Libgav1GetErrorString(Libgav1StatusCode status);
+
+#if defined(__cplusplus)
+}  // extern "C"
+
+namespace libgav1 {
+
+// Declare type aliases for C++.
+using StatusCode = Libgav1StatusCode;
+constexpr StatusCode kStatusOk = kLibgav1StatusOk;
+constexpr StatusCode kStatusUnknownError = kLibgav1StatusUnknownError;
+constexpr StatusCode kStatusInvalidArgument = kLibgav1StatusInvalidArgument;
+constexpr StatusCode kStatusOutOfMemory = kLibgav1StatusOutOfMemory;
+constexpr StatusCode kStatusResourceExhausted = kLibgav1StatusResourceExhausted;
+constexpr StatusCode kStatusNotInitialized = kLibgav1StatusNotInitialized;
+constexpr StatusCode kStatusAlready = kLibgav1StatusAlready;
+constexpr StatusCode kStatusUnimplemented = kLibgav1StatusUnimplemented;
+constexpr StatusCode kStatusInternalError = kLibgav1StatusInternalError;
+constexpr StatusCode kStatusBitstreamError = kLibgav1StatusBitstreamError;
+constexpr StatusCode kStatusTryAgain = kLibgav1StatusTryAgain;
+constexpr StatusCode kStatusNothingToDequeue = kLibgav1StatusNothingToDequeue;
+
+// Returns a human readable error string in en-US for the status code |status|.
+// Always returns a valid (non-NULL) string.
+inline const char* GetErrorString(StatusCode status) {
+  return Libgav1GetErrorString(status);
+}
+
+}  // namespace libgav1
+#endif  // defined(__cplusplus)
+
+#endif  // LIBGAV1_SRC_GAV1_STATUS_CODE_H_
diff --git a/libgav1/src/symbol_visibility.h b/libgav1/src/gav1/symbol_visibility.h
similarity index 94%
rename from libgav1/src/symbol_visibility.h
rename to libgav1/src/gav1/symbol_visibility.h
index c9ed53b..ad7498c 100644
--- a/libgav1/src/symbol_visibility.h
+++ b/libgav1/src/gav1/symbol_visibility.h
@@ -14,8 +14,8 @@
  * limitations under the License.
  */
 
-#ifndef LIBGAV1_SRC_SYMBOL_VISIBILITY_H_
-#define LIBGAV1_SRC_SYMBOL_VISIBILITY_H_
+#ifndef LIBGAV1_SRC_GAV1_SYMBOL_VISIBILITY_H_
+#define LIBGAV1_SRC_GAV1_SYMBOL_VISIBILITY_H_
 
 // This module defines the LIBGAV1_PUBLIC macro. LIBGAV1_PUBLIC, when combined
 // with the flags -fvisibility=hidden and -fvisibility-inlines-hidden, restricts
@@ -85,4 +85,4 @@
 #endif  // defined(_WIN32)
 #endif  // defined(LIBGAV1_PUBLIC)
 
-#endif  // LIBGAV1_SRC_SYMBOL_VISIBILITY_H_
+#endif  // LIBGAV1_SRC_GAV1_SYMBOL_VISIBILITY_H_
diff --git a/libgav1/src/gav1/version.h b/libgav1/src/gav1/version.h
new file mode 100644
index 0000000..e78e9a7
--- /dev/null
+++ b/libgav1/src/gav1/version.h
@@ -0,0 +1,71 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_GAV1_VERSION_H_
+#define LIBGAV1_SRC_GAV1_VERSION_H_
+
+#include "gav1/symbol_visibility.h"
+
+// This library follows the principles described by Semantic Versioning
+// (https://semver.org).
+
+#define LIBGAV1_MAJOR_VERSION 0
+#define LIBGAV1_MINOR_VERSION 16
+#define LIBGAV1_PATCH_VERSION 0
+
+#define LIBGAV1_VERSION                                           \
+  ((LIBGAV1_MAJOR_VERSION << 16) | (LIBGAV1_MINOR_VERSION << 8) | \
+   LIBGAV1_PATCH_VERSION)
+
+#if defined(__cplusplus)
+extern "C" {
+#endif
+
+// Returns the library's version number, packed in an int using 8 bits for
+// each of major/minor/patch. e.g, 1.2.3 is 0x010203.
+LIBGAV1_PUBLIC int Libgav1GetVersion(void);
+
+// Returns the library's version number as a string in the format
+// 'MAJOR.MINOR.PATCH'. Always returns a valid (non-NULL) string.
+LIBGAV1_PUBLIC const char* Libgav1GetVersionString(void);
+
+// Returns the build configuration used to produce the library. Always returns
+// a valid (non-NULL) string.
+LIBGAV1_PUBLIC const char* Libgav1GetBuildConfiguration(void);
+
+#if defined(__cplusplus)
+}  // extern "C"
+
+namespace libgav1 {
+
+// Returns the library's version number, packed in an int using 8 bits for
+// each of major/minor/patch. e.g, 1.2.3 is 0x010203.
+inline int GetVersion() { return Libgav1GetVersion(); }
+
+// Returns the library's version number as a string in the format
+// 'MAJOR.MINOR.PATCH'. Always returns a valid (non-NULL) string.
+inline const char* GetVersionString() { return Libgav1GetVersionString(); }
+
+// Returns the build configuration used to produce the library. Always returns
+// a valid (non-NULL) string.
+inline const char* GetBuildConfiguration() {
+  return Libgav1GetBuildConfiguration();
+}
+
+}  // namespace libgav1
+#endif  // defined(__cplusplus)
+
+#endif  // LIBGAV1_SRC_GAV1_VERSION_H_
diff --git a/libgav1/src/inter_intra_masks.inc b/libgav1/src/inter_intra_masks.inc
new file mode 100644
index 0000000..2c15f9c
--- /dev/null
+++ b/libgav1/src/inter_intra_masks.inc
@@ -0,0 +1,581 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file is just a convenience to separate out all the inter intra masks
+// from the code where it is used.
+
+// The tables in this file are computed based on section 7.11.3.13 in the spec.
+
+constexpr uint8_t kInterIntraMaskDc[] = {
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
+    32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32};
+
+constexpr uint8_t kInterIntraMaskVertical4x4[] = {
+    60, 60, 60, 60, 19, 19, 19, 19, 6, 6, 6, 6, 2, 2, 2, 2};
+constexpr uint8_t kInterIntraMaskVertical4x8[] = {
+    60, 60, 60, 60, 34, 34, 34, 34, 19, 19, 19, 19, 11, 11, 11, 11,
+    6,  6,  6,  6,  4,  4,  4,  4,  2,  2,  2,  2,  1,  1,  1,  1};
+constexpr uint8_t kInterIntraMaskVertical8x4[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 34, 34, 34, 34, 34, 34, 34, 34,
+    19, 19, 19, 19, 19, 19, 19, 19, 11, 11, 11, 11, 11, 11, 11, 11};
+constexpr uint8_t kInterIntraMaskVertical8x8[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 34, 34, 34, 34, 34, 34, 34, 34,
+    19, 19, 19, 19, 19, 19, 19, 19, 11, 11, 11, 11, 11, 11, 11, 11,
+    6,  6,  6,  6,  6,  6,  6,  6,  4,  4,  4,  4,  4,  4,  4,  4,
+    2,  2,  2,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  1,  1,  1};
+constexpr uint8_t kInterIntraMaskVertical8x16[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 45, 45, 45, 45, 45, 45, 45, 45, 34, 34, 34,
+    34, 34, 34, 34, 34, 26, 26, 26, 26, 26, 26, 26, 26, 19, 19, 19, 19, 19, 19,
+    19, 19, 15, 15, 15, 15, 15, 15, 15, 15, 11, 11, 11, 11, 11, 11, 11, 11, 8,
+    8,  8,  8,  8,  8,  8,  8,  6,  6,  6,  6,  6,  6,  6,  6,  5,  5,  5,  5,
+    5,  5,  5,  5,  4,  4,  4,  4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  3,  3,
+    3,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1};
+constexpr uint8_t kInterIntraMaskVertical16x8[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 45, 45, 45,
+    45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 34, 34, 34, 34, 34, 34,
+    34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 26, 26, 26, 26, 26, 26, 26, 26, 26,
+    26, 26, 26, 26, 26, 26, 26, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
+    19, 19, 19, 19, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+    15, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 8,  8,
+    8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8};
+constexpr uint8_t kInterIntraMaskVertical16x16[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 45, 45, 45,
+    45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 34, 34, 34, 34, 34, 34,
+    34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 26, 26, 26, 26, 26, 26, 26, 26, 26,
+    26, 26, 26, 26, 26, 26, 26, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
+    19, 19, 19, 19, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+    15, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 8,  8,
+    8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  6,  6,  6,  6,  6,
+    6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  5,  5,  5,  5,  5,  5,  5,  5,
+    5,  5,  5,  5,  5,  5,  5,  5,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
+    4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
+    3,  3,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
+    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  1,  1,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1};
+constexpr uint8_t kInterIntraMaskVertical16x32[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 52, 52, 52,
+    52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 45, 45, 45, 45, 45, 45,
+    45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 39, 39, 39, 39, 39, 39, 39, 39, 39,
+    39, 39, 39, 39, 39, 39, 39, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34,
+    34, 34, 34, 34, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
+    30, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 22, 22,
+    22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 19, 19, 19, 19, 19,
+    19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 17, 17, 17, 17, 17, 17, 17, 17,
+    17, 17, 17, 17, 17, 17, 17, 17, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+    15, 15, 15, 15, 15, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
+    13, 13, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10,
+    10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 8,  8,  8,  8,
+    8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  7,  7,  7,  7,  7,  7,  7,
+    7,  7,  7,  7,  7,  7,  7,  7,  7,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
+    6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
+    6,  6,  6,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
+    4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
+    4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  3,
+    3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
+    3,  3,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
+    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
+    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
+    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1};
+constexpr uint8_t kInterIntraMaskVertical32x16[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 52, 52, 52, 52, 52, 52,
+    52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52,
+    52, 52, 52, 52, 52, 52, 52, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45,
+    45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45,
+    45, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39,
+    39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 34, 34, 34, 34, 34,
+    34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34,
+    34, 34, 34, 34, 34, 34, 34, 34, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
+    30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
+    30, 30, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
+    26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 22, 22, 22, 22,
+    22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
+    22, 22, 22, 22, 22, 22, 22, 22, 22, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
+    19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
+    19, 19, 19, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17,
+    17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 15, 15, 15,
+    15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+    15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 13, 13, 13, 13, 13, 13, 13, 13, 13,
+    13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
+    13, 13, 13, 13, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
+    11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 10,
+    10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
+    10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 8,  8,  8,  8,  8,  8,  8,  8,
+    8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
+    8,  8,  8,  8,  8,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
+    7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7};
+constexpr uint8_t kInterIntraMaskVertical32x32[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 52, 52, 52, 52, 52, 52,
+    52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52,
+    52, 52, 52, 52, 52, 52, 52, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45,
+    45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45,
+    45, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39,
+    39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 34, 34, 34, 34, 34,
+    34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34,
+    34, 34, 34, 34, 34, 34, 34, 34, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
+    30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
+    30, 30, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
+    26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 22, 22, 22, 22,
+    22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
+    22, 22, 22, 22, 22, 22, 22, 22, 22, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
+    19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
+    19, 19, 19, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17,
+    17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 15, 15, 15,
+    15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+    15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 13, 13, 13, 13, 13, 13, 13, 13, 13,
+    13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
+    13, 13, 13, 13, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
+    11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 10, 10,
+    10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
+    10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 8,  8,  8,  8,  8,  8,  8,  8,
+    8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
+    8,  8,  8,  8,  8,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
+    7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  6,
+    6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
+    6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
+    6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
+    6,  6,  6,  6,  6,  6,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
+    5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
+    4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
+    4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
+    4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
+    4,  4,  4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
+    3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
+    3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
+    3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,
+    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
+    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
+    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
+    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
+    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
+    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
+    2,  2,  2,  2,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
+    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1};
+
+constexpr uint8_t kInterIntraMaskHorizontal4x4[] = {60, 19, 6, 2, 60, 19, 6, 2,
+                                                    60, 19, 6, 2, 60, 19, 6, 2};
+constexpr uint8_t kInterIntraMaskHorizontal4x8[] = {
+    60, 34, 19, 11, 60, 34, 19, 11, 60, 34, 19, 11, 60, 34, 19, 11,
+    60, 34, 19, 11, 60, 34, 19, 11, 60, 34, 19, 11, 60, 34, 19, 11};
+constexpr uint8_t kInterIntraMaskHorizontal8x4[] = {
+    60, 34, 19, 11, 6, 4, 2, 1, 60, 34, 19, 11, 6, 4, 2, 1,
+    60, 34, 19, 11, 6, 4, 2, 1, 60, 34, 19, 11, 6, 4, 2, 1};
+constexpr uint8_t kInterIntraMaskHorizontal8x8[] = {
+    60, 34, 19, 11, 6, 4, 2, 1, 60, 34, 19, 11, 6, 4, 2, 1,
+    60, 34, 19, 11, 6, 4, 2, 1, 60, 34, 19, 11, 6, 4, 2, 1,
+    60, 34, 19, 11, 6, 4, 2, 1, 60, 34, 19, 11, 6, 4, 2, 1,
+    60, 34, 19, 11, 6, 4, 2, 1, 60, 34, 19, 11, 6, 4, 2, 1};
+constexpr uint8_t kInterIntraMaskHorizontal8x16[] = {
+    60, 45, 34, 26, 19, 15, 11, 8,  60, 45, 34, 26, 19, 15, 11, 8,  60, 45, 34,
+    26, 19, 15, 11, 8,  60, 45, 34, 26, 19, 15, 11, 8,  60, 45, 34, 26, 19, 15,
+    11, 8,  60, 45, 34, 26, 19, 15, 11, 8,  60, 45, 34, 26, 19, 15, 11, 8,  60,
+    45, 34, 26, 19, 15, 11, 8,  60, 45, 34, 26, 19, 15, 11, 8,  60, 45, 34, 26,
+    19, 15, 11, 8,  60, 45, 34, 26, 19, 15, 11, 8,  60, 45, 34, 26, 19, 15, 11,
+    8,  60, 45, 34, 26, 19, 15, 11, 8,  60, 45, 34, 26, 19, 15, 11, 8,  60, 45,
+    34, 26, 19, 15, 11, 8,  60, 45, 34, 26, 19, 15, 11, 8};
+constexpr uint8_t kInterIntraMaskHorizontal16x8[] = {
+    60, 45, 34, 26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  1,  1,  60, 45, 34,
+    26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  1,  1,  60, 45, 34, 26, 19, 15,
+    11, 8,  6,  5,  4,  3,  2,  2,  1,  1,  60, 45, 34, 26, 19, 15, 11, 8,  6,
+    5,  4,  3,  2,  2,  1,  1,  60, 45, 34, 26, 19, 15, 11, 8,  6,  5,  4,  3,
+    2,  2,  1,  1,  60, 45, 34, 26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  1,
+    1,  60, 45, 34, 26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  1,  1,  60, 45,
+    34, 26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  1,  1};
+constexpr uint8_t kInterIntraMaskHorizontal16x16[] = {
+    60, 45, 34, 26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  1,  1,  60, 45, 34,
+    26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  1,  1,  60, 45, 34, 26, 19, 15,
+    11, 8,  6,  5,  4,  3,  2,  2,  1,  1,  60, 45, 34, 26, 19, 15, 11, 8,  6,
+    5,  4,  3,  2,  2,  1,  1,  60, 45, 34, 26, 19, 15, 11, 8,  6,  5,  4,  3,
+    2,  2,  1,  1,  60, 45, 34, 26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  1,
+    1,  60, 45, 34, 26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  1,  1,  60, 45,
+    34, 26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  1,  1,  60, 45, 34, 26, 19,
+    15, 11, 8,  6,  5,  4,  3,  2,  2,  1,  1,  60, 45, 34, 26, 19, 15, 11, 8,
+    6,  5,  4,  3,  2,  2,  1,  1,  60, 45, 34, 26, 19, 15, 11, 8,  6,  5,  4,
+    3,  2,  2,  1,  1,  60, 45, 34, 26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,
+    1,  1,  60, 45, 34, 26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  1,  1,  60,
+    45, 34, 26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  1,  1,  60, 45, 34, 26,
+    19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  1,  1,  60, 45, 34, 26, 19, 15, 11,
+    8,  6,  5,  4,  3,  2,  2,  1,  1};
+constexpr uint8_t kInterIntraMaskHorizontal16x32[] = {
+    60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45,
+    39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30,
+    26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19,
+    17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13,
+    11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,
+    7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52,
+    45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34,
+    30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22,
+    19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15,
+    13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10,
+    8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60,
+    52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39,
+    34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26,
+    22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17,
+    15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11,
+    10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,
+    60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45,
+    39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30,
+    26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19,
+    17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13,
+    11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,
+    7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52,
+    45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34,
+    30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22,
+    19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15,
+    13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10,
+    8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7};
+constexpr uint8_t kInterIntraMaskHorizontal32x16[] = {
+    60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,
+    4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30,
+    26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,
+    2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13,
+    11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,
+    1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,
+    5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34,
+    30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,
+    2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15,
+    13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,
+    1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,
+    6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39,
+    34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,
+    2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17,
+    15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,
+    1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,
+    6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45,
+    39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,
+    3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19,
+    17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,
+    1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,
+    7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52,
+    45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,
+    3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22,
+    19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,
+    1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10,
+    8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1};
+constexpr uint8_t kInterIntraMaskHorizontal32x32[] = {
+    60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,
+    4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30,
+    26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,
+    2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13,
+    11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,
+    1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,
+    5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34,
+    30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,
+    2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15,
+    13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,
+    1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,
+    6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39,
+    34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,
+    2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17,
+    15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,
+    1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,
+    6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45,
+    39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,
+    3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19,
+    17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,
+    1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,
+    7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52,
+    45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,
+    3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22,
+    19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,
+    1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10,
+    8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60,
+    52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,
+    4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26,
+    22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,
+    2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11,
+    10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,
+    60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,
+    4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30,
+    26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,
+    2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13,
+    11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,
+    1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,
+    5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34,
+    30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,
+    2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15,
+    13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,
+    1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,
+    6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39,
+    34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,
+    2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17,
+    15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,
+    1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,
+    6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45,
+    39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,
+    3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19,
+    17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,
+    1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,
+    7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1};
+
+constexpr uint8_t kInterIntraMaskSmooth4x4[] = {60, 60, 60, 60, 60, 19, 19, 19,
+                                                60, 19, 6,  6,  60, 19, 6,  2};
+constexpr uint8_t kInterIntraMaskSmooth4x8[] = {
+    60, 60, 60, 60, 60, 34, 34, 34, 60, 34, 19, 19, 60, 34, 19, 11,
+    60, 34, 19, 11, 60, 34, 19, 11, 60, 34, 19, 11, 60, 34, 19, 11};
+constexpr uint8_t kInterIntraMaskSmooth8x4[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 34, 34, 34, 34, 34, 34, 34,
+    60, 34, 19, 19, 19, 19, 19, 19, 60, 34, 19, 11, 11, 11, 11, 11};
+constexpr uint8_t kInterIntraMaskSmooth8x8[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 34, 34, 34, 34, 34, 34, 34,
+    60, 34, 19, 19, 19, 19, 19, 19, 60, 34, 19, 11, 11, 11, 11, 11,
+    60, 34, 19, 11, 6,  6,  6,  6,  60, 34, 19, 11, 6,  4,  4,  4,
+    60, 34, 19, 11, 6,  4,  2,  2,  60, 34, 19, 11, 6,  4,  2,  1};
+constexpr uint8_t kInterIntraMaskSmooth8x16[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 45, 45, 45, 45, 45, 45, 45, 60, 45, 34,
+    34, 34, 34, 34, 34, 60, 45, 34, 26, 26, 26, 26, 26, 60, 45, 34, 26, 19, 19,
+    19, 19, 60, 45, 34, 26, 19, 15, 15, 15, 60, 45, 34, 26, 19, 15, 11, 11, 60,
+    45, 34, 26, 19, 15, 11, 8,  60, 45, 34, 26, 19, 15, 11, 8,  60, 45, 34, 26,
+    19, 15, 11, 8,  60, 45, 34, 26, 19, 15, 11, 8,  60, 45, 34, 26, 19, 15, 11,
+    8,  60, 45, 34, 26, 19, 15, 11, 8,  60, 45, 34, 26, 19, 15, 11, 8,  60, 45,
+    34, 26, 19, 15, 11, 8,  60, 45, 34, 26, 19, 15, 11, 8};
+constexpr uint8_t kInterIntraMaskSmooth16x8[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 45, 45,
+    45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 60, 45, 34, 34, 34, 34,
+    34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 60, 45, 34, 26, 26, 26, 26, 26, 26,
+    26, 26, 26, 26, 26, 26, 26, 60, 45, 34, 26, 19, 19, 19, 19, 19, 19, 19, 19,
+    19, 19, 19, 19, 60, 45, 34, 26, 19, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+    15, 60, 45, 34, 26, 19, 15, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 60, 45,
+    34, 26, 19, 15, 11, 8,  8,  8,  8,  8,  8,  8,  8,  8};
+constexpr uint8_t kInterIntraMaskSmooth16x16[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 45, 45,
+    45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 60, 45, 34, 34, 34, 34,
+    34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 60, 45, 34, 26, 26, 26, 26, 26, 26,
+    26, 26, 26, 26, 26, 26, 26, 60, 45, 34, 26, 19, 19, 19, 19, 19, 19, 19, 19,
+    19, 19, 19, 19, 60, 45, 34, 26, 19, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+    15, 60, 45, 34, 26, 19, 15, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 60, 45,
+    34, 26, 19, 15, 11, 8,  8,  8,  8,  8,  8,  8,  8,  8,  60, 45, 34, 26, 19,
+    15, 11, 8,  6,  6,  6,  6,  6,  6,  6,  6,  60, 45, 34, 26, 19, 15, 11, 8,
+    6,  5,  5,  5,  5,  5,  5,  5,  60, 45, 34, 26, 19, 15, 11, 8,  6,  5,  4,
+    4,  4,  4,  4,  4,  60, 45, 34, 26, 19, 15, 11, 8,  6,  5,  4,  3,  3,  3,
+    3,  3,  60, 45, 34, 26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  2,  2,  60,
+    45, 34, 26, 19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  2,  2,  60, 45, 34, 26,
+    19, 15, 11, 8,  6,  5,  4,  3,  2,  2,  1,  1,  60, 45, 34, 26, 19, 15, 11,
+    8,  6,  5,  4,  3,  2,  2,  1,  1};
+constexpr uint8_t kInterIntraMaskSmooth16x32[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 52, 52,
+    52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 60, 52, 45, 45, 45, 45,
+    45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 60, 52, 45, 39, 39, 39, 39, 39, 39,
+    39, 39, 39, 39, 39, 39, 39, 60, 52, 45, 39, 34, 34, 34, 34, 34, 34, 34, 34,
+    34, 34, 34, 34, 60, 52, 45, 39, 34, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
+    30, 60, 52, 45, 39, 34, 30, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 60, 52,
+    45, 39, 34, 30, 26, 22, 22, 22, 22, 22, 22, 22, 22, 22, 60, 52, 45, 39, 34,
+    30, 26, 22, 19, 19, 19, 19, 19, 19, 19, 19, 60, 52, 45, 39, 34, 30, 26, 22,
+    19, 17, 17, 17, 17, 17, 17, 17, 60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15,
+    15, 15, 15, 15, 15, 60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 13, 13,
+    13, 13, 60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 11, 11, 11, 60,
+    52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 10, 10, 60, 52, 45, 39,
+    34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  8,  60, 52, 45, 39, 34, 30, 26,
+    22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17,
+    15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11,
+    10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,
+    60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45,
+    39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30,
+    26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19,
+    17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13,
+    11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,
+    7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52,
+    45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34,
+    30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22,
+    19, 17, 15, 13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15,
+    13, 11, 10, 8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10,
+    8,  7,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7};
+constexpr uint8_t kInterIntraMaskSmooth32x16[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 52, 52, 52, 52, 52,
+    52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52,
+    52, 52, 52, 52, 52, 52, 52, 60, 52, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45,
+    45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45,
+    45, 60, 52, 45, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39,
+    39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 60, 52, 45, 39, 34,
+    34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34,
+    34, 34, 34, 34, 34, 34, 34, 34, 60, 52, 45, 39, 34, 30, 30, 30, 30, 30, 30,
+    30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
+    30, 30, 60, 52, 45, 39, 34, 30, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
+    26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 60, 52, 45, 39,
+    34, 30, 26, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
+    22, 22, 22, 22, 22, 22, 22, 22, 22, 60, 52, 45, 39, 34, 30, 26, 22, 19, 19,
+    19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
+    19, 19, 19, 60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 17, 17, 17, 17, 17, 17,
+    17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 60, 52, 45,
+    39, 34, 30, 26, 22, 19, 17, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+    15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 60, 52, 45, 39, 34, 30, 26, 22, 19,
+    17, 15, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
+    13, 13, 13, 13, 60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 11, 11,
+    11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 60, 52,
+    45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 10, 10, 10, 10, 10, 10, 10,
+    10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 60, 52, 45, 39, 34, 30, 26, 22,
+    19, 17, 15, 13, 11, 10, 8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
+    8,  8,  8,  8,  8,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10,
+    8,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7};
+constexpr uint8_t kInterIntraMaskSmooth32x32[] = {
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,
+    60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 52, 52, 52, 52, 52,
+    52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52,
+    52, 52, 52, 52, 52, 52, 52, 60, 52, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45,
+    45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45,
+    45, 60, 52, 45, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39,
+    39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 60, 52, 45, 39, 34,
+    34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34,
+    34, 34, 34, 34, 34, 34, 34, 34, 60, 52, 45, 39, 34, 30, 30, 30, 30, 30, 30,
+    30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
+    30, 30, 60, 52, 45, 39, 34, 30, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
+    26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 60, 52, 45, 39,
+    34, 30, 26, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
+    22, 22, 22, 22, 22, 22, 22, 22, 22, 60, 52, 45, 39, 34, 30, 26, 22, 19, 19,
+    19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
+    19, 19, 19, 60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 17, 17, 17, 17, 17, 17,
+    17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 60, 52, 45,
+    39, 34, 30, 26, 22, 19, 17, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+    15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 60, 52, 45, 39, 34, 30, 26, 22, 19,
+    17, 15, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
+    13, 13, 13, 13, 60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 11, 11,
+    11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 60, 52,
+    45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 10, 10, 10, 10, 10, 10, 10,
+    10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 60, 52, 45, 39, 34, 30, 26, 22,
+    19, 17, 15, 13, 11, 10, 8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
+    8,  8,  8,  8,  8,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10,
+    8,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  60,
+    52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  6,  6,
+    6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  60, 52, 45, 39, 34, 30, 26,
+    22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
+    6,  6,  6,  6,  6,  6,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11,
+    10, 8,  7,  6,  6,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
+    60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,
+    4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  60, 52, 45, 39, 34, 30,
+    26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  4,  4,  4,  4,
+    4,  4,  4,  4,  4,  4,  4,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13,
+    11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
+    3,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,
+    5,  4,  4,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  60, 52, 45, 39, 34,
+    30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,
+    2,  2,  2,  2,  2,  2,  2,  2,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15,
+    13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  2,  2,  2,
+    2,  2,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,
+    6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  2,  2,  2,  2,  2,  60, 52, 45, 39,
+    34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,
+    2,  2,  2,  2,  2,  2,  2,  2,  2,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17,
+    15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,
+    1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,
+    6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45,
+    39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,
+    3,  2,  2,  2,  2,  1,  1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19,
+    17, 15, 13, 11, 10, 8,  7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,
+    1,  1,  1,  1,  60, 52, 45, 39, 34, 30, 26, 22, 19, 17, 15, 13, 11, 10, 8,
+    7,  6,  6,  5,  4,  4,  3,  3,  2,  2,  2,  2,  1,  1,  1,  1,  1};
+
+// For each 2D array within this array, the indices are mapped as follows: 0, 1,
+// 2 and 3 in each dimension maps to prediction dimension 4, 8, 16 and 32
+// respectively. For example, the entry in [1][2] corresponds to a prediction
+// size of 8x16 (width == 8 and height == 16).
+const uint8_t* kInterIntraMasks[kNumInterIntraModes][4][4] = {
+    // kInterIntraModeDc. This is a special case where all the non-nullptr
+    // entries point to kInterIntraMaskDc (all entries of the array are 32). The
+    // width can be set according to the prediction size to achieve the desired
+    // result.
+    {{kInterIntraMaskDc, kInterIntraMaskDc, nullptr, nullptr},
+     {kInterIntraMaskDc, kInterIntraMaskDc, kInterIntraMaskDc, nullptr},
+     {nullptr, kInterIntraMaskDc, kInterIntraMaskDc, kInterIntraMaskDc},
+     {nullptr, nullptr, kInterIntraMaskDc, kInterIntraMaskDc}},
+    // kInterIntraModeVertical
+    {{kInterIntraMaskVertical4x4, kInterIntraMaskVertical4x8, nullptr, nullptr},
+     {kInterIntraMaskVertical8x4, kInterIntraMaskVertical8x8,
+      kInterIntraMaskVertical8x16, nullptr},
+     {nullptr, kInterIntraMaskVertical16x8, kInterIntraMaskVertical16x16,
+      kInterIntraMaskVertical16x32},
+     {nullptr, nullptr, kInterIntraMaskVertical32x16,
+      kInterIntraMaskVertical32x32}},
+    // kInterIntraModeHorizontal
+    {{kInterIntraMaskHorizontal4x4, kInterIntraMaskHorizontal4x8, nullptr,
+      nullptr},
+     {kInterIntraMaskHorizontal8x4, kInterIntraMaskHorizontal8x8,
+      kInterIntraMaskHorizontal8x16, nullptr},
+     {nullptr, kInterIntraMaskHorizontal16x8, kInterIntraMaskHorizontal16x16,
+      kInterIntraMaskHorizontal16x32},
+     {nullptr, nullptr, kInterIntraMaskHorizontal32x16,
+      kInterIntraMaskHorizontal32x32}},
+    // kInterIntraModeSmooth
+    {{kInterIntraMaskSmooth4x4, kInterIntraMaskSmooth4x8, nullptr, nullptr},
+     {kInterIntraMaskSmooth8x4, kInterIntraMaskSmooth8x8,
+      kInterIntraMaskSmooth8x16, nullptr},
+     {nullptr, kInterIntraMaskSmooth16x8, kInterIntraMaskSmooth16x16,
+      kInterIntraMaskSmooth16x32},
+     {nullptr, nullptr, kInterIntraMaskSmooth32x16,
+      kInterIntraMaskSmooth32x32}}};
diff --git a/libgav1/src/internal_frame_buffer_list.cc b/libgav1/src/internal_frame_buffer_list.cc
index d1ea30f..e2d2273 100644
--- a/libgav1/src/internal_frame_buffer_list.cc
+++ b/libgav1/src/internal_frame_buffer_list.cc
@@ -14,88 +14,109 @@
 
 #include "src/internal_frame_buffer_list.h"
 
+#include <cassert>
 #include <cstdint>
 #include <memory>
 #include <new>
 #include <utility>
 
-namespace libgav1 {
+#include "src/utils/common.h"
 
+namespace libgav1 {
 extern "C" {
 
-int GetInternalFrameBuffer(void* private_data, size_t y_plane_min_size,
-                           size_t uv_plane_min_size,
-                           FrameBuffer* frame_buffer) {
-  auto* buffer_list = static_cast<InternalFrameBufferList*>(private_data);
-  // buffer_list is a null pointer if the InternalFrameBufferList::Create()
-  // call fails. For simplicity, we handle the unlikely failure of
-  // InternalFrameBufferList::Create() here, rather than at the call sites.
-  if (buffer_list == nullptr) return -1;
-  return buffer_list->GetFrameBuffer(y_plane_min_size, uv_plane_min_size,
-                                     frame_buffer);
+Libgav1StatusCode OnInternalFrameBufferSizeChanged(
+    void* callback_private_data, int bitdepth, Libgav1ImageFormat image_format,
+    int width, int height, int left_border, int right_border, int top_border,
+    int bottom_border, int stride_alignment) {
+  auto* buffer_list =
+      static_cast<InternalFrameBufferList*>(callback_private_data);
+  return buffer_list->OnFrameBufferSizeChanged(
+      bitdepth, image_format, width, height, left_border, right_border,
+      top_border, bottom_border, stride_alignment);
 }
 
-int ReleaseInternalFrameBuffer(void* private_data, FrameBuffer* frame_buffer) {
-  auto* buffer_list = static_cast<InternalFrameBufferList*>(private_data);
-  return buffer_list->ReleaseFrameBuffer(frame_buffer);
+Libgav1StatusCode GetInternalFrameBuffer(
+    void* callback_private_data, int bitdepth, Libgav1ImageFormat image_format,
+    int width, int height, int left_border, int right_border, int top_border,
+    int bottom_border, int stride_alignment, Libgav1FrameBuffer* frame_buffer) {
+  auto* buffer_list =
+      static_cast<InternalFrameBufferList*>(callback_private_data);
+  return buffer_list->GetFrameBuffer(
+      bitdepth, image_format, width, height, left_border, right_border,
+      top_border, bottom_border, stride_alignment, frame_buffer);
+}
+
+void ReleaseInternalFrameBuffer(void* callback_private_data,
+                                void* buffer_private_data) {
+  auto* buffer_list =
+      static_cast<InternalFrameBufferList*>(callback_private_data);
+  buffer_list->ReleaseFrameBuffer(buffer_private_data);
 }
 
 }  // extern "C"
 
-// static
-std::unique_ptr<InternalFrameBufferList> InternalFrameBufferList::Create(
-    size_t num_buffers) {
-  std::unique_ptr<InternalFrameBufferList> buffer_list;
-  std::unique_ptr<Buffer[]> buffers(new (std::nothrow) Buffer[num_buffers]);
-  if (buffers != nullptr) {
-    buffer_list.reset(new (std::nothrow) InternalFrameBufferList(
-        std::move(buffers), num_buffers));
-  }
-  return buffer_list;
+StatusCode InternalFrameBufferList::OnFrameBufferSizeChanged(
+    int /*bitdepth*/, Libgav1ImageFormat /*image_format*/, int /*width*/,
+    int /*height*/, int /*left_border*/, int /*right_border*/,
+    int /*top_border*/, int /*bottom_border*/, int /*stride_alignment*/) {
+  return kStatusOk;
 }
 
-InternalFrameBufferList::InternalFrameBufferList(
-    std::unique_ptr<Buffer[]> buffers, size_t num_buffers)
-    : buffers_(std::move(buffers)), num_buffers_(num_buffers) {}
+StatusCode InternalFrameBufferList::GetFrameBuffer(
+    int bitdepth, Libgav1ImageFormat image_format, int width, int height,
+    int left_border, int right_border, int top_border, int bottom_border,
+    int stride_alignment, Libgav1FrameBuffer* frame_buffer) {
+  FrameBufferInfo info;
+  StatusCode status = ComputeFrameBufferInfo(
+      bitdepth, image_format, width, height, left_border, right_border,
+      top_border, bottom_border, stride_alignment, &info);
+  if (status != kStatusOk) return status;
 
-int InternalFrameBufferList::GetFrameBuffer(size_t y_plane_min_size,
-                                            size_t uv_plane_min_size,
-                                            FrameBuffer* frame_buffer) {
-  if (uv_plane_min_size > SIZE_MAX / 2 ||
-      y_plane_min_size > SIZE_MAX - 2 * uv_plane_min_size) {
-    return -1;
+  if (info.uv_buffer_size > SIZE_MAX / 2 ||
+      info.y_buffer_size > SIZE_MAX - 2 * info.uv_buffer_size) {
+    return kStatusInvalidArgument;
   }
-  const size_t min_size = y_plane_min_size + 2 * uv_plane_min_size;
+  const size_t min_size = info.y_buffer_size + 2 * info.uv_buffer_size;
 
-  size_t i;
-  for (i = 0; i < num_buffers_; ++i) {
-    if (!buffers_[i].in_use) break;
+  Buffer* buffer = nullptr;
+  for (auto& buffer_ptr : buffers_) {
+    if (!buffer_ptr->in_use) {
+      buffer = buffer_ptr.get();
+      break;
+    }
   }
-  if (i == num_buffers_) return -1;
+  if (buffer == nullptr) {
+    std::unique_ptr<Buffer> new_buffer(new (std::nothrow) Buffer);
+    if (new_buffer == nullptr || !buffers_.push_back(std::move(new_buffer))) {
+      return kStatusOutOfMemory;
+    }
+    buffer = buffers_.back().get();
+  }
 
-  if (buffers_[i].size < min_size) {
+  if (buffer->size < min_size) {
     std::unique_ptr<uint8_t[], MallocDeleter> new_data(
         static_cast<uint8_t*>(malloc(min_size)));
-    if (new_data == nullptr) return -1;
-    buffers_[i].data = std::move(new_data);
-    buffers_[i].size = min_size;
+    if (new_data == nullptr) return kStatusOutOfMemory;
+    buffer->data = std::move(new_data);
+    buffer->size = min_size;
   }
 
-  frame_buffer->data[0] = buffers_[i].data.get();
-  frame_buffer->size[0] = y_plane_min_size;
-  frame_buffer->data[1] = frame_buffer->data[0] + y_plane_min_size;
-  frame_buffer->size[1] = uv_plane_min_size;
-  frame_buffer->data[2] = frame_buffer->data[1] + uv_plane_min_size;
-  frame_buffer->size[2] = uv_plane_min_size;
-  frame_buffer->private_data = &buffers_[i];
-  buffers_[i].in_use = true;
-  return 0;
+  uint8_t* const y_buffer = buffer->data.get();
+  uint8_t* const u_buffer =
+      (info.uv_buffer_size == 0) ? nullptr : y_buffer + info.y_buffer_size;
+  uint8_t* const v_buffer =
+      (info.uv_buffer_size == 0) ? nullptr : u_buffer + info.uv_buffer_size;
+  status = Libgav1SetFrameBuffer(&info, y_buffer, u_buffer, v_buffer, buffer,
+                                 frame_buffer);
+  if (status != kStatusOk) return status;
+  buffer->in_use = true;
+  return kStatusOk;
 }
 
-int InternalFrameBufferList::ReleaseFrameBuffer(FrameBuffer* frame_buffer) {
-  auto* const buffer = static_cast<Buffer*>(frame_buffer->private_data);
+void InternalFrameBufferList::ReleaseFrameBuffer(void* buffer_private_data) {
+  auto* const buffer = static_cast<Buffer*>(buffer_private_data);
   buffer->in_use = false;
-  return 0;
 }
 
 }  // namespace libgav1
diff --git a/libgav1/src/internal_frame_buffer_list.h b/libgav1/src/internal_frame_buffer_list.h
index 8498d35..1c50b48 100644
--- a/libgav1/src/internal_frame_buffer_list.h
+++ b/libgav1/src/internal_frame_buffer_list.h
@@ -21,22 +21,28 @@
 #include <cstdint>
 #include <memory>
 
-#include "src/frame_buffer.h"
+#include "src/gav1/frame_buffer.h"
 #include "src/utils/memory.h"
+#include "src/utils/vector.h"
 
 namespace libgav1 {
 
-extern "C" int GetInternalFrameBuffer(void* private_data,
-                                      size_t y_plane_min_size,
-                                      size_t uv_plane_min_size,
-                                      FrameBuffer* frame_buffer);
+extern "C" Libgav1StatusCode OnInternalFrameBufferSizeChanged(
+    void* callback_private_data, int bitdepth, Libgav1ImageFormat image_format,
+    int width, int height, int left_border, int right_border, int top_border,
+    int bottom_border, int stride_alignment);
 
-extern "C" int ReleaseInternalFrameBuffer(void* private_data,
-                                          FrameBuffer* frame_buffer);
+extern "C" Libgav1StatusCode GetInternalFrameBuffer(
+    void* callback_private_data, int bitdepth, Libgav1ImageFormat image_format,
+    int width, int height, int left_border, int right_border, int top_border,
+    int bottom_border, int stride_alignment, Libgav1FrameBuffer* frame_buffer);
+
+extern "C" void ReleaseInternalFrameBuffer(void* callback_private_data,
+                                           void* buffer_private_data);
 
 class InternalFrameBufferList : public Allocable {
  public:
-  static std::unique_ptr<InternalFrameBufferList> Create(size_t num_buffers);
+  InternalFrameBufferList() = default;
 
   // Not copyable or movable.
   InternalFrameBufferList(const InternalFrameBufferList&) = delete;
@@ -44,9 +50,21 @@
 
   ~InternalFrameBufferList() = default;
 
-  int GetFrameBuffer(size_t y_plane_min_size, size_t uv_plane_min_size,
-                     FrameBuffer* frame_buffer);
-  int ReleaseFrameBuffer(FrameBuffer* frame_buffer);
+  Libgav1StatusCode OnFrameBufferSizeChanged(int bitdepth,
+                                             Libgav1ImageFormat image_format,
+                                             int width, int height,
+                                             int left_border, int right_border,
+                                             int top_border, int bottom_border,
+                                             int stride_alignment);
+
+  Libgav1StatusCode GetFrameBuffer(int bitdepth,
+                                   Libgav1ImageFormat image_format, int width,
+                                   int height, int left_border,
+                                   int right_border, int top_border,
+                                   int bottom_border, int stride_alignment,
+                                   Libgav1FrameBuffer* frame_buffer);
+
+  void ReleaseFrameBuffer(void* buffer_private_data);
 
  private:
   struct Buffer : public Allocable {
@@ -55,11 +73,7 @@
     bool in_use = false;
   };
 
-  InternalFrameBufferList(std::unique_ptr<Buffer[]> buffers,
-                          size_t num_buffers);
-
-  const std::unique_ptr<Buffer[]> buffers_;
-  const size_t num_buffers_;
+  Vector<std::unique_ptr<Buffer>> buffers_;
 };
 
 }  // namespace libgav1
diff --git a/libgav1/src/libgav1_decoder.cmake b/libgav1/src/libgav1_decoder.cmake
new file mode 100644
index 0000000..b97d09d
--- /dev/null
+++ b/libgav1/src/libgav1_decoder.cmake
@@ -0,0 +1,157 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_SRC_LIBGAV1_DECODER_CMAKE_)
+  return()
+endif() # LIBGAV1_SRC_LIBGAV1_DECODER_CMAKE_
+set(LIBGAV1_SRC_LIBGAV1_DECODER_CMAKE_ 1)
+
+list(APPEND libgav1_decoder_sources
+            "${libgav1_source}/buffer_pool.cc"
+            "${libgav1_source}/buffer_pool.h"
+            "${libgav1_source}/decoder_impl.cc"
+            "${libgav1_source}/decoder_impl.h"
+            "${libgav1_source}/decoder_state.h"
+            "${libgav1_source}/tile_scratch_buffer.cc"
+            "${libgav1_source}/tile_scratch_buffer.h"
+            "${libgav1_source}/film_grain.cc"
+            "${libgav1_source}/film_grain.h"
+            "${libgav1_source}/frame_buffer.cc"
+            "${libgav1_source}/frame_buffer_utils.h"
+            "${libgav1_source}/frame_scratch_buffer.h"
+            "${libgav1_source}/inter_intra_masks.inc"
+            "${libgav1_source}/internal_frame_buffer_list.cc"
+            "${libgav1_source}/internal_frame_buffer_list.h"
+            "${libgav1_source}/loop_restoration_info.cc"
+            "${libgav1_source}/loop_restoration_info.h"
+            "${libgav1_source}/motion_vector.cc"
+            "${libgav1_source}/motion_vector.h"
+            "${libgav1_source}/obu_parser.cc"
+            "${libgav1_source}/obu_parser.h"
+            "${libgav1_source}/post_filter/cdef.cc"
+            "${libgav1_source}/post_filter/deblock.cc"
+            "${libgav1_source}/post_filter/deblock_thresholds.inc"
+            "${libgav1_source}/post_filter/loop_restoration.cc"
+            "${libgav1_source}/post_filter/post_filter.cc"
+            "${libgav1_source}/post_filter/super_res.cc"
+            "${libgav1_source}/post_filter.h"
+            "${libgav1_source}/prediction_mask.cc"
+            "${libgav1_source}/prediction_mask.h"
+            "${libgav1_source}/quantizer.cc"
+            "${libgav1_source}/quantizer.h"
+            "${libgav1_source}/quantizer_tables.inc"
+            "${libgav1_source}/reconstruction.cc"
+            "${libgav1_source}/reconstruction.h"
+            "${libgav1_source}/residual_buffer_pool.cc"
+            "${libgav1_source}/residual_buffer_pool.h"
+            "${libgav1_source}/scan_tables.inc"
+            "${libgav1_source}/symbol_decoder_context.cc"
+            "${libgav1_source}/symbol_decoder_context.h"
+            "${libgav1_source}/symbol_decoder_context_cdfs.inc"
+            "${libgav1_source}/threading_strategy.cc"
+            "${libgav1_source}/threading_strategy.h"
+            "${libgav1_source}/tile.h"
+            "${libgav1_source}/tile/bitstream/mode_info.cc"
+            "${libgav1_source}/tile/bitstream/palette.cc"
+            "${libgav1_source}/tile/bitstream/partition.cc"
+            "${libgav1_source}/tile/bitstream/transform_size.cc"
+            "${libgav1_source}/tile/prediction.cc"
+            "${libgav1_source}/tile/tile.cc"
+            "${libgav1_source}/warp_prediction.cc"
+            "${libgav1_source}/warp_prediction.h"
+            "${libgav1_source}/yuv_buffer.cc"
+            "${libgav1_source}/yuv_buffer.h")
+
+list(APPEND libgav1_api_includes "${libgav1_source}/gav1/decoder.h"
+            "${libgav1_source}/gav1/decoder_buffer.h"
+            "${libgav1_source}/gav1/decoder_settings.h"
+            "${libgav1_source}/gav1/frame_buffer.h"
+            "${libgav1_source}/gav1/status_code.h"
+            "${libgav1_source}/gav1/symbol_visibility.h"
+            "${libgav1_source}/gav1/version.h")
+
+list(APPEND libgav1_api_sources "${libgav1_source}/decoder.cc"
+            "${libgav1_source}/decoder_settings.cc"
+            "${libgav1_source}/status_code.cc"
+            "${libgav1_source}/version.cc"
+            ${libgav1_api_includes})
+
+macro(libgav1_add_decoder_targets)
+  if(BUILD_SHARED_LIBS)
+    if(MSVC OR WIN32)
+      # In order to produce a DLL and import library the Windows tools require
+      # that the exported symbols are part of the DLL target. The unfortunate
+      # side effect of this is that a single configuration cannot output both
+      # the static library and the DLL: This results in an either/or situation.
+      # Windows users of the libgav1 build can have a DLL and an import library,
+      # or they can have a static library; they cannot have both from a single
+      # configuration of the build.
+      list(APPEND libgav1_shared_lib_sources ${libgav1_api_sources})
+      list(APPEND libgav1_static_lib_sources ${libgav1_api_includes})
+    else()
+      list(APPEND libgav1_shared_lib_sources ${libgav1_api_includes})
+      list(APPEND libgav1_static_lib_sources ${libgav1_api_sources})
+    endif()
+  else()
+    list(APPEND libgav1_static_lib_sources ${libgav1_api_sources})
+  endif()
+
+  if(NOT ANDROID)
+    list(APPEND libgav1_absl_deps absl::base absl::synchronization)
+  endif()
+
+  libgav1_add_library(NAME libgav1_decoder TYPE OBJECT SOURCES
+                      ${libgav1_decoder_sources} DEFINES ${libgav1_defines}
+                      INCLUDES ${libgav1_include_paths})
+
+  libgav1_add_library(NAME
+                      libgav1_static
+                      OUTPUT_NAME
+                      libgav1
+                      TYPE
+                      STATIC
+                      SOURCES
+                      ${libgav1_static_lib_sources}
+                      DEFINES
+                      ${libgav1_defines}
+                      INCLUDES
+                      ${libgav1_include_paths}
+                      LIB_DEPS
+                      ${libgav1_absl_deps}
+                      OBJLIB_DEPS
+                      libgav1_dsp
+                      libgav1_decoder
+                      libgav1_utils
+                      PUBLIC_INCLUDES
+                      ${libgav1_source})
+
+  if(BUILD_SHARED_LIBS)
+    libgav1_add_library(NAME
+                        libgav1_shared
+                        OUTPUT_NAME
+                        libgav1
+                        TYPE
+                        SHARED
+                        SOURCES
+                        ${libgav1_shared_lib_sources}
+                        DEFINES
+                        ${libgav1_defines}
+                        INCLUDES
+                        ${libgav1_include_paths}
+                        LIB_DEPS
+                        libgav1_static
+                        PUBLIC_INCLUDES
+                        ${libgav1_source})
+  endif()
+endmacro()
diff --git a/libgav1/src/loop_filter_mask.cc b/libgav1/src/loop_filter_mask.cc
deleted file mode 100644
index 86b454e..0000000
--- a/libgav1/src/loop_filter_mask.cc
+++ /dev/null
@@ -1,205 +0,0 @@
-// Copyright 2019 The libgav1 Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "src/loop_filter_mask.h"
-
-#include <algorithm>
-#include <cstdint>
-#include <cstring>
-#include <memory>
-#include <new>
-
-#include "src/utils/array_2d.h"
-
-namespace libgav1 {
-
-// static.
-constexpr BitMaskSet LoopFilterMask::kPredictionModeDeltasMask;
-
-bool LoopFilterMask::Reset(int width, int height) {
-  num_64x64_blocks_per_row_ = DivideBy64(width + 63);
-  num_64x64_blocks_per_column_ = DivideBy64(height + 63);
-  const int num_64x64_blocks =
-      num_64x64_blocks_per_row_ * num_64x64_blocks_per_column_;
-  if (num_64x64_blocks_ == -1 || num_64x64_blocks_ < num_64x64_blocks) {
-    // Note that this need not be zero initialized here since we zero
-    // initialize the required number of entries in the loop that follows.
-    loop_filter_masks_.reset(new (std::nothrow)
-                                 Data[num_64x64_blocks]);  // NOLINT.
-    if (loop_filter_masks_ == nullptr) {
-      return false;
-    }
-  }
-  for (int i = 0; i < num_64x64_blocks; ++i) {
-    memset(&loop_filter_masks_[i], 0, sizeof(loop_filter_masks_[i]));
-  }
-  num_64x64_blocks_ = num_64x64_blocks;
-  return true;
-}
-
-void LoopFilterMask::Build(
-    const ObuSequenceHeader& sequence_header,
-    const ObuFrameHeader& frame_header, int tile_group_start,
-    int tile_group_end, const BlockParametersHolder& block_parameters_holder,
-    const Array2D<TransformSize>& inter_transform_sizes) {
-  for (int tile_number = tile_group_start; tile_number <= tile_group_end;
-       ++tile_number) {
-    const int row = tile_number / frame_header.tile_info.tile_columns;
-    const int column = tile_number % frame_header.tile_info.tile_columns;
-    const int row4x4_start = frame_header.tile_info.tile_row_start[row];
-    const int row4x4_end = frame_header.tile_info.tile_row_start[row + 1];
-    const int column4x4_start =
-        frame_header.tile_info.tile_column_start[column];
-    const int column4x4_end =
-        frame_header.tile_info.tile_column_start[column + 1];
-
-    const int num_planes = sequence_header.color_config.is_monochrome
-                               ? kMaxPlanesMonochrome
-                               : kMaxPlanes;
-    for (int plane = kPlaneY; plane < num_planes; ++plane) {
-      // For U and V planes, do not build bit masks if level == 0.
-      if (plane > kPlaneY && frame_header.loop_filter.level[plane + 1] == 0) {
-        continue;
-      }
-      const int8_t subsampling_x =
-          (plane == kPlaneY) ? 0 : sequence_header.color_config.subsampling_x;
-      const int8_t subsampling_y =
-          (plane == kPlaneY) ? 0 : sequence_header.color_config.subsampling_y;
-      const int vertical_step = 1 << subsampling_y;
-      const int horizontal_step = 1 << subsampling_x;
-
-      // Build bit masks for vertical edges (except the frame boundary).
-      if (column4x4_start != 0) {
-        const int plane_height =
-            RightShiftWithRounding(frame_header.height, subsampling_y);
-        const int row4x4_limit =
-            std::min(row4x4_end, DivideBy4(plane_height + 3) << subsampling_y);
-        const int vertical_level_index =
-            kDeblockFilterLevelIndex[plane][kLoopFilterTypeVertical];
-        for (int row4x4 = GetDeblockPosition(row4x4_start, subsampling_y);
-             row4x4 < row4x4_limit; row4x4 += vertical_step) {
-          const int column4x4 =
-              GetDeblockPosition(column4x4_start, subsampling_x);
-          const BlockParameters& bp =
-              *block_parameters_holder.Find(row4x4, column4x4);
-          const uint8_t vertical_level =
-              bp.deblock_filter_level[vertical_level_index];
-          const BlockParameters& bp_left = *block_parameters_holder.Find(
-              row4x4, column4x4 - horizontal_step);
-          const uint8_t left_level =
-              bp_left.deblock_filter_level[vertical_level_index];
-          const int unit_id = DivideBy16(row4x4) * num_64x64_blocks_per_row_ +
-                              DivideBy16(column4x4);
-          const int row = row4x4 % kNum4x4InLoopFilterMaskUnit;
-          const int column = column4x4 % kNum4x4InLoopFilterMaskUnit;
-          const int shift = LoopFilterMask::GetShift(row, column);
-          const int index = LoopFilterMask::GetIndex(row);
-          const auto mask = static_cast<uint64_t>(1) << shift;
-          // Tile boundary must be coding block boundary. So we don't have to
-          // check (!left_skip || !skip || is_vertical_border).
-          if (vertical_level != 0 || left_level != 0) {
-            assert(inter_transform_sizes[row4x4] != nullptr);
-            const TransformSize tx_size =
-                (plane == kPlaneY) ? inter_transform_sizes[row4x4][column4x4]
-                                   : bp.uv_transform_size;
-            const TransformSize left_tx_size =
-                (plane == kPlaneY)
-                    ? inter_transform_sizes[row4x4][column4x4 - horizontal_step]
-                    : bp_left.uv_transform_size;
-            const LoopFilterTransformSizeId transform_size_id =
-                GetTransformSizeIdWidth(tx_size, left_tx_size);
-            SetLeft(mask, unit_id, plane, transform_size_id, index);
-            const uint8_t current_level =
-                (vertical_level == 0) ? left_level : vertical_level;
-            SetLevel(current_level, unit_id, plane, kLoopFilterTypeVertical,
-                     LoopFilterMask::GetLevelOffset(row, column));
-          }
-        }
-      }
-
-      // Build bit masks for horizontal edges (except the frame boundary).
-      if (row4x4_start != 0) {
-        const int plane_width =
-            RightShiftWithRounding(frame_header.width, subsampling_x);
-        const int column4x4_limit = std::min(
-            column4x4_end, DivideBy4(plane_width + 3) << subsampling_y);
-        const int horizontal_level_index =
-            kDeblockFilterLevelIndex[plane][kLoopFilterTypeHorizontal];
-        for (int column4x4 = GetDeblockPosition(column4x4_start, subsampling_x);
-             column4x4 < column4x4_limit; column4x4 += horizontal_step) {
-          const int row4x4 = GetDeblockPosition(row4x4_start, subsampling_y);
-          const BlockParameters& bp =
-              *block_parameters_holder.Find(row4x4, column4x4);
-          const uint8_t horizontal_level =
-              bp.deblock_filter_level[horizontal_level_index];
-          const BlockParameters& bp_top =
-              *block_parameters_holder.Find(row4x4 - vertical_step, column4x4);
-          const uint8_t top_level =
-              bp_top.deblock_filter_level[horizontal_level_index];
-          const int unit_id = DivideBy16(row4x4) * num_64x64_blocks_per_row_ +
-                              DivideBy16(column4x4);
-          const int row = row4x4 % kNum4x4InLoopFilterMaskUnit;
-          const int column = column4x4 % kNum4x4InLoopFilterMaskUnit;
-          const int shift = LoopFilterMask::GetShift(row, column);
-          const int index = LoopFilterMask::GetIndex(row);
-          const auto mask = static_cast<uint64_t>(1) << shift;
-          // Tile boundary must be coding block boundary. So we don't have to
-          // check (!top_skip || !skip || is_horizontal_border).
-          if (horizontal_level != 0 || top_level != 0) {
-            assert(inter_transform_sizes[row4x4] != nullptr);
-            const TransformSize tx_size =
-                (plane == kPlaneY) ? inter_transform_sizes[row4x4][column4x4]
-                                   : bp.uv_transform_size;
-            const TransformSize top_tx_size =
-                (plane == kPlaneY)
-                    ? inter_transform_sizes[row4x4 - vertical_step][column4x4]
-                    : bp_top.uv_transform_size;
-            const LoopFilterTransformSizeId transform_size_id =
-                static_cast<LoopFilterTransformSizeId>(
-                    std::min({kTransformHeightLog2[tx_size] - 2,
-                              kTransformHeightLog2[top_tx_size] - 2, 2}));
-            SetTop(mask, unit_id, plane, transform_size_id, index);
-            const uint8_t current_level =
-                (horizontal_level == 0) ? top_level : horizontal_level;
-            SetLevel(current_level, unit_id, plane, kLoopFilterTypeHorizontal,
-                     LoopFilterMask::GetLevelOffset(row, column));
-          }
-        }
-      }
-    }
-  }
-  assert(IsValid());
-}
-
-bool LoopFilterMask::IsValid() const {
-  for (int mask_id = 0; mask_id < num_64x64_blocks_; ++mask_id) {
-    for (int plane = 0; plane < kMaxPlanes; ++plane) {
-      for (int i = 0; i < kNumLoopFilterTransformSizeIds; ++i) {
-        for (int j = i + 1; j < kNumLoopFilterTransformSizeIds; ++j) {
-          for (int k = 0; k < kNumLoopFilterMasks; ++k) {
-            if ((loop_filter_masks_[mask_id].left[plane][i][k] &
-                 loop_filter_masks_[mask_id].left[plane][j][k]) != 0 ||
-                (loop_filter_masks_[mask_id].top[plane][i][k] &
-                 loop_filter_masks_[mask_id].top[plane][j][k]) != 0) {
-              return false;
-            }
-          }
-        }
-      }
-    }
-  }
-  return true;
-}
-
-}  // namespace libgav1
diff --git a/libgav1/src/loop_filter_mask.h b/libgav1/src/loop_filter_mask.h
deleted file mode 100644
index 314f020..0000000
--- a/libgav1/src/loop_filter_mask.h
+++ /dev/null
@@ -1,189 +0,0 @@
-/*
- * Copyright 2019 The libgav1 Authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBGAV1_SRC_LOOP_FILTER_MASK_H_
-#define LIBGAV1_SRC_LOOP_FILTER_MASK_H_
-
-#include <array>
-#include <cassert>
-#include <cstdint>
-#include <memory>
-
-#include "src/dsp/constants.h"
-#include "src/dsp/dsp.h"
-#include "src/obu_parser.h"
-#include "src/utils/array_2d.h"
-#include "src/utils/bit_mask_set.h"
-#include "src/utils/block_parameters_holder.h"
-#include "src/utils/common.h"
-#include "src/utils/constants.h"
-#include "src/utils/segmentation.h"
-#include "src/utils/types.h"
-
-namespace libgav1 {
-
-class LoopFilterMask {
- public:
-  // This structure holds loop filter bit masks for a 64x64 block.
-  // 64x64 block contains kNum4x4In64x64 = (64x64 / (4x4) = 256)
-  // 4x4 blocks. It requires kNumLoopFilterMasks = 4 uint64_t to represent them.
-  struct Data : public Allocable {
-    uint8_t level[kMaxPlanes][kNumLoopFilterTypes][kNum4x4In64x64];
-    uint64_t left[kMaxPlanes][kNumLoopFilterTransformSizeIds]
-                 [kNumLoopFilterMasks];
-    uint64_t top[kMaxPlanes][kNumLoopFilterTransformSizeIds]
-                [kNumLoopFilterMasks];
-  };
-
-  LoopFilterMask() = default;
-
-  // Loop filter mask is built and used for each superblock individually.
-  // Thus not copyable/movable.
-  LoopFilterMask(const LoopFilterMask&) = delete;
-  LoopFilterMask& operator=(const LoopFilterMask&) = delete;
-  LoopFilterMask(LoopFilterMask&&) = delete;
-  LoopFilterMask& operator=(LoopFilterMask&&) = delete;
-
-  // Allocates the loop filter masks for the given |width| and
-  // |height| if necessary and zeros out the appropriate number of
-  // entries. Returns true on success.
-  bool Reset(int width, int height);
-
-  // Builds bit masks for tile boundaries.
-  // This function is called after the frame has been decoded so that
-  // information across tiles is available.
-  // Before this function call, bit masks of transform edges other than those
-  // on tile boundaries are built together with tile decoding, in
-  // Tile::BuildBitMask().
-  void Build(const ObuSequenceHeader& sequence_header,
-             const ObuFrameHeader& frame_header, int tile_group_start,
-             int tile_group_end,
-             const BlockParametersHolder& block_parameters_holder,
-             const Array2D<TransformSize>& inter_transform_sizes);
-
-  uint8_t GetLevel(int mask_id, int plane, LoopFilterType type,
-                   int offset) const {
-    return loop_filter_masks_[mask_id].level[plane][type][offset];
-  }
-
-  uint64_t GetLeft(int mask_id, int plane, LoopFilterTransformSizeId tx_size_id,
-                   int index) const {
-    return loop_filter_masks_[mask_id].left[plane][tx_size_id][index];
-  }
-
-  uint64_t GetTop(int mask_id, int plane, LoopFilterTransformSizeId tx_size_id,
-                  int index) const {
-    return loop_filter_masks_[mask_id].top[plane][tx_size_id][index];
-  }
-
-  int num_64x64_blocks_per_row() const { return num_64x64_blocks_per_row_; }
-
-  void SetLeft(uint64_t new_mask, int mask_id, int plane,
-               LoopFilterTransformSizeId transform_size_id, int index) {
-    loop_filter_masks_[mask_id].left[plane][transform_size_id][index] |=
-        new_mask;
-  }
-
-  void SetTop(uint64_t new_mask, int mask_id, int plane,
-              LoopFilterTransformSizeId transform_size_id, int index) {
-    loop_filter_masks_[mask_id].top[plane][transform_size_id][index] |=
-        new_mask;
-  }
-
-  void SetLevel(uint8_t level, int mask_id, int plane, LoopFilterType type,
-                int offset) {
-    loop_filter_masks_[mask_id].level[plane][type][offset] = level;
-  }
-
-  static int GetIndex(int row4x4) { return row4x4 >> 2; }
-
-  static int GetShift(int row4x4, int column4x4) {
-    return ((row4x4 & 3) << 4) | column4x4;
-  }
-
-  static int GetLevelOffset(int row4x4, int column4x4) {
-    assert(row4x4 < 16);
-    assert(column4x4 < 16);
-    return (row4x4 << 4) | column4x4;
-  }
-
-  static constexpr int GetModeId(PredictionMode mode) {
-    return static_cast<int>(kPredictionModeDeltasMask.Contains(mode));
-  }
-
-  // 7.14.5.
-  static void ComputeDeblockFilterLevels(
-      const ObuFrameHeader& frame_header, int segment_id, int level_index,
-      const int8_t delta_lf[kFrameLfCount],
-      uint8_t deblock_filter_levels[kNumReferenceFrameTypes][2]) {
-    const int delta = delta_lf[frame_header.delta_lf.multi ? level_index : 0];
-    uint8_t level = Clip3(frame_header.loop_filter.level[level_index] + delta,
-                          0, kMaxLoopFilterValue);
-    const auto feature = static_cast<SegmentFeature>(
-        kSegmentFeatureLoopFilterYVertical + level_index);
-    level = Clip3(
-        level + frame_header.segmentation.feature_data[segment_id][feature], 0,
-        kMaxLoopFilterValue);
-    if (!frame_header.loop_filter.delta_enabled) {
-      static_assert(sizeof(deblock_filter_levels[0][0]) == 1, "");
-      memset(deblock_filter_levels, level, kNumReferenceFrameTypes * 2);
-      return;
-    }
-    assert(frame_header.loop_filter.delta_enabled);
-    const int shift = level >> 5;
-    deblock_filter_levels[kReferenceFrameIntra][0] = Clip3(
-        level +
-            LeftShift(frame_header.loop_filter.ref_deltas[kReferenceFrameIntra],
-                      shift),
-        0, kMaxLoopFilterValue);
-    // deblock_filter_levels[kReferenceFrameIntra][1] is never used. So it does
-    // not have to be populated.
-    for (int reference_frame = kReferenceFrameIntra + 1;
-         reference_frame < kNumReferenceFrameTypes; ++reference_frame) {
-      for (int mode_id = 0; mode_id < 2; ++mode_id) {
-        deblock_filter_levels[reference_frame][mode_id] = Clip3(
-            level +
-                LeftShift(frame_header.loop_filter.ref_deltas[reference_frame] +
-                              frame_header.loop_filter.mode_deltas[mode_id],
-                          shift),
-            0, kMaxLoopFilterValue);
-      }
-    }
-  }
-
- private:
-  std::unique_ptr<Data[]> loop_filter_masks_;
-  int num_64x64_blocks_ = -1;
-  int num_64x64_blocks_per_row_;
-  int num_64x64_blocks_per_column_;
-
-  // Mask used to determine the index for mode_deltas lookup.
-  static constexpr BitMaskSet kPredictionModeDeltasMask{
-      BitMaskSet(kPredictionModeNearestMv, kPredictionModeNearMv,
-                 kPredictionModeNewMv, kPredictionModeNearestNearestMv,
-                 kPredictionModeNearNearMv, kPredictionModeNearestNewMv,
-                 kPredictionModeNewNearestMv, kPredictionModeNearNewMv,
-                 kPredictionModeNewNearMv, kPredictionModeNewNewMv)};
-
-  // Validates that the loop filter masks at different transform sizes are
-  // mutually exclusive. Only used in an assert. This function will not be used
-  // in optimized builds.
-  bool IsValid() const;
-};
-
-}  // namespace libgav1
-
-#endif  // LIBGAV1_SRC_LOOP_FILTER_MASK_H_
diff --git a/libgav1/src/loop_restoration_info.cc b/libgav1/src/loop_restoration_info.cc
index 23fa1b4..3830836 100644
--- a/libgav1/src/loop_restoration_info.cc
+++ b/libgav1/src/loop_restoration_info.cc
@@ -26,6 +26,7 @@
 #include "src/utils/logging.h"
 
 namespace libgav1 {
+namespace {
 
 // Controls how self guided deltas are read.
 constexpr int kSgrProjReadControl = 4;
@@ -37,40 +38,63 @@
     kLoopRestorationTypeNone, kLoopRestorationTypeWiener,
     kLoopRestorationTypeSgrProj};
 
-bool LoopRestorationInfo::Allocate() {
-  const int num_planes = is_monochrome_ ? kMaxPlanesMonochrome : kMaxPlanes;
+inline int CountLeadingZeroCoefficients(const int16_t* const filter) {
+  int number_zero_coefficients = 0;
+  if (filter[0] == 0) {
+    number_zero_coefficients++;
+    if (filter[1] == 0) {
+      number_zero_coefficients++;
+      if (filter[2] == 0) {
+        number_zero_coefficients++;
+      }
+    }
+  }
+  return number_zero_coefficients;
+}
+
+}  // namespace
+
+bool LoopRestorationInfo::Reset(const LoopRestoration* const loop_restoration,
+                                uint32_t width, uint32_t height,
+                                int8_t subsampling_x, int8_t subsampling_y,
+                                bool is_monochrome) {
+  loop_restoration_ = loop_restoration;
+  subsampling_x_ = subsampling_x;
+  subsampling_y_ = subsampling_y;
+
+  const int num_planes = is_monochrome ? kMaxPlanesMonochrome : kMaxPlanes;
   int total_num_units = 0;
   for (int plane = kPlaneY; plane < num_planes; ++plane) {
-    if (loop_restoration_.type[plane] == kLoopRestorationTypeNone) {
+    if (loop_restoration_->type[plane] == kLoopRestorationTypeNone) {
       plane_needs_filtering_[plane] = false;
       continue;
     }
     plane_needs_filtering_[plane] = true;
-    const int width = (plane == kPlaneY)
-                          ? width_
-                          : RightShiftWithRounding(width_, subsampling_x_);
-    const int height = (plane == kPlaneY)
-                           ? height_
-                           : RightShiftWithRounding(height_, subsampling_y_);
-    num_horizontal_units_[plane] =
-        std::max(1, (width + DivideBy2(loop_restoration_.unit_size[plane])) /
-                        loop_restoration_.unit_size[plane]);
-    num_vertical_units_[plane] =
-        std::max(1, (height + DivideBy2(loop_restoration_.unit_size[plane])) /
-                        loop_restoration_.unit_size[plane]);
+    const int plane_width = (plane == kPlaneY)
+                                ? width
+                                : RightShiftWithRounding(width, subsampling_x_);
+    const int plane_height =
+        (plane == kPlaneY) ? height
+                           : RightShiftWithRounding(height, subsampling_y_);
+    num_horizontal_units_[plane] = std::max(
+        1, (plane_width + DivideBy2(loop_restoration_->unit_size[plane])) /
+               loop_restoration_->unit_size[plane]);
+    num_vertical_units_[plane] = std::max(
+        1, (plane_height + DivideBy2(loop_restoration_->unit_size[plane])) /
+               loop_restoration_->unit_size[plane]);
     num_units_[plane] =
         num_horizontal_units_[plane] * num_vertical_units_[plane];
     total_num_units += num_units_[plane];
   }
   // Allocate the RestorationUnitInfo arrays for all planes in a single heap
   // allocation and divide up the buffer into arrays of the right sizes.
-  loop_restoration_info_buffer_.reset(new (std::nothrow)
-                                          RestorationUnitInfo[total_num_units]);
-  if (loop_restoration_info_buffer_ == nullptr) return false;
+  if (!loop_restoration_info_buffer_.Resize(total_num_units)) {
+    return false;
+  }
   RestorationUnitInfo* loop_restoration_info =
       loop_restoration_info_buffer_.get();
   for (int plane = kPlaneY; plane < num_planes; ++plane) {
-    if (loop_restoration_.type[plane] == kLoopRestorationTypeNone) {
+    if (loop_restoration_->type[plane] == kLoopRestorationTypeNone) {
       continue;
     }
     loop_restoration_info_[plane] = loop_restoration_info;
@@ -87,15 +111,15 @@
   if (!plane_needs_filtering_[plane]) return false;
   const int denominator_column =
       is_superres_scaled
-          ? loop_restoration_.unit_size[plane] * kSuperResScaleNumerator
-          : loop_restoration_.unit_size[plane];
+          ? loop_restoration_->unit_size[plane] * kSuperResScaleNumerator
+          : loop_restoration_->unit_size[plane];
   const int numerator_column =
       is_superres_scaled ? superres_scale_denominator : 1;
   const int pixel_column_start =
       RowOrColumn4x4ToPixel(column4x4, plane, subsampling_x_);
   const int pixel_column_end = RowOrColumn4x4ToPixel(
       column4x4 + kNum4x4BlocksWide[block_size], plane, subsampling_x_);
-  const int unit_row = loop_restoration_.unit_size[plane];
+  const int unit_row = loop_restoration_->unit_size[plane];
   const int pixel_row_start =
       RowOrColumn4x4ToPixel(row4x4, plane, subsampling_y_);
   const int pixel_row_end = RowOrColumn4x4ToPixel(
@@ -120,15 +144,15 @@
     int unit_id,
     std::array<RestorationUnitInfo, kMaxPlanes>* const reference_unit_info) {
   LoopRestorationType unit_restoration_type = kLoopRestorationTypeNone;
-  if (loop_restoration_.type[plane] == kLoopRestorationTypeSwitchable) {
+  if (loop_restoration_->type[plane] == kLoopRestorationTypeSwitchable) {
     unit_restoration_type = kBitstreamRestorationTypeMap
         [reader->ReadSymbol<kRestorationTypeSymbolCount>(
             symbol_decoder_context->restoration_type_cdf)];
-  } else if (loop_restoration_.type[plane] == kLoopRestorationTypeWiener) {
+  } else if (loop_restoration_->type[plane] == kLoopRestorationTypeWiener) {
     const bool use_wiener =
         reader->ReadSymbol(symbol_decoder_context->use_wiener_cdf);
     if (use_wiener) unit_restoration_type = kLoopRestorationTypeWiener;
-  } else if (loop_restoration_.type[plane] == kLoopRestorationTypeSgrProj) {
+  } else if (loop_restoration_->type[plane] == kLoopRestorationTypeSgrProj) {
     const bool use_sgrproj =
         reader->ReadSymbol(symbol_decoder_context->use_sgrproj_cdf);
     if (use_sgrproj) unit_restoration_type = kLoopRestorationTypeSgrProj;
@@ -149,6 +173,7 @@
     if (plane != kPlaneY) {
       loop_restoration_info_[plane][unit_id].wiener_info.filter[i][0] = 0;
     }
+    int sum = 0;
     for (int j = static_cast<int>(plane != kPlaneY); j < kNumWienerCoefficients;
          ++j) {
       const int8_t wiener_min = kWienerTapsMin[j];
@@ -167,7 +192,14 @@
       }
       loop_restoration_info_[plane][unit_id].wiener_info.filter[i][j] = value;
       (*reference_unit_info)[plane].wiener_info.filter[i][j] = value;
+      sum += value;
     }
+    loop_restoration_info_[plane][unit_id].wiener_info.filter[i][3] =
+        128 - 2 * sum;
+    loop_restoration_info_[plane][unit_id]
+        .wiener_info.number_leading_zero_coefficients[i] =
+        CountLeadingZeroCoefficients(
+            loop_restoration_info_[plane][unit_id].wiener_info.filter[i]);
   }
 }
 
@@ -194,13 +226,16 @@
         return;
       }
     } else {
-      multiplier = 0;
-      if (i == 1) {
-        multiplier =
-            Clip3((1 << kSgrProjPrecisionBits) -
-                      (*reference_unit_info)[plane].sgr_proj_info.multiplier[0],
-                  multiplier_min, multiplier_max);
-      }
+      // The range of (*reference_unit_info)[plane].sgr_proj_info.multiplier[0]
+      // from DecodeSignedSubexpWithReference() is [-96, 31], the default is
+      // -32, making Clip3(128 - 31, -32, 95) unnecessary.
+      static constexpr int kMultiplier[2] = {0, 95};
+      multiplier = kMultiplier[i];
+      assert(
+          i == 0 ||
+          Clip3((1 << kSgrProjPrecisionBits) -
+                    (*reference_unit_info)[plane].sgr_proj_info.multiplier[0],
+                multiplier_min, multiplier_max) == kMultiplier[1]);
     }
     loop_restoration_info_[plane][unit_id].sgr_proj_info.multiplier[i] =
         multiplier;
diff --git a/libgav1/src/loop_restoration_info.h b/libgav1/src/loop_restoration_info.h
index c3c3e97..f174b89 100644
--- a/libgav1/src/loop_restoration_info.h
+++ b/libgav1/src/loop_restoration_info.h
@@ -25,7 +25,9 @@
 #include "src/dsp/common.h"
 #include "src/symbol_decoder_context.h"
 #include "src/utils/constants.h"
+#include "src/utils/dynamic_buffer.h"
 #include "src/utils/entropy_decoder.h"
+#include "src/utils/types.h"
 
 namespace libgav1 {
 
@@ -38,15 +40,7 @@
 
 class LoopRestorationInfo {
  public:
-  LoopRestorationInfo(const LoopRestoration& loop_restoration, uint32_t width,
-                      uint32_t height, int8_t subsampling_x,
-                      int8_t subsampling_y, bool is_monochrome)
-      : loop_restoration_(loop_restoration),
-        width_(width),
-        height_(height),
-        subsampling_x_(subsampling_x),
-        subsampling_y_(subsampling_y),
-        is_monochrome_(is_monochrome) {}
+  LoopRestorationInfo() = default;
 
   // Non copyable/movable.
   LoopRestorationInfo(const LoopRestorationInfo&) = delete;
@@ -54,7 +48,9 @@
   LoopRestorationInfo(LoopRestorationInfo&&) = delete;
   LoopRestorationInfo& operator=(LoopRestorationInfo&&) = delete;
 
-  bool Allocate();
+  bool Reset(const LoopRestoration* loop_restoration, uint32_t width,
+             uint32_t height, int8_t subsampling_x, int8_t subsampling_y,
+             bool is_monochrome);
   // Populates the |unit_info| for the super block at |row4x4|, |column4x4|.
   // Returns true on success, false otherwise.
   bool PopulateUnitInfoForSuperBlock(Plane plane, BlockSize block_size,
@@ -75,8 +71,9 @@
       std::array<RestorationUnitInfo, kMaxPlanes>* reference_unit_info);
 
   // Getters.
-  RestorationUnitInfo loop_restoration_info(Plane plane, int unit_id) const {
-    return loop_restoration_info_[plane][unit_id];
+  const RestorationUnitInfo* loop_restoration_info(Plane plane,
+                                                   int unit_id) const {
+    return &loop_restoration_info_[plane][unit_id];
   }
 
   int num_horizontal_units(Plane plane) const {
@@ -92,14 +89,11 @@
   // points to an array of num_units_[plane] elements.
   RestorationUnitInfo* loop_restoration_info_[kMaxPlanes];
   // Owns the memory that loop_restoration_info_[plane] points to.
-  std::unique_ptr<RestorationUnitInfo[]> loop_restoration_info_buffer_;
+  DynamicBuffer<RestorationUnitInfo> loop_restoration_info_buffer_;
   bool plane_needs_filtering_[kMaxPlanes];
-  const LoopRestoration& loop_restoration_;
-  uint32_t width_;
-  uint32_t height_;
+  const LoopRestoration* loop_restoration_;
   int8_t subsampling_x_;
   int8_t subsampling_y_;
-  bool is_monochrome_;
   int num_horizontal_units_[kMaxPlanes];
   int num_vertical_units_[kMaxPlanes];
   int num_units_[kMaxPlanes];
diff --git a/libgav1/src/motion_vector.cc b/libgav1/src/motion_vector.cc
index ff82b43..8223f3d 100644
--- a/libgav1/src/motion_vector.cc
+++ b/libgav1/src/motion_vector.cc
@@ -20,109 +20,73 @@
 #include <cstdlib>
 #include <memory>
 
+#include "src/dsp/dsp.h"
 #include "src/utils/bit_mask_set.h"
 #include "src/utils/common.h"
+#include "src/utils/constants.h"
 #include "src/utils/logging.h"
 
 namespace libgav1 {
 namespace {
 
-constexpr int kMvBorder = 128;
-constexpr int kProjectionMvClamp = 16383;
-constexpr int kProjectionMvMaxVerticalOffset = 0;
-constexpr int kProjectionMvMaxHorizontalOffset = 8;
-constexpr int kInvalidMvValue = -32768;
-
-// Applies the sign of |sign_value| to |value| (and does so without a branch).
-int ApplySign(int value, int sign_value) {
-  static_assert(sizeof(int) == 4, "");
-  // The next three lines are the branch free equivalent of:
-  // return (sign_value > 0) ? value : -value;
-  const int a = sign_value >> 31;
-  const int b = value ^ a;
-  return b - a;
-}
+// Entry at index i is computed as:
+// Clip3(std::max(kBlockWidthPixels[i], kBlockHeightPixels[i], 16, 112)).
+constexpr int kWarpValidThreshold[kMaxBlockSizes] = {
+    16, 16, 16, 16, 16, 16, 32, 16, 16,  16,  32,
+    64, 32, 32, 32, 64, 64, 64, 64, 112, 112, 112};
 
 // 7.10.2.10.
-void LowerMvPrecision(const Tile::Block& block, int* const mv) {
-  assert(mv != nullptr);
-  if (block.tile.frame_header().allow_high_precision_mv) return;
-  for (int i = 0; i < 2; ++i) {
-    if (block.tile.frame_header().force_integer_mv != 0) {
-      mv[i] = ApplySign(MultiplyBy8(DivideBy8(std::abs(mv[i]) + 3)), mv[i]);
-    } else {
-      if ((mv[i] & 1) != 0) {
-        // The next line is equivalent to:
-        // if (mv[i] > 0) { --mv[i]; } else { ++mv[i]; }
-        mv[i] += ApplySign(-1, mv[i]);
-      }
+void LowerMvPrecision(const ObuFrameHeader& frame_header,
+                      MotionVector* const mvs) {
+  if (frame_header.allow_high_precision_mv) return;
+  if (frame_header.force_integer_mv != 0) {
+    for (auto& mv : mvs->mv) {
+      // The next line is equivalent to:
+      // const int value = (std::abs(static_cast<int>(mv)) + 3) & ~7;
+      // const int sign = mv >> 15;
+      // mv = ApplySign(value, sign);
+      mv = (mv + 3 - (mv >> 15)) & ~7;
     }
-  }
-}
-
-constexpr int16_t kDivisionLookup[32] = {
-    0,    16384, 8192, 5461, 4096, 3276, 2730, 2340, 2048, 1820, 1638,
-    1489, 1365,  1260, 1170, 1092, 1024, 963,  910,  862,  819,  780,
-    744,  712,   682,  655,  630,  606,  585,  564,  546,  528};
-
-// 7.9.3.
-void GetMvProjection(const MotionVector& mv, int numerator, int denominator,
-                     MotionVector* const projection_mv) {
-  denominator = std::min(denominator, static_cast<int>(kMaxFrameDistance));
-  numerator = Clip3(numerator, -kMaxFrameDistance, kMaxFrameDistance);
-  for (int i = 0; i < 2; ++i) {
-    projection_mv->mv[i] =
-        Clip3(RightShiftWithRoundingSigned(
-                  mv.mv[i] * numerator * kDivisionLookup[denominator], 14),
-              -kProjectionMvClamp, kProjectionMvClamp);
-  }
-}
-
-// 7.9.3. (without the Clip3).
-void GetMvProjectionNoClamp(const MotionVector& mv, int numerator,
-                            int denominator,
-                            MotionVector* const projection_mv) {
-  assert(std::abs(numerator) <= kMaxFrameDistance);
-  assert(denominator <= kMaxFrameDistance);
-  for (int i = 0; i < 2; ++i) {
-    projection_mv->mv[i] = RightShiftWithRoundingSigned(
-        mv.mv[i] * numerator * kDivisionLookup[denominator], 14);
+  } else {
+    for (auto& mv : mvs->mv) {
+      // The next line is equivalent to:
+      // if ((mv & 1) != 0) mv += (mv > 0) ? -1 : 1;
+      mv = (mv - (mv >> 15)) & ~1;
+    }
   }
 }
 
 // 7.10.2.1.
 void SetupGlobalMv(const Tile::Block& block, int index,
                    MotionVector* const mv) {
-  const auto& bp = block.parameters();
+  const BlockParameters& bp = *block.bp;
+  const ObuFrameHeader& frame_header = block.tile.frame_header();
   ReferenceFrameType reference_type = bp.reference_frame[index];
-  const auto& gm = block.tile.frame_header().global_motion[reference_type];
+  const auto& gm = frame_header.global_motion[reference_type];
   GlobalMotionTransformationType global_motion_type =
       (reference_type != kReferenceFrameIntra)
           ? gm.type
           : kNumGlobalMotionTransformationTypes;
   if (reference_type == kReferenceFrameIntra ||
       global_motion_type == kGlobalMotionTransformationTypeIdentity) {
-    mv->mv[MotionVector::kRow] = 0;
-    mv->mv[MotionVector::kColumn] = 0;
+    mv->mv32 = 0;
     return;
   }
   if (global_motion_type == kGlobalMotionTransformationTypeTranslation) {
     for (int i = 0; i < 2; ++i) {
       mv->mv[i] = gm.params[i] >> (kWarpedModelPrecisionBits - 3);
     }
-    LowerMvPrecision(block, mv->mv);
+    LowerMvPrecision(frame_header, mv);
     return;
   }
-  const int x = MultiplyBy4(block.column4x4) +
-                DivideBy2(kBlockWidthPixels[block.size]) - 1;
-  const int y =
-      MultiplyBy4(block.row4x4) + DivideBy2(kBlockHeightPixels[block.size]) - 1;
+  const int x = MultiplyBy4(block.column4x4) + DivideBy2(block.width) - 1;
+  const int y = MultiplyBy4(block.row4x4) + DivideBy2(block.height) - 1;
   const int xc = (gm.params[2] - (1 << kWarpedModelPrecisionBits)) * x +
                  gm.params[3] * y + gm.params[0];
   const int yc = gm.params[4] * x +
                  (gm.params[5] - (1 << kWarpedModelPrecisionBits)) * y +
                  gm.params[1];
-  if (block.tile.frame_header().allow_high_precision_mv) {
+  if (frame_header.allow_high_precision_mv) {
     mv->mv[MotionVector::kRow] =
         RightShiftWithRoundingSigned(yc, kWarpedModelPrecisionBits - 3);
     mv->mv[MotionVector::kColumn] =
@@ -132,7 +96,7 @@
         RightShiftWithRoundingSigned(yc, kWarpedModelPrecisionBits - 2));
     mv->mv[MotionVector::kColumn] = MultiplyBy2(
         RightShiftWithRoundingSigned(xc, kWarpedModelPrecisionBits - 2));
-    LowerMvPrecision(block, mv->mv);
+    LowerMvPrecision(frame_header, mv);
   }
 }
 
@@ -144,255 +108,277 @@
                                               kPredictionModeNewNearestMv);
 
 // 7.10.2.8.
-void SearchStackSingle(const Tile::Block& block, int row, int column, int index,
-                       int weight, MotionVector global_mv_candidate[2],
-                       bool* const found_new_mv, bool* const found_match,
-                       int* const num_mv_found,
-                       CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize]) {
-  const auto& bp = block.tile.Parameters(row, column);
-  const PredictionMode candidate_mode = bp.y_mode;
-  const BlockSize candidate_size = bp.size;
+void SearchStack(const Tile::Block& block, const BlockParameters& mv_bp,
+                 int index, int weight, bool* const found_new_mv,
+                 bool* const found_match, int* const num_mv_found) {
+  const BlockParameters& bp = *block.bp;
+  const std::array<GlobalMotion, kNumReferenceFrameTypes>& global_motion =
+      block.tile.frame_header().global_motion;
+  PredictionParameters& prediction_parameters = *bp.prediction_parameters;
   MotionVector candidate_mv;
-  const auto global_motion_type =
-      block.tile.frame_header()
-          .global_motion[block.parameters().reference_frame[0]]
-          .type;
-  if (IsGlobalMvBlock(candidate_mode, global_motion_type, candidate_size)) {
-    candidate_mv = global_mv_candidate[0];
+  // LowerMvPrecision() is not necessary, since the values in
+  // |prediction_parameters.global_mv| and |mv_bp.mv| were generated by it.
+  const auto global_motion_type = global_motion[bp.reference_frame[0]].type;
+  if (IsGlobalMvBlock(mv_bp.is_global_mv_block, global_motion_type)) {
+    candidate_mv = prediction_parameters.global_mv[0];
   } else {
-    candidate_mv = bp.mv[index];
+    candidate_mv = mv_bp.mv.mv[index];
   }
-  LowerMvPrecision(block, candidate_mv.mv);
-  *found_new_mv |= kPredictionModeNewMvMask.Contains(candidate_mode);
+  *found_new_mv |= kPredictionModeNewMvMask.Contains(mv_bp.y_mode);
   *found_match = true;
-  for (int i = 0; i < *num_mv_found; ++i) {
-    if (ref_mv_stack[i].mv[0] == candidate_mv) {
-      ref_mv_stack[i].weight += weight;
-      return;
-    }
+  MotionVector* const ref_mv_stack = prediction_parameters.ref_mv_stack;
+  const auto result = std::find_if(ref_mv_stack, ref_mv_stack + *num_mv_found,
+                                   [&candidate_mv](const MotionVector& ref_mv) {
+                                     return ref_mv == candidate_mv;
+                                   });
+  if (result != ref_mv_stack + *num_mv_found) {
+    prediction_parameters.IncreaseWeight(std::distance(ref_mv_stack, result),
+                                         weight);
+    return;
   }
   if (*num_mv_found >= kMaxRefMvStackSize) return;
-  ref_mv_stack[*num_mv_found] = {{candidate_mv, {}}, weight};
+  ref_mv_stack[*num_mv_found] = candidate_mv;
+  prediction_parameters.SetWeightIndexStackEntry(*num_mv_found, weight);
   ++*num_mv_found;
 }
 
 // 7.10.2.9.
-void SearchStackCompound(
-    const Tile::Block& block, int row, int column, int weight,
-    MotionVector global_mv_candidate[2], bool* const found_new_mv,
-    bool* const found_match, int* const num_mv_found,
-    CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize]) {
-  const auto& bp = block.tile.Parameters(row, column);
-  const PredictionMode candidate_mode = bp.y_mode;
-  const BlockSize candidate_size = bp.size;
-  MotionVector candidate_mv[2];
+void CompoundSearchStack(const Tile::Block& block, const BlockParameters& mv_bp,
+                         int weight, bool* const found_new_mv,
+                         bool* const found_match, int* const num_mv_found) {
+  const BlockParameters& bp = *block.bp;
+  const std::array<GlobalMotion, kNumReferenceFrameTypes>& global_motion =
+      block.tile.frame_header().global_motion;
+  PredictionParameters& prediction_parameters = *bp.prediction_parameters;
+  // LowerMvPrecision() is not necessary, since the values in
+  // |prediction_parameters.global_mv| and |mv_bp.mv| were generated by it.
+  CompoundMotionVector candidate_mv = mv_bp.mv;
   for (int i = 0; i < 2; ++i) {
-    const auto global_motion_type =
-        block.tile.frame_header()
-            .global_motion[block.parameters().reference_frame[i]]
-            .type;
-    if (IsGlobalMvBlock(candidate_mode, global_motion_type, candidate_size)) {
-      candidate_mv[i] = global_mv_candidate[i];
-    } else {
-      candidate_mv[i] = bp.mv[i];
+    const auto global_motion_type = global_motion[bp.reference_frame[i]].type;
+    if (IsGlobalMvBlock(mv_bp.is_global_mv_block, global_motion_type)) {
+      candidate_mv.mv[i] = prediction_parameters.global_mv[i];
     }
-    LowerMvPrecision(block, candidate_mv[i].mv);
   }
-  *found_new_mv |= kPredictionModeNewMvMask.Contains(candidate_mode);
+  *found_new_mv |= kPredictionModeNewMvMask.Contains(mv_bp.y_mode);
   *found_match = true;
-  for (int i = 0; i < *num_mv_found; ++i) {
-    if (ref_mv_stack[i].mv[0] == candidate_mv[0] &&
-        ref_mv_stack[i].mv[1] == candidate_mv[1]) {
-      ref_mv_stack[i].weight += weight;
-      return;
-    }
+  CompoundMotionVector* const compound_ref_mv_stack =
+      prediction_parameters.compound_ref_mv_stack;
+  const auto result =
+      std::find_if(compound_ref_mv_stack, compound_ref_mv_stack + *num_mv_found,
+                   [&candidate_mv](const CompoundMotionVector& ref_mv) {
+                     return ref_mv == candidate_mv;
+                   });
+  if (result != compound_ref_mv_stack + *num_mv_found) {
+    prediction_parameters.IncreaseWeight(
+        std::distance(compound_ref_mv_stack, result), weight);
+    return;
   }
   if (*num_mv_found >= kMaxRefMvStackSize) return;
-  ref_mv_stack[*num_mv_found] = {{candidate_mv[0], candidate_mv[1]}, weight};
+  compound_ref_mv_stack[*num_mv_found] = candidate_mv;
+  prediction_parameters.SetWeightIndexStackEntry(*num_mv_found, weight);
   ++*num_mv_found;
 }
 
 // 7.10.2.7.
-void AddReferenceMvCandidate(
-    const Tile::Block& block, int row, int column, bool is_compound, int weight,
-    MotionVector global_mv[2], bool* const found_new_mv,
-    bool* const found_match, int* const num_mv_found,
-    CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize]) {
-  const auto& bp = block.tile.Parameters(row, column);
-  if (!bp.is_inter) return;
+void AddReferenceMvCandidate(const Tile::Block& block,
+                             const BlockParameters& mv_bp, bool is_compound,
+                             int weight, bool* const found_new_mv,
+                             bool* const found_match, int* const num_mv_found) {
+  if (!mv_bp.is_inter) return;
+  const BlockParameters& bp = *block.bp;
   if (is_compound) {
-    if (bp.reference_frame[0] == block.parameters().reference_frame[0] &&
-        bp.reference_frame[1] == block.parameters().reference_frame[1]) {
-      SearchStackCompound(block, row, column, weight, global_mv, found_new_mv,
-                          found_match, num_mv_found, ref_mv_stack);
+    if (mv_bp.reference_frame[0] == bp.reference_frame[0] &&
+        mv_bp.reference_frame[1] == bp.reference_frame[1]) {
+      CompoundSearchStack(block, mv_bp, weight, found_new_mv, found_match,
+                          num_mv_found);
     }
     return;
   }
   for (int i = 0; i < 2; ++i) {
-    if (bp.reference_frame[i] == block.parameters().reference_frame[0]) {
-      SearchStackSingle(block, row, column, i, weight, global_mv, found_new_mv,
-                        found_match, num_mv_found, ref_mv_stack);
+    if (mv_bp.reference_frame[i] == bp.reference_frame[0]) {
+      SearchStack(block, mv_bp, i, weight, found_new_mv, found_match,
+                  num_mv_found);
     }
   }
 }
 
+int GetMinimumStep(int block_width_or_height4x4, int delta_row_or_column) {
+  assert(delta_row_or_column < 0);
+  if (block_width_or_height4x4 >= 16) return 4;
+  if (delta_row_or_column < -1) return 2;
+  return 0;
+}
+
 // 7.10.2.2.
-void ScanRow(const Tile::Block& block, int delta_row, bool is_compound,
-             MotionVector global_mv[2], bool* const found_new_mv,
-             bool* const found_match, int* const num_mv_found,
-             CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize]) {
-  const int block_width4x4 = kNum4x4BlocksWide[block.size];
-  int delta_column = 0;
-  if (std::abs(delta_row) > 1) {
-    delta_row += block.row4x4 & 1;
-    delta_column = 1 - (block.column4x4 & 1);
-  }
-  const int end_mv_column =
-      block.column4x4 + delta_column +
-      std::min({block_width4x4,
-                block.tile.frame_header().columns4x4 - block.column4x4, 16});
+void ScanRow(const Tile::Block& block, int mv_column, int delta_row,
+             bool is_compound, bool* const found_new_mv,
+             bool* const found_match, int* const num_mv_found) {
   const int mv_row = block.row4x4 + delta_row;
-  for (int step, mv_column = block.column4x4 + delta_column;
-       mv_column < end_mv_column; mv_column += step) {
-    if (!block.tile.IsInside(mv_row, mv_column)) break;
-    step = std::min(
-        block_width4x4,
-        static_cast<int>(
-            kNum4x4BlocksWide[block.tile.Parameters(mv_row, mv_column).size]));
-    if (std::abs(delta_row) > 1) step = std::max(step, 2);
-    if (block_width4x4 >= 16) step = std::max(step, 4);
-    AddReferenceMvCandidate(block, mv_row, mv_column, is_compound,
-                            MultiplyBy2(step), global_mv, found_new_mv,
-                            found_match, num_mv_found, ref_mv_stack);
-  }
+  const Tile& tile = block.tile;
+  if (!tile.IsTopInside(mv_row + 1)) return;
+  const int width4x4 = block.width4x4;
+  const int min_step = GetMinimumStep(width4x4, delta_row);
+  BlockParameters** bps = tile.BlockParametersAddress(mv_row, mv_column);
+  BlockParameters** const end_bps =
+      bps + std::min({static_cast<int>(width4x4),
+                      tile.frame_header().columns4x4 - block.column4x4, 16});
+  do {
+    const BlockParameters& mv_bp = **bps;
+    const int step = std::max(
+        std::min(width4x4, static_cast<int>(kNum4x4BlocksWide[mv_bp.size])),
+        min_step);
+    AddReferenceMvCandidate(block, mv_bp, is_compound, MultiplyBy2(step),
+                            found_new_mv, found_match, num_mv_found);
+    bps += step;
+  } while (bps < end_bps);
 }
 
 // 7.10.2.3.
-void ScanColumn(const Tile::Block& block, int delta_column, bool is_compound,
-                MotionVector global_mv[2], bool* const found_new_mv,
-                bool* const found_match, int* const num_mv_found,
-                CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize]) {
-  const int block_height4x4 = kNum4x4BlocksHigh[block.size];
-  int delta_row = 0;
-  if (std::abs(delta_column) > 1) {
-    delta_row = 1 - (block.row4x4 & 1);
-    delta_column += block.column4x4 & 1;
-  }
-  const int end_mv_row =
-      block.row4x4 + delta_row +
-      std::min({block_height4x4,
-                block.tile.frame_header().rows4x4 - block.row4x4, 16});
+void ScanColumn(const Tile::Block& block, int mv_row, int delta_column,
+                bool is_compound, bool* const found_new_mv,
+                bool* const found_match, int* const num_mv_found) {
   const int mv_column = block.column4x4 + delta_column;
-  for (int step, mv_row = block.row4x4 + delta_row; mv_row < end_mv_row;
-       mv_row += step) {
-    if (!block.tile.IsInside(mv_row, mv_column)) break;
-    step = std::min(
-        block_height4x4,
-        static_cast<int>(
-            kNum4x4BlocksHigh[block.tile.Parameters(mv_row, mv_column).size]));
-    if (std::abs(delta_column) > 1) step = std::max(step, 2);
-    if (block_height4x4 >= 16) step = std::max(step, 4);
-    AddReferenceMvCandidate(block, mv_row, mv_column, is_compound,
-                            MultiplyBy2(step), global_mv, found_new_mv,
-                            found_match, num_mv_found, ref_mv_stack);
-  }
+  const Tile& tile = block.tile;
+  if (!tile.IsLeftInside(mv_column + 1)) return;
+  const int height4x4 = block.height4x4;
+  const int min_step = GetMinimumStep(height4x4, delta_column);
+  const ptrdiff_t stride = tile.BlockParametersStride();
+  BlockParameters** bps = tile.BlockParametersAddress(mv_row, mv_column);
+  BlockParameters** const end_bps =
+      bps + stride * std::min({static_cast<int>(height4x4),
+                               tile.frame_header().rows4x4 - block.row4x4, 16});
+  do {
+    const BlockParameters& mv_bp = **bps;
+    const int step = std::max(
+        std::min(height4x4, static_cast<int>(kNum4x4BlocksHigh[mv_bp.size])),
+        min_step);
+    AddReferenceMvCandidate(block, mv_bp, is_compound, MultiplyBy2(step),
+                            found_new_mv, found_match, num_mv_found);
+    bps += step * stride;
+  } while (bps < end_bps);
 }
 
 // 7.10.2.4.
 void ScanPoint(const Tile::Block& block, int delta_row, int delta_column,
-               bool is_compound, MotionVector global_mv[2],
-               bool* const found_new_mv, bool* const found_match,
-               int* const num_mv_found,
-               CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize]) {
+               bool is_compound, bool* const found_new_mv,
+               bool* const found_match, int* const num_mv_found) {
   const int mv_row = block.row4x4 + delta_row;
   const int mv_column = block.column4x4 + delta_column;
-  if (!block.tile.IsInside(mv_row, mv_column) ||
-      !block.tile.HasParameters(mv_row, mv_column) ||
-      block.tile.Parameters(mv_row, mv_column).reference_frame[0] ==
-          kReferenceFrameNone) {
+  const Tile& tile = block.tile;
+  if (!tile.IsInside(mv_row, mv_column) ||
+      !tile.HasParameters(mv_row, mv_column)) {
     return;
   }
-  AddReferenceMvCandidate(block, mv_row, mv_column, is_compound, 4, global_mv,
-                          found_new_mv, found_match, num_mv_found,
-                          ref_mv_stack);
+  const BlockParameters& mv_bp = tile.Parameters(mv_row, mv_column);
+  if (mv_bp.reference_frame[0] == kReferenceFrameNone) return;
+  AddReferenceMvCandidate(block, mv_bp, is_compound, 4, found_new_mv,
+                          found_match, num_mv_found);
 }
 
 // 7.10.2.6.
-//
-// The |zero_mv_context| output parameter may be null. If |zero_mv_context| is
-// not null, the function may set |*zero_mv_context|.
 void AddTemporalReferenceMvCandidate(
-    const Tile::Block& block, int delta_row, int delta_column, bool is_compound,
-    MotionVector global_mv[2],
-    const Array2D<TemporalMotionVector>& motion_field_mv,
+    const ObuFrameHeader& frame_header, const int reference_offsets[2],
+    const MotionVector* const temporal_mvs,
+    const int8_t* const temporal_reference_offsets, int count, bool is_compound,
     int* const zero_mv_context, int* const num_mv_found,
-    CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize]) {
-  const int mv_row = (block.row4x4 + delta_row) | 1;
-  const int mv_column = (block.column4x4 + delta_column) | 1;
-  if (!block.tile.IsInside(mv_row, mv_column)) return;
-  const int x8 = mv_column >> 1;
-  const int y8 = mv_row >> 1;
-  if (zero_mv_context != nullptr && delta_row == 0 && delta_column == 0) {
-    *zero_mv_context = 1;
-  }
-  const auto& bp = block.parameters();
-  const TemporalMotionVector& temporal_mv = motion_field_mv[y8][x8];
-  if (temporal_mv.mv.mv[0] == kInvalidMvValue) return;
+    PredictionParameters* const prediction_parameters) {
+  const int mv_projection_function_index =
+      frame_header.allow_high_precision_mv ? 2 : frame_header.force_integer_mv;
+  const MotionVector* const global_mv = prediction_parameters->global_mv;
   if (is_compound) {
-    MotionVector candidate_mv[2];
-    for (int i = 0; i < 2; ++i) {
-      const int reference_offset = GetRelativeDistance(
-          block.tile.frame_header().order_hint,
-          block.tile.current_frame().order_hint(bp.reference_frame[i]),
-          block.tile.sequence_header().enable_order_hint,
-          block.tile.sequence_header().order_hint_bits);
-      GetMvProjection(temporal_mv.mv, reference_offset,
-                      temporal_mv.reference_offset, &candidate_mv[i]);
-      LowerMvPrecision(block, candidate_mv[i].mv);
+    CompoundMotionVector candidate_mvs[kMaxTemporalMvCandidatesWithPadding];
+    const dsp::Dsp& dsp = *dsp::GetDspTable(8);
+    dsp.mv_projection_compound[mv_projection_function_index](
+        temporal_mvs, temporal_reference_offsets, reference_offsets, count,
+        candidate_mvs);
+    if (*zero_mv_context == -1) {
+      int max_difference =
+          std::max(std::abs(candidate_mvs[0].mv[0].mv[0] - global_mv[0].mv[0]),
+                   std::abs(candidate_mvs[0].mv[0].mv[1] - global_mv[0].mv[1]));
+      max_difference =
+          std::max(max_difference,
+                   std::abs(candidate_mvs[0].mv[1].mv[0] - global_mv[1].mv[0]));
+      max_difference =
+          std::max(max_difference,
+                   std::abs(candidate_mvs[0].mv[1].mv[1] - global_mv[1].mv[1]));
+      *zero_mv_context = static_cast<int>(max_difference >= 16);
     }
-    if (zero_mv_context != nullptr && delta_row == 0 && delta_column == 0) {
-      *zero_mv_context = static_cast<int>(
-          std::abs(candidate_mv[0].mv[0] - global_mv[0].mv[0]) >= 16 ||
-          std::abs(candidate_mv[0].mv[1] - global_mv[0].mv[1]) >= 16 ||
-          std::abs(candidate_mv[1].mv[0] - global_mv[1].mv[0]) >= 16 ||
-          std::abs(candidate_mv[1].mv[1] - global_mv[1].mv[1]) >= 16);
-    }
-    for (int i = 0; i < *num_mv_found; ++i) {
-      if (ref_mv_stack[i].mv[0] == candidate_mv[0] &&
-          ref_mv_stack[i].mv[1] == candidate_mv[1]) {
-        ref_mv_stack[i].weight += 2;
-        return;
+    CompoundMotionVector* const compound_ref_mv_stack =
+        prediction_parameters->compound_ref_mv_stack;
+    int index = 0;
+    do {
+      const CompoundMotionVector& candidate_mv = candidate_mvs[index];
+      const auto result = std::find_if(
+          compound_ref_mv_stack, compound_ref_mv_stack + *num_mv_found,
+          [&candidate_mv](const CompoundMotionVector& ref_mv) {
+            return ref_mv == candidate_mv;
+          });
+      if (result != compound_ref_mv_stack + *num_mv_found) {
+        prediction_parameters->IncreaseWeight(
+            std::distance(compound_ref_mv_stack, result), 2);
+        continue;
       }
+      if (*num_mv_found >= kMaxRefMvStackSize) continue;
+      compound_ref_mv_stack[*num_mv_found] = candidate_mv;
+      prediction_parameters->SetWeightIndexStackEntry(*num_mv_found, 2);
+      ++*num_mv_found;
+    } while (++index < count);
+    return;
+  }
+  MotionVector* const ref_mv_stack = prediction_parameters->ref_mv_stack;
+  if (reference_offsets[0] == 0) {
+    if (*zero_mv_context == -1) {
+      const int max_difference =
+          std::max(std::abs(global_mv[0].mv[0]), std::abs(global_mv[0].mv[1]));
+      *zero_mv_context = static_cast<int>(max_difference >= 16);
+    }
+    const MotionVector candidate_mv = {};
+    const auto result =
+        std::find_if(ref_mv_stack, ref_mv_stack + *num_mv_found,
+                     [&candidate_mv](const MotionVector& ref_mv) {
+                       return ref_mv == candidate_mv;
+                     });
+    if (result != ref_mv_stack + *num_mv_found) {
+      prediction_parameters->IncreaseWeight(std::distance(ref_mv_stack, result),
+                                            2 * count);
+      return;
     }
     if (*num_mv_found >= kMaxRefMvStackSize) return;
-    ref_mv_stack[*num_mv_found] = {{candidate_mv[0], candidate_mv[1]}, 2};
+    ref_mv_stack[*num_mv_found] = candidate_mv;
+    prediction_parameters->SetWeightIndexStackEntry(*num_mv_found, 2 * count);
     ++*num_mv_found;
     return;
   }
-  assert(!is_compound);
-  MotionVector candidate_mv;
-  const int reference_offset = GetRelativeDistance(
-      block.tile.frame_header().order_hint,
-      block.tile.current_frame().order_hint(bp.reference_frame[0]),
-      block.tile.sequence_header().enable_order_hint,
-      block.tile.sequence_header().order_hint_bits);
-  GetMvProjection(temporal_mv.mv, reference_offset,
-                  temporal_mv.reference_offset, &candidate_mv);
-  LowerMvPrecision(block, candidate_mv.mv);
-  if (zero_mv_context != nullptr && delta_row == 0 && delta_column == 0) {
-    *zero_mv_context = static_cast<int>(
-        std::abs(candidate_mv.mv[0] - global_mv[0].mv[0]) >= 16 ||
-        std::abs(candidate_mv.mv[1] - global_mv[0].mv[1]) >= 16);
+  alignas(kMaxAlignment)
+      MotionVector candidate_mvs[kMaxTemporalMvCandidatesWithPadding];
+  const dsp::Dsp& dsp = *dsp::GetDspTable(8);
+  dsp.mv_projection_single[mv_projection_function_index](
+      temporal_mvs, temporal_reference_offsets, reference_offsets[0], count,
+      candidate_mvs);
+  if (*zero_mv_context == -1) {
+    const int max_difference =
+        std::max(std::abs(candidate_mvs[0].mv[0] - global_mv[0].mv[0]),
+                 std::abs(candidate_mvs[0].mv[1] - global_mv[0].mv[1]));
+    *zero_mv_context = static_cast<int>(max_difference >= 16);
   }
-  for (int i = 0; i < *num_mv_found; ++i) {
-    if (ref_mv_stack[i].mv[0] == candidate_mv) {
-      ref_mv_stack[i].weight += 2;
-      return;
+  int index = 0;
+  do {
+    const MotionVector& candidate_mv = candidate_mvs[index];
+    const auto result =
+        std::find_if(ref_mv_stack, ref_mv_stack + *num_mv_found,
+                     [&candidate_mv](const MotionVector& ref_mv) {
+                       return ref_mv == candidate_mv;
+                     });
+    if (result != ref_mv_stack + *num_mv_found) {
+      prediction_parameters->IncreaseWeight(std::distance(ref_mv_stack, result),
+                                            2);
+      continue;
     }
-  }
-  if (*num_mv_found >= kMaxRefMvStackSize) return;
-  ref_mv_stack[*num_mv_found] = {{candidate_mv, {}}, 2};
-  ++*num_mv_found;
+    if (*num_mv_found >= kMaxRefMvStackSize) continue;
+    ref_mv_stack[*num_mv_found] = candidate_mv;
+    prediction_parameters->SetWeightIndexStackEntry(*num_mv_found, 2);
+    ++*num_mv_found;
+  } while (++index < count);
 }
 
 // Part of 7.10.2.5.
@@ -400,63 +386,145 @@
                                int delta_column) {
   const int row = (block.row4x4 & 15) + delta_row;
   const int column = (block.column4x4 & 15) + delta_column;
-  return row >= 0 && row < 16 && column >= 0 && column < 16;
+  // |block.height4x4| is at least 2 for all elements in |kTemporalScanMask|.
+  // So |row| are all non-negative.
+  assert(row >= 0);
+  return row < 16 && column >= 0 && column < 16;
 }
 
+constexpr BitMaskSet kTemporalScanMask(kBlock8x8, kBlock8x16, kBlock8x32,
+                                       kBlock16x8, kBlock16x16, kBlock16x32,
+                                       kBlock32x8, kBlock32x16, kBlock32x32);
+
 // 7.10.2.5.
 //
 // The |zero_mv_context| output parameter may be null. If |zero_mv_context| is
 // not null, the function may set |*zero_mv_context|.
 void TemporalScan(const Tile::Block& block, bool is_compound,
-                  MotionVector global_mv[2],
-                  const Array2D<TemporalMotionVector>& motion_field_mv,
-                  int* const zero_mv_context, int* const num_mv_found,
-                  CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize]) {
-  const int block_width4x4 = kNum4x4BlocksWide[block.size];
-  const int block_height4x4 = kNum4x4BlocksHigh[block.size];
-  const int step_w = (block_width4x4 >= 16) ? 4 : 2;
-  const int step_h = (block_height4x4 >= 16) ? 4 : 2;
-  for (int row = 0; row < std::min(block_height4x4, 16); row += step_h) {
-    for (int column = 0; column < std::min(block_width4x4, 16);
-         column += step_w) {
-      AddTemporalReferenceMvCandidate(
-          block, row, column, is_compound, global_mv, motion_field_mv,
-          zero_mv_context, num_mv_found, ref_mv_stack);
+                  int* const zero_mv_context, int* const num_mv_found) {
+  const int step_w = (block.width4x4 >= 16) ? 4 : 2;
+  const int step_h = (block.height4x4 >= 16) ? 4 : 2;
+  const int row_start = block.row4x4 | 1;
+  const int column_start = block.column4x4 | 1;
+  const int row_end =
+      row_start + std::min(static_cast<int>(block.height4x4), 16);
+  const int column_end =
+      column_start + std::min(static_cast<int>(block.width4x4), 16);
+  const Tile& tile = block.tile;
+  const TemporalMotionField& motion_field = tile.motion_field();
+  const int stride = motion_field.mv.columns();
+  const MotionVector* motion_field_mv = motion_field.mv[0];
+  const int8_t* motion_field_reference_offset =
+      motion_field.reference_offset[0];
+  alignas(kMaxAlignment)
+      MotionVector temporal_mvs[kMaxTemporalMvCandidatesWithPadding];
+  int8_t temporal_reference_offsets[kMaxTemporalMvCandidatesWithPadding];
+  int count = 0;
+  int offset = stride * (row_start >> 1);
+  int mv_row = row_start;
+  do {
+    int mv_column = column_start;
+    do {
+      // Both horizontal and vertical offsets are positive. Only bottom and
+      // right boundaries need to be checked.
+      if (tile.IsBottomRightInside(mv_row, mv_column)) {
+        const int x8 = mv_column >> 1;
+        const MotionVector temporal_mv = motion_field_mv[offset + x8];
+        if (temporal_mv.mv[0] == kInvalidMvValue) {
+          if (mv_row == row_start && mv_column == column_start) {
+            *zero_mv_context = 1;
+          }
+        } else {
+          temporal_mvs[count] = temporal_mv;
+          temporal_reference_offsets[count++] =
+              motion_field_reference_offset[offset + x8];
+        }
+      }
+      mv_column += step_w;
+    } while (mv_column < column_end);
+    offset += stride * step_h >> 1;
+    mv_row += step_h;
+  } while (mv_row < row_end);
+  if (kTemporalScanMask.Contains(block.size)) {
+    const int temporal_sample_positions[3][2] = {
+        {block.height4x4, -2},
+        {block.height4x4, block.width4x4},
+        {block.height4x4 - 2, block.width4x4}};
+    // Getting the address of an element in Array2D is slow. Precalculate the
+    // offsets.
+    int temporal_sample_offsets[3];
+    temporal_sample_offsets[0] = stride * ((row_start + block.height4x4) >> 1) +
+                                 ((column_start - 2) >> 1);
+    temporal_sample_offsets[1] =
+        temporal_sample_offsets[0] + ((block.width4x4 + 2) >> 1);
+    temporal_sample_offsets[2] = temporal_sample_offsets[1] - stride;
+    for (int i = 0; i < 3; i++) {
+      const int row = temporal_sample_positions[i][0];
+      const int column = temporal_sample_positions[i][1];
+      if (!IsWithinTheSame64x64Block(block, row, column)) continue;
+      const int mv_row = row_start + row;
+      const int mv_column = column_start + column;
+      // IsWithinTheSame64x64Block() guarantees the reference block is inside
+      // the top and left boundary.
+      if (!tile.IsBottomRightInside(mv_row, mv_column)) continue;
+      const MotionVector temporal_mv =
+          motion_field_mv[temporal_sample_offsets[i]];
+      if (temporal_mv.mv[0] != kInvalidMvValue) {
+        temporal_mvs[count] = temporal_mv;
+        temporal_reference_offsets[count++] =
+            motion_field_reference_offset[temporal_sample_offsets[i]];
+      }
     }
   }
-  if (block_height4x4 >= kNum4x4BlocksHigh[kBlock8x8] &&
-      block_height4x4 < kNum4x4BlocksHigh[kBlock64x64] &&
-      block_width4x4 >= kNum4x4BlocksWide[kBlock8x8] &&
-      block_width4x4 < kNum4x4BlocksWide[kBlock64x64]) {
-    const int temporal_sample_positions[3][2] = {
-        {block_height4x4, -2},
-        {block_height4x4, block_width4x4},
-        {block_height4x4 - 2, block_width4x4}};
-    for (const auto& temporal_sample_position : temporal_sample_positions) {
-      const int row = temporal_sample_position[0];
-      const int column = temporal_sample_position[1];
-      if (!IsWithinTheSame64x64Block(block, row, column)) continue;
-      AddTemporalReferenceMvCandidate(
-          block, row, column, is_compound, global_mv, motion_field_mv,
-          zero_mv_context, num_mv_found, ref_mv_stack);
+  if (count != 0) {
+    BlockParameters* const bp = block.bp;
+    int reference_offsets[2];
+    const int offset_0 = tile.current_frame()
+                             .reference_info()
+                             ->relative_distance_to[bp->reference_frame[0]];
+    reference_offsets[0] =
+        Clip3(offset_0, -kMaxFrameDistance, kMaxFrameDistance);
+    if (is_compound) {
+      const int offset_1 = tile.current_frame()
+                               .reference_info()
+                               ->relative_distance_to[bp->reference_frame[1]];
+      reference_offsets[1] =
+          Clip3(offset_1, -kMaxFrameDistance, kMaxFrameDistance);
+      // Pad so that SIMD implementations won't read uninitialized memory.
+      if ((count & 1) != 0) {
+        temporal_mvs[count].mv32 = 0;
+        temporal_reference_offsets[count] = 0;
+      }
+    } else {
+      // Pad so that SIMD implementations won't read uninitialized memory.
+      for (int i = count; i < ((count + 3) & ~3); ++i) {
+        temporal_mvs[i].mv32 = 0;
+        temporal_reference_offsets[i] = 0;
+      }
     }
+    AddTemporalReferenceMvCandidate(
+        tile.frame_header(), reference_offsets, temporal_mvs,
+        temporal_reference_offsets, count, is_compound, zero_mv_context,
+        num_mv_found, &(*bp->prediction_parameters));
   }
 }
 
 // Part of 7.10.2.13.
-void AddExtraCompoundMvCandidate(
-    const Tile::Block& block, int mv_row, int mv_column,
-    const std::array<bool, kNumReferenceFrameTypes>& reference_frame_sign_bias,
-    int* const ref_id_count, MotionVector ref_id[2][2],
-    int* const ref_diff_count, MotionVector ref_diff[2][2]) {
+void AddExtraCompoundMvCandidate(const Tile::Block& block, int mv_row,
+                                 int mv_column, int* const ref_id_count,
+                                 MotionVector ref_id[2][2],
+                                 int* const ref_diff_count,
+                                 MotionVector ref_diff[2][2]) {
   const auto& bp = block.tile.Parameters(mv_row, mv_column);
+  const std::array<bool, kNumReferenceFrameTypes>& reference_frame_sign_bias =
+      block.tile.reference_frame_sign_bias();
   for (int i = 0; i < 2; ++i) {
     const ReferenceFrameType candidate_reference_frame = bp.reference_frame[i];
     if (candidate_reference_frame <= kReferenceFrameIntra) continue;
     for (int j = 0; j < 2; ++j) {
-      MotionVector candidate_mv = bp.mv[i];
+      MotionVector candidate_mv = bp.mv.mv[i];
       const ReferenceFrameType block_reference_frame =
-          block.parameters().reference_frame[j];
+          block.bp->reference_frame[j];
       if (candidate_reference_frame == block_reference_frame &&
           ref_id_count[j] < 2) {
         ref_id[j][ref_id_count[j]] = candidate_mv;
@@ -475,142 +543,153 @@
 }
 
 // Part of 7.10.2.13.
-void AddExtraSingleMvCandidate(
-    const Tile::Block& block, int mv_row, int mv_column,
-    const std::array<bool, kNumReferenceFrameTypes>& reference_frame_sign_bias,
-    int* const num_mv_found,
-    CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize]) {
+void AddExtraSingleMvCandidate(const Tile::Block& block, int mv_row,
+                               int mv_column, int* const num_mv_found) {
   const auto& bp = block.tile.Parameters(mv_row, mv_column);
-  const ReferenceFrameType block_reference_frame =
-      block.parameters().reference_frame[0];
+  const std::array<bool, kNumReferenceFrameTypes>& reference_frame_sign_bias =
+      block.tile.reference_frame_sign_bias();
+  const ReferenceFrameType block_reference_frame = block.bp->reference_frame[0];
+  PredictionParameters& prediction_parameters =
+      *block.bp->prediction_parameters;
+  MotionVector* const ref_mv_stack = prediction_parameters.ref_mv_stack;
   for (int i = 0; i < 2; ++i) {
     const ReferenceFrameType candidate_reference_frame = bp.reference_frame[i];
     if (candidate_reference_frame <= kReferenceFrameIntra) continue;
-    MotionVector candidate_mv = bp.mv[i];
+    MotionVector candidate_mv = bp.mv.mv[i];
     if (reference_frame_sign_bias[candidate_reference_frame] !=
         reference_frame_sign_bias[block_reference_frame]) {
       candidate_mv.mv[0] *= -1;
       candidate_mv.mv[1] *= -1;
     }
-    int j = 0;
-    for (; j < *num_mv_found; ++j) {
-      if (candidate_mv == ref_mv_stack[j].mv[0]) {
-        break;
-      }
+    assert(*num_mv_found <= 2);
+    if ((*num_mv_found != 0 && ref_mv_stack[0] == candidate_mv) ||
+        (*num_mv_found == 2 && ref_mv_stack[1] == candidate_mv)) {
+      continue;
     }
-    if (j == *num_mv_found) {
-      ref_mv_stack[*num_mv_found] = {{candidate_mv, {}}, 2};
-      ++*num_mv_found;
-    }
+    ref_mv_stack[*num_mv_found] = candidate_mv;
+    prediction_parameters.SetWeightIndexStackEntry(*num_mv_found, 0);
+    ++*num_mv_found;
   }
 }
 
 // 7.10.2.12.
-void ExtraSearch(
-    const Tile::Block& block, bool is_compound, MotionVector global_mv[2],
-    const std::array<bool, kNumReferenceFrameTypes>& reference_frame_sign_bias,
-    int* const num_mv_found,
-    CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize]) {
-  const int num4x4 =
-      std::min({static_cast<int>(kNum4x4BlocksWide[block.size]),
-                block.tile.frame_header().columns4x4 - block.column4x4,
-                static_cast<int>(kNum4x4BlocksHigh[block.size]),
-                block.tile.frame_header().rows4x4 - block.row4x4, 16});
+void ExtraSearch(const Tile::Block& block, bool is_compound,
+                 int* const num_mv_found) {
+  const Tile& tile = block.tile;
+  const int num4x4 = std::min({static_cast<int>(block.width4x4),
+                               tile.frame_header().columns4x4 - block.column4x4,
+                               static_cast<int>(block.height4x4),
+                               tile.frame_header().rows4x4 - block.row4x4, 16});
   int ref_id_count[2] = {};
   MotionVector ref_id[2][2] = {};
   int ref_diff_count[2] = {};
   MotionVector ref_diff[2][2] = {};
+  PredictionParameters& prediction_parameters =
+      *block.bp->prediction_parameters;
   for (int pass = 0; pass < 2 && *num_mv_found < 2; ++pass) {
     for (int i = 0; i < num4x4;) {
       const int mv_row = block.row4x4 + ((pass == 0) ? -1 : i);
       const int mv_column = block.column4x4 + ((pass == 0) ? i : -1);
-      if (!block.tile.IsInside(mv_row, mv_column)) break;
+      if (!tile.IsTopLeftInside(mv_row + 1, mv_column + 1)) break;
       if (is_compound) {
-        AddExtraCompoundMvCandidate(block, mv_row, mv_column,
-                                    reference_frame_sign_bias, ref_id_count,
+        AddExtraCompoundMvCandidate(block, mv_row, mv_column, ref_id_count,
                                     ref_id, ref_diff_count, ref_diff);
       } else {
-        AddExtraSingleMvCandidate(block, mv_row, mv_column,
-                                  reference_frame_sign_bias, num_mv_found,
-                                  ref_mv_stack);
+        AddExtraSingleMvCandidate(block, mv_row, mv_column, num_mv_found);
         if (*num_mv_found >= 2) break;
       }
-      const auto& bp = block.tile.Parameters(mv_row, mv_column);
+      const auto& bp = tile.Parameters(mv_row, mv_column);
       i +=
           (pass == 0) ? kNum4x4BlocksWide[bp.size] : kNum4x4BlocksHigh[bp.size];
     }
   }
   if (is_compound) {
     // Merge compound mode extra search into mv stack.
-    MotionVector combined_mvs[2][2] = {};
+    CompoundMotionVector* const compound_ref_mv_stack =
+        prediction_parameters.compound_ref_mv_stack;
+    CompoundMotionVector combined_mvs[2] = {};
     for (int i = 0; i < 2; ++i) {
       int count = 0;
       assert(ref_id_count[i] <= 2);
       for (int j = 0; j < ref_id_count[i]; ++j, ++count) {
-        combined_mvs[count][i] = ref_id[i][j];
+        combined_mvs[count].mv[i] = ref_id[i][j];
       }
       for (int j = 0; j < ref_diff_count[i] && count < 2; ++j, ++count) {
-        combined_mvs[count][i] = ref_diff[i][j];
+        combined_mvs[count].mv[i] = ref_diff[i][j];
       }
       for (; count < 2; ++count) {
-        combined_mvs[count][i] = global_mv[i];
+        combined_mvs[count].mv[i] = prediction_parameters.global_mv[i];
       }
     }
     if (*num_mv_found == 1) {
-      if (combined_mvs[0][0] == ref_mv_stack[0].mv[0] &&
-          combined_mvs[0][1] == ref_mv_stack[0].mv[1]) {
-        ref_mv_stack[1] = {{combined_mvs[1][0], combined_mvs[1][1]}, 2};
+      if (combined_mvs[0] == compound_ref_mv_stack[0]) {
+        compound_ref_mv_stack[1] = combined_mvs[1];
       } else {
-        ref_mv_stack[1] = {{combined_mvs[0][0], combined_mvs[0][1]}, 2};
+        compound_ref_mv_stack[1] = combined_mvs[0];
       }
-      ++*num_mv_found;
+      prediction_parameters.SetWeightIndexStackEntry(1, 0);
     } else {
       assert(*num_mv_found == 0);
-      *num_mv_found = 2;
       for (int i = 0; i < 2; ++i) {
-        ref_mv_stack[i] = {{combined_mvs[i][0], combined_mvs[i][1]}, 2};
+        compound_ref_mv_stack[i] = combined_mvs[i];
+        prediction_parameters.SetWeightIndexStackEntry(i, 0);
       }
     }
+    *num_mv_found = 2;
   } else {
     // single prediction mode
+    MotionVector* const ref_mv_stack = prediction_parameters.ref_mv_stack;
     for (int i = *num_mv_found; i < 2; ++i) {
-      ref_mv_stack[i].mv[0] = global_mv[0];
+      ref_mv_stack[i] = prediction_parameters.global_mv[0];
+      prediction_parameters.SetWeightIndexStackEntry(i, 0);
     }
   }
 }
 
-// 7.10.2.14 (part 1).
-void ClampMotionVectors(
-    const Tile::Block& block, bool is_compound, int num_mv_found,
-    CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize]) {
-  const int block_height4x4 = kNum4x4BlocksHigh[block.size];
-  const int block_width4x4 = kNum4x4BlocksWide[block.size];
-  const int row_border = kMvBorder + MultiplyBy32(block_height4x4);
-  const int column_border = kMvBorder + MultiplyBy32(block_width4x4);
-  for (int i = 0; i < num_mv_found; ++i) {
-    for (int mv_index = 0; mv_index < 1 + static_cast<int>(is_compound);
-         ++mv_index) {
-      // Clamp row (5.11.53).
-      const int macroblocks_to_top_edge = -MultiplyBy32(block.row4x4);
-      const int macroblocks_to_bottom_edge = MultiplyBy32(
-          block.tile.frame_header().rows4x4 - block_height4x4 - block.row4x4);
-      ref_mv_stack[i].mv[mv_index].mv[MotionVector::kRow] =
-          Clip3(ref_mv_stack[i].mv[mv_index].mv[MotionVector::kRow],
-                macroblocks_to_top_edge - row_border,
-                macroblocks_to_bottom_edge + row_border);
-      // Clamp column (5.11.54).
-      const int macroblocks_to_left_edge = -MultiplyBy32(block.column4x4);
-      const int macroblocks_to_right_edge =
-          MultiplyBy32(block.tile.frame_header().columns4x4 - block_width4x4 -
-                       block.column4x4);
-      ref_mv_stack[i].mv[mv_index].mv[MotionVector::kColumn] =
-          Clip3(ref_mv_stack[i].mv[mv_index].mv[MotionVector::kColumn],
-                macroblocks_to_left_edge - column_border,
-                macroblocks_to_right_edge + column_border);
-    }
+void DescendingOrderTwo(int* const a, int* const b) {
+  if (*a < *b) {
+    std::swap(*a, *b);
   }
 }
 
+// Comparator used for sorting candidate motion vectors in descending order of
+// their weights (as specified in 7.10.2.11).
+bool CompareCandidateMotionVectors(const int16_t& lhs, const int16_t& rhs) {
+  return lhs > rhs;
+}
+
+void SortWeightIndexStack(const int size, const int sort_to_n,
+                          int16_t* const weight_index_stack) {
+  if (size <= 1) return;
+  if (size <= 3) {
+    // Specialize small sort sizes to speed up.
+    int weight_index_0 = weight_index_stack[0];
+    int weight_index_1 = weight_index_stack[1];
+    DescendingOrderTwo(&weight_index_0, &weight_index_1);
+    if (size == 3) {
+      int weight_index_2 = weight_index_stack[2];
+      DescendingOrderTwo(&weight_index_1, &weight_index_2);
+      DescendingOrderTwo(&weight_index_0, &weight_index_1);
+      weight_index_stack[2] = weight_index_2;
+    }
+    weight_index_stack[0] = weight_index_0;
+    weight_index_stack[1] = weight_index_1;
+    return;
+  }
+  if (sort_to_n == 1) {
+    // std::max_element() is not efficient. Find the max element in a loop.
+    int16_t max_element = weight_index_stack[0];
+    int i = 1;
+    do {
+      max_element = std::max(max_element, weight_index_stack[i]);
+    } while (++i < size);
+    weight_index_stack[0] = max_element;
+    return;
+  }
+  std::partial_sort(&weight_index_stack[0], &weight_index_stack[sort_to_n],
+                    &weight_index_stack[size], CompareCandidateMotionVectors);
+}
+
 // 7.10.2.14 (part 2).
 void ComputeContexts(bool found_new_mv, int nearest_matches, int total_matches,
                      int* new_mv_context, int* reference_mv_context) {
@@ -637,30 +716,30 @@
   if (*num_samples_scanned >= kMaxLeastSquaresSamples) return;
   const int mv_row = block.row4x4 + delta_row;
   const int mv_column = block.column4x4 + delta_column;
-  if (!block.tile.IsInside(mv_row, mv_column) ||
-      !block.tile.HasParameters(mv_row, mv_column)) {
+  const Tile& tile = block.tile;
+  if (!tile.IsInside(mv_row, mv_column) ||
+      !tile.HasParameters(mv_row, mv_column)) {
     return;
   }
-  const BlockParameters& bp = block.tile.Parameters(mv_row, mv_column);
-  if (bp.reference_frame[0] != block.parameters().reference_frame[0] ||
-      bp.reference_frame[1] != kReferenceFrameNone) {
+  const BlockParameters& bp = *block.bp;
+  const BlockParameters& mv_bp = tile.Parameters(mv_row, mv_column);
+  if (mv_bp.reference_frame[0] != bp.reference_frame[0] ||
+      mv_bp.reference_frame[1] != kReferenceFrameNone) {
     return;
   }
   ++*num_samples_scanned;
-  const int candidate_height4x4 = kNum4x4BlocksHigh[bp.size];
+  const int candidate_height4x4 = kNum4x4BlocksHigh[mv_bp.size];
   const int candidate_row = mv_row & ~(candidate_height4x4 - 1);
-  const int candidate_width4x4 = kNum4x4BlocksWide[bp.size];
+  const int candidate_width4x4 = kNum4x4BlocksWide[mv_bp.size];
   const int candidate_column = mv_column & ~(candidate_width4x4 - 1);
   const BlockParameters& candidate_bp =
-      block.tile.Parameters(candidate_row, candidate_column);
-  const int threshold = Clip3(
-      std::max(kBlockWidthPixels[block.size], kBlockHeightPixels[block.size]),
-      16, 112);
+      tile.Parameters(candidate_row, candidate_column);
   const int mv_diff_row =
-      std::abs(candidate_bp.mv[0].mv[0] - block.parameters().mv[0].mv[0]);
+      std::abs(candidate_bp.mv.mv[0].mv[0] - bp.mv.mv[0].mv[0]);
   const int mv_diff_column =
-      std::abs(candidate_bp.mv[0].mv[1] - block.parameters().mv[0].mv[1]);
-  const bool is_valid = mv_diff_row + mv_diff_column <= threshold;
+      std::abs(candidate_bp.mv.mv[0].mv[1] - bp.mv.mv[0].mv[1]);
+  const bool is_valid =
+      mv_diff_row + mv_diff_column <= kWarpValidThreshold[block.size];
   if (!is_valid && *num_samples_scanned > 1) {
     return;
   }
@@ -671,243 +750,164 @@
   candidates[*num_warp_samples][0] = MultiplyBy8(mid_y);
   candidates[*num_warp_samples][1] = MultiplyBy8(mid_x);
   candidates[*num_warp_samples][2] =
-      MultiplyBy8(mid_y) + candidate_bp.mv[0].mv[0];
+      MultiplyBy8(mid_y) + candidate_bp.mv.mv[0].mv[0];
   candidates[*num_warp_samples][3] =
-      MultiplyBy8(mid_x) + candidate_bp.mv[0].mv[1];
+      MultiplyBy8(mid_x) + candidate_bp.mv.mv[0].mv[1];
   if (is_valid) ++*num_warp_samples;
 }
 
-// Comparator used for sorting candidate motion vectors in descending order of
-// their weights (as specified in 7.10.2.11).
-bool CompareCandidateMotionVectors(const CandidateMotionVector& lhs,
-                                   const CandidateMotionVector& rhs) {
-  return lhs.weight > rhs.weight;
-}
-
-// Part of 7.9.4.
-bool Project(int value, int delta, int dst_sign, int max_value, int max_offset,
-             int* const projected_value) {
-  const int base_value = value & ~7;
-  const int offset = (delta >= 0) ? DivideBy64(delta) : -DivideBy64(-delta);
-  value += ApplySign(offset, dst_sign);
-  if (value < 0 || value >= max_value || value < base_value - max_offset ||
-      value >= base_value + 8 + max_offset) {
-    return false;
-  }
-  *projected_value = value;
-  return true;
-}
-
-// 7.9.4.
-bool GetBlockPosition(int x8, int y8, int dst_sign, int rows4x4, int columns4x4,
-                      const MotionVector& projection_mv, int* const position_y8,
-                      int* const position_x8) {
-  return Project(y8, projection_mv.mv[0], dst_sign, DivideBy2(rows4x4),
-                 kProjectionMvMaxVerticalOffset, position_y8) &&
-         Project(x8, projection_mv.mv[1], dst_sign, DivideBy2(columns4x4),
-                 kProjectionMvMaxHorizontalOffset, position_x8);
-}
-
 // 7.9.2.
+// In the spec, |dst_sign| is either 1 or -1. Here we set |dst_sign| to either 0
+// or -1 so that it can be XORed and subtracted directly in ApplySign() and
+// corresponding SIMD implementations.
 bool MotionFieldProjection(
-    ReferenceFrameType source, int dst_sign,
-    const ObuSequenceHeader& sequence_header,
-    const ObuFrameHeader& frame_header, const RefCountedBuffer& current_frame,
+    const ObuFrameHeader& frame_header,
     const std::array<RefCountedBufferPtr, kNumReferenceFrameTypes>&
         reference_frames,
-    Array2D<TemporalMotionVector>* const motion_field_mv, int y8_start,
-    int y8_end, int x8_start, int x8_end) {
+    ReferenceFrameType source, int reference_to_current_with_sign, int dst_sign,
+    int y8_start, int y8_end, int x8_start, int x8_end,
+    TemporalMotionField* const motion_field) {
   const int source_index =
       frame_header.reference_frame_index[source - kReferenceFrameLast];
   auto* const source_frame = reference_frames[source_index].get();
   assert(source_frame != nullptr);
+  assert(dst_sign == 0 || dst_sign == -1);
   if (source_frame->rows4x4() != frame_header.rows4x4 ||
       source_frame->columns4x4() != frame_header.columns4x4 ||
       IsIntraFrame(source_frame->frame_type())) {
     return false;
   }
-  const int reference_to_current_with_sign =
-      GetRelativeDistance(
-          current_frame.order_hint(source), frame_header.order_hint,
-          sequence_header.enable_order_hint, sequence_header.order_hint_bits) *
-      dst_sign;
-  if (std::abs(reference_to_current_with_sign) > kMaxFrameDistance) return true;
-  // Index 0 of these two arrays are never used.
-  int reference_offsets[kNumReferenceFrameTypes];
-  bool skip_reference[kNumReferenceFrameTypes];
-  for (int source_reference_type = kReferenceFrameLast;
-       source_reference_type <= kNumInterReferenceFrameTypes;
-       ++source_reference_type) {
-    const int reference_offset = GetRelativeDistance(
-        current_frame.order_hint(source),
-        source_frame->order_hint(
-            static_cast<ReferenceFrameType>(source_reference_type)),
-        sequence_header.enable_order_hint, sequence_header.order_hint_bits);
-    skip_reference[source_reference_type] =
-        std::abs(reference_offset) > kMaxFrameDistance || reference_offset <= 0;
-    reference_offsets[source_reference_type] = reference_offset;
-  }
-  // The column range has to be offset by kProjectionMvMaxHorizontalOffset since
-  // coordinates in that range could end up being position_x8 because of
-  // projection.
-  const int adjusted_x8_start =
-      std::max(x8_start - kProjectionMvMaxHorizontalOffset, 0);
-  const int adjusted_x8_end =
-      std::min(x8_end + kProjectionMvMaxHorizontalOffset,
-               DivideBy2(frame_header.columns4x4));
-  for (int y8 = y8_start; y8 < y8_end; ++y8) {
-    for (int x8 = adjusted_x8_start; x8 < adjusted_x8_end; ++x8) {
-      const ReferenceFrameType source_reference =
-          *source_frame->motion_field_reference_frame(y8, x8);
-      if (source_reference <= kReferenceFrameIntra ||
-          skip_reference[source_reference]) {
-        continue;
-      }
-      const int reference_offset = reference_offsets[source_reference];
-      const MotionVector& mv = *source_frame->motion_field_mv(y8, x8);
-      MotionVector projection_mv;
-      GetMvProjectionNoClamp(mv, reference_to_current_with_sign,
-                             reference_offset, &projection_mv);
-      int position_y8;
-      int position_x8;
-      if (!GetBlockPosition(x8, y8, dst_sign, frame_header.rows4x4,
-                            frame_header.columns4x4, projection_mv,
-                            &position_y8, &position_x8) ||
-          position_x8 < x8_start || position_x8 >= x8_end) {
-        // Do not update the motion vector if the block position is not valid or
-        // if position_x8 is outside the current range of x8_start and x8_end.
-        // Note that position_y8 will always be within the range of y8_start and
-        // y8_end.
-        continue;
-      }
-      TemporalMotionVector& temporal_mv =
-          (*motion_field_mv)[position_y8][position_x8];
-      temporal_mv.mv = mv;
-      temporal_mv.reference_offset = reference_offset;
-    }
-  }
+  assert(reference_to_current_with_sign >= -kMaxFrameDistance);
+  if (reference_to_current_with_sign > kMaxFrameDistance) return true;
+  const ReferenceInfo& reference_info = *source_frame->reference_info();
+  const dsp::Dsp& dsp = *dsp::GetDspTable(8);
+  dsp.motion_field_projection_kernel(
+      reference_info, reference_to_current_with_sign, dst_sign, y8_start,
+      y8_end, x8_start, x8_end, motion_field);
   return true;
 }
 
 }  // namespace
 
-void FindMvStack(
-    const Tile::Block& block, bool is_compound,
-    const std::array<bool, kNumReferenceFrameTypes>& reference_frame_sign_bias,
-    const Array2D<TemporalMotionVector>& motion_field_mv,
-    CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize],
-    int* const num_mv_found, MvContexts* const contexts,
-    MotionVector global_mv[2]) {
-  const int block_width4x4 = kNum4x4BlocksWide[block.size];
-  const int block_height4x4 = kNum4x4BlocksHigh[block.size];
-  SetupGlobalMv(block, 0, &global_mv[0]);
-  if (is_compound) SetupGlobalMv(block, 1, &global_mv[1]);
+void FindMvStack(const Tile::Block& block, bool is_compound,
+                 MvContexts* const contexts) {
+  PredictionParameters& prediction_parameters =
+      *block.bp->prediction_parameters;
+  SetupGlobalMv(block, 0, &prediction_parameters.global_mv[0]);
+  if (is_compound) SetupGlobalMv(block, 1, &prediction_parameters.global_mv[1]);
   bool found_new_mv = false;
   bool found_row_match = false;
-  *num_mv_found = 0;
-  ScanRow(block, -1, is_compound, global_mv, &found_new_mv, &found_row_match,
-          num_mv_found, ref_mv_stack);
+  int num_mv_found = 0;
+  ScanRow(block, block.column4x4, -1, is_compound, &found_new_mv,
+          &found_row_match, &num_mv_found);
   bool found_column_match = false;
-  ScanColumn(block, -1, is_compound, global_mv, &found_new_mv,
-             &found_column_match, num_mv_found, ref_mv_stack);
-  if (std::max(block_width4x4, block_height4x4) <= 16) {
-    ScanPoint(block, -1, block_width4x4, is_compound, global_mv, &found_new_mv,
-              &found_row_match, num_mv_found, ref_mv_stack);
+  ScanColumn(block, block.row4x4, -1, is_compound, &found_new_mv,
+             &found_column_match, &num_mv_found);
+  if (std::max(block.width4x4, block.height4x4) <= 16) {
+    ScanPoint(block, -1, block.width4x4, is_compound, &found_new_mv,
+              &found_row_match, &num_mv_found);
   }
   const int nearest_matches =
       static_cast<int>(found_row_match) + static_cast<int>(found_column_match);
-  const int nearest_mv_count = *num_mv_found;
-  for (int i = 0; i < nearest_mv_count; ++i) {
-    ref_mv_stack[i].weight += kExtraWeightForNearestMvs;
-  }
-  if (contexts != nullptr) contexts->zero_mv = 0;
+  prediction_parameters.nearest_mv_count = num_mv_found;
   if (block.tile.frame_header().use_ref_frame_mvs) {
-    TemporalScan(block, is_compound, global_mv, motion_field_mv,
-                 (contexts != nullptr) ? &contexts->zero_mv : nullptr,
-                 num_mv_found, ref_mv_stack);
+    // Initialize to invalid value, and it will be set when temporal mv is zero.
+    contexts->zero_mv = -1;
+    TemporalScan(block, is_compound, &contexts->zero_mv, &num_mv_found);
+  } else {
+    contexts->zero_mv = 0;
   }
   bool dummy_bool = false;
-  ScanPoint(block, -1, -1, is_compound, global_mv, &dummy_bool,
-            &found_row_match, num_mv_found, ref_mv_stack);
-  const int deltas[2] = {-3, -5};
+  ScanPoint(block, -1, -1, is_compound, &dummy_bool, &found_row_match,
+            &num_mv_found);
+  static constexpr int deltas[2] = {-3, -5};
   for (int i = 0; i < 2; ++i) {
-    if (i == 0 || block_height4x4 > 1) {
-      ScanRow(block, deltas[i], is_compound, global_mv, &dummy_bool,
-              &found_row_match, num_mv_found, ref_mv_stack);
+    if (i == 0 || block.height4x4 > 1) {
+      ScanRow(block, block.column4x4 | 1, deltas[i] + (block.row4x4 & 1),
+              is_compound, &dummy_bool, &found_row_match, &num_mv_found);
     }
-    if (i == 0 || block_width4x4 > 1) {
-      ScanColumn(block, deltas[i], is_compound, global_mv, &dummy_bool,
-                 &found_column_match, num_mv_found, ref_mv_stack);
+    if (i == 0 || block.width4x4 > 1) {
+      ScanColumn(block, block.row4x4 | 1, deltas[i] + (block.column4x4 & 1),
+                 is_compound, &dummy_bool, &found_column_match, &num_mv_found);
     }
   }
-  std::stable_sort(&ref_mv_stack[0], &ref_mv_stack[nearest_mv_count],
-                   CompareCandidateMotionVectors);
-  std::stable_sort(&ref_mv_stack[nearest_mv_count],
-                   &ref_mv_stack[*num_mv_found], CompareCandidateMotionVectors);
-  if (*num_mv_found < 2) {
-    ExtraSearch(block, is_compound, global_mv, reference_frame_sign_bias,
-                num_mv_found, ref_mv_stack);
+  if (num_mv_found < 2) {
+    ExtraSearch(block, is_compound, &num_mv_found);
+  } else {
+    // The sort of |weight_index_stack| could be moved to Tile::AssignIntraMv()
+    // and Tile::AssignInterMv(), and only do a partial sort to the max index we
+    // need. However, the speed gain is trivial.
+    // For intra case, only the first 1 or 2 mvs in the stack will be used.
+    // For inter case, |prediction_parameters.ref_mv_index| is at most 3.
+    // We only need to do the partial sort up to the first 4 mvs.
+    SortWeightIndexStack(prediction_parameters.nearest_mv_count, 4,
+                         prediction_parameters.weight_index_stack);
+    // When there are 4 or more nearest mvs, the other mvs will not be used.
+    if (prediction_parameters.nearest_mv_count < 4) {
+      SortWeightIndexStack(
+          num_mv_found - prediction_parameters.nearest_mv_count,
+          4 - prediction_parameters.nearest_mv_count,
+          prediction_parameters.weight_index_stack +
+              prediction_parameters.nearest_mv_count);
+    }
   }
+  prediction_parameters.ref_mv_count = num_mv_found;
   const int total_matches =
       static_cast<int>(found_row_match) + static_cast<int>(found_column_match);
-  if (contexts != nullptr) {
-    ComputeContexts(found_new_mv, nearest_matches, total_matches,
-                    &contexts->new_mv, &contexts->reference_mv);
-  }
-  ClampMotionVectors(block, is_compound, *num_mv_found, ref_mv_stack);
+  ComputeContexts(found_new_mv, nearest_matches, total_matches,
+                  &contexts->new_mv, &contexts->reference_mv);
+  // The mv stack clamping process is in Tile::AssignIntraMv() and
+  // Tile::AssignInterMv(), and only up to two mvs are clamped.
 }
 
 void FindWarpSamples(const Tile::Block& block, int* const num_warp_samples,
                      int* const num_samples_scanned,
                      int candidates[kMaxLeastSquaresSamples][4]) {
+  const Tile& tile = block.tile;
   bool top_left = true;
   bool top_right = true;
-  const int block_width4x4 = kNum4x4BlocksWide[block.size];
-  const int block_height4x4 = kNum4x4BlocksHigh[block.size];
   int step = 1;
-  if (block.top_available) {
+  if (block.top_available[kPlaneY]) {
     BlockSize source_size =
-        block.tile.Parameters(block.row4x4 - 1, block.column4x4).size;
+        tile.Parameters(block.row4x4 - 1, block.column4x4).size;
     const int source_width4x4 = kNum4x4BlocksWide[source_size];
-    if (block_width4x4 <= source_width4x4) {
-      // The & here is equivalent to % since source_width4x4 is a power of
-      // two.
+    if (block.width4x4 <= source_width4x4) {
+      // The & here is equivalent to % since source_width4x4 is a power of two.
       const int column_offset = -(block.column4x4 & (source_width4x4 - 1));
       if (column_offset < 0) top_left = false;
-      if (column_offset + source_width4x4 > block_width4x4) top_right = false;
+      if (column_offset + source_width4x4 > block.width4x4) top_right = false;
       AddSample(block, -1, 0, num_warp_samples, num_samples_scanned,
                 candidates);
     } else {
       for (int i = 0;
-           i < std::min(block_width4x4,
-                        block.tile.frame_header().columns4x4 - block.column4x4);
+           i < std::min(static_cast<int>(block.width4x4),
+                        tile.frame_header().columns4x4 - block.column4x4);
            i += step) {
         source_size =
-            block.tile.Parameters(block.row4x4 - 1, block.column4x4 + i).size;
-        step = std::min(block_width4x4,
+            tile.Parameters(block.row4x4 - 1, block.column4x4 + i).size;
+        step = std::min(static_cast<int>(block.width4x4),
                         static_cast<int>(kNum4x4BlocksWide[source_size]));
         AddSample(block, -1, i, num_warp_samples, num_samples_scanned,
                   candidates);
       }
     }
   }
-  if (block.left_available) {
+  if (block.left_available[kPlaneY]) {
     BlockSize source_size =
-        block.tile.Parameters(block.row4x4, block.column4x4 - 1).size;
+        tile.Parameters(block.row4x4, block.column4x4 - 1).size;
     const int source_height4x4 = kNum4x4BlocksHigh[source_size];
-    if (block_height4x4 <= source_height4x4) {
+    if (block.height4x4 <= source_height4x4) {
       const int row_offset = -(block.row4x4 & (source_height4x4 - 1));
       if (row_offset < 0) top_left = false;
       AddSample(block, 0, -1, num_warp_samples, num_samples_scanned,
                 candidates);
     } else {
-      for (int i = 0;
-           i < std::min(block_height4x4,
-                        block.tile.frame_header().rows4x4 - block.row4x4);
+      for (int i = 0; i < std::min(static_cast<int>(block.height4x4),
+                                   tile.frame_header().rows4x4 - block.row4x4);
            i += step) {
         source_size =
-            block.tile.Parameters(block.row4x4 + i, block.column4x4 - 1).size;
-        step = std::min(block_height4x4,
+            tile.Parameters(block.row4x4 + i, block.column4x4 - 1).size;
+        step = std::min(static_cast<int>(block.height4x4),
                         static_cast<int>(kNum4x4BlocksHigh[source_size]));
         AddSample(block, i, -1, num_warp_samples, num_samples_scanned,
                   candidates);
@@ -917,78 +917,82 @@
   if (top_left) {
     AddSample(block, -1, -1, num_warp_samples, num_samples_scanned, candidates);
   }
-  if (top_right && std::max(block_width4x4, block_height4x4) <= 16) {
-    AddSample(block, -1, block_width4x4, num_warp_samples, num_samples_scanned,
+  if (top_right && block.size <= kBlock64x64) {
+    AddSample(block, -1, block.width4x4, num_warp_samples, num_samples_scanned,
               candidates);
   }
   if (*num_warp_samples == 0 && *num_samples_scanned > 0) *num_warp_samples = 1;
 }
 
 void SetupMotionField(
-    const ObuSequenceHeader& sequence_header,
     const ObuFrameHeader& frame_header, const RefCountedBuffer& current_frame,
     const std::array<RefCountedBufferPtr, kNumReferenceFrameTypes>&
         reference_frames,
-    Array2D<TemporalMotionVector>* const motion_field_mv, int row4x4_start,
-    int row4x4_end, int column4x4_start, int column4x4_end) {
+    int row4x4_start, int row4x4_end, int column4x4_start, int column4x4_end,
+    TemporalMotionField* const motion_field) {
   assert(frame_header.use_ref_frame_mvs);
-  assert(sequence_header.enable_order_hint);
   const int y8_start = DivideBy2(row4x4_start);
   const int y8_end = DivideBy2(std::min(row4x4_end, frame_header.rows4x4));
   const int x8_start = DivideBy2(column4x4_start);
   const int x8_end =
-      std::min(DivideBy2(column4x4_end), DivideBy2(frame_header.columns4x4));
-  for (int y8 = y8_start; y8 < y8_end; ++y8) {
-    for (int x8 = x8_start; x8 < x8_end; ++x8) {
-      (*motion_field_mv)[y8][x8].mv.mv[0] = kInvalidMvValue;
+      DivideBy2(std::min(column4x4_end, frame_header.columns4x4));
+  const int last_index = frame_header.reference_frame_index[0];
+  const ReferenceInfo& reference_info = *current_frame.reference_info();
+  if (!IsIntraFrame(reference_frames[last_index]->frame_type())) {
+    const int last_alternate_order_hint =
+        reference_frames[last_index]
+            ->reference_info()
+            ->order_hint[kReferenceFrameAlternate];
+    const int current_gold_order_hint =
+        reference_info.order_hint[kReferenceFrameGolden];
+    if (last_alternate_order_hint != current_gold_order_hint) {
+      const int reference_offset_last =
+          -reference_info.relative_distance_from[kReferenceFrameLast];
+      if (std::abs(reference_offset_last) <= kMaxFrameDistance) {
+        MotionFieldProjection(frame_header, reference_frames,
+                              kReferenceFrameLast, reference_offset_last, -1,
+                              y8_start, y8_end, x8_start, x8_end, motion_field);
+      }
     }
   }
-  const int current_gold_order_hint =
-      current_frame.order_hint(kReferenceFrameGolden);
-  const int last_index = frame_header.reference_frame_index[0];
-  const int last_alternate_order_hint =
-      reference_frames[last_index]->order_hint(kReferenceFrameAlternate);
-  if (last_alternate_order_hint != current_gold_order_hint) {
-    MotionFieldProjection(kReferenceFrameLast, -1, sequence_header,
-                          frame_header, current_frame, reference_frames,
-                          motion_field_mv, y8_start, y8_end, x8_start, x8_end);
-  }
   int ref_stamp = 1;
-  if (GetRelativeDistance(current_frame.order_hint(kReferenceFrameBackward),
-                          frame_header.order_hint,
-                          sequence_header.enable_order_hint,
-                          sequence_header.order_hint_bits) > 0 &&
-      MotionFieldProjection(kReferenceFrameBackward, 1, sequence_header,
-                            frame_header, current_frame, reference_frames,
-                            motion_field_mv, y8_start, y8_end, x8_start,
-                            x8_end)) {
+  const int reference_offset_backward =
+      reference_info.relative_distance_from[kReferenceFrameBackward];
+  if (reference_offset_backward > 0 &&
+      MotionFieldProjection(frame_header, reference_frames,
+                            kReferenceFrameBackward, reference_offset_backward,
+                            0, y8_start, y8_end, x8_start, x8_end,
+                            motion_field)) {
     --ref_stamp;
   }
-  if (GetRelativeDistance(current_frame.order_hint(kReferenceFrameAlternate2),
-                          frame_header.order_hint,
-                          sequence_header.enable_order_hint,
-                          sequence_header.order_hint_bits) > 0 &&
-      MotionFieldProjection(kReferenceFrameAlternate2, 1, sequence_header,
-                            frame_header, current_frame, reference_frames,
-                            motion_field_mv, y8_start, y8_end, x8_start,
-                            x8_end)) {
-    --ref_stamp;
-  }
-  if (ref_stamp >= 0 &&
-      GetRelativeDistance(current_frame.order_hint(kReferenceFrameAlternate),
-                          frame_header.order_hint,
-                          sequence_header.enable_order_hint,
-                          sequence_header.order_hint_bits) > 0 &&
-      MotionFieldProjection(kReferenceFrameAlternate, 1, sequence_header,
-                            frame_header, current_frame, reference_frames,
-                            motion_field_mv, y8_start, y8_end, x8_start,
-                            x8_end)) {
+  const int reference_offset_alternate2 =
+      reference_info.relative_distance_from[kReferenceFrameAlternate2];
+  if (reference_offset_alternate2 > 0 &&
+      MotionFieldProjection(frame_header, reference_frames,
+                            kReferenceFrameAlternate2,
+                            reference_offset_alternate2, 0, y8_start, y8_end,
+                            x8_start, x8_end, motion_field)) {
     --ref_stamp;
   }
   if (ref_stamp >= 0) {
-    MotionFieldProjection(kReferenceFrameLast2, -1, sequence_header,
-                          frame_header, current_frame, reference_frames,
-                          motion_field_mv, y8_start, y8_end, x8_start, x8_end);
+    const int reference_offset_alternate =
+        reference_info.relative_distance_from[kReferenceFrameAlternate];
+    if (reference_offset_alternate > 0 &&
+        MotionFieldProjection(frame_header, reference_frames,
+                              kReferenceFrameAlternate,
+                              reference_offset_alternate, 0, y8_start, y8_end,
+                              x8_start, x8_end, motion_field)) {
+      --ref_stamp;
+    }
+  }
+  if (ref_stamp >= 0) {
+    const int reference_offset_last2 =
+        -reference_info.relative_distance_from[kReferenceFrameLast2];
+    if (std::abs(reference_offset_last2) <= kMaxFrameDistance) {
+      MotionFieldProjection(frame_header, reference_frames,
+                            kReferenceFrameLast2, reference_offset_last2, -1,
+                            y8_start, y8_end, x8_start, x8_end, motion_field);
+    }
   }
 }
 
diff --git a/libgav1/src/motion_vector.h b/libgav1/src/motion_vector.h
index 1f01f6b..d739e80 100644
--- a/libgav1/src/motion_vector.h
+++ b/libgav1/src/motion_vector.h
@@ -30,24 +30,16 @@
 
 namespace libgav1 {
 
-inline bool IsGlobalMvBlock(PredictionMode mode,
-                            GlobalMotionTransformationType type,
-                            BlockSize size) {
-  return ((mode == kPredictionModeGlobalMv ||
-           mode == kPredictionModeGlobalGlobalMv) &&
-          type > kGlobalMotionTransformationTypeTranslation &&
-          std::min(kBlockWidthPixels[size], kBlockHeightPixels[size]) >= 8);
+constexpr bool IsGlobalMvBlock(bool is_global_mv_block,
+                               GlobalMotionTransformationType type) {
+  return is_global_mv_block &&
+         type > kGlobalMotionTransformationTypeTranslation;
 }
 
 // The |contexts| output parameter may be null. If the caller does not need
 // the |contexts| output, pass nullptr as the argument.
-void FindMvStack(
-    const Tile::Block& block, bool is_compound,
-    const std::array<bool, kNumReferenceFrameTypes>& reference_frame_sign_bias,
-    const Array2D<TemporalMotionVector>& motion_field_mv,
-    CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize], int* num_mv_found,
-    MvContexts* contexts,
-    MotionVector global_mv[2]);  // 7.10.2
+void FindMvStack(const Tile::Block& block, bool is_compound,
+                 MvContexts* contexts);  // 7.10.2
 
 void FindWarpSamples(const Tile::Block& block, int* num_warp_samples,
                      int* num_samples_scanned,
@@ -56,12 +48,11 @@
 // Section 7.9.1 in the spec. But this is done per tile instead of for the whole
 // frame.
 void SetupMotionField(
-    const ObuSequenceHeader& sequence_header,
     const ObuFrameHeader& frame_header, const RefCountedBuffer& current_frame,
     const std::array<RefCountedBufferPtr, kNumReferenceFrameTypes>&
         reference_frames,
-    Array2D<TemporalMotionVector>* motion_field_mv, int row4x4_start,
-    int row4x4_end, int column4x4_start, int column4x4_end);
+    int row4x4_start, int row4x4_end, int column4x4_start, int column4x4_end,
+    TemporalMotionField* motion_field);
 
 }  // namespace libgav1
 
diff --git a/libgav1/src/obu_parser.cc b/libgav1/src/obu_parser.cc
index e2138d3..41df909 100644
--- a/libgav1/src/obu_parser.cc
+++ b/libgav1/src/obu_parser.cc
@@ -31,10 +31,6 @@
 namespace libgav1 {
 namespace {
 
-// This is set to 0 since we only support one layer. This should be part of
-// DecoderSettings if we support more than one layer.
-constexpr int kOperatingPoint = 0;
-
 // 5.9.16.
 // Find the smallest value of k such that block_size << k is greater than or
 // equal to target.
@@ -73,6 +69,36 @@
   return ((operating_point_idc >> (spatial_id + 8)) & 1) != 0;
 }
 
+// Returns the index of the last nonzero byte in the |data| buffer of |size|
+// bytes. If there is no nonzero byte in the |data| buffer, returns -1.
+int GetLastNonzeroByteIndex(const uint8_t* data, size_t size) {
+  // Scan backward for a nonzero byte.
+  if (size > INT_MAX) return -1;
+  int i = static_cast<int>(size) - 1;
+  while (i >= 0 && data[i] == 0) {
+    --i;
+  }
+  return i;
+}
+
+// A cleanup helper class that releases the frame buffer reference held in
+// |frame| in the destructor.
+class RefCountedBufferPtrCleanup {
+ public:
+  explicit RefCountedBufferPtrCleanup(RefCountedBufferPtr* frame)
+      : frame_(*frame) {}
+
+  // Not copyable or movable.
+  RefCountedBufferPtrCleanup(const RefCountedBufferPtrCleanup&) = delete;
+  RefCountedBufferPtrCleanup& operator=(const RefCountedBufferPtrCleanup&) =
+      delete;
+
+  ~RefCountedBufferPtrCleanup() { frame_ = nullptr; }
+
+ private:
+  RefCountedBufferPtr& frame_;
+};
+
 }  // namespace
 
 bool ObuSequenceHeader::ParametersChanged(const ObuSequenceHeader& old) const {
@@ -115,7 +141,7 @@
   ColorConfig* const color_config = &sequence_header->color_config;
   OBU_READ_BIT_OR_FAIL;
   const auto high_bitdepth = static_cast<bool>(scratch);
-  if (sequence_header->profile >= kProfile2 && high_bitdepth) {
+  if (sequence_header->profile == kProfile2 && high_bitdepth) {
     OBU_READ_BIT_OR_FAIL;
     const auto is_twelve_bit = static_cast<bool>(scratch);
     color_config->bitdepth = is_twelve_bit ? 12 : 10;
@@ -132,7 +158,7 @@
   const auto color_description_present_flag = static_cast<bool>(scratch);
   if (color_description_present_flag) {
     OBU_READ_LITERAL_OR_FAIL(8);
-    color_config->color_primaries = static_cast<ColorPrimaries>(scratch);
+    color_config->color_primary = static_cast<ColorPrimary>(scratch);
     OBU_READ_LITERAL_OR_FAIL(8);
     color_config->transfer_characteristics =
         static_cast<TransferCharacteristics>(scratch);
@@ -140,13 +166,14 @@
     color_config->matrix_coefficients =
         static_cast<MatrixCoefficients>(scratch);
   } else {
-    color_config->color_primaries = kColorPrimaryUnspecified;
-    color_config->transfer_characteristics = kTransferCharacteristicUnspecified;
-    color_config->matrix_coefficients = kMatrixCoefficientUnspecified;
+    color_config->color_primary = kColorPrimaryUnspecified;
+    color_config->transfer_characteristics =
+        kTransferCharacteristicsUnspecified;
+    color_config->matrix_coefficients = kMatrixCoefficientsUnspecified;
   }
   if (color_config->is_monochrome) {
     OBU_READ_BIT_OR_FAIL;
-    color_config->color_range = scratch;
+    color_config->color_range = static_cast<ColorRange>(scratch);
     // Set subsampling_x and subsampling_y to 1 for monochrome. This makes it
     // easy to allow monochrome to be supported in profile 0. Profile 0
     // requires subsampling_x and subsampling_y to be 1.
@@ -154,15 +181,16 @@
     color_config->subsampling_y = 1;
     color_config->chroma_sample_position = kChromaSamplePositionUnknown;
   } else {
-    if (color_config->color_primaries == kColorPrimaryBt709 &&
-        color_config->transfer_characteristics == kTransferCharacteristicSrgb &&
-        color_config->matrix_coefficients == kMatrixCoefficientIdentity) {
-      color_config->color_range = 1;
+    if (color_config->color_primary == kColorPrimaryBt709 &&
+        color_config->transfer_characteristics ==
+            kTransferCharacteristicsSrgb &&
+        color_config->matrix_coefficients == kMatrixCoefficientsIdentity) {
+      color_config->color_range = kColorRangeFull;
       color_config->subsampling_x = 0;
       color_config->subsampling_y = 0;
     } else {
       OBU_READ_BIT_OR_FAIL;
-      color_config->color_range = scratch;
+      color_config->color_range = static_cast<ColorRange>(scratch);
       if (sequence_header->profile == kProfile0) {
         color_config->subsampling_x = 1;
         color_config->subsampling_y = 1;
@@ -194,7 +222,7 @@
     OBU_READ_BIT_OR_FAIL;
     color_config->separate_uv_delta_q = static_cast<bool>(scratch);
   }
-  if (color_config->matrix_coefficients == kMatrixCoefficientIdentity &&
+  if (color_config->matrix_coefficients == kMatrixCoefficientsIdentity &&
       (color_config->subsampling_x != 0 || color_config->subsampling_y != 0)) {
     LIBGAV1_DLOG(ERROR,
                  "matrix_coefficients is MC_IDENTITY, but subsampling_x (%d) "
@@ -303,6 +331,13 @@
     const auto initial_display_delay_present_flag = static_cast<bool>(scratch);
     OBU_READ_LITERAL_OR_FAIL(5);
     sequence_header.operating_points = static_cast<int>(1 + scratch);
+    if (operating_point_ >= sequence_header.operating_points) {
+      LIBGAV1_DLOG(
+          ERROR,
+          "Invalid operating point: %d (valid range is [0,%d] inclusive).",
+          operating_point_, sequence_header.operating_points - 1);
+      return false;
+    }
     for (int i = 0; i < sequence_header.operating_points; ++i) {
       OBU_READ_LITERAL_OR_FAIL(12);
       sequence_header.operating_point_idc[i] = static_cast<int>(scratch);
@@ -412,6 +447,8 @@
     if (sequence_header.enable_order_hint) {
       OBU_READ_LITERAL_OR_FAIL(3);
       sequence_header.order_hint_bits = 1 + scratch;
+      sequence_header.order_hint_shift_bits =
+          Mod32(32 - sequence_header.order_hint_bits);
     }
   }
   OBU_READ_BIT_OR_FAIL;
@@ -440,7 +477,7 @@
   // OperatingPointIdc is equal to 0, then obu_extension_flag is equal to 0 for
   // all OBUs that follow this sequence header until the next sequence header.
   extension_disallowed_ =
-      (sequence_header_.operating_point_idc[kOperatingPoint] == 0);
+      (sequence_header_.operating_point_idc[operating_point_] == 0);
   return true;
 }
 
@@ -513,12 +550,12 @@
   int64_t scratch;
   // SuperRes.
   frame_header_.upscaled_width = frame_header_.width;
-  bool use_superres = false;
+  frame_header_.use_superres = false;
   if (sequence_header_.enable_superres) {
     OBU_READ_BIT_OR_FAIL;
-    use_superres = static_cast<bool>(scratch);
+    frame_header_.use_superres = static_cast<bool>(scratch);
   }
-  if (use_superres) {
+  if (frame_header_.use_superres) {
     OBU_READ_LITERAL_OR_FAIL(3);
     // 9 is the smallest value for the denominator.
     frame_header_.superres_scale_denominator = scratch + 9;
@@ -699,16 +736,15 @@
   used_frame[last_frame_idx] = true;
   used_frame[gold_frame_idx] = true;
 
+  assert(sequence_header_.order_hint_bits >= 1);
   const int current_frame_hint = 1 << (sequence_header_.order_hint_bits - 1);
-
   // shifted_order_hints contains the expected output order shifted such that
   // the current frame has hint equal to current_frame_hint.
   std::array<int, kNumReferenceFrameTypes> shifted_order_hints;
   for (int i = 0; i < kNumReferenceFrameTypes; ++i) {
-    const int reference_hint = decoder_state_.reference_order_hint[i];
     const int relative_distance = GetRelativeDistance(
-        reference_hint, frame_header_.order_hint,
-        sequence_header_.enable_order_hint, sequence_header_.order_hint_bits);
+        decoder_state_.reference_order_hint[i], frame_header_.order_hint,
+        sequence_header_.order_hint_shift_bits);
     shifted_order_hints[i] = current_frame_hint + relative_distance;
   }
 
@@ -831,8 +867,8 @@
   loop_filter->delta_enabled = static_cast<bool>(scratch);
   if (loop_filter->delta_enabled) {
     OBU_READ_BIT_OR_FAIL;
-    const auto loop_filter_delta_update = static_cast<bool>(scratch);
-    if (loop_filter_delta_update) {
+    loop_filter->delta_update = static_cast<bool>(scratch);
+    if (loop_filter->delta_update) {
       for (auto& ref_delta : loop_filter->ref_deltas) {
         OBU_READ_BIT_OR_FAIL;
         const auto update_ref_delta = static_cast<bool>(scratch);
@@ -858,6 +894,8 @@
         }
       }
     }
+  } else {
+    loop_filter->delta_update = false;
   }
   return true;
 }
@@ -975,9 +1013,13 @@
               Clip3(scratch_int, -kSegmentationFeatureMaxValues[j],
                     kSegmentationFeatureMaxValues[j]);
         } else {
-          OBU_READ_LITERAL_OR_FAIL(kSegmentationFeatureBits[j]);
-          segmentation->feature_data[i][j] = Clip3(
-              static_cast<int>(scratch), 0, kSegmentationFeatureMaxValues[j]);
+          if (kSegmentationFeatureBits[j] > 0) {
+            OBU_READ_LITERAL_OR_FAIL(kSegmentationFeatureBits[j]);
+            segmentation->feature_data[i][j] = Clip3(
+                static_cast<int>(scratch), 0, kSegmentationFeatureMaxValues[j]);
+          } else {
+            segmentation->feature_data[i][j] = 0;
+          }
         }
         segmentation->last_active_segment_id = i;
         if (j >= kSegmentFeatureReferenceFrame) {
@@ -1043,29 +1085,32 @@
 }
 
 bool ObuParser::ParseCdefParameters() {
+  const int coeff_shift = sequence_header_.color_config.bitdepth - 8;
   if (frame_header_.coded_lossless || frame_header_.allow_intrabc ||
       !sequence_header_.enable_cdef) {
-    frame_header_.cdef.damping = 3;
+    frame_header_.cdef.damping = 3 + coeff_shift;
     return true;
   }
   Cdef* const cdef = &frame_header_.cdef;
   int64_t scratch;
   OBU_READ_LITERAL_OR_FAIL(2);
-  cdef->damping = scratch + 3;
+  cdef->damping = scratch + 3 + coeff_shift;
   OBU_READ_LITERAL_OR_FAIL(2);
   cdef->bits = scratch;
   for (int i = 0; i < (1 << cdef->bits); ++i) {
     OBU_READ_LITERAL_OR_FAIL(4);
-    cdef->y_primary_strength[i] = scratch;
+    cdef->y_primary_strength[i] = scratch << coeff_shift;
     OBU_READ_LITERAL_OR_FAIL(2);
     cdef->y_secondary_strength[i] = scratch;
     if (cdef->y_secondary_strength[i] == 3) ++cdef->y_secondary_strength[i];
+    cdef->y_secondary_strength[i] <<= coeff_shift;
     if (sequence_header_.color_config.is_monochrome) continue;
     OBU_READ_LITERAL_OR_FAIL(4);
-    cdef->uv_primary_strength[i] = scratch;
+    cdef->uv_primary_strength[i] = scratch << coeff_shift;
     OBU_READ_LITERAL_OR_FAIL(2);
     cdef->uv_secondary_strength[i] = scratch;
     if (cdef->uv_secondary_strength[i] == 3) ++cdef->uv_secondary_strength[i];
+    cdef->uv_secondary_strength[i] <<= coeff_shift;
   }
   return true;
 }
@@ -1152,25 +1197,29 @@
   int forward_hint = -1;
   int backward_hint = -1;
   for (int i = 0; i < kNumInterReferenceFrameTypes; ++i) {
-    const int reference_hint =
+    const unsigned int reference_hint =
         decoder_state_
             .reference_order_hint[frame_header_.reference_frame_index[i]];
-    const int relative_distance = GetRelativeDistance(
-        reference_hint, frame_header_.order_hint,
-        sequence_header_.enable_order_hint, sequence_header_.order_hint_bits);
+    // TODO(linfengz): |relative_distance| equals
+    // current_frame_->reference_info()->
+    //     relative_distance_from[i + kReferenceFrameLast];
+    // However, the unit test ObuParserTest.SkipModeParameters() would fail.
+    // Will figure out how to initialize |current_frame_.reference_info_| in the
+    // RefCountedBuffer later.
+    const int relative_distance =
+        GetRelativeDistance(reference_hint, frame_header_.order_hint,
+                            sequence_header_.order_hint_shift_bits);
     if (relative_distance < 0) {
       if (forward_index < 0 ||
           GetRelativeDistance(reference_hint, forward_hint,
-                              sequence_header_.enable_order_hint,
-                              sequence_header_.order_hint_bits) > 0) {
+                              sequence_header_.order_hint_shift_bits) > 0) {
         forward_index = i;
         forward_hint = reference_hint;
       }
     } else if (relative_distance > 0) {
       if (backward_index < 0 ||
           GetRelativeDistance(reference_hint, backward_hint,
-                              sequence_header_.enable_order_hint,
-                              sequence_header_.order_hint_bits) < 0) {
+                              sequence_header_.order_hint_shift_bits) < 0) {
         backward_index = i;
         backward_hint = reference_hint;
       }
@@ -1189,16 +1238,14 @@
   int second_forward_index = -1;
   int second_forward_hint = -1;
   for (int i = 0; i < kNumInterReferenceFrameTypes; ++i) {
-    const int reference_hint =
+    const unsigned int reference_hint =
         decoder_state_
             .reference_order_hint[frame_header_.reference_frame_index[i]];
     if (GetRelativeDistance(reference_hint, forward_hint,
-                            sequence_header_.enable_order_hint,
-                            sequence_header_.order_hint_bits) < 0) {
+                            sequence_header_.order_hint_shift_bits) < 0) {
       if (second_forward_index < 0 ||
           GetRelativeDistance(reference_hint, second_forward_hint,
-                              sequence_header_.enable_order_hint,
-                              sequence_header_.order_hint_bits) > 0) {
+                              sequence_header_.order_hint_shift_bits) > 0) {
         second_forward_index = i;
         second_forward_hint = reference_hint;
       }
@@ -1497,7 +1544,7 @@
     for (int i = 0; i < num_pos_y; ++i) {
       OBU_READ_LITERAL_OR_FAIL(8);
       film_grain_params.auto_regression_coeff_y[i] =
-          static_cast<int>(scratch) - 128;
+          static_cast<int8_t>(scratch - 128);
     }
   }
   if (film_grain_params.chroma_scaling_from_luma ||
@@ -1505,7 +1552,7 @@
     for (int i = 0; i < num_pos_uv; ++i) {
       OBU_READ_LITERAL_OR_FAIL(8);
       film_grain_params.auto_regression_coeff_u[i] =
-          static_cast<int>(scratch) - 128;
+          static_cast<int8_t>(scratch - 128);
     }
   }
   if (film_grain_params.chroma_scaling_from_luma ||
@@ -1513,28 +1560,28 @@
     for (int i = 0; i < num_pos_uv; ++i) {
       OBU_READ_LITERAL_OR_FAIL(8);
       film_grain_params.auto_regression_coeff_v[i] =
-          static_cast<int>(scratch) - 128;
+          static_cast<int8_t>(scratch - 128);
     }
   }
   OBU_READ_LITERAL_OR_FAIL(2);
-  film_grain_params.auto_regression_shift = scratch + 6;
+  film_grain_params.auto_regression_shift = static_cast<uint8_t>(scratch + 6);
   OBU_READ_LITERAL_OR_FAIL(2);
   film_grain_params.grain_scale_shift = static_cast<int>(scratch);
   if (film_grain_params.num_u_points > 0) {
     OBU_READ_LITERAL_OR_FAIL(8);
-    film_grain_params.u_multiplier = static_cast<int>(scratch);
+    film_grain_params.u_multiplier = static_cast<int8_t>(scratch - 128);
     OBU_READ_LITERAL_OR_FAIL(8);
-    film_grain_params.u_luma_multiplier = static_cast<int>(scratch);
+    film_grain_params.u_luma_multiplier = static_cast<int8_t>(scratch - 128);
     OBU_READ_LITERAL_OR_FAIL(9);
-    film_grain_params.u_offset = static_cast<int>(scratch);
+    film_grain_params.u_offset = static_cast<int16_t>(scratch - 256);
   }
   if (film_grain_params.num_v_points > 0) {
     OBU_READ_LITERAL_OR_FAIL(8);
-    film_grain_params.v_multiplier = static_cast<int>(scratch);
+    film_grain_params.v_multiplier = static_cast<int8_t>(scratch - 128);
     OBU_READ_LITERAL_OR_FAIL(8);
-    film_grain_params.v_luma_multiplier = static_cast<int>(scratch);
+    film_grain_params.v_luma_multiplier = static_cast<int8_t>(scratch - 128);
     OBU_READ_LITERAL_OR_FAIL(9);
-    film_grain_params.v_offset = static_cast<int>(scratch);
+    film_grain_params.v_offset = static_cast<int16_t>(scratch - 256);
   }
   OBU_READ_BIT_OR_FAIL;
   film_grain_params.overlap_flag = static_cast<bool>(scratch);
@@ -1630,14 +1677,15 @@
       tile_info->tile_column_start[i] = sb_start << sb_shift;
       const int max_width =
           std::min(sb_columns - sb_start, static_cast<int>(sb_max_tile_width));
-      int sb_size;
-      if (!bit_reader_->DecodeUniform(max_width, &sb_size)) {
+      if (!bit_reader_->DecodeUniform(
+              max_width, &tile_info->tile_column_width_in_superblocks[i])) {
         LIBGAV1_DLOG(ERROR, "Not enough bits.");
         return false;
       }
-      ++sb_size;
-      widest_tile_sb = std::max(sb_size, widest_tile_sb);
-      sb_start += sb_size;
+      ++tile_info->tile_column_width_in_superblocks[i];
+      widest_tile_sb = std::max(tile_info->tile_column_width_in_superblocks[i],
+                                widest_tile_sb);
+      sb_start += tile_info->tile_column_width_in_superblocks[i];
     }
     tile_info->tile_column_start[i] = frame_header_.columns4x4;
     tile_info->tile_columns = i;
@@ -1656,19 +1704,23 @@
       }
       tile_info->tile_row_start[i] = sb_start << sb_shift;
       const int max_height = std::min(sb_rows - sb_start, max_tile_height_sb);
-      int sb_size;
-      if (!bit_reader_->DecodeUniform(max_height, &sb_size)) {
+      if (!bit_reader_->DecodeUniform(
+              max_height, &tile_info->tile_row_height_in_superblocks[i])) {
         LIBGAV1_DLOG(ERROR, "Not enough bits.");
         return false;
       }
-      ++sb_size;
-      sb_start += sb_size;
+      ++tile_info->tile_row_height_in_superblocks[i];
+      sb_start += tile_info->tile_row_height_in_superblocks[i];
     }
     tile_info->tile_row_start[i] = frame_header_.rows4x4;
     tile_info->tile_rows = i;
     tile_info->tile_rows_log2 = CeilLog2(tile_info->tile_rows);
   }
   tile_info->tile_count = tile_info->tile_rows * tile_info->tile_columns;
+  if (!tile_buffers_.reserve(tile_info->tile_count)) {
+    LIBGAV1_DLOG(ERROR, "Unable to allocate memory for tile_buffers_.");
+    return false;
+  }
   tile_info->context_update_id = 0;
   const int tile_bits =
       tile_info->tile_columns_log2 + tile_info->tile_rows_log2;
@@ -1702,6 +1754,11 @@
   int64_t scratch;
   if (sequence_header_.reduced_still_picture_header) {
     frame_header_.show_frame = true;
+    current_frame_ = buffer_pool_->GetFreeBuffer();
+    if (current_frame_ == nullptr) {
+      LIBGAV1_DLOG(ERROR, "Could not get current_frame from the buffer pool.");
+      return false;
+    }
   } else {
     OBU_READ_BIT_OR_FAIL;
     frame_header_.show_existing_frame = static_cast<bool>(scratch);
@@ -1733,9 +1790,9 @@
       }
       // Section 7.18.2. Note: This is also needed for Section 7.21 if
       // frame_type is kFrameKey.
-      decoder_state_.current_frame =
+      current_frame_ =
           decoder_state_.reference_frame[frame_header_.frame_to_show];
-      if (decoder_state_.current_frame == nullptr) {
+      if (current_frame_ == nullptr) {
         LIBGAV1_DLOG(ERROR, "Buffer %d does not contain a decoded frame",
                      frame_header_.frame_to_show);
         return false;
@@ -1743,19 +1800,19 @@
       // Section 6.8.2: It is a requirement of bitstream conformance that
       // when show_existing_frame is used to show a previous frame, that the
       // value of showable_frame for the previous frame was equal to 1.
-      if (!decoder_state_.current_frame->showable_frame()) {
+      if (!current_frame_->showable_frame()) {
         LIBGAV1_DLOG(ERROR, "Buffer %d does not contain a showable frame",
                      frame_header_.frame_to_show);
         return false;
       }
-      if (decoder_state_.current_frame->frame_type() == kFrameKey) {
+      if (current_frame_->frame_type() == kFrameKey) {
         frame_header_.refresh_frame_flags = 0xff;
         // Section 6.8.2: It is a requirement of bitstream conformance that
         // when show_existing_frame is used to show a previous frame with
         // RefFrameType[ frame_to_show_map_idx ] equal to KEY_FRAME, that
         // the frame is output via the show_existing_frame mechanism at most
         // once.
-        decoder_state_.current_frame->set_showable_frame(false);
+        current_frame_->set_showable_frame(false);
 
         // Section 7.21. Note: decoder_state_.current_frame_id must be set
         // only when frame_type is kFrameKey per the spec. Among all the
@@ -1769,9 +1826,14 @@
       }
       return true;
     }
+    current_frame_ = buffer_pool_->GetFreeBuffer();
+    if (current_frame_ == nullptr) {
+      LIBGAV1_DLOG(ERROR, "Could not get current_frame from the buffer pool.");
+      return false;
+    }
     OBU_READ_LITERAL_OR_FAIL(2);
     frame_header_.frame_type = static_cast<FrameType>(scratch);
-    decoder_state_.current_frame->set_frame_type(frame_header_.frame_type);
+    current_frame_->set_frame_type(frame_header_.frame_type);
     OBU_READ_BIT_OR_FAIL;
     frame_header_.show_frame = static_cast<bool>(scratch);
     if (frame_header_.show_frame &&
@@ -1787,8 +1849,7 @@
       OBU_READ_BIT_OR_FAIL;
       frame_header_.showable_frame = static_cast<bool>(scratch);
     }
-    decoder_state_.current_frame->set_showable_frame(
-        frame_header_.showable_frame);
+    current_frame_->set_showable_frame(frame_header_.showable_frame);
     if (frame_header_.frame_type == kFrameSwitch ||
         (frame_header_.frame_type == kFrameKey && frame_header_.show_frame)) {
       frame_header_.error_resilient_mode = true;
@@ -1800,7 +1861,6 @@
   if (frame_header_.frame_type == kFrameKey && frame_header_.show_frame) {
     decoder_state_.reference_valid.fill(false);
     decoder_state_.reference_order_hint.fill(0);
-    decoder_state_.current_frame->ClearOrderHints();
   }
   OBU_READ_BIT_OR_FAIL;
   frame_header_.enable_cdf_update = !static_cast<bool>(scratch);
@@ -1940,14 +2000,24 @@
         OBU_READ_LITERAL_OR_FAIL(3);
         frame_header_.reference_frame_index[i] = scratch;
       }
+      const int reference_frame_index = frame_header_.reference_frame_index[i];
+      assert(reference_frame_index >= 0);
+      // Section 6.8.2: It is a requirement of bitstream conformance that
+      // RefValid[ ref_frame_idx[ i ] ] is equal to 1 ...
+      // The remainder of the statement is handled by ParseSequenceHeader().
+      // Note if support for Annex C: Error resilience behavior is added this
+      // check should be omitted per C.5 Decoder consequences of processable
+      // frames.
+      if (!decoder_state_.reference_valid[reference_frame_index]) {
+        LIBGAV1_DLOG(ERROR, "ref_frame_idx[%d] (%d) is not valid.", i,
+                     reference_frame_index);
+        return false;
+      }
       // Check if the inter frame requests a nonexistent reference, whether or
       // not frame_refs_short_signaling is used.
-      assert(frame_header_.reference_frame_index[i] >= 0);
-      if (decoder_state_
-              .reference_frame[frame_header_.reference_frame_index[i]] ==
-          nullptr) {
+      if (decoder_state_.reference_frame[reference_frame_index] == nullptr) {
         LIBGAV1_DLOG(ERROR, "ref_frame_idx[%d] (%d) is not a decoded frame.", i,
-                     frame_header_.reference_frame_index[i]);
+                     reference_frame_index);
         return false;
       }
       if (sequence_header_.frame_id_numbers_present) {
@@ -1966,13 +2036,11 @@
         // Section 6.8.2: It is a requirement of bitstream conformance that
         // RefValid[ ref_frame_idx[ i ] ] is equal to 1, ...
         if (frame_header_.expected_frame_id[i] !=
-                decoder_state_.reference_frame_id
-                    [frame_header_.reference_frame_index[i]] ||
-            !decoder_state_
-                 .reference_valid[frame_header_.reference_frame_index[i]]) {
+                decoder_state_.reference_frame_id[reference_frame_index] ||
+            !decoder_state_.reference_valid[reference_frame_index]) {
           LIBGAV1_DLOG(ERROR,
                        "Reference buffer %d has a frame id number mismatch.",
-                       frame_header_.reference_frame_index[i]);
+                       reference_frame_index);
           return false;
         }
       }
@@ -2045,22 +2113,49 @@
   // At this point, we have parsed the frame and render sizes and computed
   // the image size, whether it's an intra or inter frame. So we can save
   // the sizes in the current frame now.
-  if (!decoder_state_.current_frame->SetFrameDimensions(frame_header_)) {
+  if (!current_frame_->SetFrameDimensions(frame_header_)) {
     LIBGAV1_DLOG(ERROR, "Setting current frame dimensions failed.");
     return false;
   }
   if (!IsIntraFrame(frame_header_.frame_type)) {
-    for (int i = 0; i < kNumInterReferenceFrameTypes; ++i) {
-      const auto reference_frame =
-          static_cast<ReferenceFrameType>(kReferenceFrameLast + i);
+    // Initialize the kReferenceFrameIntra type reference frame information to
+    // simplify the frame type validation in motion field projection.
+    // Set the kReferenceFrameIntra type |order_hint_| to
+    // |frame_header_.order_hint|. This guarantees that in SIMD implementations,
+    // the other reference frame information of the kReferenceFrameIntra type
+    // could be correctly initialized using the following loop with
+    // |frame_header_.order_hint| being the |hint|.
+    ReferenceInfo* const reference_info = current_frame_->reference_info();
+    reference_info->order_hint[kReferenceFrameIntra] = frame_header_.order_hint;
+    reference_info->relative_distance_from[kReferenceFrameIntra] = 0;
+    reference_info->relative_distance_to[kReferenceFrameIntra] = 0;
+    reference_info->skip_references[kReferenceFrameIntra] = true;
+    reference_info->projection_divisions[kReferenceFrameIntra] = 0;
+
+    for (int i = kReferenceFrameLast; i <= kNumInterReferenceFrameTypes; ++i) {
+      const auto reference_frame = static_cast<ReferenceFrameType>(i);
       const uint8_t hint =
-          decoder_state_
-              .reference_order_hint[frame_header_.reference_frame_index[i]];
-      decoder_state_.current_frame->set_order_hint(reference_frame, hint);
-      decoder_state_.reference_frame_sign_bias[reference_frame] =
+          decoder_state_.reference_order_hint
+              [frame_header_.reference_frame_index[i - kReferenceFrameLast]];
+      reference_info->order_hint[reference_frame] = hint;
+      const int relative_distance_from =
           GetRelativeDistance(hint, frame_header_.order_hint,
-                              sequence_header_.enable_order_hint,
-                              sequence_header_.order_hint_bits) > 0;
+                              sequence_header_.order_hint_shift_bits);
+      const int relative_distance_to =
+          GetRelativeDistance(frame_header_.order_hint, hint,
+                              sequence_header_.order_hint_shift_bits);
+      reference_info->relative_distance_from[reference_frame] =
+          relative_distance_from;
+      reference_info->relative_distance_to[reference_frame] =
+          relative_distance_to;
+      reference_info->skip_references[reference_frame] =
+          relative_distance_to > kMaxFrameDistance || relative_distance_to <= 0;
+      reference_info->projection_divisions[reference_frame] =
+          reference_info->skip_references[reference_frame]
+              ? 0
+              : kProjectionMvDivisionLookup[relative_distance_to];
+      decoder_state_.reference_frame_sign_bias[reference_frame] =
+          relative_distance_from > 0;
     }
   }
   if (frame_header_.enable_cdf_update &&
@@ -2079,18 +2174,25 @@
   if (!has_sequence_header_) return false;
   if (!ParseFrameParameters()) return false;
   if (frame_header_.show_existing_frame) return true;
+  assert(!obu_headers_.empty());
+  current_frame_->set_spatial_id(obu_headers_.back().spatial_id);
+  current_frame_->set_temporal_id(obu_headers_.back().temporal_id);
   bool status = ParseTileInfoSyntax() && ParseQuantizerParameters() &&
                 ParseSegmentationParameters();
   if (!status) return false;
-  decoder_state_.current_frame->SetSegmentationParameters(
-      frame_header_.segmentation);
+  current_frame_->SetSegmentationParameters(frame_header_.segmentation);
   status =
       ParseQuantizerIndexDeltaParameters() && ParseLoopFilterDeltaParameters();
   if (!status) return false;
   ComputeSegmentLosslessAndQIndex();
+  // Section 6.8.2: It is a requirement of bitstream conformance that
+  // delta_q_present is equal to 0 when CodedLossless is equal to 1.
+  if (frame_header_.coded_lossless && frame_header_.delta_q.present) {
+    return false;
+  }
   status = ParseLoopFilterParameters();
   if (!status) return false;
-  decoder_state_.current_frame->SetLoopFilterDeltas(frame_header_.loop_filter);
+  current_frame_->SetLoopFilterDeltas(frame_header_.loop_filter);
   status = ParseCdefParameters() && ParseLoopRestorationParameters() &&
            ParseTxModeSyntax() && ParseFrameReferenceModeSyntax() &&
            ParseSkipModeParameters() && ReadAllowWarpedMotion();
@@ -2100,22 +2202,212 @@
   frame_header_.reduced_tx_set = static_cast<bool>(scratch);
   status = ParseGlobalMotionParameters();
   if (!status) return false;
-  decoder_state_.current_frame->SetGlobalMotions(frame_header_.global_motion);
+  current_frame_->SetGlobalMotions(frame_header_.global_motion);
   status = ParseFilmGrainParameters();
   if (!status) return false;
   if (sequence_header_.film_grain_params_present) {
-    decoder_state_.current_frame->set_film_grain_params(
-        frame_header_.film_grain_params);
+    current_frame_->set_film_grain_params(frame_header_.film_grain_params);
   }
   return true;
 }
 
-bool ObuParser::ParseMetadata(size_t size) {
+bool ObuParser::ParsePadding(const uint8_t* data, size_t size) {
+  // The spec allows a padding OBU to be header-only (i.e., |size| = 0). So
+  // check trailing bits only if |size| > 0.
+  if (size == 0) return true;
+  // The payload of a padding OBU is byte aligned. Therefore the first
+  // trailing byte should be 0x80. See https://crbug.com/aomedia/2393.
+  const int i = GetLastNonzeroByteIndex(data, size);
+  if (i < 0) {
+    LIBGAV1_DLOG(ERROR, "Trailing bit is missing.");
+    return false;
+  }
+  if (data[i] != 0x80) {
+    LIBGAV1_DLOG(
+        ERROR,
+        "The last nonzero byte of the payload data is 0x%x, should be 0x80.",
+        data[i]);
+    return false;
+  }
+  // Skip all bits before the trailing bit.
+  bit_reader_->SkipBytes(i);
+  return true;
+}
+
+bool ObuParser::ParseMetadataScalability() {
   int64_t scratch;
-  OBU_READ_LITERAL_OR_FAIL(16);
-  size -= 2;
-  const auto type = static_cast<MetadataType>(scratch);
-  switch (type) {
+  // scalability_mode_idc
+  OBU_READ_LITERAL_OR_FAIL(8);
+  const auto scalability_mode_idc = static_cast<int>(scratch);
+  if (scalability_mode_idc == kScalabilitySS) {
+    // Parse scalability_structure().
+    // spatial_layers_cnt_minus_1
+    OBU_READ_LITERAL_OR_FAIL(2);
+    const auto spatial_layers_count = static_cast<int>(scratch) + 1;
+    // spatial_layer_dimensions_present_flag
+    OBU_READ_BIT_OR_FAIL;
+    const auto spatial_layer_dimensions_present_flag =
+        static_cast<bool>(scratch);
+    // spatial_layer_description_present_flag
+    OBU_READ_BIT_OR_FAIL;
+    const auto spatial_layer_description_present_flag =
+        static_cast<bool>(scratch);
+    // temporal_group_description_present_flag
+    OBU_READ_BIT_OR_FAIL;
+    const auto temporal_group_description_present_flag =
+        static_cast<bool>(scratch);
+    // scalability_structure_reserved_3bits
+    OBU_READ_LITERAL_OR_FAIL(3);
+    if (scratch != 0) {
+      LIBGAV1_DLOG(WARNING,
+                   "scalability_structure_reserved_3bits is not zero.");
+    }
+    if (spatial_layer_dimensions_present_flag) {
+      for (int i = 0; i < spatial_layers_count; ++i) {
+        // spatial_layer_max_width[i]
+        OBU_READ_LITERAL_OR_FAIL(16);
+        // spatial_layer_max_height[i]
+        OBU_READ_LITERAL_OR_FAIL(16);
+      }
+    }
+    if (spatial_layer_description_present_flag) {
+      for (int i = 0; i < spatial_layers_count; ++i) {
+        // spatial_layer_ref_id[i]
+        OBU_READ_LITERAL_OR_FAIL(8);
+      }
+    }
+    if (temporal_group_description_present_flag) {
+      // temporal_group_size
+      OBU_READ_LITERAL_OR_FAIL(8);
+      const auto temporal_group_size = static_cast<int>(scratch);
+      for (int i = 0; i < temporal_group_size; ++i) {
+        // temporal_group_temporal_id[i]
+        OBU_READ_LITERAL_OR_FAIL(3);
+        // temporal_group_temporal_switching_up_point_flag[i]
+        OBU_READ_BIT_OR_FAIL;
+        // temporal_group_spatial_switching_up_point_flag[i]
+        OBU_READ_BIT_OR_FAIL;
+        // temporal_group_ref_cnt[i]
+        OBU_READ_LITERAL_OR_FAIL(3);
+        const auto temporal_group_ref_count = static_cast<int>(scratch);
+        for (int j = 0; j < temporal_group_ref_count; ++j) {
+          // temporal_group_ref_pic_diff[i][j]
+          OBU_READ_LITERAL_OR_FAIL(8);
+        }
+      }
+    }
+  }
+  return true;
+}
+
+bool ObuParser::ParseMetadataTimecode() {
+  int64_t scratch;
+  // counting_type: should be the same for all pictures in the coded video
+  // sequence. 7..31 are reserved.
+  OBU_READ_LITERAL_OR_FAIL(5);
+  // full_timestamp_flag
+  OBU_READ_BIT_OR_FAIL;
+  const auto full_timestamp_flag = static_cast<bool>(scratch);
+  // discontinuity_flag
+  OBU_READ_BIT_OR_FAIL;
+  // cnt_dropped_flag
+  OBU_READ_BIT_OR_FAIL;
+  // n_frames
+  OBU_READ_LITERAL_OR_FAIL(9);
+  if (full_timestamp_flag) {
+    // seconds_value
+    OBU_READ_LITERAL_OR_FAIL(6);
+    const auto seconds_value = static_cast<int>(scratch);
+    if (seconds_value > 59) {
+      LIBGAV1_DLOG(ERROR, "Invalid seconds_value %d.", seconds_value);
+      return false;
+    }
+    // minutes_value
+    OBU_READ_LITERAL_OR_FAIL(6);
+    const auto minutes_value = static_cast<int>(scratch);
+    if (minutes_value > 59) {
+      LIBGAV1_DLOG(ERROR, "Invalid minutes_value %d.", minutes_value);
+      return false;
+    }
+    // hours_value
+    OBU_READ_LITERAL_OR_FAIL(5);
+    const auto hours_value = static_cast<int>(scratch);
+    if (hours_value > 23) {
+      LIBGAV1_DLOG(ERROR, "Invalid hours_value %d.", hours_value);
+      return false;
+    }
+  } else {
+    // seconds_flag
+    OBU_READ_BIT_OR_FAIL;
+    const auto seconds_flag = static_cast<bool>(scratch);
+    if (seconds_flag) {
+      // seconds_value
+      OBU_READ_LITERAL_OR_FAIL(6);
+      const auto seconds_value = static_cast<int>(scratch);
+      if (seconds_value > 59) {
+        LIBGAV1_DLOG(ERROR, "Invalid seconds_value %d.", seconds_value);
+        return false;
+      }
+      // minutes_flag
+      OBU_READ_BIT_OR_FAIL;
+      const auto minutes_flag = static_cast<bool>(scratch);
+      if (minutes_flag) {
+        // minutes_value
+        OBU_READ_LITERAL_OR_FAIL(6);
+        const auto minutes_value = static_cast<int>(scratch);
+        if (minutes_value > 59) {
+          LIBGAV1_DLOG(ERROR, "Invalid minutes_value %d.", minutes_value);
+          return false;
+        }
+        // hours_flag
+        OBU_READ_BIT_OR_FAIL;
+        const auto hours_flag = static_cast<bool>(scratch);
+        if (hours_flag) {
+          // hours_value
+          OBU_READ_LITERAL_OR_FAIL(5);
+          const auto hours_value = static_cast<int>(scratch);
+          if (hours_value > 23) {
+            LIBGAV1_DLOG(ERROR, "Invalid hours_value %d.", hours_value);
+            return false;
+          }
+        }
+      }
+    }
+  }
+  // time_offset_length: should be the same for all pictures in the coded
+  // video sequence.
+  OBU_READ_LITERAL_OR_FAIL(5);
+  const auto time_offset_length = static_cast<int>(scratch);
+  if (time_offset_length > 0) {
+    // time_offset_value
+    OBU_READ_LITERAL_OR_FAIL(time_offset_length);
+  }
+  // Compute clockTimestamp. Section 6.7.7:
+  //   When timing_info_present_flag is equal to 1 and discontinuity_flag is
+  //   equal to 0, the value of clockTimestamp shall be greater than or equal
+  //   to the value of clockTimestamp for the previous set of clock timestamp
+  //   syntax elements in output order.
+  return true;
+}
+
+bool ObuParser::ParseMetadata(const uint8_t* data, size_t size) {
+  const size_t start_offset = bit_reader_->byte_offset();
+  size_t metadata_type;
+  if (!bit_reader_->ReadUnsignedLeb128(&metadata_type)) {
+    LIBGAV1_DLOG(ERROR, "Could not read metadata_type.");
+    return false;
+  }
+  const size_t metadata_type_size = bit_reader_->byte_offset() - start_offset;
+  if (size < metadata_type_size) {
+    LIBGAV1_DLOG(
+        ERROR, "metadata_type is longer than metadata OBU payload %zu vs %zu.",
+        metadata_type_size, size);
+    return false;
+  }
+  data += metadata_type_size;
+  size -= metadata_type_size;
+  int64_t scratch;
+  switch (metadata_type) {
     case kMetadataTypeHdrContentLightLevel:
       OBU_READ_LITERAL_OR_FAIL(16);
       metadata_.max_cll = scratch;
@@ -2138,40 +2430,133 @@
       OBU_READ_LITERAL_OR_FAIL(32);
       metadata_.luminance_min = static_cast<uint32_t>(scratch);
       break;
-    default:
-      LIBGAV1_DLOG(ERROR, "Unknown metadata type: %u", type);
-      return false;
+    case kMetadataTypeScalability:
+      if (!ParseMetadataScalability()) return false;
+      break;
+    case kMetadataTypeItutT35: {
+      OBU_READ_LITERAL_OR_FAIL(8);
+      metadata_.itu_t_t35_country_code = static_cast<uint8_t>(scratch);
+      ++data;
+      --size;
+      if (metadata_.itu_t_t35_country_code == 0xFF) {
+        OBU_READ_LITERAL_OR_FAIL(8);
+        metadata_.itu_t_t35_country_code_extension_byte =
+            static_cast<uint8_t>(scratch);
+        ++data;
+        --size;
+      }
+      // Read itu_t_t35_payload_bytes. Section 6.7.2 of the spec says:
+      //   itu_t_t35_payload_bytes shall be bytes containing data registered as
+      //   specified in Recommendation ITU-T T.35.
+      // Therefore itu_t_t35_payload_bytes is byte aligned and the first
+      // trailing byte should be 0x80. Since the exact syntax of
+      // itu_t_t35_payload_bytes is not defined in the AV1 spec, identify the
+      // end of itu_t_t35_payload_bytes by searching for the trailing bit.
+      const int i = GetLastNonzeroByteIndex(data, size);
+      if (i < 0) {
+        LIBGAV1_DLOG(ERROR, "Trailing bit is missing.");
+        return false;
+      }
+      if (data[i] != 0x80) {
+        LIBGAV1_DLOG(
+            ERROR,
+            "itu_t_t35_payload_bytes is not byte aligned. The last nonzero "
+            "byte of the payload data is 0x%x, should be 0x80.",
+            data[i]);
+        return false;
+      }
+      if (i != 0) {
+        // data[0]..data[i - 1] are itu_t_t35_payload_bytes.
+        metadata_.itu_t_t35_payload_bytes.reset(new (std::nothrow) uint8_t[i]);
+        if (metadata_.itu_t_t35_payload_bytes == nullptr) {
+          LIBGAV1_DLOG(ERROR, "Allocation of itu_t_t35_payload_bytes failed.");
+          return false;
+        }
+        memcpy(metadata_.itu_t_t35_payload_bytes.get(), data, i);
+        metadata_.itu_t_t35_payload_size = i;
+      }
+      // Skip all bits before the trailing bit.
+      bit_reader_->SkipBytes(i);
+      break;
+    }
+    case kMetadataTypeTimecode:
+      if (!ParseMetadataTimecode()) return false;
+      break;
+    default: {
+      // metadata_type is equal to a value reserved for future use or a user
+      // private value.
+      //
+      // The Note in Section 5.8.1 says "Decoders should ignore the entire OBU
+      // if they do not understand the metadata_type." Find the trailing bit
+      // and skip all bits before the trailing bit.
+      const int i = GetLastNonzeroByteIndex(data, size);
+      if (i >= 0) {
+        // The last 1 bit in the last nonzero byte is the trailing bit. Skip
+        // all bits before the trailing bit.
+        const int n = CountTrailingZeros(data[i]);
+        bit_reader_->SkipBits(i * 8 + 7 - n);
+      }
+      break;
+    }
   }
   return true;
 }
 
-bool ObuParser::ValidateTileGroup() {
-  const auto& tile_group = tile_groups_.back();
-  if (tile_group.start != next_tile_group_start_ ||
-      tile_group.start > tile_group.end ||
-      tile_group.end >= frame_header_.tile_info.tile_count) {
+bool ObuParser::AddTileBuffers(int start, int end, size_t total_size,
+                               size_t tg_header_size,
+                               size_t bytes_consumed_so_far) {
+  // Validate that the tile group start and end are within the allowed range.
+  if (start != next_tile_group_start_ || start > end ||
+      end >= frame_header_.tile_info.tile_count) {
     LIBGAV1_DLOG(ERROR,
                  "Invalid tile group start %d or end %d: expected tile group "
                  "start %d, tile_count %d.",
-                 tile_group.start, tile_group.end, next_tile_group_start_,
+                 start, end, next_tile_group_start_,
                  frame_header_.tile_info.tile_count);
     return false;
   }
-  next_tile_group_start_ = tile_group.end + 1;
-  return true;
-}
+  next_tile_group_start_ = end + 1;
 
-bool ObuParser::SetTileDataOffset(size_t total_size, size_t tg_header_size,
-                                  size_t bytes_consumed_so_far) {
   if (total_size < tg_header_size) {
     LIBGAV1_DLOG(ERROR, "total_size (%zu) is less than tg_header_size (%zu).)",
                  total_size, tg_header_size);
     return false;
   }
-  auto& tile_group = tile_groups_.back();
-  tile_group.data_size = total_size - tg_header_size;
-  tile_group.data_offset = bytes_consumed_so_far + tg_header_size;
-  tile_group.data = data_ + tile_group.data_offset;
+  size_t bytes_left = total_size - tg_header_size;
+  const uint8_t* data = data_ + bytes_consumed_so_far + tg_header_size;
+  for (int tile_number = start; tile_number <= end; ++tile_number) {
+    size_t tile_size = 0;
+    if (tile_number != end) {
+      RawBitReader bit_reader(data, bytes_left);
+      if (!bit_reader.ReadLittleEndian(frame_header_.tile_info.tile_size_bytes,
+                                       &tile_size)) {
+        LIBGAV1_DLOG(ERROR, "Could not read tile size for tile #%d",
+                     tile_number);
+        return false;
+      }
+      ++tile_size;
+      data += frame_header_.tile_info.tile_size_bytes;
+      bytes_left -= frame_header_.tile_info.tile_size_bytes;
+      if (tile_size > bytes_left) {
+        LIBGAV1_DLOG(ERROR, "Invalid tile size %zu for tile #%d", tile_size,
+                     tile_number);
+        return false;
+      }
+    } else {
+      tile_size = bytes_left;
+      if (tile_size == 0) {
+        LIBGAV1_DLOG(ERROR, "Invalid tile size %zu for tile #%d", tile_size,
+                     tile_number);
+        return false;
+      }
+    }
+    // The memory for this has been allocated in ParseTileInfoSyntax(). So it is
+    // safe to use push_back_unchecked here.
+    tile_buffers_.push_back_unchecked({data, tile_size});
+    data += tile_size;
+    bytes_left -= tile_size;
+  }
+  bit_reader_->SkipBytes(total_size - tg_header_size);
   return true;
 }
 
@@ -2180,29 +2565,19 @@
   const size_t start_offset = bit_reader_->byte_offset();
   const int tile_bits =
       tile_info->tile_columns_log2 + tile_info->tile_rows_log2;
-  if (!tile_groups_.emplace_back()) {
-    LIBGAV1_DLOG(ERROR, "Could not add an element to tile_groups_.");
-    return false;
-  }
-  auto& tile_group = tile_groups_.back();
   if (tile_bits == 0) {
-    tile_group.start = 0;
-    tile_group.end = 0;
-    if (!ValidateTileGroup()) return false;
-    return SetTileDataOffset(size, 0, bytes_consumed_so_far);
+    return AddTileBuffers(0, 0, size, 0, bytes_consumed_so_far);
   }
   int64_t scratch;
   OBU_READ_BIT_OR_FAIL;
   const auto tile_start_and_end_present_flag = static_cast<bool>(scratch);
   if (!tile_start_and_end_present_flag) {
-    tile_group.start = 0;
-    tile_group.end = tile_info->tile_count - 1;
-    if (!ValidateTileGroup()) return false;
     if (!bit_reader_->AlignToNextByte()) {
       LIBGAV1_DLOG(ERROR, "Byte alignment has non zero bits.");
       return false;
     }
-    return SetTileDataOffset(size, 1, bytes_consumed_so_far);
+    return AddTileBuffers(0, tile_info->tile_count - 1, size, 1,
+                          bytes_consumed_so_far);
   }
   if (obu_headers_.back().type == kObuFrame) {
     // 6.10.1: If obu_type is equal to OBU_FRAME, it is a requirement of
@@ -2213,16 +2588,16 @@
     return false;
   }
   OBU_READ_LITERAL_OR_FAIL(tile_bits);
-  tile_group.start = static_cast<int>(scratch);
+  const int start = static_cast<int>(scratch);
   OBU_READ_LITERAL_OR_FAIL(tile_bits);
-  tile_group.end = static_cast<int>(scratch);
-  if (!ValidateTileGroup()) return false;
+  const int end = static_cast<int>(scratch);
   if (!bit_reader_->AlignToNextByte()) {
     LIBGAV1_DLOG(ERROR, "Byte alignment has non zero bits.");
     return false;
   }
   const size_t tg_header_size = bit_reader_->byte_offset() - start_offset;
-  return SetTileDataOffset(size, tg_header_size, bytes_consumed_so_far);
+  return AddTileBuffers(start, end, size, tg_header_size,
+                        bytes_consumed_so_far);
 }
 
 bool ObuParser::ParseHeader() {
@@ -2237,17 +2612,10 @@
   OBU_READ_BIT_OR_FAIL;
   const auto extension_flag = static_cast<bool>(scratch);
   OBU_READ_BIT_OR_FAIL;
-  const auto has_size_field = static_cast<bool>(scratch);
-  if (!has_size_field) {
-    LIBGAV1_DLOG(
-        ERROR,
-        "has_size_field is zero. libgav1 does not support such streams.");
-    return false;
-  }
+  obu_header.has_size_field = static_cast<bool>(scratch);
   OBU_READ_BIT_OR_FAIL;  // reserved.
   if (scratch != 0) {
-    LIBGAV1_DLOG(ERROR, "obu_reserved_1bit is not zero.");
-    return false;
+    LIBGAV1_DLOG(WARNING, "obu_reserved_1bit is not zero.");
   }
   obu_header.has_extension = extension_flag;
   if (extension_flag) {
@@ -2262,8 +2630,7 @@
     obu_header.spatial_id = scratch;
     OBU_READ_LITERAL_OR_FAIL(3);  // reserved.
     if (scratch != 0) {
-      LIBGAV1_DLOG(ERROR, "extension_header_reserved_3bits is not zero.");
-      return false;
+      LIBGAV1_DLOG(WARNING, "extension_header_reserved_3bits is not zero.");
     }
   } else {
     obu_header.temporal_id = 0;
@@ -2285,17 +2652,22 @@
 
 bool ObuParser::HasData() const { return size_ > 0; }
 
-bool ObuParser::ParseOneFrame() {
-  if (data_ == nullptr || size_ == 0) return false;
+StatusCode ObuParser::ParseOneFrame(RefCountedBufferPtr* const current_frame) {
+  if (data_ == nullptr || size_ == 0) return kStatusInvalidArgument;
+
+  assert(current_frame_ == nullptr);
+  // This is used to release any references held in case of parsing failure.
+  RefCountedBufferPtrCleanup current_frame_cleanup(&current_frame_);
+
   const uint8_t* data = data_;
   size_t size = size_;
 
   // Clear everything except the sequence header.
   obu_headers_.clear();
   frame_header_ = {};
-  tile_groups_.clear();
+  metadata_ = {};
+  tile_buffers_.clear();
   next_tile_group_start_ = 0;
-  // TODO(b/120903866): |metadata_| must be reset here.
 
   bool parsed_one_full_frame = false;
   bool seen_frame_header = false;
@@ -2304,34 +2676,41 @@
   while (size > 0 && !parsed_one_full_frame) {
     if (!InitBitReader(data, size)) {
       LIBGAV1_DLOG(ERROR, "Failed to initialize bit reader.");
-      return false;
+      return kStatusOutOfMemory;
     }
     if (!ParseHeader()) {
       LIBGAV1_DLOG(ERROR, "Failed to parse OBU Header.");
-      return false;
+      return kStatusBitstreamError;
+    }
+    const ObuHeader& obu_header = obu_headers_.back();
+    if (!obu_header.has_size_field) {
+      LIBGAV1_DLOG(
+          ERROR,
+          "has_size_field is zero. libgav1 does not support such streams.");
+      return kStatusUnimplemented;
     }
     const size_t obu_header_size = bit_reader_->byte_offset();
     size_t obu_size;
     if (!bit_reader_->ReadUnsignedLeb128(&obu_size)) {
       LIBGAV1_DLOG(ERROR, "Could not read OBU size.");
-      return false;
+      return kStatusBitstreamError;
     }
     const size_t obu_length_size = bit_reader_->byte_offset() - obu_header_size;
     if (size - bit_reader_->byte_offset() < obu_size) {
       LIBGAV1_DLOG(ERROR, "Not enough bits left to parse OBU %zu vs %zu.",
                    size - bit_reader_->bit_offset(), obu_size);
-      return false;
+      return kStatusBitstreamError;
     }
 
-    const ObuHeader& obu_header = obu_headers_.back();
     const ObuType obu_type = obu_header.type;
     if (obu_type != kObuSequenceHeader && obu_type != kObuTemporalDelimiter &&
         has_sequence_header_ &&
-        sequence_header_.operating_point_idc[kOperatingPoint] != 0 &&
+        sequence_header_.operating_point_idc[operating_point_] != 0 &&
         obu_header.has_extension &&
-        (!InTemporalLayer(sequence_header_.operating_point_idc[kOperatingPoint],
-                          obu_header.temporal_id) ||
-         !InSpatialLayer(sequence_header_.operating_point_idc[kOperatingPoint],
+        (!InTemporalLayer(
+             sequence_header_.operating_point_idc[operating_point_],
+             obu_header.temporal_id) ||
+         !InSpatialLayer(sequence_header_.operating_point_idc[operating_point_],
                          obu_header.spatial_id))) {
       obu_headers_.pop_back();
       bit_reader_->SkipBytes(obu_size);
@@ -2341,6 +2720,10 @@
     }
 
     const size_t obu_start_position = bit_reader_->bit_offset();
+    // The bit_reader_ is byte aligned after reading obu_header and obu_size.
+    // Therefore the byte offset can be computed as obu_start_position >> 3
+    // below.
+    assert((obu_start_position & 7) == 0);
     bool obu_skipped = false;
     switch (obu_type) {
       case kObuTemporalDelimiter:
@@ -2348,18 +2731,25 @@
       case kObuSequenceHeader:
         if (!ParseSequenceHeader(seen_frame_header)) {
           LIBGAV1_DLOG(ERROR, "Failed to parse SequenceHeader OBU.");
-          return false;
+          return kStatusBitstreamError;
+        }
+        if (sequence_header_.color_config.bitdepth > LIBGAV1_MAX_BITDEPTH) {
+          LIBGAV1_DLOG(
+              ERROR,
+              "Bitdepth %d is not supported. The maximum bitdepth is %d.",
+              sequence_header_.color_config.bitdepth, LIBGAV1_MAX_BITDEPTH);
+          return kStatusUnimplemented;
         }
         break;
       case kObuFrameHeader:
         if (seen_frame_header) {
           LIBGAV1_DLOG(ERROR,
                        "Frame header found but frame header was already seen.");
-          return false;
+          return kStatusBitstreamError;
         }
         if (!ParseFrameHeader()) {
           LIBGAV1_DLOG(ERROR, "Failed to parse FrameHeader OBU.");
-          return false;
+          return kStatusBitstreamError;
         }
         frame_header = &data[obu_start_position >> 3];
         frame_header_size_in_bits =
@@ -2372,7 +2762,7 @@
           LIBGAV1_DLOG(ERROR,
                        "Redundant frame header found but frame header was not "
                        "yet seen.");
-          return false;
+          return kStatusBitstreamError;
         }
         const size_t fh_size = (frame_header_size_in_bits + 7) >> 3;
         if (obu_size < fh_size ||
@@ -2380,7 +2770,7 @@
                 0) {
           LIBGAV1_DLOG(ERROR,
                        "Redundant frame header differs from frame header.");
-          return false;
+          return kStatusBitstreamError;
         }
         bit_reader_->SkipBits(frame_header_size_in_bits);
         break;
@@ -2390,35 +2780,34 @@
         if (seen_frame_header) {
           LIBGAV1_DLOG(ERROR,
                        "Frame header found but frame header was already seen.");
-          return false;
+          return kStatusBitstreamError;
         }
         if (!ParseFrameHeader()) {
           LIBGAV1_DLOG(ERROR, "Failed to parse FrameHeader in Frame OBU.");
-          return false;
+          return kStatusBitstreamError;
         }
         // Section 6.8.2: If obu_type is equal to OBU_FRAME, it is a
         // requirement of bitstream conformance that show_existing_frame is
         // equal to 0.
         if (frame_header_.show_existing_frame) {
           LIBGAV1_DLOG(ERROR, "Frame OBU cannot set show_existing_frame to 1.");
-          return false;
+          return kStatusBitstreamError;
         }
         if (!bit_reader_->AlignToNextByte()) {
           LIBGAV1_DLOG(ERROR, "Byte alignment has non zero bits.");
-          return false;
+          return kStatusBitstreamError;
         }
         const size_t fh_size = bit_reader_->byte_offset() - fh_start_offset;
         if (fh_size >= obu_size) {
           LIBGAV1_DLOG(ERROR, "Frame header size (%zu) >= obu_size (%zu).",
                        fh_size, obu_size);
-          return false;
+          return kStatusBitstreamError;
         }
         if (!ParseTileGroup(obu_size - fh_size,
                             size_ - size + bit_reader_->byte_offset())) {
           LIBGAV1_DLOG(ERROR, "Failed to parse TileGroup in Frame OBU.");
-          return false;
+          return kStatusBitstreamError;
         }
-        bit_reader_->SkipBytes(tile_groups_.back().data_size);
         parsed_one_full_frame = true;
         break;
       }
@@ -2426,21 +2815,25 @@
         if (!ParseTileGroup(obu_size,
                             size_ - size + bit_reader_->byte_offset())) {
           LIBGAV1_DLOG(ERROR, "Failed to parse TileGroup OBU.");
-          return false;
+          return kStatusBitstreamError;
         }
-        bit_reader_->SkipBytes(tile_groups_.back().data_size);
         parsed_one_full_frame =
-            (tile_groups_.back().end == frame_header_.tile_info.tile_count - 1);
+            (next_tile_group_start_ == frame_header_.tile_info.tile_count);
         break;
       case kObuTileList:
         LIBGAV1_DLOG(ERROR, "Decoding of tile list OBUs is not supported.");
-        return false;
+        return kStatusUnimplemented;
       case kObuPadding:
-      // TODO(b/120903866): Fix ParseMetadata() and then invoke that for the
-      // kObuMetadata case.
+        if (!ParsePadding(&data[obu_start_position >> 3], obu_size)) {
+          LIBGAV1_DLOG(ERROR, "Failed to parse Padding OBU.");
+          return kStatusBitstreamError;
+        }
+        break;
       case kObuMetadata:
-        bit_reader_->SkipBytes(obu_size);
-        obu_skipped = true;
+        if (!ParseMetadata(&data[obu_start_position >> 3], obu_size)) {
+          LIBGAV1_DLOG(ERROR, "Failed to parse Metadata OBU.");
+          return kStatusBitstreamError;
+        }
         break;
       default:
         // Skip reserved OBUs. Section 6.2.2: Reserved units are for future use
@@ -2459,14 +2852,14 @@
             "Parsed OBU size (%zu bits) is greater than expected OBU size "
             "(%zu bytes) obu_type: %d.",
             parsed_obu_size_in_bits, obu_size, obu_type);
-        return false;
+        return kStatusBitstreamError;
       }
       if (!bit_reader_->VerifyAndSkipTrailingBits(obu_size * 8 -
                                                   parsed_obu_size_in_bits)) {
         LIBGAV1_DLOG(ERROR,
                      "Error when verifying trailing bits for obu type: %d",
                      obu_type);
-        return false;
+        return kStatusBitstreamError;
       }
     }
     const size_t bytes_consumed = bit_reader_->byte_offset();
@@ -2477,18 +2870,19 @@
                    "OBU size (%zu) and consumed size (%zu) does not match for "
                    "obu_type: %d.",
                    obu_size, consumed_obu_size, obu_type);
-      return false;
+      return kStatusBitstreamError;
     }
     data += bytes_consumed;
     size -= bytes_consumed;
   }
   if (!parsed_one_full_frame && seen_frame_header) {
     LIBGAV1_DLOG(ERROR, "The last tile group in the frame was not received.");
-    return false;
+    return kStatusBitstreamError;
   }
   data_ = data;
   size_ = size;
-  return true;
+  *current_frame = std::move(current_frame_);
+  return kStatusOk;
 }
 
 }  // namespace libgav1
diff --git a/libgav1/src/obu_parser.h b/libgav1/src/obu_parser.h
index 1671b13..22a2396 100644
--- a/libgav1/src/obu_parser.h
+++ b/libgav1/src/obu_parser.h
@@ -23,8 +23,11 @@
 #include <memory>
 #include <type_traits>
 
-#include "src/decoder_buffer.h"
+#include "src/buffer_pool.h"
+#include "src/decoder_state.h"
 #include "src/dsp/common.h"
+#include "src/gav1/decoder_buffer.h"
+#include "src/gav1/status_code.h"
 #include "src/quantizer.h"
 #include "src/utils/common.h"
 #include "src/utils/compiler_attributes.h"
@@ -35,17 +38,12 @@
 
 namespace libgav1 {
 
-struct DecoderState;
-
 // structs and enums related to Open Bitstream Units (OBU).
 
 enum {
   kMinimumMajorBitstreamLevel = 2,
-  kMaxOperatingPoints = 32,
   kSelectScreenContentTools = 2,
   kSelectIntegerMv = 2,
-  kLoopFilterMaxModeDeltas = 2,
-  kMaxCdefStrengths = 8,
   kLoopRestorationTileSizeMax = 256,
   kGlobalMotionAlphaBits = 12,
   kGlobalMotionTranslationBits = 12,
@@ -55,14 +53,16 @@
   kGlobalMotionTranslationOnlyPrecisionBits = 3,
   kMaxTileWidth = 4096,
   kMaxTileArea = 4096 * 2304,
-  kMaxTileColumns = 64,
-  kMaxTileRows = 64,
-  kPrimaryReferenceNone = 7
+  kPrimaryReferenceNone = 7,
+  // A special value of the scalability_mode_idc syntax element that indicates
+  // the picture prediction structure is specified in scalability_structure().
+  kScalabilitySS = 14
 };  // anonymous enum
 
 struct ObuHeader {
   ObuType type;
   bool has_extension;
+  bool has_size_field;
   int8_t temporal_id;
   int8_t spatial_id;
 };
@@ -87,81 +87,17 @@
   uint8_t minor;  // Range: 0-3.
 };
 
-enum ColorPrimaries : uint8_t {
-  // 0 is reserved.
-  kColorPrimaryBt709 = 1,
-  kColorPrimaryUnspecified,
-  // 3 is reserved.
-  kColorPrimaryBt470M = 4,
-  kColorPrimaryBt470Bg,
-  kColorPrimaryBt601,
-  kColorPrimarySmpte240,
-  kColorPrimaryGenericFilm,
-  kColorPrimaryBt2020,
-  kColorPrimaryXyz,
-  kColorPrimarySmpte431,
-  kColorPrimarySmpte432,
-  // 13-21 are reserved.
-  kColorPrimaryEbu3213 = 22,
-  // 23-254 are reserved.
-  kMaxColorPrimaries = 255
-};
-
-enum TransferCharacteristics : uint8_t {
-  // 0 is reserved.
-  kTransferCharacteristicBt709 = 1,
-  kTransferCharacteristicUnspecified,
-  // 3 is reserved.
-  kTransferCharacteristicBt470M = 4,
-  kTransferCharacteristicBt470Bg,
-  kTransferCharacteristicBt601,
-  kTransferCharacteristicSmpte240,
-  kTransferCharacteristicLinear,
-  kTransferCharacteristicLog100,
-  kTransferCharacteristicLog100Sqrt10,
-  kTransferCharacteristicIec61966,
-  kTransferCharacteristicBt1361,
-  kTransferCharacteristicSrgb,
-  kTransferCharacteristicBt2020TenBit,
-  kTransferCharacteristicBt2020TwelveBit,
-  kTransferCharacteristicSmpte2084,
-  kTransferCharacteristicSmpte428,
-  kTransferCharacteristicHlg,
-  // 19-254 are reserved.
-  kMaxTransferCharacteristics = 255
-};
-
-enum MatrixCoefficients : uint8_t {
-  kMatrixCoefficientIdentity,
-  kMatrixCoefficientBt709,
-  kMatrixCoefficientUnspecified,
-  // 3 is reserved.
-  kMatrixCoefficientFcc = 4,
-  kMatrixCoefficientBt470BG,
-  kMatrixCoefficientBt601,
-  kMatrixCoefficientSmpte240,
-  kMatrixCoefficientSmpteYcgco,
-  kMatrixCoefficientBt2020Ncl,
-  kMatrixCoefficientBt2020Cl,
-  kMatrixCoefficientSmpte2085,
-  kMatrixCoefficientChromatNcl,
-  kMatrixCoefficientChromatCl,
-  kMatrixCoefficientIctcp,
-  // 15-254 are reserved.
-  kMaxMatrixCoefficients = 255
-};
-
 struct ColorConfig {
   int8_t bitdepth;
   bool is_monochrome;
-  ColorPrimaries color_primaries;
+  ColorPrimary color_primary;
   TransferCharacteristics transfer_characteristics;
   MatrixCoefficients matrix_coefficients;
   // A binary value (0 or 1) that is associated with the VideoFullRangeFlag
   // variable specified in ISO/IEC 23091-4/ITUT H.273.
   // * 0: the studio swing representation.
   // * 1: the full swing representation.
-  int8_t color_range;
+  ColorRange color_range;
   int8_t subsampling_x;
   int8_t subsampling_y;
   ChromaSamplePosition chroma_sample_position;
@@ -227,6 +163,9 @@
   // If enable_order_hint is true, order_hint_bits is in the range [1, 8].
   // If enable_order_hint is false, order_hint_bits is 0.
   int8_t order_hint_bits;
+  // order_hint_shift_bits equals (32 - order_hint_bits) % 32.
+  // This is used frequently in GetRelativeDistance().
+  uint8_t order_hint_shift_bits;
   bool enable_jnt_comp;
   bool enable_ref_frame_mvs;
   bool choose_screen_content_tools;
@@ -251,7 +190,6 @@
   // call.
   OperatingParameters operating_parameters;
 };
-
 // Verify it is safe to use offsetof with ObuSequenceHeader and to use memcmp
 // to compare two ObuSequenceHeader objects.
 static_assert(std::is_standard_layout<ObuSequenceHeader>::value, "");
@@ -266,189 +204,20 @@
                       sizeof(OperatingParameters),
               "");
 
-// Loop filter parameters:
-//
-// If level[0] and level[1] are both equal to 0, the loop filter process is
-// not invoked.
-//
-// |sharpness| and |delta_enabled| are only used by the loop filter process.
-//
-// The |ref_deltas| and |mode_deltas| arrays are used not only by the loop
-// filter process but also by the reference frame update and loading
-// processes. The loop filter process uses |ref_deltas| and |mode_deltas| only
-// when |delta_enabled| is true.
-struct LoopFilter {
-  // Contains loop filter strength values in the range of [0, 63].
-  std::array<int8_t, kFrameLfCount> level;
-  // Indicates the sharpness level in the range of [0, 7].
-  int8_t sharpness;
-  // Whether the filter level depends on the mode and reference frame used to
-  // predict a block.
-  bool delta_enabled;
-  // Contains the adjustment needed for the filter level based on the chosen
-  // reference frame, in the range of [-64, 63].
-  std::array<int8_t, kNumReferenceFrameTypes> ref_deltas;
-  // Contains the adjustment needed for the filter level based on the chosen
-  // mode, in the range of [-64, 63].
-  std::array<int8_t, kLoopFilterMaxModeDeltas> mode_deltas;
-};
-
-struct Delta {
-  bool present;
-  uint8_t scale;
-  bool multi;
-};
-
-struct Cdef {
-  uint8_t damping;
-  uint8_t bits;
-  uint8_t y_primary_strength[kMaxCdefStrengths];
-  uint8_t y_secondary_strength[kMaxCdefStrengths];
-  uint8_t uv_primary_strength[kMaxCdefStrengths];
-  uint8_t uv_secondary_strength[kMaxCdefStrengths];
-};
-
-enum GlobalMotionTransformationType : uint8_t {
-  kGlobalMotionTransformationTypeIdentity,
-  kGlobalMotionTransformationTypeTranslation,
-  kGlobalMotionTransformationTypeRotZoom,
-  kGlobalMotionTransformationTypeAffine,
-  kNumGlobalMotionTransformationTypes
-};
-
-// Global motion and warped motion parameters. See the paper for more info:
-// S. Parker, Y. Chen, D. Barker, P. de Rivaz, D. Mukherjee, "Global and locally
-// adaptive warped motion compensation in video compression", Proc. IEEE
-// International Conference on Image Processing (ICIP), pp. 275-279, Sep. 2017.
-struct GlobalMotion {
-  GlobalMotionTransformationType type;
-  int32_t params[6];
-
-  // Represent two shearing operations. Computed from |params| by SetupShear().
-  //
-  // The least significant six (= kWarpParamRoundingBits) bits are all zeros.
-  // (This means alpha, beta, gamma, and delta could be represented by a 10-bit
-  // signed integer.) The minimum value is INT16_MIN (= -32768) and the maximum
-  // value is 32704 = 0x7fc0, the largest int16_t value whose least significant
-  // six bits are all zeros.
-  //
-  // Valid warp parameters (as validated by SetupShear()) have smaller ranges.
-  // Their absolute values are less than 2^14 (= 16384). (This follows from
-  // the warpValid check at the end of Section 7.11.3.6.)
-  //
-  // NOTE: Section 7.11.3.6 of the spec allows a maximum value of 32768, which
-  // is outside the range of int16_t. When cast to int16_t, 32768 becomes
-  // -32768. This potential int16_t overflow does not matter because either
-  // 32768 or -32768 causes SetupShear() to return false,
-  int16_t alpha;
-  int16_t beta;
-  int16_t gamma;
-  int16_t delta;
-};
-
-struct TileInfo {
-  bool uniform_spacing;
-  int sb_rows;
-  int sb_columns;
-  int tile_count;
-  int tile_columns_log2;
-  int tile_columns;
-  int tile_column_start[kMaxTileColumns + 1];
-  int tile_rows_log2;
-  int tile_rows;
-  int tile_row_start[kMaxTileRows + 1];
-  int16_t context_update_id;
-  uint8_t tile_size_bytes;
-};
-
-struct ObuFrameHeader {
-  uint16_t display_frame_id;
-  uint16_t current_frame_id;
-  int64_t frame_offset;
-  uint16_t expected_frame_id[kNumInterReferenceFrameTypes];
-  int32_t width;
-  int32_t height;
-  int32_t columns4x4;
-  int32_t rows4x4;
-  // The render size (render_width and render_height) is a hint to the
-  // application about the desired display size. It has no effect on the
-  // decoding process.
-  int32_t render_width;
-  int32_t render_height;
-  int32_t upscaled_width;
-  LoopRestoration loop_restoration;
-  uint32_t buffer_removal_time[kMaxOperatingPoints];
-  uint32_t frame_presentation_time;
-  // Note: global_motion[0] (for kReferenceFrameIntra) is not used.
-  std::array<GlobalMotion, kNumReferenceFrameTypes> global_motion;
-  TileInfo tile_info;
-  QuantizerParameters quantizer;
-  Segmentation segmentation;
-  bool show_existing_frame;
-  // frame_to_show is in the range [0, 7]. Only used if show_existing_frame is
-  // true.
-  int8_t frame_to_show;
-  FrameType frame_type;
-  bool show_frame;
-  bool showable_frame;
-  bool error_resilient_mode;
-  bool enable_cdf_update;
-  bool frame_size_override_flag;
-  // The order_hint syntax element in the uncompressed header. If
-  // show_existing_frame is false, the OrderHint variable in the spec is equal
-  // to this field, and so this field can be used in place of OrderHint when
-  // show_existing_frame is known to be false, such as during tile decoding.
-  uint8_t order_hint;
-  int8_t primary_reference_frame;
-  bool render_and_frame_size_different;
-  uint8_t superres_scale_denominator;
-  bool allow_screen_content_tools;
-  bool allow_intrabc;
-  bool frame_refs_short_signaling;
-  // A bitmask that specifies which reference frame slots will be updated with
-  // the current frame after it is decoded.
-  uint8_t refresh_frame_flags;
-  static_assert(sizeof(ObuFrameHeader::refresh_frame_flags) * 8 ==
-                    kNumReferenceFrameTypes,
-                "");
-  bool found_reference;
-  int8_t force_integer_mv;
-  bool allow_high_precision_mv;
-  InterpolationFilter interpolation_filter;
-  bool is_motion_mode_switchable;
-  bool use_ref_frame_mvs;
-  bool enable_frame_end_update_cdf;
-  // True if all segments are losslessly encoded at the coded resolution.
-  bool coded_lossless;
-  // True if all segments are losslessly encoded at the upscaled resolution.
-  bool upscaled_lossless;
-  TxMode tx_mode;
-  // True means that the mode info for inter blocks contains the syntax
-  // element comp_mode that indicates whether to use single or compound
-  // prediction. False means that all inter blocks will use single prediction.
-  bool reference_mode_select;
-  // The frames to use for compound prediction when skip_mode is true.
-  ReferenceFrameType skip_mode_frame[2];
-  bool skip_mode_present;
-  bool reduced_tx_set;
-  bool allow_warped_motion;
-  Delta delta_q;
-  Delta delta_lf;
-  // A valid value of reference_frame_index[i] is in the range [0, 7]. -1
-  // indicates an invalid value.
-  int8_t reference_frame_index[kNumInterReferenceFrameTypes];
-  // The ref_order_hint[ i ] syntax element in the uncompressed header.
-  // Specifies the expected output order hint for each reference frame.
-  uint8_t reference_order_hint[kNumReferenceFrameTypes];
-  LoopFilter loop_filter;
-  Cdef cdef;
-  FilmGrainParams film_grain_params;
+struct TileBuffer {
+  const uint8_t* data;
+  size_t size;
 };
 
 enum MetadataType : uint8_t {
   // 0 is reserved for AOM use.
   kMetadataTypeHdrContentLightLevel = 1,
-  kMetadataTypeHdrMasteringDisplayColorVolume
+  kMetadataTypeHdrMasteringDisplayColorVolume = 2,
+  kMetadataTypeScalability = 3,
+  kMetadataTypeItutT35 = 4,
+  kMetadataTypeTimecode = 5,
+  // 6-31 are unregistered user private.
+  // 32 and greater are reserved for AOM use.
 };
 
 struct ObuMetadata {
@@ -462,28 +231,28 @@
   uint16_t white_point_chromaticity_y;
   uint32_t luminance_max;
   uint32_t luminance_min;
-};
-
-struct ObuTileGroup {
-  int start;
-  int end;
-  // Pointer to the start of Tile Group data.
-  const uint8_t* data;
-  // Size of the Tile Group data (excluding the Tile Group headers).
-  size_t data_size;
-  // Offset of the start of Tile Group data relative to |ObuParser->data_|.
-  size_t data_offset;
+  // ITU-T T.35.
+  uint8_t itu_t_t35_country_code;
+  uint8_t itu_t_t35_country_code_extension_byte;  // Valid if
+                                                  // itu_t_t35_country_code is
+                                                  // 0xFF.
+  std::unique_ptr<uint8_t[]> itu_t_t35_payload_bytes;
+  size_t itu_t_t35_payload_size;
 };
 
 class ObuParser : public Allocable {
  public:
-  ObuParser(const uint8_t* const data, size_t size,
-            DecoderState* const decoder_state)
-      : data_(data), size_(size), decoder_state_(*decoder_state) {}
+  ObuParser(const uint8_t* const data, size_t size, int operating_point,
+            BufferPool* const buffer_pool, DecoderState* const decoder_state)
+      : data_(data),
+        size_(size),
+        operating_point_(operating_point),
+        buffer_pool_(buffer_pool),
+        decoder_state_(*decoder_state) {}
 
-  // Copyable and Movable.
-  ObuParser(const ObuParser& rhs) = default;
-  ObuParser& operator=(const ObuParser& rhs) = default;
+  // Not copyable or movable.
+  ObuParser(const ObuParser& rhs) = delete;
+  ObuParser& operator=(const ObuParser& rhs) = delete;
 
   // Returns true if there is more data that needs to be parsed.
   bool HasData() const;
@@ -496,16 +265,17 @@
   //   * A kFrameHeader with show_existing_frame = true is seen.
   //
   // If the parsing is successful, relevant fields will be populated. The fields
-  // are valid only if the return value is true. Returns true on success, false
-  // otherwise.
-  bool ParseOneFrame();
+  // are valid only if the return value is kStatusOk. Returns kStatusOk on
+  // success, an error status otherwise. On success, |current_frame| will be
+  // populated with a valid frame buffer.
+  StatusCode ParseOneFrame(RefCountedBufferPtr* current_frame);
 
   // Getters. Only valid if ParseOneFrame() completes successfully.
   const Vector<ObuHeader>& obu_headers() const { return obu_headers_; }
   const ObuSequenceHeader& sequence_header() const { return sequence_header_; }
   const ObuFrameHeader& frame_header() const { return frame_header_; }
+  const Vector<TileBuffer>& tile_buffers() const { return tile_buffers_; }
   const ObuMetadata& metadata() const { return metadata_; }
-  const Vector<ObuTileGroup>& tile_groups() const { return tile_groups_; }
 
   // Setters.
   void set_sequence_header(const ObuSequenceHeader& sequence_header) {
@@ -513,6 +283,11 @@
     has_sequence_header_ = true;
   }
 
+  // Moves |tile_buffers_| into |tile_buffers|.
+  void MoveTileBuffer(Vector<TileBuffer>* tile_buffers) {
+    *tile_buffers = std::move(tile_buffers_);
+  }
+
  private:
   // Initializes the bit reader. This is a function of its own to make unit
   // testing of private functions simpler.
@@ -524,11 +299,11 @@
   bool ParseTimingInfo(ObuSequenceHeader* sequence_header);        // 5.5.3.
   bool ParseDecoderModelInfo(ObuSequenceHeader* sequence_header);  // 5.5.4.
   bool ParseOperatingParameters(ObuSequenceHeader* sequence_header,
-                                int index);  // 5.5.5.
+                                int index);          // 5.5.5.
   bool ParseSequenceHeader(bool seen_frame_header);  // 5.5.1.
-  bool ParseFrameParameters();               // 5.9.2, 5.9.7 and 5.9.10.
-  void MarkInvalidReferenceFrames();         // 5.9.4.
-  bool ParseFrameSizeAndRenderSize();        // 5.9.5 and 5.9.6.
+  bool ParseFrameParameters();                       // 5.9.2, 5.9.7 and 5.9.10.
+  void MarkInvalidReferenceFrames();                 // 5.9.4.
+  bool ParseFrameSizeAndRenderSize();                // 5.9.5 and 5.9.6.
   bool ParseSuperResParametersAndComputeImageSize();  // 5.9.8 and 5.9.9.
   // Checks the bitstream conformance requirement in Section 6.8.6.
   bool ValidateInterFrameSize() const;
@@ -573,27 +348,38 @@
   bool ParseFilmGrainParameters();     // 5.9.30.
   bool ParseTileInfoSyntax();          // 5.9.15.
   bool ParseFrameHeader();             // 5.9.
-  bool ParseMetadata(size_t size);     // 5.8.
-  // Validates the |start| and |end| fields of the current tile group. If
-  // valid, updates next_tile_group_start_ and returns true. Otherwise,
-  // returns false.
-  bool ValidateTileGroup();
-  bool SetTileDataOffset(size_t total_size, size_t tg_header_size,
-                         size_t bytes_consumed_so_far);
+  // |data| and |size| specify the payload data of the padding OBU.
+  // NOTE: Although the payload data is available in the bit_reader_ member,
+  // it is also passed to ParsePadding() as function parameters so that
+  // ParsePadding() can find the trailing bit of the OBU and skip over the
+  // payload data as an opaque chunk of data.
+  bool ParsePadding(const uint8_t* data, size_t size);  // 5.7.
+  bool ParseMetadataScalability();                      // 5.8.5 and 5.8.6.
+  bool ParseMetadataTimecode();                         // 5.8.7.
+  // |data| and |size| specify the payload data of the metadata OBU.
+  // NOTE: Although the payload data is available in the bit_reader_ member,
+  // it is also passed to ParseMetadata() as function parameters so that
+  // ParseMetadata() can find the trailing bit of the OBU and either extract
+  // or skip over the payload data as an opaque chunk of data.
+  bool ParseMetadata(const uint8_t* data, size_t size);  // 5.8.
+  // Adds and populates the TileBuffer for each tile in the tile group.
+  bool AddTileBuffers(int start, int end, size_t total_size,
+                      size_t tg_header_size, size_t bytes_consumed_so_far);
   bool ParseTileGroup(size_t size, size_t bytes_consumed_so_far);  // 5.11.1.
 
   // Parser elements.
   std::unique_ptr<RawBitReader> bit_reader_;
   const uint8_t* data_;
   size_t size_;
+  const int operating_point_;
 
   // OBU elements. Only valid if ParseOneFrame() completes successfully.
   Vector<ObuHeader> obu_headers_;
   ObuSequenceHeader sequence_header_ = {};
   ObuFrameHeader frame_header_ = {};
+  Vector<TileBuffer> tile_buffers_;
   ObuMetadata metadata_ = {};
-  Vector<ObuTileGroup> tile_groups_;
-  // The expected |start| value of the next ObuTileGroup.
+  // The expected starting tile number of the next Tile Group.
   int next_tile_group_start_ = 0;
   // If true, the sequence_header_ field is valid.
   bool has_sequence_header_ = false;
@@ -601,7 +387,14 @@
   // 0. Set to true when parsing a sequence header if OperatingPointIdc is 0.
   bool extension_disallowed_ = false;
 
+  BufferPool* const buffer_pool_;
   DecoderState& decoder_state_;
+  // Used by ParseOneFrame() to populate the current frame that is being
+  // decoded. The invariant maintained is that this variable will be nullptr at
+  // the beginning and at the end of each call to ParseOneFrame(). This ensures
+  // that the ObuParser is not holding on to any references to the current
+  // frame once the ParseOneFrame() call is complete.
+  RefCountedBufferPtr current_frame_;
 
   // For unit testing private functions.
   friend class ObuParserTest;
diff --git a/libgav1/src/post_filter.cc b/libgav1/src/post_filter.cc
deleted file mode 100644
index 945a906..0000000
--- a/libgav1/src/post_filter.cc
+++ /dev/null
@@ -1,1593 +0,0 @@
-// Copyright 2019 The libgav1 Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "src/post_filter.h"
-
-#include <algorithm>
-#include <atomic>
-#include <cassert>
-#include <cstddef>
-#include <cstdint>
-#include <cstring>
-#include <memory>
-
-#include "src/dsp/constants.h"
-#include "src/utils/array_2d.h"
-#include "src/utils/blocking_counter.h"
-#include "src/utils/constants.h"
-#include "src/utils/logging.h"
-#include "src/utils/memory.h"
-#include "src/utils/types.h"
-
-namespace libgav1 {
-namespace {
-
-constexpr uint8_t kCdefUvDirection[2][2][8] = {
-    {{0, 1, 2, 3, 4, 5, 6, 7}, {1, 2, 2, 2, 3, 4, 6, 0}},
-    {{7, 0, 2, 4, 5, 6, 6, 6}, {0, 1, 2, 3, 4, 5, 6, 7}}};
-
-template <typename Pixel>
-void ExtendFrame(uint8_t* const frame_start, const int width, const int height,
-                 ptrdiff_t stride, const int left, const int right,
-                 const int top, const int bottom) {
-  auto* const start = reinterpret_cast<Pixel*>(frame_start);
-  const Pixel* src = start;
-  Pixel* dst = start - left;
-  stride /= sizeof(Pixel);
-  // Copy to left and right borders.
-  for (int y = 0; y < height; ++y) {
-    Memset(dst, src[0], left);
-    Memset(dst + (left + width), src[width - 1], right);
-    src += stride;
-    dst += stride;
-  }
-  // Copy to top borders.
-  src = start - left;
-  dst = start - left - top * stride;
-  for (int y = 0; y < top; ++y) {
-    memcpy(dst, src, sizeof(Pixel) * stride);
-    dst += stride;
-  }
-  // Copy to bottom borders.
-  dst = start - left + height * stride;
-  src = dst - stride;
-  for (int y = 0; y < bottom; ++y) {
-    memcpy(dst, src, sizeof(Pixel) * stride);
-    dst += stride;
-  }
-}
-
-template <typename Pixel>
-void CopyPlane(const uint8_t* source, int source_stride, const int width,
-               const int height, uint8_t* dest, int dest_stride) {
-  auto* dst = reinterpret_cast<Pixel*>(dest);
-  const auto* src = reinterpret_cast<const Pixel*>(source);
-  source_stride /= sizeof(Pixel);
-  dest_stride /= sizeof(Pixel);
-  for (int y = 0; y < height; ++y) {
-    memcpy(dst, src, width * sizeof(Pixel));
-    src += source_stride;
-    dst += dest_stride;
-  }
-}
-
-template <int bitdepth, typename Pixel>
-void ComputeSuperRes(const uint8_t* source, uint32_t source_stride,
-                     const int upscaled_width, const int height,
-                     const int initial_subpixel_x, const int step,
-                     uint8_t* dest, uint32_t dest_stride) {
-  const auto* src = reinterpret_cast<const Pixel*>(source);
-  auto* dst = reinterpret_cast<Pixel*>(dest);
-  source_stride /= sizeof(Pixel);
-  dest_stride /= sizeof(Pixel);
-  src -= DivideBy2(kSuperResFilterTaps);
-  for (int y = 0; y < height; ++y) {
-    int subpixel_x = initial_subpixel_x;
-    for (int x = 0; x < upscaled_width; ++x) {
-      int sum = 0;
-      const Pixel* const src_x = &src[subpixel_x >> kSuperResScaleBits];
-      const int src_x_subpixel =
-          (subpixel_x & kSuperResScaleMask) >> kSuperResExtraBits;
-      for (int i = 0; i < kSuperResFilterTaps; ++i) {
-        sum += src_x[i] * kUpscaleFilter[src_x_subpixel][i];
-      }
-      dst[x] = Clip3(RightShiftWithRounding(sum, kFilterBits), 0,
-                     (1 << bitdepth) - 1);
-      subpixel_x += step;
-    }
-    src += source_stride;
-    dst += dest_stride;
-  }
-}
-
-}  // namespace
-
-// Static data member definitions.
-constexpr int PostFilter::kCdefLargeValue;
-
-bool PostFilter::ApplyFiltering() {
-  if (DoDeblock() && !ApplyDeblockFilter()) return false;
-  if (DoCdef() && !ApplyCdef()) return false;
-  if (DoSuperRes() && !ApplySuperRes()) return false;
-  if (DoRestoration() && !ApplyLoopRestoration()) return false;
-  // Extend frame boundary for inter frame convolution, referencing.
-  for (int plane = kPlaneY; plane < planes_; ++plane) {
-    const int8_t subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
-    const int8_t subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
-    const int plane_width =
-        RightShiftWithRounding(upscaled_width_, subsampling_x);
-    const int plane_height = RightShiftWithRounding(height_, subsampling_y);
-    assert(source_buffer_->left_border(plane) >= kMinLeftBorderPixels &&
-           source_buffer_->right_border(plane) >= kMinRightBorderPixels);
-    ExtendFrameBoundary(
-        source_buffer_->data(plane), plane_width, plane_height,
-        source_buffer_->stride(plane), source_buffer_->left_border(plane),
-        source_buffer_->right_border(plane), source_buffer_->top_border(plane),
-        source_buffer_->bottom_border(plane));
-  }
-  return true;
-}
-
-bool PostFilter::DoRestoration() const {
-  return DoRestoration(loop_restoration_, do_post_filter_mask_, planes_);
-}
-
-bool PostFilter::DoRestoration(const LoopRestoration& loop_restoration,
-                               uint8_t do_post_filter_mask, int num_planes) {
-  if ((do_post_filter_mask & 0x08) == 0) return false;
-  if (num_planes == kMaxPlanesMonochrome) {
-    return loop_restoration.type[kPlaneY] != kLoopRestorationTypeNone;
-  }
-  return loop_restoration.type[kPlaneY] != kLoopRestorationTypeNone ||
-         loop_restoration.type[kPlaneU] != kLoopRestorationTypeNone ||
-         loop_restoration.type[kPlaneV] != kLoopRestorationTypeNone;
-}
-
-void PostFilter::ExtendFrameBoundary(uint8_t* const frame_start,
-                                     const int width, const int height,
-                                     const ptrdiff_t stride, const int left,
-                                     const int right, const int top,
-                                     const int bottom) {
-  if (bitdepth_ == 8) {
-    ExtendFrame<uint8_t>(frame_start, width, height, stride, left, right, top,
-                         bottom);
-  } else {
-    ExtendFrame<uint16_t>(frame_start, width, height, stride, left, right, top,
-                          bottom);
-  }
-}
-
-void PostFilter::DeblockFilterWorker(const DeblockFilterJob* jobs, int num_jobs,
-                                     std::atomic<int>* job_counter,
-                                     DeblockFilter deblock_filter) {
-  int job_index;
-  while ((job_index = job_counter->fetch_add(1, std::memory_order_relaxed)) <
-         num_jobs) {
-    const DeblockFilterJob& job = jobs[job_index];
-    for (int column4x4 = 0, column_unit = 0;
-         column4x4 < frame_header_.columns4x4;
-         column4x4 += kNum4x4InLoopFilterMaskUnit, ++column_unit) {
-      const int unit_id = GetDeblockUnitId(job.row_unit, column_unit);
-      (this->*deblock_filter)(static_cast<Plane>(job.plane), job.row4x4,
-                              column4x4, unit_id);
-    }
-  }
-}
-
-bool PostFilter::ApplyDeblockFilterThreaded() {
-  const int jobs_per_plane = DivideBy16(frame_header_.rows4x4 + 15);
-  const int num_workers = thread_pool_->num_threads();
-  int planes[kMaxPlanes];
-  planes[0] = kPlaneY;
-  int num_planes = 1;
-  for (int plane = kPlaneU; plane < planes_; ++plane) {
-    if (frame_header_.loop_filter.level[plane + 1] != 0) {
-      planes[num_planes++] = plane;
-    }
-  }
-  const int num_jobs = num_planes * jobs_per_plane;
-  std::unique_ptr<DeblockFilterJob[]> jobs_unique_ptr(
-      new (std::nothrow) DeblockFilterJob[num_jobs]);
-  if (jobs_unique_ptr == nullptr) return false;
-  DeblockFilterJob* jobs = jobs_unique_ptr.get();
-  // The vertical filters are not dependent on each other. So simply schedule
-  // them for all possible rows.
-  //
-  // The horizontal filter for a row/column depends on the vertical filter being
-  // finished for the blocks to the top right and to the right. To work around
-  // this synchronization, we simply wait for the vertical filter to finish for
-  // all rows. Now, the horizontal filters can also be scheduled
-  // unconditionally similar to the vertical filters.
-  //
-  // The only synchronization involved is to know when the each directional
-  // filter is complete for the entire frame.
-  for (DeblockFilter deblock_filter : {&PostFilter::VerticalDeblockFilter,
-                                       &PostFilter::HorizontalDeblockFilter}) {
-    int job_index = 0;
-    for (int i = 0; i < num_planes; ++i) {
-      const int plane = planes[i];
-      for (int row4x4 = 0, row_unit = 0; row4x4 < frame_header_.rows4x4;
-           row4x4 += kNum4x4InLoopFilterMaskUnit, ++row_unit) {
-        assert(job_index < num_jobs);
-        DeblockFilterJob& job = jobs[job_index++];
-        job.plane = plane;
-        job.row4x4 = row4x4;
-        job.row_unit = row_unit;
-      }
-    }
-    assert(job_index == num_jobs);
-    std::atomic<int> job_counter(0);
-    BlockingCounter pending_workers(num_workers);
-    for (int i = 0; i < num_workers; ++i) {
-      thread_pool_->Schedule([this, jobs, num_jobs, &job_counter,
-                              deblock_filter, &pending_workers]() {
-        DeblockFilterWorker(jobs, num_jobs, &job_counter, deblock_filter);
-        pending_workers.Decrement();
-      });
-    }
-    // Run the jobs on the current thread.
-    DeblockFilterWorker(jobs, num_jobs, &job_counter, deblock_filter);
-    // Wait for the threadpool jobs to finish.
-    pending_workers.Wait();
-  }
-  return true;
-}
-
-bool PostFilter::ApplyDeblockFilter() {
-  InitDeblockFilterParams();
-
-  if (thread_pool_ != nullptr) {
-    return ApplyDeblockFilterThreaded();
-  }
-
-  for (int plane = kPlaneY; plane < planes_; ++plane) {
-    if (plane != kPlaneY && frame_header_.loop_filter.level[plane + 1] == 0) {
-      continue;
-    }
-
-    // Iterate through each 64x64 block and apply deblock filtering.
-    for (int row4x4 = 0, row_unit = 0; row4x4 < frame_header_.rows4x4;
-         row4x4 += kNum4x4InLoopFilterMaskUnit, ++row_unit) {
-      int column4x4;
-      int column_unit;
-      for (column4x4 = 0, column_unit = 0; column4x4 < frame_header_.columns4x4;
-           column4x4 += kNum4x4InLoopFilterMaskUnit, ++column_unit) {
-        // First apply vertical filtering
-        const int unit_id = GetDeblockUnitId(row_unit, column_unit);
-        VerticalDeblockFilter(static_cast<Plane>(plane), row4x4, column4x4,
-                              unit_id);
-
-        // Delay one superblock to apply horizontal filtering.
-        if (column4x4 != 0) {
-          HorizontalDeblockFilter(static_cast<Plane>(plane), row4x4,
-                                  column4x4 - kNum4x4InLoopFilterMaskUnit,
-                                  unit_id - 1);
-        }
-      }
-      // Horizontal filtering for the last 64x64 block.
-      const int unit_id = GetDeblockUnitId(row_unit, column_unit - 1);
-      HorizontalDeblockFilter(static_cast<Plane>(plane), row4x4,
-                              column4x4 - kNum4x4InLoopFilterMaskUnit, unit_id);
-    }
-  }
-  return true;
-}
-
-void PostFilter::ComputeDeblockFilterLevels(
-    const int8_t delta_lf[kFrameLfCount],
-    uint8_t deblock_filter_levels[kMaxSegments][kFrameLfCount]
-                                 [kNumReferenceFrameTypes][2]) const {
-  if (!DoDeblock()) return;
-  for (int segment_id = 0;
-       segment_id < (frame_header_.segmentation.enabled ? kMaxSegments : 1);
-       ++segment_id) {
-    int level_index = 0;
-    for (; level_index < 2; ++level_index) {
-      LoopFilterMask::ComputeDeblockFilterLevels(
-          frame_header_, segment_id, level_index, delta_lf,
-          deblock_filter_levels[segment_id][level_index]);
-    }
-    for (; level_index < kFrameLfCount; ++level_index) {
-      if (frame_header_.loop_filter.level[level_index] != 0) {
-        LoopFilterMask::ComputeDeblockFilterLevels(
-            frame_header_, segment_id, level_index, delta_lf,
-            deblock_filter_levels[segment_id][level_index]);
-      }
-    }
-  }
-}
-
-uint8_t* PostFilter::GetCdefBufferAndStride(
-    const int start_x, const int start_y, const int plane,
-    const int subsampling_x, const int subsampling_y,
-    const int window_buffer_plane_size, const int vertical_shift,
-    const int horizontal_shift, int* cdef_stride) {
-  if (!DoRestoration() && thread_pool_ != nullptr) {
-    // write output to threaded_window_buffer.
-    *cdef_stride = window_buffer_width_ * pixel_size_;
-    const int column_window = start_x % (window_buffer_width_ >> subsampling_x);
-    const int row_window = start_y % (window_buffer_height_ >> subsampling_y);
-    return threaded_window_buffer_ + plane * window_buffer_plane_size +
-           row_window * (*cdef_stride) + column_window * pixel_size_;
-  }
-  // write output to cdef_buffer_.
-  *cdef_stride = cdef_buffer_->stride(plane);
-  // In-place cdef is applied by writing the output to the top-left
-  // corner, if restoration is not present. In this case,
-  // cdef_buffer_ == source_buffer_.
-  const ptrdiff_t buffer_offset =
-      DoRestoration()
-          ? 0
-          : vertical_shift * (*cdef_stride) + horizontal_shift * pixel_size_;
-  return cdef_buffer_->data(plane) + start_y * (*cdef_stride) +
-         start_x * pixel_size_ + buffer_offset;
-}
-
-template <typename Pixel>
-void PostFilter::ApplyCdefForOneUnit(uint16_t* cdef_block, const int index,
-                                     const int block_width4x4,
-                                     const int block_height4x4,
-                                     const int row4x4_start,
-                                     const int column4x4_start) {
-  const int coeff_shift = bitdepth_ - 8;
-  const int step = kNum4x4BlocksWide[kBlock8x8];
-  const int horizontal_shift = -source_buffer_->alignment() / pixel_size_;
-  const int vertical_shift = -kCdefBorder;
-  const int window_buffer_plane_size =
-      window_buffer_width_ * window_buffer_height_ * pixel_size_;
-
-  if (index == -1) {
-    for (int plane = kPlaneY; plane < planes_; ++plane) {
-      const int subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
-      const int subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
-      const int start_x = MultiplyBy4(column4x4_start) >> subsampling_x;
-      const int start_y = MultiplyBy4(row4x4_start) >> subsampling_y;
-      int cdef_stride;
-      uint8_t* const cdef_buffer = GetCdefBufferAndStride(
-          start_x, start_y, plane, subsampling_x, subsampling_y,
-          window_buffer_plane_size, vertical_shift, horizontal_shift,
-          &cdef_stride);
-      const int src_stride = source_buffer_->stride(plane);
-      uint8_t* const src_buffer = source_buffer_->data(plane) +
-                                  start_y * src_stride + start_x * pixel_size_;
-      const int block_width = MultiplyBy4(block_width4x4) >> subsampling_x;
-      const int block_height = MultiplyBy4(block_height4x4) >> subsampling_y;
-      for (int y = 0; y < block_height; ++y) {
-        memcpy(cdef_buffer + y * cdef_stride, src_buffer + y * src_stride,
-               block_width * pixel_size_);
-      }
-    }
-    return;
-  }
-
-  PrepareCdefBlock<Pixel>(source_buffer_, planes_, subsampling_x_,
-                          subsampling_y_, frame_header_.width,
-                          frame_header_.height, block_width4x4, block_height4x4,
-                          row4x4_start, column4x4_start, cdef_block,
-                          kRestorationProcessingUnitSizeWithBorders);
-
-  for (int row4x4 = row4x4_start; row4x4 < row4x4_start + block_height4x4;
-       row4x4 += step) {
-    for (int column4x4 = column4x4_start;
-         column4x4 < column4x4_start + block_width4x4; column4x4 += step) {
-      const bool skip =
-          block_parameters_.Find(row4x4, column4x4) != nullptr &&
-          block_parameters_.Find(row4x4 + 1, column4x4) != nullptr &&
-          block_parameters_.Find(row4x4, column4x4 + 1) != nullptr &&
-          block_parameters_.Find(row4x4 + 1, column4x4 + 1) != nullptr &&
-          block_parameters_.Find(row4x4, column4x4)->skip &&
-          block_parameters_.Find(row4x4 + 1, column4x4)->skip &&
-          block_parameters_.Find(row4x4, column4x4 + 1)->skip &&
-          block_parameters_.Find(row4x4 + 1, column4x4 + 1)->skip;
-      int damping = frame_header_.cdef.damping + coeff_shift;
-      int direction_y;
-      int direction;
-      int variance;
-      uint8_t primary_strength;
-      uint8_t secondary_strength;
-
-      for (int plane = kPlaneY; plane < planes_; ++plane) {
-        const int subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
-        const int subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
-        const int start_x = MultiplyBy4(column4x4) >> subsampling_x;
-        const int start_y = MultiplyBy4(row4x4) >> subsampling_y;
-        const int block_width = 8 >> subsampling_x;
-        const int block_height = 8 >> subsampling_y;
-        int cdef_stride;
-        uint8_t* const cdef_buffer = GetCdefBufferAndStride(
-            start_x, start_y, plane, subsampling_x, subsampling_y,
-            window_buffer_plane_size, vertical_shift, horizontal_shift,
-            &cdef_stride);
-        const int src_stride = source_buffer_->stride(plane);
-        uint8_t* const src_buffer = source_buffer_->data(plane) +
-                                    start_y * src_stride +
-                                    start_x * pixel_size_;
-
-        if (skip) {  // No cdef filtering.
-          for (int y = 0; y < block_height; ++y) {
-            memcpy(cdef_buffer + y * cdef_stride, src_buffer + y * src_stride,
-                   block_width * pixel_size_);
-          }
-          continue;
-        }
-
-        if (plane == kPlaneY) {
-          dsp_.cdef_direction(src_buffer, src_stride, &direction_y, &variance);
-          primary_strength = frame_header_.cdef.y_primary_strength[index]
-                             << coeff_shift;
-          secondary_strength = frame_header_.cdef.y_secondary_strength[index]
-                               << coeff_shift;
-          direction = (primary_strength == 0) ? 0 : direction_y;
-          const int variance_strength =
-              ((variance >> 6) != 0) ? std::min(FloorLog2(variance >> 6), 12)
-                                     : 0;
-          primary_strength =
-              (variance != 0)
-                  ? (primary_strength * (4 + variance_strength) + 8) >> 4
-                  : 0;
-        } else {
-          primary_strength = frame_header_.cdef.uv_primary_strength[index]
-                             << coeff_shift;
-          secondary_strength = frame_header_.cdef.uv_secondary_strength[index]
-                               << coeff_shift;
-          direction = (primary_strength == 0)
-                          ? 0
-                          : kCdefUvDirection[subsampling_x_][subsampling_y_]
-                                            [direction_y];
-          damping = frame_header_.cdef.damping + coeff_shift - 1;
-        }
-
-        if ((primary_strength | secondary_strength) == 0) {
-          for (int y = 0; y < block_height; ++y) {
-            memcpy(cdef_buffer + y * cdef_stride, src_buffer + y * src_stride,
-                   block_width * pixel_size_);
-          }
-          continue;
-        }
-        uint16_t* cdef_src =
-            cdef_block + plane * kRestorationProcessingUnitSizeWithBorders *
-                             kRestorationProcessingUnitSizeWithBorders;
-        cdef_src += kCdefBorder * kRestorationProcessingUnitSizeWithBorders +
-                    kCdefBorder;
-        cdef_src += (MultiplyBy4(row4x4 - row4x4_start) >> subsampling_y) *
-                        kRestorationProcessingUnitSizeWithBorders +
-                    (MultiplyBy4(column4x4 - column4x4_start) >> subsampling_x);
-        dsp_.cdef_filter(cdef_src, kRestorationProcessingUnitSizeWithBorders,
-                         frame_header_.rows4x4, frame_header_.columns4x4,
-                         start_x, start_y, subsampling_x, subsampling_y,
-                         primary_strength, secondary_strength, damping,
-                         direction, cdef_buffer, cdef_stride);
-      }
-    }
-  }
-}
-
-template <typename Pixel>
-void PostFilter::ApplyCdefForOneRowInWindow(const int row4x4,
-                                            const int column4x4_start) {
-  const int step_64x64 = 16;  // = 64/4.
-  uint16_t cdef_block[kRestorationProcessingUnitSizeWithBorders *
-                      kRestorationProcessingUnitSizeWithBorders * 3];
-
-  for (int column4x4_64x64 = 0;
-       column4x4_64x64 < std::min(DivideBy4(window_buffer_width_),
-                                  frame_header_.columns4x4 - column4x4_start);
-       column4x4_64x64 += step_64x64) {
-    const int column4x4 = column4x4_start + column4x4_64x64;
-    const int index = cdef_index_[DivideBy16(row4x4)][DivideBy16(column4x4)];
-    const int block_width4x4 =
-        std::min(step_64x64, frame_header_.columns4x4 - column4x4);
-    const int block_height4x4 =
-        std::min(step_64x64, frame_header_.rows4x4 - row4x4);
-
-    ApplyCdefForOneUnit<Pixel>(cdef_block, index, block_width4x4,
-                               block_height4x4, row4x4, column4x4);
-  }
-}
-
-// Each thread processes one row inside the window.
-// Y, U, V planes are processed together inside one thread.
-template <typename Pixel>
-bool PostFilter::ApplyCdefThreaded() {
-  assert((window_buffer_height_ & 63) == 0);
-  const int num_workers = thread_pool_->num_threads();
-  const int horizontal_shift = -source_buffer_->alignment() / pixel_size_;
-  const int vertical_shift = -kCdefBorder;
-  const int window_buffer_plane_size =
-      window_buffer_width_ * window_buffer_height_ * pixel_size_;
-  const int window_buffer_height4x4 = DivideBy4(window_buffer_height_);
-  const int step_64x64 = 16;  // = 64/4.
-  for (int row4x4 = 0; row4x4 < frame_header_.rows4x4;
-       row4x4 += window_buffer_height4x4) {
-    const int actual_window_height4x4 =
-        std::min(window_buffer_height4x4, frame_header_.rows4x4 - row4x4);
-    const int vertical_units_per_window =
-        DivideBy16(actual_window_height4x4 + 15);
-    for (int column4x4 = 0; column4x4 < frame_header_.columns4x4;
-         column4x4 += DivideBy4(window_buffer_width_)) {
-      const int jobs_for_threadpool =
-          vertical_units_per_window * num_workers / (num_workers + 1);
-      BlockingCounter pending_jobs(jobs_for_threadpool);
-      int job_count = 0;
-      for (int row64x64 = 0; row64x64 < actual_window_height4x4;
-           row64x64 += step_64x64) {
-        if (job_count < jobs_for_threadpool) {
-          thread_pool_->Schedule(
-              [this, row4x4, column4x4, row64x64, &pending_jobs]() {
-                ApplyCdefForOneRowInWindow<Pixel>(row4x4 + row64x64, column4x4);
-                pending_jobs.Decrement();
-              });
-        } else {
-          ApplyCdefForOneRowInWindow<Pixel>(row4x4 + row64x64, column4x4);
-        }
-        ++job_count;
-      }
-      pending_jobs.Wait();
-      if (DoRestoration()) continue;
-
-      // Copy |threaded_window_buffer_| to cdef_buffer_ (== source_buffer_).
-      assert(cdef_buffer_ == source_buffer_);
-      for (int plane = kPlaneY; plane < planes_; ++plane) {
-        const int cdef_stride = cdef_buffer_->stride(plane);
-        const ptrdiff_t buffer_offset =
-            vertical_shift * cdef_stride + horizontal_shift * pixel_size_;
-        const int8_t subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
-        const int8_t subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
-        const int plane_row = MultiplyBy4(row4x4) >> subsampling_y;
-        const int plane_column = MultiplyBy4(column4x4) >> subsampling_x;
-        int copy_width = std::min(frame_header_.columns4x4 - column4x4,
-                                  DivideBy4(window_buffer_width_));
-        copy_width = MultiplyBy4(copy_width) >> subsampling_x;
-        int copy_height =
-            std::min(frame_header_.rows4x4 - row4x4, window_buffer_height4x4);
-        copy_height = MultiplyBy4(copy_height) >> subsampling_y;
-        CopyPlane<Pixel>(
-            threaded_window_buffer_ + plane * window_buffer_plane_size,
-            window_buffer_width_ * pixel_size_, copy_width, copy_height,
-            cdef_buffer_->data(plane) + plane_row * cdef_stride +
-                plane_column * pixel_size_ + buffer_offset,
-            cdef_stride);
-      }
-    }
-  }
-  if (!DoRestoration()) {
-    for (int plane = kPlaneY; plane < planes_; ++plane) {
-      if (!cdef_buffer_->ShiftBuffer(plane, horizontal_shift, vertical_shift)) {
-        LIBGAV1_DLOG(ERROR,
-                     "Error shifting frame buffer head pointer at plane: %d",
-                     plane);
-        return false;
-      }
-    }
-  }
-
-  return true;
-}
-
-bool PostFilter::ApplyCdef() {
-  if (!DoRestoration()) {
-    cdef_buffer_ = source_buffer_;
-  } else {
-    if (!cdef_filtered_buffer_.Realloc(
-            bitdepth_, planes_ == kMaxPlanesMonochrome, upscaled_width_,
-            height_, subsampling_x_, subsampling_y_, kBorderPixels,
-            /*byte_alignment=*/0, nullptr, nullptr, nullptr)) {
-      return false;
-    }
-    cdef_buffer_ = &cdef_filtered_buffer_;
-  }
-
-  if (thread_pool_ != nullptr) {
-#if LIBGAV1_MAX_BITDEPTH >= 10
-    if (bitdepth_ >= 10) {
-      return ApplyCdefThreaded<uint16_t>();
-    }
-#endif
-    return ApplyCdefThreaded<uint8_t>();
-  }
-
-  const int step_64x64 = 16;  // = 64/4.
-  // Apply cdef on each 8x8 Y block and
-  // (8 >> subsampling_x)x(8 >> subsampling_y) UV block.
-  for (int row4x4 = 0; row4x4 < frame_header_.rows4x4; row4x4 += step_64x64) {
-    for (int column4x4 = 0; column4x4 < frame_header_.columns4x4;
-         column4x4 += step_64x64) {
-      const int index = cdef_index_[DivideBy16(row4x4)][DivideBy16(column4x4)];
-      const int block_width4x4 =
-          std::min(step_64x64, frame_header_.columns4x4 - column4x4);
-      const int block_height4x4 =
-          std::min(step_64x64, frame_header_.rows4x4 - row4x4);
-
-#if LIBGAV1_MAX_BITDEPTH >= 10
-      if (bitdepth_ >= 10) {
-        ApplyCdefForOneUnit<uint16_t>(cdef_block_, index, block_width4x4,
-                                      block_height4x4, row4x4, column4x4);
-        continue;
-      }
-#endif  // LIBGAV1_MAX_BITDEPTH >= 10
-      ApplyCdefForOneUnit<uint8_t>(cdef_block_, index, block_width4x4,
-                                   block_height4x4, row4x4, column4x4);
-    }
-  }
-  if (!DoRestoration()) {
-    const int horizontal_shift = -source_buffer_->alignment() / pixel_size_;
-    const int vertical_shift = -kCdefBorder;
-    for (int plane = kPlaneY; plane < planes_; ++plane) {
-      if (!source_buffer_->ShiftBuffer(plane, horizontal_shift,
-                                       vertical_shift)) {
-        LIBGAV1_DLOG(ERROR,
-                     "Error shifting frame buffer head pointer at plane: %d",
-                     plane);
-        return false;
-      }
-    }
-  }
-  return true;
-}
-
-void PostFilter::FrameSuperRes(YuvBuffer* const input_buffer) {
-  // Copy input_buffer to super_res_buffer_.
-  for (int plane = kPlaneY; plane < planes_; ++plane) {
-    const int8_t subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
-    const int8_t subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
-    const int border_height = kBorderPixels >> subsampling_y;
-    const int border_width = kBorderPixels >> subsampling_x;
-    const int plane_width =
-        MultiplyBy4(frame_header_.columns4x4) >> subsampling_x;
-    const int plane_height =
-        MultiplyBy4(frame_header_.rows4x4) >> subsampling_y;
-    if (bitdepth_ == 8) {
-      CopyPlane<uint8_t>(input_buffer->data(plane), input_buffer->stride(plane),
-                         plane_width, plane_height,
-                         super_res_buffer_.data(plane),
-                         super_res_buffer_.stride(plane));
-    } else {
-      CopyPlane<uint16_t>(input_buffer->data(plane),
-                          input_buffer->stride(plane), plane_width,
-                          plane_height, super_res_buffer_.data(plane),
-                          super_res_buffer_.stride(plane));
-    }
-    ExtendFrameBoundary(super_res_buffer_.data(plane), plane_width,
-                        plane_height, super_res_buffer_.stride(plane),
-                        border_width, border_width, border_height,
-                        border_height);
-  }
-
-  // Upscale filter and write to frame.
-  for (int plane = kPlaneY; plane < planes_; ++plane) {
-    const int8_t subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
-    const int8_t subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
-    const int downscaled_width = RightShiftWithRounding(width_, subsampling_x);
-    const int upscaled_width =
-        RightShiftWithRounding(upscaled_width_, subsampling_x);
-    const int plane_height = RightShiftWithRounding(height_, subsampling_y);
-    const int superres_width = downscaled_width << kSuperResScaleBits;
-    const int step = (superres_width + upscaled_width / 2) / upscaled_width;
-    const int error = step * upscaled_width - superres_width;
-    int initial_subpixel_x =
-        (-((upscaled_width - downscaled_width) << (kSuperResScaleBits - 1)) +
-         DivideBy2(upscaled_width)) /
-            upscaled_width +
-        (1 << (kSuperResExtraBits - 1)) - error / 2;
-    initial_subpixel_x &= kSuperResScaleMask;
-    if (bitdepth_ == 8) {
-      ComputeSuperRes<8, uint8_t>(
-          super_res_buffer_.data(plane), super_res_buffer_.stride(plane),
-          upscaled_width, plane_height, initial_subpixel_x, step,
-          input_buffer->data(plane), input_buffer->stride(plane));
-    } else {
-      ComputeSuperRes<10, uint16_t>(
-          super_res_buffer_.data(plane), super_res_buffer_.stride(plane),
-          upscaled_width, plane_height, initial_subpixel_x, step,
-          input_buffer->data(plane), input_buffer->stride(plane));
-    }
-  }
-  // Extend original frame, copy to borders.
-  for (int plane = kPlaneY; plane < planes_; ++plane) {
-    const int8_t subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
-    uint8_t* const frame_start = input_buffer->data(plane);
-    const int plane_width =
-        RightShiftWithRounding(upscaled_width_, subsampling_x);
-    ExtendFrameBoundary(
-        frame_start, plane_width, input_buffer->displayed_height(plane),
-        input_buffer->stride(plane), input_buffer->left_border(plane),
-        input_buffer->right_border(plane), input_buffer->top_border(plane),
-        input_buffer->bottom_border(plane));
-  }
-}
-
-bool PostFilter::ApplySuperRes() {
-  if (!super_res_buffer_.Realloc(bitdepth_, planes_ == kMaxPlanesMonochrome,
-                                 MultiplyBy4(frame_header_.columns4x4),
-                                 MultiplyBy4(frame_header_.rows4x4),
-                                 subsampling_x_, subsampling_y_, kBorderPixels,
-                                 /*byte_alignment=*/0, nullptr, nullptr,
-                                 nullptr)) {
-    return false;
-  }
-  // cdef_buffer_ points to the buffer after cdef process (regardless whether
-  // cdef filtering is actually applied).
-  // source_buffer_ points to the deblocked buffer.
-  if (DoCdef()) {
-    // If loop restoration is present, it requires both deblocked buffer and
-    // cdef filtered buffer. Otherwise, only cdef filtered buffer is required.
-    FrameSuperRes(cdef_buffer_);
-    if (DoRestoration()) FrameSuperRes(source_buffer_);
-  } else {
-    FrameSuperRes(source_buffer_);
-  }
-  return true;
-}
-
-template <typename Pixel>
-void PostFilter::ApplyLoopRestorationForOneRowInWindow(
-    uint8_t* const cdef_buffer, const ptrdiff_t cdef_buffer_stride,
-    uint8_t* const deblock_buffer, const ptrdiff_t deblock_buffer_stride,
-    const Plane plane, const int plane_height, const int plane_width,
-    const int x, const int y, const int row, const int unit_row,
-    const int current_process_unit_height, const int process_unit_width,
-    const int window_width, const int plane_unit_size,
-    const int num_horizontal_units) {
-  for (int column = 0; column < window_width; column += process_unit_width) {
-    const int unit_x = x + column;
-    const int unit_column =
-        std::min(unit_x / plane_unit_size, num_horizontal_units - 1);
-    const int unit_id = unit_row * num_horizontal_units + unit_column;
-    const LoopRestorationType type =
-        restoration_info_
-            ->loop_restoration_info(static_cast<Plane>(plane), unit_id)
-            .type;
-    const int current_process_unit_width =
-        (unit_x + process_unit_width <= plane_width) ? process_unit_width
-                                                     : plane_width - unit_x;
-    ApplyLoopRestorationForOneUnit<Pixel>(
-        cdef_buffer, cdef_buffer_stride, deblock_buffer, deblock_buffer_stride,
-        plane, plane_height, unit_id, type, x, y, row, column,
-        current_process_unit_width, current_process_unit_height,
-        process_unit_width, window_buffer_width_);
-  }
-}
-
-template <typename Pixel>
-void PostFilter::ApplyLoopRestorationForOneUnit(
-    uint8_t* const cdef_buffer, const ptrdiff_t cdef_buffer_stride,
-    uint8_t* const deblock_buffer, const ptrdiff_t deblock_buffer_stride,
-    const Plane plane, const int plane_height, const int unit_id,
-    const LoopRestorationType type, const int x, const int y, const int row,
-    const int column, const int current_process_unit_width,
-    const int current_process_unit_height, const int plane_process_unit_width,
-    const int window_width) {
-  const int unit_x = x + column;
-  const int unit_y = y + row;
-  uint8_t* cdef_unit_buffer =
-      cdef_buffer + unit_y * cdef_buffer_stride + unit_x * pixel_size_;
-  Array2DView<Pixel> loop_restored_window(
-      window_buffer_height_, window_buffer_width_,
-      reinterpret_cast<Pixel*>(threaded_window_buffer_));
-  if (type == kLoopRestorationTypeNone) {
-    Pixel* dest = &loop_restored_window[row][column];
-    for (int k = 0; k < current_process_unit_height; ++k) {
-      memcpy(dest, cdef_unit_buffer, current_process_unit_width * pixel_size_);
-      dest += window_width;
-      cdef_unit_buffer += cdef_buffer_stride;
-    }
-    return;
-  }
-
-  // The SIMD implementation of wiener filter (currently WienerFilter_SSE4_1())
-  // over-reads 6 bytes, so add 6 extra bytes at the end of block_buffer for 8
-  // bit.
-  alignas(alignof(uint16_t))
-      uint8_t block_buffer[kRestorationProcessingUnitSizeWithBorders *
-                               kRestorationProcessingUnitSizeWithBorders *
-                               sizeof(Pixel) +
-                           ((sizeof(Pixel) == 1) ? 6 : 0)];
-  const ptrdiff_t block_buffer_stride =
-      kRestorationProcessingUnitSizeWithBorders * pixel_size_;
-  IntermediateBuffers intermediate_buffers;
-
-  RestorationBuffer restoration_buffer = {
-      {intermediate_buffers.box_filter.output[0],
-       intermediate_buffers.box_filter.output[1]},
-      plane_process_unit_width,
-      {intermediate_buffers.box_filter.intermediate_a,
-       intermediate_buffers.box_filter.intermediate_b},
-      kRestorationProcessingUnitSizeWithBorders + kRestorationPadding,
-      intermediate_buffers.wiener,
-      kMaxSuperBlockSizeInPixels};
-  uint8_t* deblock_unit_buffer =
-      deblock_buffer + unit_y * deblock_buffer_stride + unit_x * pixel_size_;
-  assert(type == kLoopRestorationTypeSgrProj ||
-         type == kLoopRestorationTypeWiener);
-  const dsp::LoopRestorationFunc restoration_func =
-      dsp_.loop_restorations[type - 2];
-  PrepareLoopRestorationBlock<Pixel>(
-      cdef_unit_buffer, cdef_buffer_stride, deblock_unit_buffer,
-      deblock_buffer_stride, block_buffer, block_buffer_stride,
-      current_process_unit_width, current_process_unit_height, unit_y == 0,
-      unit_y + current_process_unit_height >= plane_height);
-  restoration_func(reinterpret_cast<const uint8_t*>(
-                       block_buffer + kRestorationBorder * block_buffer_stride +
-                       kRestorationBorder * pixel_size_),
-                   &loop_restored_window[row][column],
-                   restoration_info_->loop_restoration_info(
-                       static_cast<Plane>(plane), unit_id),
-                   block_buffer_stride, window_width * pixel_size_,
-                   current_process_unit_width, current_process_unit_height,
-                   &restoration_buffer);
-}
-
-// Multi-thread version of loop restoration, based on a moving window of size
-// |window_buffer_width_|x|window_buffer_height_|. Inside the moving window, we
-// create a filtering job for each row and each filtering job is submitted to
-// the thread pool. Each free thread takes one job from the thread pool and
-// completes filtering until all jobs are finished. This approach requires an
-// extra buffer (|threaded_window_buffer_|) to hold the filtering output, whose
-// size is the size of the window. It also needs block buffers (i.e.,
-// |block_buffer| and |intermediate_buffers| in
-// ApplyLoopRestorationForOneUnit()) to store intermediate results in loop
-// restoration for each thread. After all units inside the window are filtered,
-// the output is written to the frame buffer.
-template <typename Pixel>
-bool PostFilter::ApplyLoopRestorationThreaded() {
-  if (!DoCdef()) cdef_buffer_ = source_buffer_;
-  const int plane_process_unit_width[kMaxPlanes] = {
-      kRestorationProcessingUnitSize,
-      kRestorationProcessingUnitSize >> subsampling_x_,
-      kRestorationProcessingUnitSize >> subsampling_x_};
-  const int plane_process_unit_height[kMaxPlanes] = {
-      kRestorationProcessingUnitSize,
-      kRestorationProcessingUnitSize >> subsampling_y_,
-      kRestorationProcessingUnitSize >> subsampling_y_};
-
-  const int horizontal_shift = -source_buffer_->alignment() / pixel_size_;
-  const int vertical_shift = -kRestorationBorder;
-  for (int plane = kPlaneY; plane < planes_; ++plane) {
-    if (loop_restoration_.type[plane] == kLoopRestorationTypeNone) {
-      if (!DoCdef()) continue;
-      CopyPlane<Pixel>(cdef_buffer_->data(plane), cdef_buffer_->stride(plane),
-                       cdef_buffer_->displayed_width(plane),
-                       cdef_buffer_->displayed_height(plane),
-                       source_buffer_->data(plane),
-                       source_buffer_->stride(plane));
-      continue;
-    }
-
-    const int8_t subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
-    const int8_t subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
-    const int unit_height_offset = kRestorationUnitOffset >> subsampling_y;
-    uint8_t* src_buffer = source_buffer_->data(plane);
-    const int src_stride = source_buffer_->stride(plane);
-    uint8_t* cdef_buffer = cdef_buffer_->data(plane);
-    const int cdef_buffer_stride = cdef_buffer_->stride(plane);
-    uint8_t* deblock_buffer = source_buffer_->data(plane);
-    const int deblock_buffer_stride = source_buffer_->stride(plane);
-    const int plane_unit_size = loop_restoration_.unit_size[plane];
-    const int num_vertical_units =
-        restoration_info_->num_vertical_units(static_cast<Plane>(plane));
-    const int num_horizontal_units =
-        restoration_info_->num_horizontal_units(static_cast<Plane>(plane));
-    const int plane_width =
-        RightShiftWithRounding(upscaled_width_, subsampling_x);
-    const int plane_height = RightShiftWithRounding(height_, subsampling_y);
-    const ptrdiff_t src_unit_buffer_offset =
-        vertical_shift * src_stride + horizontal_shift * pixel_size_;
-    ExtendFrameBoundary(cdef_buffer, plane_width, plane_height,
-                        cdef_buffer_stride, kRestorationBorder,
-                        kRestorationBorder, kRestorationBorder,
-                        kRestorationBorder);
-    if (DoCdef()) {
-      ExtendFrameBoundary(deblock_buffer, plane_width, plane_height,
-                          deblock_buffer_stride, kRestorationBorder,
-                          kRestorationBorder, kRestorationBorder,
-                          kRestorationBorder);
-    }
-
-    const int num_workers = thread_pool_->num_threads();
-    for (int y = 0; y < plane_height; y += window_buffer_height_) {
-      const int actual_window_height =
-          std::min(window_buffer_height_ - ((y == 0) ? unit_height_offset : 0),
-                   plane_height - y);
-      int vertical_units_per_window =
-          (actual_window_height + plane_process_unit_height[plane] - 1) /
-          plane_process_unit_height[plane];
-      if (y == 0) {
-        // The first row of loop restoration processing units is not 64x64, but
-        // 64x56 (|unit_height_offset| = 8 rows less than other restoration
-        // processing units). For u/v with subsampling, the size is halved. To
-        // compute the number of vertical units per window, we need to take a
-        // special handling for it.
-        const int height_without_first_unit =
-            actual_window_height -
-            std::min(actual_window_height,
-                     plane_process_unit_height[plane] - unit_height_offset);
-        vertical_units_per_window =
-            (height_without_first_unit + plane_process_unit_height[plane] - 1) /
-                plane_process_unit_height[plane] +
-            1;
-      }
-      for (int x = 0; x < plane_width; x += window_buffer_width_) {
-        const int actual_window_width =
-            std::min(window_buffer_width_, plane_width - x);
-        const int jobs_for_threadpool =
-            vertical_units_per_window * num_workers / (num_workers + 1);
-        assert(jobs_for_threadpool < vertical_units_per_window);
-        BlockingCounter pending_jobs(jobs_for_threadpool);
-        int job_count = 0;
-        int current_process_unit_height;
-        for (int row = 0; row < actual_window_height;
-             row += current_process_unit_height) {
-          const int unit_y = y + row;
-          const int expected_height = plane_process_unit_height[plane] +
-                                      ((unit_y == 0) ? -unit_height_offset : 0);
-          current_process_unit_height =
-              (unit_y + expected_height <= plane_height)
-                  ? expected_height
-                  : plane_height - unit_y;
-          const int unit_row =
-              std::min((unit_y + unit_height_offset) / plane_unit_size,
-                       num_vertical_units - 1);
-          const int process_unit_width = plane_process_unit_width[plane];
-
-          if (job_count < jobs_for_threadpool) {
-            thread_pool_->Schedule(
-                [this, cdef_buffer, cdef_buffer_stride, deblock_buffer,
-                 deblock_buffer_stride, process_unit_width,
-                 current_process_unit_height, actual_window_width,
-                 plane_unit_size, num_horizontal_units, x, y, row, unit_row,
-                 plane_height, plane_width, plane, &pending_jobs]() {
-                  ApplyLoopRestorationForOneRowInWindow<Pixel>(
-                      cdef_buffer, cdef_buffer_stride, deblock_buffer,
-                      deblock_buffer_stride, static_cast<Plane>(plane),
-                      plane_height, plane_width, x, y, row, unit_row,
-                      current_process_unit_height, process_unit_width,
-                      actual_window_width, plane_unit_size,
-                      num_horizontal_units);
-                  pending_jobs.Decrement();
-                });
-          } else {
-            ApplyLoopRestorationForOneRowInWindow<Pixel>(
-                cdef_buffer, cdef_buffer_stride, deblock_buffer,
-                deblock_buffer_stride, static_cast<Plane>(plane), plane_height,
-                plane_width, x, y, row, unit_row, current_process_unit_height,
-                process_unit_width, actual_window_width, plane_unit_size,
-                num_horizontal_units);
-          }
-          ++job_count;
-        }
-        // Wait for all jobs of current window to finish.
-        pending_jobs.Wait();
-        // Copy |threaded_window_buffer_| to output frame.
-        CopyPlane<Pixel>(threaded_window_buffer_,
-                         window_buffer_width_ * pixel_size_,
-                         actual_window_width, actual_window_height,
-                         src_buffer + y * src_stride + x * pixel_size_ +
-                             src_unit_buffer_offset,
-                         src_stride);
-      }
-      if (y == 0) y -= unit_height_offset;
-    }
-    if (!source_buffer_->ShiftBuffer(plane, horizontal_shift, vertical_shift)) {
-      LIBGAV1_DLOG(ERROR,
-                   "Error shifting frame buffer head pointer at plane: %d",
-                   plane);
-      return false;
-    }
-  }
-  return true;
-}
-
-bool PostFilter::ApplyLoopRestoration() {
-  if (thread_pool_ != nullptr) {
-    assert(threaded_window_buffer_ != nullptr);
-#if LIBGAV1_MAX_BITDEPTH >= 10
-    if (bitdepth_ >= 10) {
-      return ApplyLoopRestorationThreaded<uint16_t>();
-    }
-#endif
-    return ApplyLoopRestorationThreaded<uint8_t>();
-  }
-
-  if (!DoCdef()) cdef_buffer_ = source_buffer_;
-  const ptrdiff_t block_buffer_stride =
-      kRestorationProcessingUnitSizeWithBorders * pixel_size_;
-  const int plane_process_unit_width[kMaxPlanes] = {
-      kRestorationProcessingUnitSize,
-      kRestorationProcessingUnitSize >> subsampling_x_,
-      kRestorationProcessingUnitSize >> subsampling_x_};
-  const int plane_process_unit_height[kMaxPlanes] = {
-      kRestorationProcessingUnitSize,
-      kRestorationProcessingUnitSize >> subsampling_y_,
-      kRestorationProcessingUnitSize >> subsampling_y_};
-  IntermediateBuffers intermediate_buffers;
-  RestorationBuffer restoration_buffer = {
-      {intermediate_buffers.box_filter.output[0],
-       intermediate_buffers.box_filter.output[1]},
-      plane_process_unit_width[kPlaneY],
-      {intermediate_buffers.box_filter.intermediate_a,
-       intermediate_buffers.box_filter.intermediate_b},
-      kRestorationProcessingUnitSizeWithBorders + kRestorationPadding,
-      intermediate_buffers.wiener,
-      kMaxSuperBlockSizeInPixels};
-
-  for (int plane = kPlaneY; plane < planes_; ++plane) {
-    if (loop_restoration_.type[plane] == kLoopRestorationTypeNone) {
-      if (!DoCdef()) continue;
-      if (cdef_buffer_->bitdepth() == 8) {
-        CopyPlane<uint8_t>(
-            cdef_buffer_->data(plane), cdef_buffer_->stride(plane),
-            cdef_buffer_->displayed_width(plane),
-            cdef_buffer_->displayed_height(plane), source_buffer_->data(plane),
-            source_buffer_->stride(plane));
-#if LIBGAV1_MAX_BITDEPTH >= 10
-      } else {
-        CopyPlane<uint16_t>(
-            cdef_buffer_->data(plane), cdef_buffer_->stride(plane),
-            cdef_buffer_->displayed_width(plane),
-            cdef_buffer_->displayed_height(plane), source_buffer_->data(plane),
-            source_buffer_->stride(plane));
-#endif
-      }
-      continue;
-    }
-    const int8_t subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
-    const int8_t subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
-    const int unit_height_offset = kRestorationUnitOffset >> subsampling_y;
-    restoration_buffer.box_filter_process_output_stride =
-        plane_process_unit_width[plane];
-    uint8_t* src_buffer = source_buffer_->data(plane);
-    const ptrdiff_t src_stride = source_buffer_->stride(plane);
-    uint8_t* cdef_buffer = cdef_buffer_->data(plane);
-    const ptrdiff_t cdef_buffer_stride = cdef_buffer_->stride(plane);
-    uint8_t* deblock_buffer = source_buffer_->data(plane);
-    const ptrdiff_t deblock_buffer_stride = source_buffer_->stride(plane);
-    const int plane_unit_size = loop_restoration_.unit_size[plane];
-    const int num_vertical_units =
-        restoration_info_->num_vertical_units(static_cast<Plane>(plane));
-    const int num_horizontal_units =
-        restoration_info_->num_horizontal_units(static_cast<Plane>(plane));
-    const int plane_width =
-        RightShiftWithRounding(upscaled_width_, subsampling_x);
-    const int plane_height = RightShiftWithRounding(height_, subsampling_y);
-    ExtendFrameBoundary(cdef_buffer, plane_width, plane_height,
-                        cdef_buffer_stride, kRestorationBorder,
-                        kRestorationBorder, kRestorationBorder,
-                        kRestorationBorder);
-    if (DoCdef()) {
-      ExtendFrameBoundary(deblock_buffer, plane_width, plane_height,
-                          deblock_buffer_stride, kRestorationBorder,
-                          kRestorationBorder, kRestorationBorder,
-                          kRestorationBorder);
-    }
-
-    int loop_restored_rows = 0;
-    const int horizontal_shift = -source_buffer_->alignment() / pixel_size_;
-    const int vertical_shift = -kRestorationBorder;
-    const ptrdiff_t src_unit_buffer_offset =
-        vertical_shift * src_stride + horizontal_shift * pixel_size_;
-    for (int unit_row = 0; unit_row < num_vertical_units; ++unit_row) {
-      int current_unit_height = plane_unit_size;
-      // Note [1]: we need to identify the entire restoration area. So the
-      // condition check of finding the boundary is first. In contrast, Note [2]
-      // is a case where condition check of the first row is first.
-      if (unit_row == num_vertical_units - 1) {
-        // Take care of the last row. The max height of last row units could be
-        // 3/2 unit_size.
-        current_unit_height = plane_height - loop_restored_rows;
-      } else if (unit_row == 0) {
-        // The size of restoration units in the first row has to subtract the
-        // height offset.
-        current_unit_height -= unit_height_offset;
-      }
-
-      for (int unit_column = 0; unit_column < num_horizontal_units;
-           ++unit_column) {
-        const int unit_id = unit_row * num_horizontal_units + unit_column;
-        const LoopRestorationType type =
-            restoration_info_
-                ->loop_restoration_info(static_cast<Plane>(plane), unit_id)
-                .type;
-        uint8_t* src_unit_buffer =
-            src_buffer + unit_column * plane_unit_size * pixel_size_;
-        uint8_t* cdef_unit_buffer =
-            cdef_buffer + unit_column * plane_unit_size * pixel_size_;
-        uint8_t* deblock_unit_buffer =
-            deblock_buffer + unit_column * plane_unit_size * pixel_size_;
-
-        // Take care of the last column. The max width of last column unit
-        // could be 3/2 unit_size.
-        const int current_unit_width =
-            (unit_column == num_horizontal_units - 1)
-                ? plane_width - plane_unit_size * unit_column
-                : plane_unit_size;
-
-        if (type == kLoopRestorationTypeNone) {
-          for (int y = 0; y < current_unit_height; ++y) {
-            memcpy(src_unit_buffer + src_unit_buffer_offset, cdef_unit_buffer,
-                   current_unit_width * pixel_size_);
-            src_unit_buffer += src_stride;
-            cdef_unit_buffer += cdef_buffer_stride;
-          }
-          continue;
-        }
-
-        assert(type == kLoopRestorationTypeWiener ||
-               type == kLoopRestorationTypeSgrProj);
-        const dsp::LoopRestorationFunc restoration_func =
-            dsp_.loop_restorations[type - 2];
-        for (int row = 0; row < current_unit_height;) {
-          const int current_process_unit_height =
-              plane_process_unit_height[plane] +
-              ((unit_row + row == 0) ? -unit_height_offset : 0);
-
-          for (int column = 0; column < current_unit_width;
-               column += plane_process_unit_width[plane]) {
-            const int processing_unit_width = std::min(
-                plane_process_unit_width[plane], current_unit_width - column);
-            int processing_unit_height = plane_process_unit_height[plane];
-            // Note [2]: the height of processing units in the first row has
-            // special cases where the frame height is less than
-            // plane_process_unit_height[plane].
-            if (unit_row + row == 0) {
-              processing_unit_height = std::min(
-                  plane_process_unit_height[plane] - unit_height_offset,
-                  current_unit_height);
-            } else if (current_unit_height - row <
-                       plane_process_unit_height[plane]) {
-              // The height of last row of processing units.
-              processing_unit_height = current_unit_height - row;
-            }
-            // We apply in-place loop restoration, by copying the source block
-            // to a buffer and computing loop restoration on it. The restored
-            // pixel values are then stored to the frame buffer. However,
-            // loop restoration requires (a) 3 pixel extension on current 64x64
-            // processing unit, (b) unrestored pixels.
-            // To address this, we store the restored pixels not onto the start
-            // of current block on the source frame buffer, say point A,
-            // but to its top by three pixels and to the left by
-            // alignment/pixel_size_ pixels, say point B, such that
-            // next processing unit can fetch 3 pixel border of unrestored
-            // values. And we need to adjust the input frame buffer pointer to
-            // its left and top corner, point B.
-            uint8_t* const cdef_process_unit_buffer =
-                cdef_unit_buffer + column * pixel_size_;
-            uint8_t* const deblock_process_unit_buffer =
-                deblock_unit_buffer + column * pixel_size_;
-            const bool frame_top_border = unit_row + row == 0;
-            const bool frame_bottom_border =
-                (unit_row == num_vertical_units - 1) &&
-                (row + current_process_unit_height >= current_unit_height);
-            if (bitdepth_ == 8) {
-              PrepareLoopRestorationBlock<uint8_t>(
-                  cdef_process_unit_buffer, cdef_buffer_stride,
-                  deblock_process_unit_buffer, deblock_buffer_stride,
-                  block_buffer_, block_buffer_stride, processing_unit_width,
-                  processing_unit_height, frame_top_border,
-                  frame_bottom_border);
-            } else {
-              PrepareLoopRestorationBlock<uint16_t>(
-                  cdef_process_unit_buffer, cdef_buffer_stride,
-                  deblock_process_unit_buffer, deblock_buffer_stride,
-                  block_buffer_, block_buffer_stride, processing_unit_width,
-                  processing_unit_height, frame_top_border,
-                  frame_bottom_border);
-            }
-            restoration_func(
-                reinterpret_cast<const uint8_t*>(
-                    block_buffer_ + kRestorationBorder * block_buffer_stride +
-                    kRestorationBorder * pixel_size_),
-                src_unit_buffer + column * pixel_size_ + src_unit_buffer_offset,
-                restoration_info_->loop_restoration_info(
-                    static_cast<Plane>(plane), unit_id),
-                block_buffer_stride, src_stride, processing_unit_width,
-                processing_unit_height, &restoration_buffer);
-          }
-          row += current_process_unit_height;
-          src_unit_buffer += current_process_unit_height * src_stride;
-          cdef_unit_buffer += current_process_unit_height * cdef_buffer_stride;
-          deblock_unit_buffer +=
-              current_process_unit_height * deblock_buffer_stride;
-        }
-      }
-      loop_restored_rows += current_unit_height;
-      src_buffer += current_unit_height * src_stride;
-      cdef_buffer += current_unit_height * cdef_buffer_stride;
-      deblock_buffer += current_unit_height * deblock_buffer_stride;
-    }
-    // Adjust frame buffer pointer once a plane is loop restored.
-    // If loop restoration is applied to a plane, we write the filtered frame
-    // to the upper-left side of original source_buffer_->data().
-    // The new buffer pointer is still within the physical frame buffer.
-    // Here negative shifts are used, to indicate shifting towards the
-    // upper-left corner. Shifts are in pixels.
-    if (!source_buffer_->ShiftBuffer(plane, horizontal_shift, vertical_shift)) {
-      LIBGAV1_DLOG(ERROR,
-                   "Error shifting frame buffer head pointer at plane: %d",
-                   plane);
-      return false;
-    }
-  }
-
-  return true;
-}
-
-void PostFilter::HorizontalDeblockFilter(Plane plane, int row4x4_start,
-                                         int column4x4_start, int unit_id) {
-  const int8_t subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
-  const int8_t subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
-  const int row_step = 1 << subsampling_y;
-  const int column_step = 1 << subsampling_x;
-  const size_t src_step = 4 * pixel_size_;
-  const ptrdiff_t row_stride = MultiplyBy4(source_buffer_->stride(plane));
-  const ptrdiff_t src_stride = source_buffer_->stride(plane);
-  uint8_t* src = SetBufferOffset(source_buffer_, plane, row4x4_start,
-                                 column4x4_start, subsampling_x, subsampling_y);
-  const uint64_t single_row_mask = 0xffff;
-  // 3 (11), 5 (0101).
-  const uint64_t two_block_mask = (subsampling_x > 0) ? 5 : 3;
-  const LoopFilterType type = kLoopFilterTypeHorizontal;
-  // Subsampled UV samples correspond to the right/bottom position of
-  // Y samples.
-  const int column = subsampling_x;
-
-  // AV1 smallest transform size is 4x4, thus minimum horizontal edge size is
-  // 4x4. For SIMD implementation, sse2 could compute 8 pixels at the same time.
-  // __m128i = 8 x uint16_t, AVX2 could compute 16 pixels at the same time.
-  // __m256i = 16 x uint16_t, assuming pixel type is 16 bit. It means we could
-  // filter 2 horizontal edges using sse2 and 4 edges using AVX2.
-  // The bitmask enables us to call different SIMD implementations to filter
-  // 1 edge, or 2 edges or 4 edges.
-  // TODO(chengchen): Here, the implementation only consider 1 and 2 edges.
-  // Add support for 4 edges. More branches involved, for example, if input is
-  // 8 bit, __m128i = 16 x 8 bit, we could apply filtering for 4 edges using
-  // sse2, 8 edges using AVX2. If input is 16 bit, __m128 = 8 x 16 bit, then
-  // we apply filtering for 2 edges using sse2, and 4 edges using AVX2.
-  for (int row4x4 = 0; MultiplyBy4(row4x4_start + row4x4) < height_ &&
-                       row4x4 < kNum4x4InLoopFilterMaskUnit;
-       row4x4 += row_step) {
-    if (row4x4_start + row4x4 == 0) {
-      src += row_stride;
-      continue;
-    }
-    // Subsampled UV samples correspond to the right/bottom position of
-    // Y samples.
-    const int row = GetDeblockPosition(row4x4, subsampling_y);
-    const int index = GetIndex(row);
-    const int shift = GetShift(row, column);
-    const int level_offset = LoopFilterMask::GetLevelOffset(row, column);
-    // Mask of current row. mask4x4 represents the vertical filter length for
-    // the current horizontal edge is 4, and we needs to apply 3-tap filtering.
-    // Similarly, mask8x8 and mask16x16 represent filter lengths are 8 and 16.
-    uint64_t mask4x4 =
-        (masks_->GetTop(unit_id, plane, kLoopFilterTransformSizeId4x4, index) >>
-         shift) &
-        single_row_mask;
-    uint64_t mask8x8 =
-        (masks_->GetTop(unit_id, plane, kLoopFilterTransformSizeId8x8, index) >>
-         shift) &
-        single_row_mask;
-    uint64_t mask16x16 =
-        (masks_->GetTop(unit_id, plane, kLoopFilterTransformSizeId16x16,
-                        index) >>
-         shift) &
-        single_row_mask;
-    // mask4x4, mask8x8, mask16x16 are mutually exclusive.
-    assert((mask4x4 & mask8x8) == 0 && (mask4x4 & mask16x16) == 0 &&
-           (mask8x8 & mask16x16) == 0);
-    // Apply deblock filter for one row.
-    uint8_t* src_row = src;
-    int column_offset = 0;
-    for (uint64_t mask = mask4x4 | mask8x8 | mask16x16; mask != 0;) {
-      int edge_count = 1;
-      if ((mask & 1) != 0) {
-        // Filter parameters of current edge.
-        const uint8_t level = masks_->GetLevel(unit_id, plane, type,
-                                               level_offset + column_offset);
-        int outer_thresh_0;
-        int inner_thresh_0;
-        int hev_thresh_0;
-        GetDeblockFilterParams(level, &outer_thresh_0, &inner_thresh_0,
-                               &hev_thresh_0);
-        // Filter parameters of next edge. Clip the index to avoid over
-        // reading at the edge of the block. The values will be unused in that
-        // case.
-        const int level_next_index = level_offset + column_offset + column_step;
-        const uint8_t level_next =
-            masks_->GetLevel(unit_id, plane, type, level_next_index & 0xff);
-        int outer_thresh_1;
-        int inner_thresh_1;
-        int hev_thresh_1;
-        GetDeblockFilterParams(level_next, &outer_thresh_1, &inner_thresh_1,
-                               &hev_thresh_1);
-
-        if ((mask16x16 & 1) != 0) {
-          const dsp::LoopFilterSize size = (plane == kPlaneY)
-                                               ? dsp::kLoopFilterSize14
-                                               : dsp::kLoopFilterSize6;
-          const dsp::LoopFilterFunc filter_func = dsp_.loop_filters[size][type];
-          if ((mask16x16 & two_block_mask) == two_block_mask) {
-            edge_count = 2;
-            // Apply filtering for two edges.
-            filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
-                        hev_thresh_0);
-            filter_func(src_row + src_step, src_stride, outer_thresh_1,
-                        inner_thresh_1, hev_thresh_1);
-          } else {
-            // Apply single edge filtering.
-            filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
-                        hev_thresh_0);
-          }
-        }
-
-        if ((mask8x8 & 1) != 0) {
-          const dsp::LoopFilterSize size =
-              plane == kPlaneY ? dsp::kLoopFilterSize8 : dsp::kLoopFilterSize6;
-          const dsp::LoopFilterFunc filter_func = dsp_.loop_filters[size][type];
-          if ((mask8x8 & two_block_mask) == two_block_mask) {
-            edge_count = 2;
-            // Apply filtering for two edges.
-            filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
-                        hev_thresh_0);
-            filter_func(src_row + src_step, src_stride, outer_thresh_1,
-                        inner_thresh_1, hev_thresh_1);
-          } else {
-            // Apply single edge filtering.
-            filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
-                        hev_thresh_0);
-          }
-        }
-
-        if ((mask4x4 & 1) != 0) {
-          const dsp::LoopFilterSize size = dsp::kLoopFilterSize4;
-          const dsp::LoopFilterFunc filter_func = dsp_.loop_filters[size][type];
-          if ((mask4x4 & two_block_mask) == two_block_mask) {
-            edge_count = 2;
-            // Apply filtering for two edges.
-            filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
-                        hev_thresh_0);
-            filter_func(src_row + src_step, src_stride, outer_thresh_1,
-                        inner_thresh_1, hev_thresh_1);
-          } else {
-            // Apply single edge filtering.
-            filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
-                        hev_thresh_0);
-          }
-        }
-      }
-
-      const int step = edge_count * column_step;
-      mask4x4 >>= step;
-      mask8x8 >>= step;
-      mask16x16 >>= step;
-      mask >>= step;
-      column_offset += step;
-      src_row += MultiplyBy4(edge_count) * pixel_size_;
-    }
-    src += row_stride;
-  }
-}
-
-void PostFilter::VerticalDeblockFilter(Plane plane, int row4x4_start,
-                                       int column4x4_start, int unit_id) {
-  const int8_t subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x_;
-  const int8_t subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y_;
-  const int row_step = 1 << subsampling_y;
-  const int two_row_step = row_step << 1;
-  const int column_step = 1 << subsampling_x;
-  const size_t src_step = (bitdepth_ == 8) ? 4 : 4 * sizeof(uint16_t);
-  const ptrdiff_t row_stride = MultiplyBy4(source_buffer_->stride(plane));
-  const ptrdiff_t two_row_stride = row_stride << 1;
-  const ptrdiff_t src_stride = source_buffer_->stride(plane);
-  uint8_t* src = SetBufferOffset(source_buffer_, plane, row4x4_start,
-                                 column4x4_start, subsampling_x, subsampling_y);
-  const uint64_t single_row_mask = 0xffff;
-  const LoopFilterType type = kLoopFilterTypeVertical;
-  // Subsampled UV samples correspond to the right/bottom position of
-  // Y samples.
-  const int column = subsampling_x;
-
-  // AV1 smallest transform size is 4x4, thus minimum vertical edge size is 4x4.
-  // For SIMD implementation, sse2 could compute 8 pixels at the same time.
-  // __m128i = 8 x uint16_t, AVX2 could compute 16 pixels at the same time.
-  // __m256i = 16 x uint16_t, assuming pixel type is 16 bit. It means we could
-  // filter 2 vertical edges using sse2 and 4 edges using AVX2.
-  // The bitmask enables us to call different SIMD implementations to filter
-  // 1 edge, or 2 edges or 4 edges.
-  // TODO(chengchen): Here, the implementation only consider 1 and 2 edges.
-  // Add support for 4 edges. More branches involved, for example, if input is
-  // 8 bit, __m128i = 16 x 8 bit, we could apply filtering for 4 edges using
-  // sse2, 8 edges using AVX2. If input is 16 bit, __m128 = 8 x 16 bit, then
-  // we apply filtering for 2 edges using sse2, and 4 edges using AVX2.
-  for (int row4x4 = 0; MultiplyBy4(row4x4_start + row4x4) < height_ &&
-                       row4x4 < kNum4x4InLoopFilterMaskUnit;
-       row4x4 += two_row_step) {
-    // Subsampled UV samples correspond to the right/bottom position of
-    // Y samples.
-    const int row = GetDeblockPosition(row4x4, subsampling_y);
-    const int row_next = row + row_step;
-    const int index = GetIndex(row);
-    const int shift = GetShift(row, column);
-    const int level_offset = LoopFilterMask::GetLevelOffset(row, column);
-    const int index_next = GetIndex(row_next);
-    const int shift_next_row = GetShift(row_next, column);
-    const int level_offset_next_row =
-        LoopFilterMask::GetLevelOffset(row_next, column);
-    // TODO(chengchen): replace 0, 1, 2 to meaningful enum names.
-    // mask of current row. mask4x4 represents the horizontal filter length for
-    // the current vertical edge is 4, and we needs to apply 3-tap filtering.
-    // Similarly, mask8x8 and mask16x16 represent filter lengths are 8 and 16.
-    uint64_t mask4x4_0 =
-        (masks_->GetLeft(unit_id, plane, kLoopFilterTransformSizeId4x4,
-                         index) >>
-         shift) &
-        single_row_mask;
-    uint64_t mask8x8_0 =
-        (masks_->GetLeft(unit_id, plane, kLoopFilterTransformSizeId8x8,
-                         index) >>
-         shift) &
-        single_row_mask;
-    uint64_t mask16x16_0 =
-        (masks_->GetLeft(unit_id, plane, kLoopFilterTransformSizeId16x16,
-                         index) >>
-         shift) &
-        single_row_mask;
-    // mask4x4, mask8x8, mask16x16 are mutually exclusive.
-    assert((mask4x4_0 & mask8x8_0) == 0 && (mask4x4_0 & mask16x16_0) == 0 &&
-           (mask8x8_0 & mask16x16_0) == 0);
-    // mask of the next row. With mask of current and the next row, we can call
-    // the corresponding SIMD function to apply filtering for two vertical
-    // edges together.
-    uint64_t mask4x4_1 =
-        (masks_->GetLeft(unit_id, plane, kLoopFilterTransformSizeId4x4,
-                         index_next) >>
-         shift_next_row) &
-        single_row_mask;
-    uint64_t mask8x8_1 =
-        (masks_->GetLeft(unit_id, plane, kLoopFilterTransformSizeId8x8,
-                         index_next) >>
-         shift_next_row) &
-        single_row_mask;
-    uint64_t mask16x16_1 =
-        (masks_->GetLeft(unit_id, plane, kLoopFilterTransformSizeId16x16,
-                         index_next) >>
-         shift_next_row) &
-        single_row_mask;
-    // mask4x4, mask8x8, mask16x16 are mutually exclusive.
-    assert((mask4x4_1 & mask8x8_1) == 0 && (mask4x4_1 & mask16x16_1) == 0 &&
-           (mask8x8_1 & mask16x16_1) == 0);
-    // Apply deblock filter for two rows.
-    uint8_t* src_row = src;
-    int column_offset = 0;
-    for (uint64_t mask = mask4x4_0 | mask8x8_0 | mask16x16_0 | mask4x4_1 |
-                         mask8x8_1 | mask16x16_1;
-         mask != 0;) {
-      if ((mask & 1) != 0) {
-        // Filter parameters of current row.
-        const uint8_t level = masks_->GetLevel(unit_id, plane, type,
-                                               level_offset + column_offset);
-        int outer_thresh_0;
-        int inner_thresh_0;
-        int hev_thresh_0;
-        GetDeblockFilterParams(level, &outer_thresh_0, &inner_thresh_0,
-                               &hev_thresh_0);
-        // Filter parameters of next row. Clip the index to avoid over
-        // reading at the edge of the block. The values will be unused in that
-        // case.
-        const int level_next_index = level_offset_next_row + column_offset;
-        const uint8_t level_next =
-            masks_->GetLevel(unit_id, plane, type, level_next_index & 0xff);
-        int outer_thresh_1;
-        int inner_thresh_1;
-        int hev_thresh_1;
-        GetDeblockFilterParams(level_next, &outer_thresh_1, &inner_thresh_1,
-                               &hev_thresh_1);
-        uint8_t* const src_row_next = src_row + row_stride;
-
-        if (((mask16x16_0 | mask16x16_1) & 1) != 0) {
-          const dsp::LoopFilterSize size = (plane == kPlaneY)
-                                               ? dsp::kLoopFilterSize14
-                                               : dsp::kLoopFilterSize6;
-          const dsp::LoopFilterFunc filter_func = dsp_.loop_filters[size][type];
-          if ((mask16x16_0 & mask16x16_1 & 1) != 0) {
-            // Apply dual vertical edge filtering.
-            filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
-                        hev_thresh_0);
-            filter_func(src_row_next, src_stride, outer_thresh_1,
-                        inner_thresh_1, hev_thresh_1);
-          } else if ((mask16x16_0 & 1) != 0) {
-            filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
-                        hev_thresh_0);
-          } else {
-            filter_func(src_row_next, src_stride, outer_thresh_1,
-                        inner_thresh_1, hev_thresh_1);
-          }
-        }
-
-        if (((mask8x8_0 | mask8x8_1) & 1) != 0) {
-          const dsp::LoopFilterSize size = (plane == kPlaneY)
-                                               ? dsp::kLoopFilterSize8
-                                               : dsp::kLoopFilterSize6;
-          const dsp::LoopFilterFunc filter_func = dsp_.loop_filters[size][type];
-          if ((mask8x8_0 & mask8x8_1 & 1) != 0) {
-            filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
-                        hev_thresh_0);
-            filter_func(src_row_next, src_stride, outer_thresh_1,
-                        inner_thresh_1, hev_thresh_1);
-          } else if ((mask8x8_0 & 1) != 0) {
-            filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
-                        hev_thresh_0);
-          } else {
-            filter_func(src_row_next, src_stride, outer_thresh_1,
-                        inner_thresh_1, hev_thresh_1);
-          }
-        }
-
-        if (((mask4x4_0 | mask4x4_1) & 1) != 0) {
-          const dsp::LoopFilterSize size = dsp::kLoopFilterSize4;
-          const dsp::LoopFilterFunc filter_func = dsp_.loop_filters[size][type];
-          if ((mask4x4_0 & mask4x4_1 & 1) != 0) {
-            filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
-                        hev_thresh_0);
-            filter_func(src_row_next, src_stride, outer_thresh_1,
-                        inner_thresh_1, hev_thresh_1);
-          } else if ((mask4x4_0 & 1) != 0) {
-            filter_func(src_row, src_stride, outer_thresh_0, inner_thresh_0,
-                        hev_thresh_0);
-          } else {
-            filter_func(src_row_next, src_stride, outer_thresh_1,
-                        inner_thresh_1, hev_thresh_1);
-          }
-        }
-      }
-
-      mask4x4_0 >>= column_step;
-      mask8x8_0 >>= column_step;
-      mask16x16_0 >>= column_step;
-      mask4x4_1 >>= column_step;
-      mask8x8_1 >>= column_step;
-      mask16x16_1 >>= column_step;
-      mask >>= column_step;
-      column_offset += column_step;
-      src_row += src_step;
-    }
-    src += two_row_stride;
-  }
-}
-
-void PostFilter::InitDeblockFilterParams() {
-  const int8_t sharpness = frame_header_.loop_filter.sharpness;
-  assert(0 <= sharpness && sharpness < 8);
-  const int shift = DivideBy4(sharpness + 3);  // ceil(sharpness / 4.0)
-  for (int level = 0; level <= kMaxLoopFilterValue; ++level) {
-    uint8_t limit = level >> shift;
-    if (sharpness > 0) {
-      limit = Clip3(limit, 1, 9 - sharpness);
-    } else {
-      limit = std::max(limit, static_cast<uint8_t>(1));
-    }
-    inner_thresh_[level] = limit;
-    outer_thresh_[level] = 2 * (level + 2) + limit;
-    hev_thresh_[level] = level >> 4;
-  }
-}
-
-void PostFilter::GetDeblockFilterParams(uint8_t level, int* outer_thresh,
-                                        int* inner_thresh,
-                                        int* hev_thresh) const {
-  *outer_thresh = outer_thresh_[level];
-  *inner_thresh = inner_thresh_[level];
-  *hev_thresh = hev_thresh_[level];
-}
-
-}  // namespace libgav1
diff --git a/libgav1/src/post_filter.h b/libgav1/src/post_filter.h
index c4146cb..d300049 100644
--- a/libgav1/src/post_filter.h
+++ b/libgav1/src/post_filter.h
@@ -27,7 +27,7 @@
 
 #include "src/dsp/common.h"
 #include "src/dsp/dsp.h"
-#include "src/loop_filter_mask.h"
+#include "src/frame_scratch_buffer.h"
 #include "src/loop_restoration_info.h"
 #include "src/obu_parser.h"
 #include "src/utils/array_2d.h"
@@ -46,55 +46,37 @@
 // and loop restoration.
 // Historically, for example in libaom, loop filter refers to deblock filter.
 // To avoid name conflicts, we call this class PostFilter (post processing).
-// Input info includes deblock parameters (bit masks), CDEF
-// parameters, super resolution parameters and loop restoration parameters.
 // In-loop post filtering order is:
 // deblock --> CDEF --> super resolution--> loop restoration.
 // When CDEF and super resolution is not used, we can combine deblock
 // and restoration together to only filter frame buffer once.
 class PostFilter {
  public:
-  static constexpr int kCdefLargeValue = 30000;
-
   // This class does not take ownership of the masks/restoration_info, but it
   // may change their values.
+  //
+  // The overall flow of data in this class (for both single and multi-threaded
+  // cases) is as follows:
+  //   -> Input: |frame_buffer_|.
+  //   -> Initialize |source_buffer_|, |cdef_buffer_| and
+  //      |loop_restoration_buffer_|.
+  //   -> Deblocking:
+  //      * Input: |source_buffer_|
+  //      * Output: |source_buffer_|
+  //   -> CDEF:
+  //      * Input: |source_buffer_|
+  //      * Output: |cdef_buffer_|
+  //   -> SuperRes:
+  //      * Input: |cdef_buffer_|
+  //      * Output: |cdef_buffer_|
+  //   -> Loop Restoration:
+  //      * Input: |cdef_buffer_|
+  //      * Output: |loop_restoration_buffer_|.
+  //   -> Now |frame_buffer_| contains the filtered frame.
   PostFilter(const ObuFrameHeader& frame_header,
              const ObuSequenceHeader& sequence_header,
-             LoopFilterMask* const masks, const Array2D<int16_t>& cdef_index,
-             LoopRestorationInfo* const restoration_info,
-             BlockParametersHolder* block_parameters,
-             YuvBuffer* const source_buffer, const dsp::Dsp* dsp,
-             ThreadPool* const thread_pool,
-             uint8_t* const threaded_window_buffer, int do_post_filter_mask)
-      : frame_header_(frame_header),
-        loop_restoration_(frame_header.loop_restoration),
-        dsp_(*dsp),
-        // Deblocking filter always uses 64x64 as step size.
-        num_64x64_blocks_per_row_(DivideBy64(frame_header.width + 63)),
-        upscaled_width_(frame_header.upscaled_width),
-        width_(frame_header.width),
-        height_(frame_header.height),
-        bitdepth_(sequence_header.color_config.bitdepth),
-        subsampling_x_(sequence_header.color_config.subsampling_x),
-        subsampling_y_(sequence_header.color_config.subsampling_y),
-        planes_(sequence_header.color_config.is_monochrome
-                    ? kMaxPlanesMonochrome
-                    : kMaxPlanes),
-        pixel_size_(static_cast<int>((bitdepth_ == 8) ? sizeof(uint8_t)
-                                                      : sizeof(uint16_t))),
-        masks_(masks),
-        cdef_index_(cdef_index),
-        threaded_window_buffer_(threaded_window_buffer),
-        restoration_info_(restoration_info),
-        window_buffer_width_(GetWindowBufferWidth(thread_pool, frame_header)),
-        window_buffer_height_(GetWindowBufferHeight(thread_pool, frame_header)),
-        block_parameters_(*block_parameters),
-        source_buffer_(source_buffer),
-        do_post_filter_mask_(do_post_filter_mask),
-        thread_pool_(thread_pool) {
-    const int8_t zero_delta_lf[kFrameLfCount] = {};
-    ComputeDeblockFilterLevels(zero_delta_lf, deblock_filter_levels_);
-  }
+             FrameScratchBuffer* frame_scratch_buffer, YuvBuffer* frame_buffer,
+             const dsp::Dsp* dsp, int do_post_filter_mask);
 
   // non copyable/movable.
   PostFilter(const PostFilter&) = delete;
@@ -102,63 +84,111 @@
   PostFilter(PostFilter&&) = delete;
   PostFilter& operator=(PostFilter&&) = delete;
 
-  // The overall function that applies all post processing filtering.
+  // The overall function that applies all post processing filtering with
+  // multiple threads.
   // * The filtering order is:
   //   deblock --> CDEF --> super resolution--> loop restoration.
-  // * The output of each filter is the input for the following filter.
-  //   A special case is that loop restoration needs both the deblocked frame
-  //   and the cdef filtered frame:
+  // * The output of each filter is the input for the following filter. A
+  //   special case is that loop restoration needs a few rows of the deblocked
+  //   frame and the entire cdef filtered frame:
   //   deblock --> CDEF --> super resolution --> loop restoration.
   //              |                                 ^
   //              |                                 |
   //              -----------> super resolution -----
   // * Any of these filters could be present or absent.
-  // Two pointers are used in this class: source_buffer_ and cdef_buffer_.
-  // * source_buffer_ always points to the input frame buffer, which holds the
-  //   (upscaled) deblocked frame buffer during the ApplyFiltering() method.
-  // * cdef_buffer_ points to the (upscaled) cdef filtered frame buffer,
-  //   however, if cdef is not present, cdef_buffer_ is the same as
-  //   source_buffer_.
-  // Each filter:
-  // * Deblock: in-place filtering. Input and output are both source_buffer_.
-  // * Cdef: allocates cdef_filtered_buffer_.
-  //         Sets cdef_buffer_ to cdef_filtered_buffer_.
-  //         Input is source_buffer_. Output is cdef_buffer_.
-  // * SuperRes: allocates super_res_buffer_.
-  //             Inputs are source_buffer_ and cdef_buffer_.
-  //             FrameSuperRes takes one input and applies super resolution.
-  //             When FrameSuperRes is called, super_res_buffer_ is the
-  //             intermediate buffer to hold a copy of the input.
-  //             Super resolution process is applied and result
-  //             is written to the input buffer.
-  //             Therefore, contents of inputs are changed, but their meanings
-  //             remain.
-  // * Restoration: near in-place filtering. Allocates a local block for loop
-  //                restoration units, which is 64x64.
-  //                Inputs are source_buffer_ and cdef_buffer_.
-  //                Ouput is source_buffer_.
-  bool ApplyFiltering();
-  bool DoCdef() const { return DoCdef(frame_header_, do_post_filter_mask_); }
+  // * |frame_buffer_| points to the decoded frame buffer. When
+  //   ApplyFilteringThreaded() is called, |frame_buffer_| is modified by each
+  //   of the filters as described below.
+  // Filter behavior (multi-threaded):
+  // * Deblock: In-place filtering. The output is written to |source_buffer_|.
+  //            If cdef and loop restoration are both on, then 4 rows (as
+  //            specified by |kDeblockedRowsForLoopRestoration|) in every 64x64
+  //            block is copied into |deblock_buffer_|.
+  // * Cdef: Filtering output is written into |threaded_window_buffer_| and then
+  //         copied into the |cdef_buffer_| (which is just |source_buffer_| with
+  //         a shift to the top-left).
+  // * SuperRes: Near in-place filtering (with an additional line buffer for
+  //             each row). The output is written to |cdef_buffer_|.
+  // * Restoration: Uses the |cdef_buffer_| and |deblock_buffer_| as the input
+  //                and the output is written into the
+  //                |threaded_window_buffer_|. It is then copied to the
+  //                |loop_restoration_buffer_| (which is just |cdef_buffer_|
+  //                with a shift to the top-left).
+  void ApplyFilteringThreaded();
+
+  // Does the overall post processing filter for one superblock row starting at
+  // |row4x4| with height 4*|sb4x4|. If |do_deblock| is false, deblocking filter
+  // will not be applied.
+  //
+  // Filter behavior (single-threaded):
+  // * Deblock: In-place filtering. The output is written to |source_buffer_|.
+  //            If cdef and loop restoration are both on, then 4 rows (as
+  //            specified by |kDeblockedRowsForLoopRestoration|) in every 64x64
+  //            block is copied into |deblock_buffer_|.
+  // * Cdef: In-place filtering. The output is written into |cdef_buffer_|
+  //         (which is just |source_buffer_| with a shift to the top-left).
+  // * SuperRes: Near in-place filtering (with an additional line buffer for
+  //             each row). The output is written to |cdef_buffer_|.
+  // * Restoration: Near in-place filtering. Uses a local block of size 64x64.
+  //                Uses the |cdef_buffer_| and |deblock_buffer_| as the input
+  //                and the output is written into |loop_restoration_buffer_|
+  //                (which is just |source_buffer_| with a shift to the
+  //                top-left).
+  // Returns the index of the last row whose post processing is complete and can
+  // be used for referencing.
+  int ApplyFilteringForOneSuperBlockRow(int row4x4, int sb4x4, bool is_last_row,
+                                        bool do_deblock);
+
+  // Apply deblocking filter in one direction (specified by |loop_filter_type|)
+  // for the superblock row starting at |row4x4_start| for columns starting from
+  // |column4x4_start| in increments of 16 (or 8 for chroma with subsampling)
+  // until the smallest multiple of 16 that is >= |column4x4_end| or until
+  // |frame_header_.columns4x4|, whichever is lower. This function must be
+  // called only if |DoDeblock()| returns true.
+  void ApplyDeblockFilter(LoopFilterType loop_filter_type, int row4x4_start,
+                          int column4x4_start, int column4x4_end, int sb4x4);
+
   static bool DoCdef(const ObuFrameHeader& frame_header,
                      int do_post_filter_mask) {
-    return (do_post_filter_mask & 0x02) != 0 &&
-           (frame_header.cdef.bits > 0 ||
+    return (frame_header.cdef.bits > 0 ||
             frame_header.cdef.y_primary_strength[0] > 0 ||
             frame_header.cdef.y_secondary_strength[0] > 0 ||
             frame_header.cdef.uv_primary_strength[0] > 0 ||
-            frame_header.cdef.uv_secondary_strength[0] > 0);
+            frame_header.cdef.uv_secondary_strength[0] > 0) &&
+           (do_post_filter_mask & 0x02) != 0;
   }
+  bool DoCdef() const { return DoCdef(frame_header_, do_post_filter_mask_); }
   // If filter levels for Y plane (0 for vertical, 1 for horizontal),
   // are all zero, deblock filter will not be applied.
   static bool DoDeblock(const ObuFrameHeader& frame_header,
                         uint8_t do_post_filter_mask) {
-    return (do_post_filter_mask & 0x01) != 0 &&
-           (frame_header.loop_filter.level[0] > 0 ||
-            frame_header.loop_filter.level[1] > 0);
+    return (frame_header.loop_filter.level[0] > 0 ||
+            frame_header.loop_filter.level[1] > 0) &&
+           (do_post_filter_mask & 0x01) != 0;
   }
   bool DoDeblock() const {
     return DoDeblock(frame_header_, do_post_filter_mask_);
   }
+
+  // This function takes the cdef filtered buffer and the deblocked buffer to
+  // prepare a block as input for loop restoration.
+  // In striped loop restoration:
+  // The filtering needs to fetch the area of size (width + 6) x (height + 4),
+  // in which (width + 6) x height area is from upscaled frame
+  // (superres_buffer). Top 2 rows and bottom 2 rows are from deblocked frame
+  // (deblock_buffer). Special cases are: (1). when it is the top border, the
+  // top 2 rows are from cdef filtered frame. (2). when it is the bottom border,
+  // the bottom 2 rows are from cdef filtered frame. This function is called
+  // only when cdef is applied for this frame.
+  template <typename Pixel>
+  static void PrepareLoopRestorationBlock(const Pixel* src_buffer,
+                                          ptrdiff_t src_stride,
+                                          const Pixel* deblock_buffer,
+                                          ptrdiff_t deblock_stride, Pixel* dst,
+                                          ptrdiff_t dst_stride, int width,
+                                          int height, bool frame_top_border,
+                                          bool frame_bottom_border);
+
   uint8_t GetZeroDeltaDeblockFilterLevel(int segment_id, int level_index,
                                          ReferenceFrameType type,
                                          int mode_id) const {
@@ -170,40 +200,52 @@
       const int8_t delta_lf[kFrameLfCount],
       uint8_t deblock_filter_levels[kMaxSegments][kFrameLfCount]
                                    [kNumReferenceFrameTypes][2]) const;
-  bool DoRestoration() const;
   // Returns true if loop restoration will be performed for the given parameters
   // and mask.
   static bool DoRestoration(const LoopRestoration& loop_restoration,
-                            uint8_t do_post_filter_mask, int num_planes);
-  bool DoSuperRes() const {
-    return (do_post_filter_mask_ & 0x04) != 0 && width_ != upscaled_width_;
+                            uint8_t do_post_filter_mask, int num_planes) {
+    if (num_planes == kMaxPlanesMonochrome) {
+      return loop_restoration.type[kPlaneY] != kLoopRestorationTypeNone &&
+             (do_post_filter_mask & 0x08) != 0;
+    }
+    return (loop_restoration.type[kPlaneY] != kLoopRestorationTypeNone ||
+            loop_restoration.type[kPlaneU] != kLoopRestorationTypeNone ||
+            loop_restoration.type[kPlaneV] != kLoopRestorationTypeNone) &&
+           (do_post_filter_mask & 0x08) != 0;
   }
-  LoopFilterMask* masks() const { return masks_; }
-  LoopRestorationInfo* restoration_info() const { return restoration_info_; }
-  static uint8_t* SetBufferOffset(YuvBuffer* buffer, Plane plane, int row4x4,
-                                  int column4x4, int8_t subsampling_x,
-                                  int8_t subsampling_y) {
-    const size_t pixel_size =
-        (buffer->bitdepth() == 8) ? sizeof(uint8_t) : sizeof(uint16_t);
-    return buffer->data(plane) +
-           RowOrColumn4x4ToPixel(row4x4, plane, subsampling_y) *
-               buffer->stride(plane) +
-           RowOrColumn4x4ToPixel(column4x4, plane, subsampling_x) * pixel_size;
+  bool DoRestoration() const {
+    return DoRestoration(loop_restoration_, do_post_filter_mask_, planes_);
   }
 
-  // Extends frame, sets border pixel values to its closest frame boundary.
-  // Loop restoration needs three pixels border around current block.
-  // If we are at the left boundary of the frame, we extend the frame 3
-  // pixels to the left, and copy current pixel value to them.
-  // If we are at the top boundary of the frame, we need to extend the frame
-  // by three rows. They are copies of the first line of pixels.
-  // Similarly for the right and bottom boundary.
-  // The frame buffer should already be large enough to hold the extension.
-  // Super resolution needs to fill frame border as well. The border size
-  // is kBorderPixels.
-  void ExtendFrameBoundary(uint8_t* frame_start, int width, int height,
-                           ptrdiff_t stride, int left, int right, int top,
-                           int bottom);
+  // Returns a pointer to the unfiltered buffer. This is used by the Tile class
+  // to determine where to write the output of the tile decoding process taking
+  // in-place filtering offsets into consideration.
+  uint8_t* GetUnfilteredBuffer(int plane) { return source_buffer_[plane]; }
+  const YuvBuffer& frame_buffer() const { return frame_buffer_; }
+
+  // Returns true if SuperRes will be performed for the given frame header and
+  // mask.
+  static bool DoSuperRes(const ObuFrameHeader& frame_header,
+                         uint8_t do_post_filter_mask) {
+    return frame_header.width != frame_header.upscaled_width &&
+           (do_post_filter_mask & 0x04) != 0;
+  }
+  bool DoSuperRes() const {
+    return DoSuperRes(frame_header_, do_post_filter_mask_);
+  }
+  LoopRestorationInfo* restoration_info() const { return restoration_info_; }
+  uint8_t* GetBufferOffset(uint8_t* base_buffer, int stride, Plane plane,
+                           int row4x4, int column4x4) const {
+    return base_buffer +
+           RowOrColumn4x4ToPixel(row4x4, plane, subsampling_y_[plane]) *
+               stride +
+           RowOrColumn4x4ToPixel(column4x4, plane, subsampling_x_[plane]) *
+               pixel_size_;
+  }
+  uint8_t* GetSourceBuffer(Plane plane, int row4x4, int column4x4) const {
+    return GetBufferOffset(source_buffer_[plane], frame_buffer_.stride(plane),
+                           plane, row4x4, column4x4);
+  }
 
   static int GetWindowBufferWidth(const ThreadPool* const thread_pool,
                                   const ObuFrameHeader& frame_header) {
@@ -225,63 +267,168 @@
     return std::min(adjusted_frame_height, window_height);
   }
 
+  template <typename Pixel>
+  static void ExtendFrame(Pixel* frame_start, int width, int height,
+                          ptrdiff_t stride, int left, int right, int top,
+                          int bottom);
+
  private:
   // The type of the HorizontalDeblockFilter and VerticalDeblockFilter member
   // functions.
-  using DeblockFilter = void (PostFilter::*)(Plane plane, int row4x4_start,
-                                             int column4x4_start, int unit_id);
-  // Represents a job for a worker thread to apply the deblock filter.
-  struct DeblockFilterJob : public Allocable {
-    int plane;
-    int row4x4;
-    int row_unit;
-  };
-  // Buffers for loop restoration intermediate results. Depending on the filter
-  // type, only one member of the union is used.
-  union IntermediateBuffers {
-    // For Wiener filter.
-    // The array |intermediate| in Section 7.17.4, the intermediate results
-    // between the horizontal and vertical filters.
-    alignas(kMaxAlignment)
-        uint16_t wiener[(kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1) *
-                        kMaxSuperBlockSizeInPixels];
-    // For self-guided filter.
-    struct {
-      // The arrays flt0 and flt1 in Section 7.17.2, the outputs of the box
-      // filter process in pass 0 and pass 1.
-      alignas(
-          kMaxAlignment) int32_t output[2][kMaxBoxFilterProcessOutputPixels];
-      // The 2d arrays A and B in Section 7.17.3, the intermediate results in
-      // the box filter process. Reused for pass 0 and pass 1.
-      alignas(kMaxAlignment) uint32_t
-          intermediate_a[kBoxFilterProcessIntermediatePixels];
-      alignas(kMaxAlignment) uint32_t
-          intermediate_b[kBoxFilterProcessIntermediatePixels];
-    } box_filter;
-  };
+  using DeblockFilter = void (PostFilter::*)(int row4x4_start,
+                                             int column4x4_start);
+  // The lookup table for picking the deblock filter, according to deblock
+  // filter type.
+  const DeblockFilter deblock_filter_func_[2] = {
+      &PostFilter::VerticalDeblockFilter, &PostFilter::HorizontalDeblockFilter};
 
-  bool ApplyDeblockFilter();
-  void DeblockFilterWorker(const DeblockFilterJob* jobs, int num_jobs,
-                           std::atomic<int>* job_counter,
+  // Functions common to all post filters.
+
+  // Extends the frame by setting the border pixel values to the one from its
+  // closest frame boundary.
+  void ExtendFrameBoundary(uint8_t* frame_start, int width, int height,
+                           ptrdiff_t stride, int left, int right, int top,
+                           int bottom) const;
+  // Extend frame boundary for referencing if the frame will be saved as a
+  // reference frame.
+  void ExtendBordersForReferenceFrame();
+  // Copies the deblocked pixels needed for loop restoration.
+  void CopyDeblockedPixels(Plane plane, int row4x4);
+  // Copies the border for one superblock row. If |for_loop_restoration| is
+  // true, then it assumes that the border extension is being performed for the
+  // input of the loop restoration process. If |for_loop_restoration| is false,
+  // then it assumes that the border extension is being performed for using the
+  // current frame as a reference frame. In this case, |progress_row_| is also
+  // updated.
+  void CopyBordersForOneSuperBlockRow(int row4x4, int sb4x4,
+                                      bool for_loop_restoration);
+  // Sets up the |deblock_buffer_| for loop restoration.
+  void SetupDeblockBuffer(int row4x4_start, int sb4x4);
+  // Returns true if we can perform border extension in loop (i.e.) without
+  // waiting until the entire frame is decoded. If intra_block_copy is true, we
+  // do in-loop border extension only if the upscaled_width is the same as 4 *
+  // columns4x4. Otherwise, we cannot do in loop border extension since those
+  // pixels may be used by intra block copy.
+  bool DoBorderExtensionInLoop() const {
+    return !frame_header_.allow_intrabc ||
+           frame_header_.upscaled_width ==
+               MultiplyBy4(frame_header_.columns4x4);
+  }
+  template <typename Pixel>
+  void CopyPlane(const Pixel* src, ptrdiff_t src_stride, int width, int height,
+                 Pixel* dst, ptrdiff_t dst_stride) {
+    for (int y = 0; y < height; ++y) {
+      memcpy(dst, src, width * sizeof(Pixel));
+      src += src_stride;
+      dst += dst_stride;
+    }
+  }
+
+  // Functions for the Deblocking filter.
+
+  static int GetIndex(int row4x4) { return DivideBy4(row4x4); }
+  static int GetShift(int row4x4, int column4x4) {
+    return ((row4x4 & 3) << 4) | column4x4;
+  }
+  int GetDeblockUnitId(int row_unit, int column_unit) const {
+    return row_unit * num_64x64_blocks_per_row_ + column_unit;
+  }
+  bool GetHorizontalDeblockFilterEdgeInfo(int row4x4, int column4x4,
+                                          uint8_t* level, int* step,
+                                          int* filter_length) const;
+  void GetHorizontalDeblockFilterEdgeInfoUV(int row4x4, int column4x4,
+                                            uint8_t* level_u, uint8_t* level_v,
+                                            int* step,
+                                            int* filter_length) const;
+  bool GetVerticalDeblockFilterEdgeInfo(int row4x4, int column4x4,
+                                        BlockParameters* const* bp_ptr,
+                                        uint8_t* level, int* step,
+                                        int* filter_length) const;
+  void GetVerticalDeblockFilterEdgeInfoUV(int column4x4,
+                                          BlockParameters* const* bp_ptr,
+                                          uint8_t* level_u, uint8_t* level_v,
+                                          int* step, int* filter_length) const;
+  void HorizontalDeblockFilter(int row4x4_start, int column4x4_start);
+  void VerticalDeblockFilter(int row4x4_start, int column4x4_start);
+  // HorizontalDeblockFilter and VerticalDeblockFilter must have the correct
+  // signature.
+  static_assert(std::is_same<decltype(&PostFilter::HorizontalDeblockFilter),
+                             DeblockFilter>::value,
+                "");
+  static_assert(std::is_same<decltype(&PostFilter::VerticalDeblockFilter),
+                             DeblockFilter>::value,
+                "");
+  // Applies deblock filtering for the superblock row starting at |row4x4| with
+  // a height of 4*|sb4x4|.
+  void ApplyDeblockFilterForOneSuperBlockRow(int row4x4, int sb4x4);
+  void DeblockFilterWorker(int jobs_per_plane, const Plane* planes,
+                           int num_planes, std::atomic<int>* job_counter,
                            DeblockFilter deblock_filter);
-  bool ApplyDeblockFilterThreaded();
+  void ApplyDeblockFilterThreaded();
+
+  // Functions for the cdef filter.
 
   uint8_t* GetCdefBufferAndStride(int start_x, int start_y, int plane,
-                                  int subsampling_x, int subsampling_y,
                                   int window_buffer_plane_size,
-                                  int vertical_shift, int horizontal_shift,
-                                  int* cdef_stride);
+                                  int* cdef_stride) const;
+  // This function prepares the input source block for cdef filtering. The input
+  // source block contains a 12x12 block, with the inner 8x8 as the desired
+  // filter region. It pads the block if the 12x12 block includes out of frame
+  // pixels with a large value. This achieves the required behavior defined in
+  // section 5.11.52 of the spec.
+  template <typename Pixel>
+  void PrepareCdefBlock(int block_width4x4, int block_height4x4, int row4x4,
+                        int column4x4, uint16_t* cdef_source,
+                        ptrdiff_t cdef_stride, bool y_plane);
   template <typename Pixel>
   void ApplyCdefForOneUnit(uint16_t* cdef_block, int index, int block_width4x4,
                            int block_height4x4, int row4x4_start,
                            int column4x4_start);
+  // Helper function used by ApplyCdefForOneSuperBlockRow to avoid some code
+  // duplication.
+  void ApplyCdefForOneSuperBlockRowHelper(int row4x4, int block_height4x4);
+  // Applies cdef filtering for the superblock row starting at |row4x4| with a
+  // height of 4*|sb4x4|.
+  void ApplyCdefForOneSuperBlockRow(int row4x4, int sb4x4, bool is_last_row);
   template <typename Pixel>
   void ApplyCdefForOneRowInWindow(int row, int column);
   template <typename Pixel>
-  bool ApplyCdefThreaded();
-  bool ApplyCdef();  // Sections 7.15 and 7.15.1.
+  void ApplyCdefThreaded();
+  void ApplyCdef();  // Sections 7.15 and 7.15.1.
 
-  bool ApplySuperRes();
+  // Functions for the SuperRes filter.
+
+  // Applies super resolution for the |buffers| for |rows[plane]| rows of each
+  // plane. If |in_place| is true, the line buffer will not be used and the
+  // SuperRes output will be written to a row above the input row. If |in_place|
+  // is false, the line buffer will be used to store a copy of the input and the
+  // output will be written to the same row as the input row.
+  template <bool in_place>
+  void ApplySuperRes(const std::array<uint8_t*, kMaxPlanes>& buffers,
+                     const std::array<int, kMaxPlanes>& strides,
+                     const std::array<int, kMaxPlanes>& rows,
+                     size_t line_buffer_offset);  // Section 7.16.
+  // Applies SuperRes for the superblock row starting at |row4x4| with a height
+  // of 4*|sb4x4|.
+  void ApplySuperResForOneSuperBlockRow(int row4x4, int sb4x4,
+                                        bool is_last_row);
+  void ApplySuperResThreaded();
+
+  // Functions for the Loop Restoration filter.
+
+  template <typename Pixel>
+  void ApplyLoopRestorationForOneRowInWindow(
+      const Pixel* src_buffer, Plane plane, int plane_height, int plane_width,
+      int y, int x, int row, int unit_row, int current_process_unit_height,
+      int plane_unit_size, int window_width,
+      Array2DView<Pixel>* loop_restored_window);
+  // Applies loop restoration for the superblock row starting at |row4x4_start|
+  // with a height of 4*|sb4x4|.
+  template <typename Pixel>
+  void ApplyLoopRestorationSingleThread(int row4x4_start, int sb4x4);
+  void ApplyLoopRestoration(int row4x4_start, int sb4x4);
+  template <typename Pixel>
+  void ApplyLoopRestorationThreaded();
   // Note for ApplyLoopRestoration():
   // First, we must differentiate loop restoration processing unit from loop
   // restoration unit.
@@ -313,48 +460,7 @@
   // then sizes of the first row of processing units are 64x56, 64x56, 12x56,
   // respectively. The second row is 64x64, 64x64, 12x64.
   // The third row is 64x20, 64x20, 12x20.
-  bool ApplyLoopRestoration();
-  template <typename Pixel>
-  bool ApplyLoopRestorationThreaded();
-  template <typename Pixel>
-  void ApplyLoopRestorationForOneRowInWindow(
-      uint8_t* cdef_buffer, ptrdiff_t cdef_buffer_stride,
-      uint8_t* deblock_buffer, ptrdiff_t deblock_buffer_stride, Plane plane,
-      int plane_height, int plane_width, int x, int y, int row, int unit_row,
-      int current_process_unit_height, int process_unit_width, int window_width,
-      int plane_unit_size, int num_horizontal_units);
-  template <typename Pixel>
-  void ApplyLoopRestorationForOneUnit(
-      uint8_t* cdef_buffer, ptrdiff_t cdef_buffer_stride,
-      uint8_t* deblock_buffer, ptrdiff_t deblock_buffer_stride, Plane plane,
-      int plane_height, int unit_id, LoopRestorationType type, int x, int y,
-      int row, int column, int current_process_unit_width,
-      int current_process_unit_height, int plane_process_unit_width,
-      int window_width);
-  static int GetIndex(int row4x4) { return DivideBy4(row4x4); }
-  static int GetShift(int row4x4, int column4x4) {
-    return ((row4x4 & 3) << 4) | column4x4;
-  }
-  int GetDeblockUnitId(int row_unit, int column_unit) const {
-    return row_unit * num_64x64_blocks_per_row_ + column_unit;
-  }
-  void HorizontalDeblockFilter(Plane plane, int row4x4_start,
-                               int column4x4_start, int unit_id);
-  void VerticalDeblockFilter(Plane plane, int row4x4_start, int column4x4_start,
-                             int unit_id);
-  // HorizontalDeblockFilter and VerticalDeblockFilter must have the correct
-  // signature.
-  static_assert(std::is_same<decltype(&PostFilter::HorizontalDeblockFilter),
-                             DeblockFilter>::value,
-                "");
-  static_assert(std::is_same<decltype(&PostFilter::VerticalDeblockFilter),
-                             DeblockFilter>::value,
-                "");
-  void InitDeblockFilterParams();  // Part of 7.14.4.
-  void GetDeblockFilterParams(uint8_t level, int* outer_thresh,
-                              int* inner_thresh, int* hev_thresh) const;
-  // Applies super resolution and writes result to input_buffer.
-  void FrameSuperRes(YuvBuffer* input_buffer);  // Section 7.16.
+  void ApplyLoopRestoration();
 
   const ObuFrameHeader& frame_header_;
   const LoopRestoration& loop_restoration_;
@@ -364,23 +470,27 @@
   const int width_;
   const int height_;
   const int8_t bitdepth_;
-  const int8_t subsampling_x_;
-  const int8_t subsampling_y_;
+  const int8_t subsampling_x_[kMaxPlanes];
+  const int8_t subsampling_y_[kMaxPlanes];
   const int8_t planes_;
   const int pixel_size_;
-  // This class does not take ownership of the masks/restoration_info, but it
-  // could change their values.
-  LoopFilterMask* const masks_;
-  uint8_t inner_thresh_[kMaxLoopFilterValue + 1] = {};
-  uint8_t outer_thresh_[kMaxLoopFilterValue + 1] = {};
-  uint8_t hev_thresh_[kMaxLoopFilterValue + 1] = {};
+  const uint8_t* const inner_thresh_;
+  const uint8_t* const outer_thresh_;
+  const bool needs_chroma_deblock_;
   // This stores the deblocking filter levels assuming that the delta is zero.
   // This will be used by all superblocks whose delta is zero (without having to
   // recompute them). The dimensions (in order) are: segment_id, level_index
   // (based on plane and direction), reference_frame and mode_id.
   uint8_t deblock_filter_levels_[kMaxSegments][kFrameLfCount]
                                 [kNumReferenceFrameTypes][2];
+  // Stores the SuperRes info for the frame.
+  struct {
+    int upscaled_width;
+    int initial_subpixel_x;
+    int step;
+  } super_res_info_[kMaxPlanes];
   const Array2D<int16_t>& cdef_index_;
+  const Array2D<TransformSize>& inter_transform_sizes_;
   // Pointer to the data buffer used for multi-threaded cdef or loop
   // restoration. The size of this buffer must be at least
   // |window_buffer_width_| * |window_buffer_height_| * |pixel_size_|.
@@ -389,39 +499,45 @@
   // nullptr as well.
   uint8_t* const threaded_window_buffer_;
   LoopRestorationInfo* const restoration_info_;
-  const int window_buffer_width_;
-  const int window_buffer_height_;
+  // Pointer to the line buffer used by ApplySuperRes(). If SuperRes is on, then
+  // the buffer will be large enough to hold one downscaled row +
+  // 2 * kSuperResHorizontalBorder + kSuperResHorizontalPadding.
+  uint8_t* const superres_line_buffer_;
   const BlockParametersHolder& block_parameters_;
   // Frame buffer to hold cdef filtered frame.
   YuvBuffer cdef_filtered_buffer_;
-  // Frame buffer to hold the copy of the buffer to be upscaled,
-  // allocated only when super res is required.
-  YuvBuffer super_res_buffer_;
-  // Input frame buffer. During ApplyFiltering(), it holds the (upscaled)
-  // deblocked frame.
-  // When ApplyFiltering() is done, it holds the final output of PostFilter.
-  YuvBuffer* const source_buffer_;
-  // Frame buffer pointer. It always points to (upscaled) cdef filtered frame.
-  // Set in ApplyCdef(). If cdef is not present, in ApplyLoopRestoration(),
-  // cdef_buffer_ is the same as source_buffer_.
-  YuvBuffer* cdef_buffer_ = nullptr;
+  // Input frame buffer.
+  YuvBuffer& frame_buffer_;
+  // A view into |frame_buffer_| that points to the input and output of the
+  // deblocking process.
+  uint8_t* source_buffer_[kMaxPlanes];
+  // A view into |frame_buffer_| that points to the output of the CDEF filtered
+  // planes (to facilitate in-place CDEF filtering).
+  uint8_t* cdef_buffer_[kMaxPlanes];
+  // A view into |frame_buffer_| that points to the planes after the SuperRes
+  // filter is applied (to facilitate in-place SuperRes).
+  uint8_t* superres_buffer_[kMaxPlanes];
+  // A view into |frame_buffer_| that points to the output of the Loop Restored
+  // planes (to facilitate in-place Loop Restoration).
+  uint8_t* loop_restoration_buffer_[kMaxPlanes];
+  // Buffer used to store the deblocked pixels that are necessary for loop
+  // restoration. This buffer will store 4 rows for every 64x64 block (4 rows
+  // for every 32x32 for chroma with subsampling). The indices of the rows that
+  // are stored are specified in |kDeblockedRowsForLoopRestoration|. First 4
+  // rows of this buffer are never populated and never used.
+  // This buffer is used only when both Cdef and Loop Restoration are on.
+  YuvBuffer& deblock_buffer_;
   const uint8_t do_post_filter_mask_;
-
   ThreadPool* const thread_pool_;
+  const int window_buffer_width_;
+  const int window_buffer_height_;
 
-  // A small buffer to hold input source image block for loop restoration.
-  // Its size is one processing unit size + borders.
-  // Self-guided filter needs an extra one-pixel border.
-  // Wiener filter needs extended border of three pixels.
-  // Therefore the size of the buffer is 70x70 pixels.
-  alignas(alignof(uint16_t)) uint8_t
-      block_buffer_[kRestorationProcessingUnitSizeWithBorders *
-                    kRestorationProcessingUnitSizeWithBorders *
-                    sizeof(uint16_t)];
+  // Tracks the progress of the post filters.
+  int progress_row_ = -1;
+
   // A block buffer to hold the input that is converted to uint16_t before
   // cdef filtering. Only used in single threaded case.
-  uint16_t cdef_block_[kRestorationProcessingUnitSizeWithBorders *
-                       kRestorationProcessingUnitSizeWithBorders * 3];
+  uint16_t cdef_block_[kCdefUnitSizeWithBorders * kCdefUnitSizeWithBorders * 3];
 
   template <int bitdepth, typename Pixel>
   friend class PostFilterSuperResTest;
@@ -430,171 +546,29 @@
   friend class PostFilterHelperFuncTest;
 };
 
-// This function takes the cdef filtered buffer and the deblocked buffer to
-// prepare a block as input for loop restoration.
-// In striped loop restoration:
-// The filtering needs to fetch the area of size (width + 6) x (height + 6),
-// in which (width + 6) x height area is from cdef filtered frame
-// (cdef_buffer). Top 3 rows and bottom 3 rows are from deblocked frame
-// (deblock_buffer).
-// Special cases are:
-// (1). when it is the top border, the top 3 rows are from cdef
-// filtered frame.
-// (2). when it is the bottom border, the bottom 3 rows are from cdef
-// filtered frame.
-// For the top 3 rows and bottom 3 rows, the top_row[0] is a copy of the
-// top_row[1]. The bottom_row[2] is a copy of the bottom_row[1]. If cdef is
-// not applied for this frame, cdef_buffer is the same as deblock_buffer.
-template <typename Pixel>
-void PrepareLoopRestorationBlock(const uint8_t* cdef_buffer,
-                                 ptrdiff_t cdef_stride,
-                                 const uint8_t* deblock_buffer,
-                                 ptrdiff_t deblock_stride, uint8_t* dest,
-                                 ptrdiff_t dest_stride, const int width,
-                                 const int height, const bool frame_top_border,
-                                 const bool frame_bottom_border) {
-  const auto* cdef_ptr = reinterpret_cast<const Pixel*>(cdef_buffer);
-  cdef_stride /= sizeof(Pixel);
-  const auto* deblock_ptr = reinterpret_cast<const Pixel*>(deblock_buffer);
-  deblock_stride /= sizeof(Pixel);
-  auto* dst = reinterpret_cast<Pixel*>(dest);
-  dest_stride /= sizeof(Pixel);
-  // Top 3 rows.
-  cdef_ptr -= (kRestorationBorder - 1) * cdef_stride + kRestorationBorder;
-  deblock_ptr -= (kRestorationBorder - 1) * deblock_stride + kRestorationBorder;
-  for (int i = 0; i < kRestorationBorder; ++i) {
-    if (frame_top_border) {
-      memcpy(dst, cdef_ptr, sizeof(Pixel) * (width + 2 * kRestorationBorder));
-    } else {
-      memcpy(dst, deblock_ptr,
-             sizeof(Pixel) * (width + 2 * kRestorationBorder));
-    }
-    if (i > 0) {
-      cdef_ptr += cdef_stride;
-      deblock_ptr += deblock_stride;
-    }
-    dst += dest_stride;
-  }
-  // Main body.
-  for (int i = 0; i < height; ++i) {
-    memcpy(dst, cdef_ptr, sizeof(Pixel) * (width + 2 * kRestorationBorder));
-    cdef_ptr += cdef_stride;
-    dst += dest_stride;
-  }
-  // Bottom 3 rows.
-  deblock_ptr += height * deblock_stride;
-  for (int i = 0; i < kRestorationBorder; ++i) {
-    if (frame_bottom_border) {
-      memcpy(dst, cdef_ptr, sizeof(Pixel) * (width + 2 * kRestorationBorder));
-    } else {
-      memcpy(dst, deblock_ptr,
-             sizeof(Pixel) * (width + 2 * kRestorationBorder));
-    }
-    if (i < kRestorationBorder - 2) {
-      cdef_ptr += cdef_stride;
-      deblock_ptr += deblock_stride;
-    }
-    dst += dest_stride;
-  }
-}
+extern template void PostFilter::ExtendFrame<uint8_t>(uint8_t* frame_start,
+                                                      int width, int height,
+                                                      ptrdiff_t stride,
+                                                      int left, int right,
+                                                      int top, int bottom);
+extern template void PostFilter::PrepareLoopRestorationBlock<uint8_t>(
+    const uint8_t* src_buffer, ptrdiff_t src_stride,
+    const uint8_t* deblock_buffer, ptrdiff_t deblock_stride, uint8_t* dst,
+    ptrdiff_t dst_stride, const int width, const int height,
+    const bool frame_top_border, const bool frame_bottom_border);
 
-template <typename Pixel>
-void CopyRows(const Pixel* src, const ptrdiff_t src_stride,
-              const int block_width, const int unit_width,
-              const bool is_frame_top, const bool is_frame_bottom,
-              const bool is_frame_left, const bool is_frame_right,
-              const bool copy_top, const int num_rows, uint16_t* dst,
-              const ptrdiff_t dst_stride) {
-  if (is_frame_top || is_frame_bottom) {
-    if (is_frame_bottom) dst -= kCdefBorder;
-    for (int y = 0; y < num_rows; ++y) {
-      Memset(dst, PostFilter::kCdefLargeValue, unit_width + 2 * kCdefBorder);
-      dst += dst_stride;
-    }
-  } else {
-    if (copy_top) {
-      src -= kCdefBorder * src_stride;
-      dst += kCdefBorder;
-    }
-    for (int y = 0; y < num_rows; ++y) {
-      for (int x = -kCdefBorder; x < 0; ++x) {
-        dst[x] = is_frame_left ? PostFilter::kCdefLargeValue : src[x];
-      }
-      for (int x = 0; x < block_width; ++x) {
-        dst[x] = src[x];
-      }
-      for (int x = block_width; x < unit_width + kCdefBorder; ++x) {
-        dst[x] = is_frame_right ? PostFilter::kCdefLargeValue : src[x];
-      }
-      dst += dst_stride;
-      src += src_stride;
-    }
-  }
-}
-
-// This function prepares the input source block for cdef filtering.
-// The input source block contains a 12x12 block, with the inner 8x8 as the
-// desired filter region.
-// It pads the block if the 12x12 block includes out of frame pixels with
-// a large value.
-// This achieves the required behavior defined in section 5.11.52 of the spec.
-template <typename Pixel>
-void PrepareCdefBlock(const YuvBuffer* const source_buffer, const int planes,
-                      const int subsampling_x, const int subsampling_y,
-                      const int frame_width, const int frame_height,
-                      const int block_width4x4, const int block_height4x4,
-                      const int row_64x64, const int column_64x64,
-                      uint16_t* cdef_source, const ptrdiff_t cdef_stride) {
-  for (int plane = kPlaneY; plane < planes; ++plane) {
-    uint16_t* cdef_src =
-        cdef_source + plane * kRestorationProcessingUnitSizeWithBorders *
-                          kRestorationProcessingUnitSizeWithBorders;
-    const int plane_subsampling_x = (plane == kPlaneY) ? 0 : subsampling_x;
-    const int plane_subsampling_y = (plane == kPlaneY) ? 0 : subsampling_y;
-    const int start_x = MultiplyBy4(column_64x64) >> plane_subsampling_x;
-    const int start_y = MultiplyBy4(row_64x64) >> plane_subsampling_y;
-    const int plane_width =
-        RightShiftWithRounding(frame_width, plane_subsampling_x);
-    const int plane_height =
-        RightShiftWithRounding(frame_height, plane_subsampling_y);
-    const int block_width = MultiplyBy4(block_width4x4) >> plane_subsampling_x;
-    const int block_height =
-        MultiplyBy4(block_height4x4) >> plane_subsampling_y;
-    // unit_width, unit_height are the same as block_width, block_height unless
-    // it reaches the frame boundary, where block_width < 64 or
-    // block_height < 64. unit_width, unit_height guarantee we build blocks on
-    // a multiple of 8.
-    const int unit_width =
-        Align(block_width, (plane_subsampling_x > 0) ? 4 : 8);
-    const int unit_height =
-        Align(block_height, (plane_subsampling_y > 0) ? 4 : 8);
-    const bool is_frame_left = column_64x64 == 0;
-    const bool is_frame_right = start_x + block_width >= plane_width;
-    const bool is_frame_top = row_64x64 == 0;
-    const bool is_frame_bottom = start_y + block_height >= plane_height;
-    const int src_stride = source_buffer->stride(plane) / sizeof(Pixel);
-    const Pixel* src_buffer =
-        reinterpret_cast<const Pixel*>(source_buffer->data(plane)) +
-        start_y * src_stride + start_x;
-    // Copy to the top 2 rows.
-    CopyRows(src_buffer, src_stride, block_width, unit_width, is_frame_top,
-             false, is_frame_left, is_frame_right, true, kCdefBorder, cdef_src,
-             cdef_stride);
-    cdef_src += kCdefBorder * cdef_stride + kCdefBorder;
-
-    // Copy the body.
-    CopyRows(src_buffer, src_stride, block_width, unit_width, false, false,
-             is_frame_left, is_frame_right, false, block_height, cdef_src,
-             cdef_stride);
-    src_buffer += block_height * src_stride;
-    cdef_src += block_height * cdef_stride;
-
-    // Copy to bottom rows.
-    CopyRows(src_buffer, src_stride, block_width, unit_width, false,
-             is_frame_bottom, is_frame_left, is_frame_right, false,
-             kCdefBorder + unit_height - block_height, cdef_src, cdef_stride);
-  }
-}
+#if LIBGAV1_MAX_BITDEPTH >= 10
+extern template void PostFilter::ExtendFrame<uint16_t>(uint16_t* frame_start,
+                                                       int width, int height,
+                                                       ptrdiff_t stride,
+                                                       int left, int right,
+                                                       int top, int bottom);
+extern template void PostFilter::PrepareLoopRestorationBlock<uint16_t>(
+    const uint16_t* src_buffer, ptrdiff_t src_stride,
+    const uint16_t* deblock_buffer, ptrdiff_t deblock_stride, uint16_t* dst,
+    ptrdiff_t dst_stride, const int width, const int height,
+    const bool frame_top_border, const bool frame_bottom_border);
+#endif
 
 }  // namespace libgav1
 
diff --git a/libgav1/src/post_filter/cdef.cc b/libgav1/src/post_filter/cdef.cc
new file mode 100644
index 0000000..9b6bb00
--- /dev/null
+++ b/libgav1/src/post_filter/cdef.cc
@@ -0,0 +1,571 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "src/post_filter.h"
+#include "src/utils/blocking_counter.h"
+#include "src/utils/compiler_attributes.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace {
+
+constexpr int kStep64x64 = 16;  // =64/4.
+constexpr int kCdefSkip = 8;
+
+constexpr uint8_t kCdefUvDirection[2][2][8] = {
+    {{0, 1, 2, 3, 4, 5, 6, 7}, {1, 2, 2, 2, 3, 4, 6, 0}},
+    {{7, 0, 2, 4, 5, 6, 6, 6}, {0, 1, 2, 3, 4, 5, 6, 7}}};
+
+template <typename Pixel>
+void CopyRowForCdef(const Pixel* src, int block_width, int unit_width,
+                    bool is_frame_left, bool is_frame_right,
+                    uint16_t* const dst) {
+  if (sizeof(src[0]) == sizeof(dst[0])) {
+    if (is_frame_left) {
+      Memset(dst - kCdefBorder, kCdefLargeValue, kCdefBorder);
+    } else {
+      memcpy(dst - kCdefBorder, src - kCdefBorder,
+             kCdefBorder * sizeof(dst[0]));
+    }
+    memcpy(dst, src, block_width * sizeof(dst[0]));
+    if (is_frame_right) {
+      Memset(dst + block_width, kCdefLargeValue,
+             unit_width + kCdefBorder - block_width);
+    } else {
+      memcpy(dst + block_width, src + block_width,
+             (unit_width + kCdefBorder - block_width) * sizeof(dst[0]));
+    }
+    return;
+  }
+  for (int x = -kCdefBorder; x < 0; ++x) {
+    dst[x] = is_frame_left ? static_cast<uint16_t>(kCdefLargeValue) : src[x];
+  }
+  for (int x = 0; x < block_width; ++x) {
+    dst[x] = src[x];
+  }
+  for (int x = block_width; x < unit_width + kCdefBorder; ++x) {
+    dst[x] = is_frame_right ? static_cast<uint16_t>(kCdefLargeValue) : src[x];
+  }
+}
+
+// For |height| rows, copy |width| pixels of size |pixel_size| from |src| to
+// |dst|.
+void CopyPixels(const uint8_t* src, int src_stride, uint8_t* dst,
+                int dst_stride, int width, int height, size_t pixel_size) {
+  int y = height;
+  do {
+    memcpy(dst, src, width * pixel_size);
+    src += src_stride;
+    dst += dst_stride;
+  } while (--y != 0);
+}
+
+}  // namespace
+
+uint8_t* PostFilter::GetCdefBufferAndStride(const int start_x,
+                                            const int start_y, const int plane,
+                                            const int window_buffer_plane_size,
+                                            int* cdef_stride) const {
+  if (thread_pool_ != nullptr) {
+    // write output to threaded_window_buffer.
+    *cdef_stride = window_buffer_width_ * pixel_size_;
+    const int column_window =
+        start_x % (window_buffer_width_ >> subsampling_x_[plane]);
+    const int row_window =
+        start_y % (window_buffer_height_ >> subsampling_y_[plane]);
+    return threaded_window_buffer_ + plane * window_buffer_plane_size +
+           row_window * (*cdef_stride) + column_window * pixel_size_;
+  }
+  // write output to |cdef_buffer_|.
+  *cdef_stride = frame_buffer_.stride(plane);
+  return cdef_buffer_[plane] + start_y * (*cdef_stride) + start_x * pixel_size_;
+}
+
+template <typename Pixel>
+void PostFilter::PrepareCdefBlock(int block_width4x4, int block_height4x4,
+                                  int row4x4, int column4x4,
+                                  uint16_t* cdef_source, ptrdiff_t cdef_stride,
+                                  const bool y_plane) {
+  assert(y_plane || planes_ == kMaxPlanes);
+  const int max_planes = y_plane ? 1 : kMaxPlanes;
+  const int8_t subsampling_x = y_plane ? 0 : subsampling_x_[kPlaneU];
+  const int8_t subsampling_y = y_plane ? 0 : subsampling_y_[kPlaneU];
+  const int start_x = MultiplyBy4(column4x4) >> subsampling_x;
+  const int start_y = MultiplyBy4(row4x4) >> subsampling_y;
+  const int plane_width = RightShiftWithRounding(width_, subsampling_x);
+  const int plane_height = RightShiftWithRounding(height_, subsampling_y);
+  const int block_width = MultiplyBy4(block_width4x4) >> subsampling_x;
+  const int block_height = MultiplyBy4(block_height4x4) >> subsampling_y;
+  // unit_width, unit_height are the same as block_width, block_height unless
+  // it reaches the frame boundary, where block_width < 64 or
+  // block_height < 64. unit_width, unit_height guarantee we build blocks on
+  // a multiple of 8.
+  const int unit_width = Align(block_width, 8 >> subsampling_x);
+  const int unit_height = Align(block_height, 8 >> subsampling_y);
+  const bool is_frame_left = column4x4 == 0;
+  const bool is_frame_right = start_x + block_width >= plane_width;
+  const bool is_frame_top = row4x4 == 0;
+  const bool is_frame_bottom = start_y + block_height >= plane_height;
+  const int y_offset = is_frame_top ? 0 : kCdefBorder;
+
+  for (int plane = y_plane ? kPlaneY : kPlaneU; plane < max_planes; ++plane) {
+    uint16_t* cdef_src = cdef_source + plane * kCdefUnitSizeWithBorders *
+                                           kCdefUnitSizeWithBorders;
+    const int src_stride = frame_buffer_.stride(plane) / sizeof(Pixel);
+    const Pixel* src_buffer =
+        reinterpret_cast<const Pixel*>(source_buffer_[plane]) +
+        (start_y - y_offset) * src_stride + start_x;
+
+    // All the copying code will use negative indices for populating the left
+    // border. So the starting point is set to kCdefBorder.
+    cdef_src += kCdefBorder;
+
+    // Copy the top 2 rows.
+    if (is_frame_top) {
+      for (int y = 0; y < kCdefBorder; ++y) {
+        Memset(cdef_src - kCdefBorder, kCdefLargeValue,
+               unit_width + 2 * kCdefBorder);
+        cdef_src += cdef_stride;
+      }
+    } else {
+      for (int y = 0; y < kCdefBorder; ++y) {
+        CopyRowForCdef(src_buffer, block_width, unit_width, is_frame_left,
+                       is_frame_right, cdef_src);
+        src_buffer += src_stride;
+        cdef_src += cdef_stride;
+      }
+    }
+
+    // Copy the body.
+    int y = block_height;
+    do {
+      CopyRowForCdef(src_buffer, block_width, unit_width, is_frame_left,
+                     is_frame_right, cdef_src);
+      cdef_src += cdef_stride;
+      src_buffer += src_stride;
+    } while (--y != 0);
+
+    // Copy the bottom 2 rows.
+    if (is_frame_bottom) {
+      do {
+        Memset(cdef_src - kCdefBorder, kCdefLargeValue,
+               unit_width + 2 * kCdefBorder);
+        cdef_src += cdef_stride;
+      } while (++y < kCdefBorder + unit_height - block_height);
+    } else {
+      do {
+        CopyRowForCdef(src_buffer, block_width, unit_width, is_frame_left,
+                       is_frame_right, cdef_src);
+        src_buffer += src_stride;
+        cdef_src += cdef_stride;
+      } while (++y < kCdefBorder + unit_height - block_height);
+    }
+  }
+}
+
+template <typename Pixel>
+void PostFilter::ApplyCdefForOneUnit(uint16_t* cdef_block, const int index,
+                                     const int block_width4x4,
+                                     const int block_height4x4,
+                                     const int row4x4_start,
+                                     const int column4x4_start) {
+  // Cdef operates in 8x8 blocks (4x4 for chroma with subsampling).
+  static constexpr int kStep = 8;
+  static constexpr int kStep4x4 = 2;
+
+  const int window_buffer_plane_size =
+      window_buffer_width_ * window_buffer_height_ * sizeof(Pixel);
+  int cdef_buffer_row_base_stride[kMaxPlanes];
+  int cdef_buffer_stride[kMaxPlanes];
+  uint8_t* cdef_buffer_row_base[kMaxPlanes];
+  int src_buffer_row_base_stride[kMaxPlanes];
+  const uint8_t* src_buffer_row_base[kMaxPlanes];
+  int column_step[kMaxPlanes];
+  assert(planes_ >= 1);
+  for (int plane = kPlaneY; plane < planes_; ++plane) {
+    const int start_y = MultiplyBy4(row4x4_start) >> subsampling_y_[plane];
+    const int start_x = MultiplyBy4(column4x4_start) >> subsampling_x_[plane];
+    cdef_buffer_row_base[plane] = GetCdefBufferAndStride(
+        start_x, start_y, plane, window_buffer_plane_size,
+        &cdef_buffer_stride[plane]);
+    cdef_buffer_row_base_stride[plane] =
+        cdef_buffer_stride[plane] * (kStep >> subsampling_y_[plane]);
+    src_buffer_row_base[plane] = source_buffer_[plane] +
+                                 start_y * frame_buffer_.stride(plane) +
+                                 start_x * sizeof(Pixel);
+    src_buffer_row_base_stride[plane] =
+        frame_buffer_.stride(plane) * (kStep >> subsampling_y_[plane]);
+    column_step[plane] = (kStep >> subsampling_x_[plane]) * sizeof(Pixel);
+  }
+
+  if (index == -1) {
+    for (int plane = kPlaneY; plane < planes_; ++plane) {
+      CopyPixels(src_buffer_row_base[plane], frame_buffer_.stride(plane),
+                 cdef_buffer_row_base[plane], cdef_buffer_stride[plane],
+                 MultiplyBy4(block_width4x4) >> subsampling_x_[plane],
+                 MultiplyBy4(block_height4x4) >> subsampling_y_[plane],
+                 sizeof(Pixel));
+    }
+    return;
+  }
+
+  PrepareCdefBlock<Pixel>(block_width4x4, block_height4x4, row4x4_start,
+                          column4x4_start, cdef_block, kCdefUnitSizeWithBorders,
+                          true);
+
+  // Stored direction used during the u/v pass.  If bit 3 is set, then block is
+  // a skip.
+  int direction_y[8 * 8];
+  int y_index = 0;
+
+  const uint8_t y_primary_strength =
+      frame_header_.cdef.y_primary_strength[index];
+  const uint8_t y_secondary_strength =
+      frame_header_.cdef.y_secondary_strength[index];
+  // y_strength_index is 0 for both primary and secondary strengths being
+  // non-zero, 1 for primary only, 2 for secondary only. This will be updated
+  // with y_primary_strength after variance is applied.
+  int y_strength_index = static_cast<int>(y_secondary_strength == 0);
+
+  const bool compute_direction_and_variance =
+      (y_primary_strength | frame_header_.cdef.uv_primary_strength[index]) != 0;
+  BlockParameters* const* bp_row0_base =
+      block_parameters_.Address(row4x4_start, column4x4_start);
+  BlockParameters* const* bp_row1_base =
+      bp_row0_base + block_parameters_.columns4x4();
+  const int bp_stride = MultiplyBy2(block_parameters_.columns4x4());
+  int row4x4 = row4x4_start;
+  do {
+    uint8_t* cdef_buffer_base = cdef_buffer_row_base[kPlaneY];
+    const uint8_t* src_buffer_base = src_buffer_row_base[kPlaneY];
+    BlockParameters* const* bp0 = bp_row0_base;
+    BlockParameters* const* bp1 = bp_row1_base;
+    int column4x4 = column4x4_start;
+    do {
+      const int block_width = kStep;
+      const int block_height = kStep;
+      const int cdef_stride = cdef_buffer_stride[kPlaneY];
+      uint8_t* const cdef_buffer = cdef_buffer_base;
+      const int src_stride = frame_buffer_.stride(kPlaneY);
+      const uint8_t* const src_buffer = src_buffer_base;
+
+      const bool skip = (*bp0)->skip && (*(bp0 + 1))->skip && (*bp1)->skip &&
+                        (*(bp1 + 1))->skip;
+
+      if (skip) {  // No cdef filtering.
+        direction_y[y_index] = kCdefSkip;
+        CopyPixels(src_buffer, src_stride, cdef_buffer, cdef_stride,
+                   block_width, block_height, sizeof(Pixel));
+      } else {
+        // Zero out residual skip flag.
+        direction_y[y_index] = 0;
+
+        int variance = 0;
+        if (compute_direction_and_variance) {
+          dsp_.cdef_direction(src_buffer, src_stride, &direction_y[y_index],
+                              &variance);
+        }
+        const int direction =
+            (y_primary_strength == 0) ? 0 : direction_y[y_index];
+        const int variance_strength =
+            ((variance >> 6) != 0) ? std::min(FloorLog2(variance >> 6), 12) : 0;
+        const uint8_t primary_strength =
+            (variance != 0)
+                ? (y_primary_strength * (4 + variance_strength) + 8) >> 4
+                : 0;
+
+        if ((primary_strength | y_secondary_strength) == 0) {
+          CopyPixels(src_buffer, src_stride, cdef_buffer, cdef_stride,
+                     block_width, block_height, sizeof(Pixel));
+        } else {
+          uint16_t* cdef_src =
+              cdef_block + kCdefBorder * kCdefUnitSizeWithBorders + kCdefBorder;
+          cdef_src +=
+              (MultiplyBy4(row4x4 - row4x4_start)) * kCdefUnitSizeWithBorders +
+              (MultiplyBy4(column4x4 - column4x4_start));
+          const int strength_index =
+              y_strength_index | (static_cast<int>(primary_strength == 0) << 1);
+          dsp_.cdef_filters[1][strength_index](
+              cdef_src, kCdefUnitSizeWithBorders, block_height,
+              primary_strength, y_secondary_strength,
+              frame_header_.cdef.damping, direction, cdef_buffer, cdef_stride);
+        }
+      }
+      cdef_buffer_base += column_step[kPlaneY];
+      src_buffer_base += column_step[kPlaneY];
+
+      bp0 += kStep4x4;
+      bp1 += kStep4x4;
+      column4x4 += kStep4x4;
+      y_index++;
+    } while (column4x4 < column4x4_start + block_width4x4);
+
+    cdef_buffer_row_base[kPlaneY] += cdef_buffer_row_base_stride[kPlaneY];
+    src_buffer_row_base[kPlaneY] += src_buffer_row_base_stride[kPlaneY];
+    bp_row0_base += bp_stride;
+    bp_row1_base += bp_stride;
+    row4x4 += kStep4x4;
+  } while (row4x4 < row4x4_start + block_height4x4);
+
+  if (planes_ == kMaxPlanesMonochrome) {
+    return;
+  }
+
+  const uint8_t uv_primary_strength =
+      frame_header_.cdef.uv_primary_strength[index];
+  const uint8_t uv_secondary_strength =
+      frame_header_.cdef.uv_secondary_strength[index];
+
+  if ((uv_primary_strength | uv_secondary_strength) == 0) {
+    for (int plane = kPlaneU; plane <= kPlaneV; ++plane) {
+      CopyPixels(src_buffer_row_base[plane], frame_buffer_.stride(plane),
+                 cdef_buffer_row_base[plane], cdef_buffer_stride[plane],
+                 MultiplyBy4(block_width4x4) >> subsampling_x_[plane],
+                 MultiplyBy4(block_height4x4) >> subsampling_y_[plane],
+                 sizeof(Pixel));
+    }
+    return;
+  }
+
+  PrepareCdefBlock<Pixel>(block_width4x4, block_height4x4, row4x4_start,
+                          column4x4_start, cdef_block, kCdefUnitSizeWithBorders,
+                          false);
+
+  // uv_strength_index is 0 for both primary and secondary strengths being
+  // non-zero, 1 for primary only, 2 for secondary only.
+  const int uv_strength_index =
+      (static_cast<int>(uv_primary_strength == 0) << 1) |
+      static_cast<int>(uv_secondary_strength == 0);
+  for (int plane = kPlaneU; plane <= kPlaneV; ++plane) {
+    const int8_t subsampling_x = subsampling_x_[plane];
+    const int8_t subsampling_y = subsampling_y_[plane];
+    const int block_width = kStep >> subsampling_x;
+    const int block_height = kStep >> subsampling_y;
+    int row4x4 = row4x4_start;
+
+    y_index = 0;
+    do {
+      uint8_t* cdef_buffer_base = cdef_buffer_row_base[plane];
+      const uint8_t* src_buffer_base = src_buffer_row_base[plane];
+      int column4x4 = column4x4_start;
+      do {
+        const int cdef_stride = cdef_buffer_stride[plane];
+        uint8_t* const cdef_buffer = cdef_buffer_base;
+        const int src_stride = frame_buffer_.stride(plane);
+        const uint8_t* const src_buffer = src_buffer_base;
+        const bool skip = (direction_y[y_index] & kCdefSkip) != 0;
+        int dual_cdef = 0;
+
+        if (skip) {  // No cdef filtering.
+          CopyPixels(src_buffer, src_stride, cdef_buffer, cdef_stride,
+                     block_width, block_height, sizeof(Pixel));
+        } else {
+          // Make sure block pair is not out of bounds.
+          if (column4x4 + (kStep4x4 * 2) <= column4x4_start + block_width4x4) {
+            // Enable dual processing if subsampling_x is 1.
+            dual_cdef = subsampling_x;
+          }
+
+          int direction = (uv_primary_strength == 0)
+                              ? 0
+                              : kCdefUvDirection[subsampling_x][subsampling_y]
+                                                [direction_y[y_index]];
+
+          if (dual_cdef != 0) {
+            if (uv_primary_strength &&
+                direction_y[y_index] != direction_y[y_index + 1]) {
+              // Disable dual processing if the second block of the pair does
+              // not have the same direction.
+              dual_cdef = 0;
+            }
+
+            // Disable dual processing if the second block of the pair is a
+            // skip.
+            if (direction_y[y_index + 1] == kCdefSkip) {
+              dual_cdef = 0;
+            }
+          }
+
+          uint16_t* cdef_src = cdef_block + plane * kCdefUnitSizeWithBorders *
+                                                kCdefUnitSizeWithBorders;
+          cdef_src += kCdefBorder * kCdefUnitSizeWithBorders + kCdefBorder;
+          cdef_src +=
+              (MultiplyBy4(row4x4 - row4x4_start) >> subsampling_y) *
+                  kCdefUnitSizeWithBorders +
+              (MultiplyBy4(column4x4 - column4x4_start) >> subsampling_x);
+          // Block width is 8 if either dual_cdef is true or subsampling_x == 0.
+          const int width_index = dual_cdef | (subsampling_x ^ 1);
+          dsp_.cdef_filters[width_index][uv_strength_index](
+              cdef_src, kCdefUnitSizeWithBorders, block_height,
+              uv_primary_strength, uv_secondary_strength,
+              frame_header_.cdef.damping - 1, direction, cdef_buffer,
+              cdef_stride);
+        }
+        // When dual_cdef is set, the above cdef_filter() will process 2 blocks,
+        // so adjust the pointers and indexes for 2 blocks.
+        cdef_buffer_base += column_step[plane] << dual_cdef;
+        src_buffer_base += column_step[plane] << dual_cdef;
+        column4x4 += kStep4x4 << dual_cdef;
+        y_index += 1 << dual_cdef;
+      } while (column4x4 < column4x4_start + block_width4x4);
+
+      cdef_buffer_row_base[plane] += cdef_buffer_row_base_stride[plane];
+      src_buffer_row_base[plane] += src_buffer_row_base_stride[plane];
+      row4x4 += kStep4x4;
+    } while (row4x4 < row4x4_start + block_height4x4);
+  }
+}
+
+void PostFilter::ApplyCdefForOneSuperBlockRowHelper(int row4x4,
+                                                    int block_height4x4) {
+  for (int column4x4 = 0; column4x4 < frame_header_.columns4x4;
+       column4x4 += kStep64x64) {
+    const int index = cdef_index_[DivideBy16(row4x4)][DivideBy16(column4x4)];
+    const int block_width4x4 =
+        std::min(kStep64x64, frame_header_.columns4x4 - column4x4);
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+    if (bitdepth_ >= 10) {
+      ApplyCdefForOneUnit<uint16_t>(cdef_block_, index, block_width4x4,
+                                    block_height4x4, row4x4, column4x4);
+      continue;
+    }
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+    ApplyCdefForOneUnit<uint8_t>(cdef_block_, index, block_width4x4,
+                                 block_height4x4, row4x4, column4x4);
+  }
+}
+
+void PostFilter::ApplyCdefForOneSuperBlockRow(int row4x4_start, int sb4x4,
+                                              bool is_last_row) {
+  assert(row4x4_start >= 0);
+  assert(DoCdef());
+  for (int y = 0; y < sb4x4; y += kStep64x64) {
+    const int row4x4 = row4x4_start + y;
+    if (row4x4 >= frame_header_.rows4x4) return;
+
+    // Apply cdef for the last 8 rows of the previous superblock row.
+    // One exception: If the superblock size is 128x128 and is_last_row is true,
+    // then we simply apply cdef for the entire superblock row without any lag.
+    // In that case, apply cdef for the previous superblock row only during the
+    // first iteration (y == 0).
+    if (row4x4 > 0 && (!is_last_row || y == 0)) {
+      assert(row4x4 >= 16);
+      ApplyCdefForOneSuperBlockRowHelper(row4x4 - 2, 2);
+    }
+
+    // Apply cdef for the current superblock row. If this is the last superblock
+    // row we apply cdef for all the rows, otherwise we leave out the last 8
+    // rows.
+    const int block_height4x4 =
+        std::min(kStep64x64, frame_header_.rows4x4 - row4x4);
+    const int height4x4 = block_height4x4 - (is_last_row ? 0 : 2);
+    if (height4x4 > 0) {
+      ApplyCdefForOneSuperBlockRowHelper(row4x4, height4x4);
+    }
+  }
+}
+
+template <typename Pixel>
+void PostFilter::ApplyCdefForOneRowInWindow(const int row4x4,
+                                            const int column4x4_start) {
+  uint16_t cdef_block[kCdefUnitSizeWithBorders * kCdefUnitSizeWithBorders * 3];
+
+  for (int column4x4_64x64 = 0;
+       column4x4_64x64 < std::min(DivideBy4(window_buffer_width_),
+                                  frame_header_.columns4x4 - column4x4_start);
+       column4x4_64x64 += kStep64x64) {
+    const int column4x4 = column4x4_start + column4x4_64x64;
+    const int index = cdef_index_[DivideBy16(row4x4)][DivideBy16(column4x4)];
+    const int block_width4x4 =
+        std::min(kStep64x64, frame_header_.columns4x4 - column4x4);
+    const int block_height4x4 =
+        std::min(kStep64x64, frame_header_.rows4x4 - row4x4);
+
+    ApplyCdefForOneUnit<Pixel>(cdef_block, index, block_width4x4,
+                               block_height4x4, row4x4, column4x4);
+  }
+}
+
+// Each thread processes one row inside the window.
+// Y, U, V planes are processed together inside one thread.
+template <typename Pixel>
+void PostFilter::ApplyCdefThreaded() {
+  assert((window_buffer_height_ & 63) == 0);
+  const int num_workers = thread_pool_->num_threads();
+  const int window_buffer_plane_size =
+      window_buffer_width_ * window_buffer_height_;
+  const int window_buffer_height4x4 = DivideBy4(window_buffer_height_);
+  for (int row4x4 = 0; row4x4 < frame_header_.rows4x4;
+       row4x4 += window_buffer_height4x4) {
+    const int actual_window_height4x4 =
+        std::min(window_buffer_height4x4, frame_header_.rows4x4 - row4x4);
+    const int vertical_units_per_window =
+        DivideBy16(actual_window_height4x4 + 15);
+    for (int column4x4 = 0; column4x4 < frame_header_.columns4x4;
+         column4x4 += DivideBy4(window_buffer_width_)) {
+      const int jobs_for_threadpool =
+          vertical_units_per_window * num_workers / (num_workers + 1);
+      BlockingCounter pending_jobs(jobs_for_threadpool);
+      int job_count = 0;
+      for (int row64x64 = 0; row64x64 < actual_window_height4x4;
+           row64x64 += kStep64x64) {
+        if (job_count < jobs_for_threadpool) {
+          thread_pool_->Schedule(
+              [this, row4x4, column4x4, row64x64, &pending_jobs]() {
+                ApplyCdefForOneRowInWindow<Pixel>(row4x4 + row64x64, column4x4);
+                pending_jobs.Decrement();
+              });
+        } else {
+          ApplyCdefForOneRowInWindow<Pixel>(row4x4 + row64x64, column4x4);
+        }
+        ++job_count;
+      }
+      pending_jobs.Wait();
+
+      // Copy |threaded_window_buffer_| to |cdef_buffer_|.
+      for (int plane = kPlaneY; plane < planes_; ++plane) {
+        const ptrdiff_t src_stride =
+            frame_buffer_.stride(plane) / sizeof(Pixel);
+        const int plane_row = MultiplyBy4(row4x4) >> subsampling_y_[plane];
+        const int plane_column =
+            MultiplyBy4(column4x4) >> subsampling_x_[plane];
+        int copy_width = std::min(frame_header_.columns4x4 - column4x4,
+                                  DivideBy4(window_buffer_width_));
+        copy_width = MultiplyBy4(copy_width) >> subsampling_x_[plane];
+        int copy_height =
+            std::min(frame_header_.rows4x4 - row4x4, window_buffer_height4x4);
+        copy_height = MultiplyBy4(copy_height) >> subsampling_y_[plane];
+        CopyPlane<Pixel>(
+            reinterpret_cast<const Pixel*>(threaded_window_buffer_) +
+                plane * window_buffer_plane_size,
+            window_buffer_width_, copy_width, copy_height,
+            reinterpret_cast<Pixel*>(cdef_buffer_[plane]) +
+                plane_row * src_stride + plane_column,
+            src_stride);
+      }
+    }
+  }
+}
+
+void PostFilter::ApplyCdef() {
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  if (bitdepth_ >= 10) {
+    ApplyCdefThreaded<uint16_t>();
+    return;
+  }
+#endif
+  ApplyCdefThreaded<uint8_t>();
+}
+
+}  // namespace libgav1
diff --git a/libgav1/src/post_filter/deblock.cc b/libgav1/src/post_filter/deblock.cc
new file mode 100644
index 0000000..c4e0852
--- /dev/null
+++ b/libgav1/src/post_filter/deblock.cc
@@ -0,0 +1,567 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include <atomic>
+
+#include "src/post_filter.h"
+#include "src/utils/blocking_counter.h"
+
+namespace libgav1 {
+namespace {
+
+constexpr uint8_t HevThresh(int level) { return DivideBy16(level); }
+
+// GetLoopFilterSize* functions depend on this exact ordering of the
+// LoopFilterSize enums.
+static_assert(dsp::kLoopFilterSize4 == 0, "");
+static_assert(dsp::kLoopFilterSize6 == 1, "");
+static_assert(dsp::kLoopFilterSize8 == 2, "");
+static_assert(dsp::kLoopFilterSize14 == 3, "");
+
+dsp::LoopFilterSize GetLoopFilterSizeY(int filter_length) {
+  // |filter_length| must be a power of 2.
+  assert((filter_length & (filter_length - 1)) == 0);
+  // This code is the branch free equivalent of:
+  //   if (filter_length == 4) return kLoopFilterSize4;
+  //   if (filter_length == 8) return kLoopFilterSize8;
+  //   return kLoopFilterSize14;
+  return static_cast<dsp::LoopFilterSize>(
+      MultiplyBy2(static_cast<int>(filter_length > 4)) +
+      static_cast<int>(filter_length > 8));
+}
+
+constexpr dsp::LoopFilterSize GetLoopFilterSizeUV(int filter_length) {
+  // For U & V planes, size is kLoopFilterSize4 if |filter_length| is 4,
+  // otherwise size is kLoopFilterSize6.
+  return static_cast<dsp::LoopFilterSize>(filter_length != 4);
+}
+
+bool NonBlockBorderNeedsFilter(const BlockParameters& bp, int filter_id,
+                               uint8_t* const level) {
+  if (bp.deblock_filter_level[filter_id] == 0 || (bp.skip && bp.is_inter)) {
+    return false;
+  }
+  *level = bp.deblock_filter_level[filter_id];
+  return true;
+}
+
+// 7.14.5.
+void ComputeDeblockFilterLevelsHelper(
+    const ObuFrameHeader& frame_header, int segment_id, int level_index,
+    const int8_t delta_lf[kFrameLfCount],
+    uint8_t deblock_filter_levels[kNumReferenceFrameTypes][2]) {
+  const int delta = delta_lf[frame_header.delta_lf.multi ? level_index : 0];
+  uint8_t level = Clip3(frame_header.loop_filter.level[level_index] + delta, 0,
+                        kMaxLoopFilterValue);
+  const auto feature = static_cast<SegmentFeature>(
+      kSegmentFeatureLoopFilterYVertical + level_index);
+  level =
+      Clip3(level + frame_header.segmentation.feature_data[segment_id][feature],
+            0, kMaxLoopFilterValue);
+  if (!frame_header.loop_filter.delta_enabled) {
+    static_assert(sizeof(deblock_filter_levels[0][0]) == 1, "");
+    memset(deblock_filter_levels, level, kNumReferenceFrameTypes * 2);
+    return;
+  }
+  assert(frame_header.loop_filter.delta_enabled);
+  const int shift = level >> 5;
+  deblock_filter_levels[kReferenceFrameIntra][0] = Clip3(
+      level +
+          LeftShift(frame_header.loop_filter.ref_deltas[kReferenceFrameIntra],
+                    shift),
+      0, kMaxLoopFilterValue);
+  // deblock_filter_levels[kReferenceFrameIntra][1] is never used. So it does
+  // not have to be populated.
+  for (int reference_frame = kReferenceFrameIntra + 1;
+       reference_frame < kNumReferenceFrameTypes; ++reference_frame) {
+    for (int mode_id = 0; mode_id < 2; ++mode_id) {
+      deblock_filter_levels[reference_frame][mode_id] = Clip3(
+          level +
+              LeftShift(frame_header.loop_filter.ref_deltas[reference_frame] +
+                            frame_header.loop_filter.mode_deltas[mode_id],
+                        shift),
+          0, kMaxLoopFilterValue);
+    }
+  }
+}
+
+}  // namespace
+
+void PostFilter::ComputeDeblockFilterLevels(
+    const int8_t delta_lf[kFrameLfCount],
+    uint8_t deblock_filter_levels[kMaxSegments][kFrameLfCount]
+                                 [kNumReferenceFrameTypes][2]) const {
+  if (!DoDeblock()) return;
+  for (int segment_id = 0;
+       segment_id < (frame_header_.segmentation.enabled ? kMaxSegments : 1);
+       ++segment_id) {
+    int level_index = 0;
+    for (; level_index < 2; ++level_index) {
+      ComputeDeblockFilterLevelsHelper(
+          frame_header_, segment_id, level_index, delta_lf,
+          deblock_filter_levels[segment_id][level_index]);
+    }
+    for (; level_index < kFrameLfCount; ++level_index) {
+      if (frame_header_.loop_filter.level[level_index] != 0) {
+        ComputeDeblockFilterLevelsHelper(
+            frame_header_, segment_id, level_index, delta_lf,
+            deblock_filter_levels[segment_id][level_index]);
+      }
+    }
+  }
+}
+
+bool PostFilter::GetHorizontalDeblockFilterEdgeInfo(int row4x4, int column4x4,
+                                                    uint8_t* level, int* step,
+                                                    int* filter_length) const {
+  *step = kTransformHeight[inter_transform_sizes_[row4x4][column4x4]];
+  if (row4x4 == 0) return false;
+
+  const BlockParameters* bp = block_parameters_.Find(row4x4, column4x4);
+  const int row4x4_prev = row4x4 - 1;
+  assert(row4x4_prev >= 0);
+  const BlockParameters* bp_prev =
+      block_parameters_.Find(row4x4_prev, column4x4);
+
+  if (bp == bp_prev) {
+    // Not a border.
+    if (!NonBlockBorderNeedsFilter(*bp, 1, level)) return false;
+  } else {
+    const uint8_t level_this = bp->deblock_filter_level[1];
+    *level = level_this;
+    if (level_this == 0) {
+      const uint8_t level_prev = bp_prev->deblock_filter_level[1];
+      if (level_prev == 0) return false;
+      *level = level_prev;
+    }
+  }
+  const int step_prev =
+      kTransformHeight[inter_transform_sizes_[row4x4_prev][column4x4]];
+  *filter_length = std::min(*step, step_prev);
+  return true;
+}
+
+void PostFilter::GetHorizontalDeblockFilterEdgeInfoUV(
+    int row4x4, int column4x4, uint8_t* level_u, uint8_t* level_v, int* step,
+    int* filter_length) const {
+  const int subsampling_x = subsampling_x_[kPlaneU];
+  const int subsampling_y = subsampling_y_[kPlaneU];
+  row4x4 = GetDeblockPosition(row4x4, subsampling_y);
+  column4x4 = GetDeblockPosition(column4x4, subsampling_x);
+  const BlockParameters* bp = block_parameters_.Find(row4x4, column4x4);
+  *level_u = 0;
+  *level_v = 0;
+  *step = kTransformHeight[bp->uv_transform_size];
+  if (row4x4 == subsampling_y) {
+    return;
+  }
+
+  bool need_filter_u = frame_header_.loop_filter.level[kPlaneU + 1] != 0;
+  bool need_filter_v = frame_header_.loop_filter.level[kPlaneV + 1] != 0;
+  assert(need_filter_u || need_filter_v);
+  const int filter_id_u =
+      kDeblockFilterLevelIndex[kPlaneU][kLoopFilterTypeHorizontal];
+  const int filter_id_v =
+      kDeblockFilterLevelIndex[kPlaneV][kLoopFilterTypeHorizontal];
+  const int row4x4_prev = row4x4 - (1 << subsampling_y);
+  assert(row4x4_prev >= 0);
+  const BlockParameters* bp_prev =
+      block_parameters_.Find(row4x4_prev, column4x4);
+
+  if (bp == bp_prev) {
+    // Not a border.
+    const bool skip = bp->skip && bp->is_inter;
+    need_filter_u =
+        need_filter_u && bp->deblock_filter_level[filter_id_u] != 0 && !skip;
+    need_filter_v =
+        need_filter_v && bp->deblock_filter_level[filter_id_v] != 0 && !skip;
+    if (!need_filter_u && !need_filter_v) return;
+    if (need_filter_u) *level_u = bp->deblock_filter_level[filter_id_u];
+    if (need_filter_v) *level_v = bp->deblock_filter_level[filter_id_v];
+    *filter_length = *step;
+    return;
+  }
+
+  // It is a border.
+  if (need_filter_u) {
+    const uint8_t level_u_this = bp->deblock_filter_level[filter_id_u];
+    *level_u = level_u_this;
+    if (level_u_this == 0) {
+      *level_u = bp_prev->deblock_filter_level[filter_id_u];
+    }
+  }
+  if (need_filter_v) {
+    const uint8_t level_v_this = bp->deblock_filter_level[filter_id_v];
+    *level_v = level_v_this;
+    if (level_v_this == 0) {
+      *level_v = bp_prev->deblock_filter_level[filter_id_v];
+    }
+  }
+  const int step_prev = kTransformHeight[bp_prev->uv_transform_size];
+  *filter_length = std::min(*step, step_prev);
+}
+
+bool PostFilter::GetVerticalDeblockFilterEdgeInfo(
+    int row4x4, int column4x4, BlockParameters* const* bp_ptr, uint8_t* level,
+    int* step, int* filter_length) const {
+  const BlockParameters* bp = *bp_ptr;
+  *step = kTransformWidth[inter_transform_sizes_[row4x4][column4x4]];
+  if (column4x4 == 0) return false;
+
+  const int filter_id = 0;
+  const int column4x4_prev = column4x4 - 1;
+  assert(column4x4_prev >= 0);
+  const BlockParameters* bp_prev = *(bp_ptr - 1);
+  if (bp == bp_prev) {
+    // Not a border.
+    if (!NonBlockBorderNeedsFilter(*bp, filter_id, level)) return false;
+  } else {
+    // It is a border.
+    const uint8_t level_this = bp->deblock_filter_level[filter_id];
+    *level = level_this;
+    if (level_this == 0) {
+      const uint8_t level_prev = bp_prev->deblock_filter_level[filter_id];
+      if (level_prev == 0) return false;
+      *level = level_prev;
+    }
+  }
+  const int step_prev =
+      kTransformWidth[inter_transform_sizes_[row4x4][column4x4_prev]];
+  *filter_length = std::min(*step, step_prev);
+  return true;
+}
+
+void PostFilter::GetVerticalDeblockFilterEdgeInfoUV(
+    int column4x4, BlockParameters* const* bp_ptr, uint8_t* level_u,
+    uint8_t* level_v, int* step, int* filter_length) const {
+  const int subsampling_x = subsampling_x_[kPlaneU];
+  column4x4 = GetDeblockPosition(column4x4, subsampling_x);
+  const BlockParameters* bp = *bp_ptr;
+  *level_u = 0;
+  *level_v = 0;
+  *step = kTransformWidth[bp->uv_transform_size];
+  if (column4x4 == subsampling_x) {
+    return;
+  }
+
+  bool need_filter_u = frame_header_.loop_filter.level[kPlaneU + 1] != 0;
+  bool need_filter_v = frame_header_.loop_filter.level[kPlaneV + 1] != 0;
+  assert(need_filter_u || need_filter_v);
+  const int filter_id_u =
+      kDeblockFilterLevelIndex[kPlaneU][kLoopFilterTypeVertical];
+  const int filter_id_v =
+      kDeblockFilterLevelIndex[kPlaneV][kLoopFilterTypeVertical];
+  const BlockParameters* bp_prev = *(bp_ptr - (1 << subsampling_x));
+
+  if (bp == bp_prev) {
+    // Not a border.
+    const bool skip = bp->skip && bp->is_inter;
+    need_filter_u =
+        need_filter_u && bp->deblock_filter_level[filter_id_u] != 0 && !skip;
+    need_filter_v =
+        need_filter_v && bp->deblock_filter_level[filter_id_v] != 0 && !skip;
+    if (!need_filter_u && !need_filter_v) return;
+    if (need_filter_u) *level_u = bp->deblock_filter_level[filter_id_u];
+    if (need_filter_v) *level_v = bp->deblock_filter_level[filter_id_v];
+    *filter_length = *step;
+    return;
+  }
+
+  // It is a border.
+  if (need_filter_u) {
+    const uint8_t level_u_this = bp->deblock_filter_level[filter_id_u];
+    *level_u = level_u_this;
+    if (level_u_this == 0) {
+      *level_u = bp_prev->deblock_filter_level[filter_id_u];
+    }
+  }
+  if (need_filter_v) {
+    const uint8_t level_v_this = bp->deblock_filter_level[filter_id_v];
+    *level_v = level_v_this;
+    if (level_v_this == 0) {
+      *level_v = bp_prev->deblock_filter_level[filter_id_v];
+    }
+  }
+  const int step_prev = kTransformWidth[bp_prev->uv_transform_size];
+  *filter_length = std::min(*step, step_prev);
+}
+
+void PostFilter::HorizontalDeblockFilter(int row4x4_start,
+                                         int column4x4_start) {
+  const int column_step = 1;
+  const size_t src_step = MultiplyBy4(pixel_size_);
+  const ptrdiff_t src_stride = frame_buffer_.stride(kPlaneY);
+  uint8_t* src = GetSourceBuffer(kPlaneY, row4x4_start, column4x4_start);
+  int row_step;
+  uint8_t level;
+  int filter_length;
+
+  for (int column4x4 = 0; column4x4 < kNum4x4InLoopFilterUnit &&
+                          MultiplyBy4(column4x4_start + column4x4) < width_;
+       column4x4 += column_step, src += src_step) {
+    uint8_t* src_row = src;
+    for (int row4x4 = 0; row4x4 < kNum4x4InLoopFilterUnit &&
+                         MultiplyBy4(row4x4_start + row4x4) < height_;
+         row4x4 += row_step) {
+      const bool need_filter = GetHorizontalDeblockFilterEdgeInfo(
+          row4x4_start + row4x4, column4x4_start + column4x4, &level, &row_step,
+          &filter_length);
+      if (need_filter) {
+        const dsp::LoopFilterSize size = GetLoopFilterSizeY(filter_length);
+        dsp_.loop_filters[size][kLoopFilterTypeHorizontal](
+            src_row, src_stride, outer_thresh_[level], inner_thresh_[level],
+            HevThresh(level));
+      }
+      // TODO(chengchen): use shifts instead of multiplication.
+      src_row += row_step * src_stride;
+      row_step = DivideBy4(row_step);
+    }
+  }
+
+  if (needs_chroma_deblock_) {
+    const int8_t subsampling_x = subsampling_x_[kPlaneU];
+    const int8_t subsampling_y = subsampling_y_[kPlaneU];
+    const int column_step = 1 << subsampling_x;
+    const ptrdiff_t src_stride_u = frame_buffer_.stride(kPlaneU);
+    const ptrdiff_t src_stride_v = frame_buffer_.stride(kPlaneV);
+    uint8_t* src_u = GetSourceBuffer(kPlaneU, row4x4_start, column4x4_start);
+    uint8_t* src_v = GetSourceBuffer(kPlaneV, row4x4_start, column4x4_start);
+    int row_step;
+    uint8_t level_u;
+    uint8_t level_v;
+    int filter_length;
+
+    for (int column4x4 = 0; column4x4 < kNum4x4InLoopFilterUnit &&
+                            MultiplyBy4(column4x4_start + column4x4) < width_;
+         column4x4 += column_step, src_u += src_step, src_v += src_step) {
+      uint8_t* src_row_u = src_u;
+      uint8_t* src_row_v = src_v;
+      for (int row4x4 = 0; row4x4 < kNum4x4InLoopFilterUnit &&
+                           MultiplyBy4(row4x4_start + row4x4) < height_;
+           row4x4 += row_step) {
+        GetHorizontalDeblockFilterEdgeInfoUV(
+            row4x4_start + row4x4, column4x4_start + column4x4, &level_u,
+            &level_v, &row_step, &filter_length);
+        if (level_u != 0) {
+          const dsp::LoopFilterSize size = GetLoopFilterSizeUV(filter_length);
+          dsp_.loop_filters[size][kLoopFilterTypeHorizontal](
+              src_row_u, src_stride_u, outer_thresh_[level_u],
+              inner_thresh_[level_u], HevThresh(level_u));
+        }
+        if (level_v != 0) {
+          const dsp::LoopFilterSize size = GetLoopFilterSizeUV(filter_length);
+          dsp_.loop_filters[size][kLoopFilterTypeHorizontal](
+              src_row_v, src_stride_v, outer_thresh_[level_v],
+              inner_thresh_[level_v], HevThresh(level_v));
+        }
+        src_row_u += row_step * src_stride_u;
+        src_row_v += row_step * src_stride_v;
+        row_step = DivideBy4(row_step << subsampling_y);
+      }
+    }
+  }
+}
+
+void PostFilter::VerticalDeblockFilter(int row4x4_start, int column4x4_start) {
+  const ptrdiff_t row_stride = MultiplyBy4(frame_buffer_.stride(kPlaneY));
+  const ptrdiff_t src_stride = frame_buffer_.stride(kPlaneY);
+  uint8_t* src = GetSourceBuffer(kPlaneY, row4x4_start, column4x4_start);
+  int column_step;
+  uint8_t level;
+  int filter_length;
+
+  BlockParameters* const* bp_row_base =
+      block_parameters_.Address(row4x4_start, column4x4_start);
+  const int bp_stride = block_parameters_.columns4x4();
+  for (int row4x4 = 0; row4x4 < kNum4x4InLoopFilterUnit &&
+                       MultiplyBy4(row4x4_start + row4x4) < height_;
+       ++row4x4, src += row_stride, bp_row_base += bp_stride) {
+    uint8_t* src_row = src;
+    BlockParameters* const* bp = bp_row_base;
+    for (int column4x4 = 0; column4x4 < kNum4x4InLoopFilterUnit &&
+                            MultiplyBy4(column4x4_start + column4x4) < width_;
+         column4x4 += column_step, bp += column_step) {
+      const bool need_filter = GetVerticalDeblockFilterEdgeInfo(
+          row4x4_start + row4x4, column4x4_start + column4x4, bp, &level,
+          &column_step, &filter_length);
+      if (need_filter) {
+        const dsp::LoopFilterSize size = GetLoopFilterSizeY(filter_length);
+        dsp_.loop_filters[size][kLoopFilterTypeVertical](
+            src_row, src_stride, outer_thresh_[level], inner_thresh_[level],
+            HevThresh(level));
+      }
+      src_row += column_step * pixel_size_;
+      column_step = DivideBy4(column_step);
+    }
+  }
+
+  if (needs_chroma_deblock_) {
+    const int8_t subsampling_x = subsampling_x_[kPlaneU];
+    const int8_t subsampling_y = subsampling_y_[kPlaneU];
+    const int row_step = 1 << subsampling_y;
+    uint8_t* src_u = GetSourceBuffer(kPlaneU, row4x4_start, column4x4_start);
+    uint8_t* src_v = GetSourceBuffer(kPlaneV, row4x4_start, column4x4_start);
+    const ptrdiff_t src_stride_u = frame_buffer_.stride(kPlaneU);
+    const ptrdiff_t src_stride_v = frame_buffer_.stride(kPlaneV);
+    const ptrdiff_t row_stride_u = MultiplyBy4(frame_buffer_.stride(kPlaneU));
+    const ptrdiff_t row_stride_v = MultiplyBy4(frame_buffer_.stride(kPlaneV));
+    const LoopFilterType type = kLoopFilterTypeVertical;
+    int column_step;
+    uint8_t level_u;
+    uint8_t level_v;
+    int filter_length;
+
+    BlockParameters* const* bp_row_base = block_parameters_.Address(
+        GetDeblockPosition(row4x4_start, subsampling_y),
+        GetDeblockPosition(column4x4_start, subsampling_x));
+    const int bp_stride = block_parameters_.columns4x4() * row_step;
+    for (int row4x4 = 0; row4x4 < kNum4x4InLoopFilterUnit &&
+                         MultiplyBy4(row4x4_start + row4x4) < height_;
+         row4x4 += row_step, src_u += row_stride_u, src_v += row_stride_v,
+             bp_row_base += bp_stride) {
+      uint8_t* src_row_u = src_u;
+      uint8_t* src_row_v = src_v;
+      BlockParameters* const* bp = bp_row_base;
+      for (int column4x4 = 0; column4x4 < kNum4x4InLoopFilterUnit &&
+                              MultiplyBy4(column4x4_start + column4x4) < width_;
+           column4x4 += column_step, bp += column_step) {
+        GetVerticalDeblockFilterEdgeInfoUV(column4x4_start + column4x4, bp,
+                                           &level_u, &level_v, &column_step,
+                                           &filter_length);
+        if (level_u != 0) {
+          const dsp::LoopFilterSize size = GetLoopFilterSizeUV(filter_length);
+          dsp_.loop_filters[size][type](
+              src_row_u, src_stride_u, outer_thresh_[level_u],
+              inner_thresh_[level_u], HevThresh(level_u));
+        }
+        if (level_v != 0) {
+          const dsp::LoopFilterSize size = GetLoopFilterSizeUV(filter_length);
+          dsp_.loop_filters[size][type](
+              src_row_v, src_stride_v, outer_thresh_[level_v],
+              inner_thresh_[level_v], HevThresh(level_v));
+        }
+        src_row_u += column_step * pixel_size_;
+        src_row_v += column_step * pixel_size_;
+        column_step = DivideBy4(column_step << subsampling_x);
+      }
+    }
+  }
+}
+
+void PostFilter::ApplyDeblockFilterForOneSuperBlockRow(int row4x4_start,
+                                                       int sb4x4) {
+  assert(row4x4_start >= 0);
+  assert(DoDeblock());
+  for (int y = 0; y < sb4x4; y += 16) {
+    const int row4x4 = row4x4_start + y;
+    if (row4x4 >= frame_header_.rows4x4) break;
+    int column4x4;
+    for (column4x4 = 0; column4x4 < frame_header_.columns4x4;
+         column4x4 += kNum4x4InLoopFilterUnit) {
+      // First apply vertical filtering
+      VerticalDeblockFilter(row4x4, column4x4);
+
+      // Delay one superblock to apply horizontal filtering.
+      if (column4x4 != 0) {
+        HorizontalDeblockFilter(row4x4, column4x4 - kNum4x4InLoopFilterUnit);
+      }
+    }
+    // Horizontal filtering for the last 64x64 block.
+    HorizontalDeblockFilter(row4x4, column4x4 - kNum4x4InLoopFilterUnit);
+  }
+}
+
+void PostFilter::DeblockFilterWorker(int jobs_per_plane,
+                                     const Plane* /*planes*/,
+                                     int /*num_planes*/,
+                                     std::atomic<int>* job_counter,
+                                     DeblockFilter deblock_filter) {
+  const int total_jobs = jobs_per_plane;
+  int job_index;
+  while ((job_index = job_counter->fetch_add(1, std::memory_order_relaxed)) <
+         total_jobs) {
+    const int row_unit = job_index % jobs_per_plane;
+    const int row4x4 = row_unit * kNum4x4InLoopFilterUnit;
+    for (int column4x4 = 0; column4x4 < frame_header_.columns4x4;
+         column4x4 += kNum4x4InLoopFilterUnit) {
+      (this->*deblock_filter)(row4x4, column4x4);
+    }
+  }
+}
+
+void PostFilter::ApplyDeblockFilterThreaded() {
+  const int jobs_per_plane = DivideBy16(frame_header_.rows4x4 + 15);
+  const int num_workers = thread_pool_->num_threads();
+  std::array<Plane, kMaxPlanes> planes;
+  planes[0] = kPlaneY;
+  int num_planes = 1;
+  for (int plane = kPlaneU; plane < planes_; ++plane) {
+    if (frame_header_.loop_filter.level[plane + 1] != 0) {
+      planes[num_planes++] = static_cast<Plane>(plane);
+    }
+  }
+  // The vertical filters are not dependent on each other. So simply schedule
+  // them for all possible rows.
+  //
+  // The horizontal filter for a row/column depends on the vertical filter being
+  // finished for the blocks to the top and to the right. To work around
+  // this synchronization, we simply wait for the vertical filter to finish for
+  // all rows. Now, the horizontal filters can also be scheduled
+  // unconditionally similar to the vertical filters.
+  //
+  // The only synchronization involved is to know when the each directional
+  // filter is complete for the entire frame.
+  for (const auto& type :
+       {kLoopFilterTypeVertical, kLoopFilterTypeHorizontal}) {
+    const DeblockFilter deblock_filter = deblock_filter_func_[type];
+    std::atomic<int> job_counter(0);
+    BlockingCounter pending_workers(num_workers);
+    for (int i = 0; i < num_workers; ++i) {
+      thread_pool_->Schedule([this, jobs_per_plane, &planes, num_planes,
+                              &job_counter, deblock_filter,
+                              &pending_workers]() {
+        DeblockFilterWorker(jobs_per_plane, planes.data(), num_planes,
+                            &job_counter, deblock_filter);
+        pending_workers.Decrement();
+      });
+    }
+    // Run the jobs on the current thread.
+    DeblockFilterWorker(jobs_per_plane, planes.data(), num_planes, &job_counter,
+                        deblock_filter);
+    // Wait for the threadpool jobs to finish.
+    pending_workers.Wait();
+  }
+}
+
+void PostFilter::ApplyDeblockFilter(LoopFilterType loop_filter_type,
+                                    int row4x4_start, int column4x4_start,
+                                    int column4x4_end, int sb4x4) {
+  assert(row4x4_start >= 0);
+  assert(DoDeblock());
+
+  column4x4_end = std::min(column4x4_end, frame_header_.columns4x4);
+  if (column4x4_start >= column4x4_end) return;
+
+  const DeblockFilter deblock_filter = deblock_filter_func_[loop_filter_type];
+  const int sb_height4x4 =
+      std::min(sb4x4, frame_header_.rows4x4 - row4x4_start);
+  for (int y = 0; y < sb_height4x4; y += kNum4x4InLoopFilterUnit) {
+    const int row4x4 = row4x4_start + y;
+    for (int column4x4 = column4x4_start; column4x4 < column4x4_end;
+         column4x4 += kNum4x4InLoopFilterUnit) {
+      (this->*deblock_filter)(row4x4, column4x4);
+    }
+  }
+}
+
+}  // namespace libgav1
diff --git a/libgav1/src/post_filter/deblock_thresholds.inc b/libgav1/src/post_filter/deblock_thresholds.inc
new file mode 100644
index 0000000..ca12aaa
--- /dev/null
+++ b/libgav1/src/post_filter/deblock_thresholds.inc
@@ -0,0 +1,85 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Thresholds for the deblocking filter. Precomputed values of part of Section
+// 7.14.4 for all possible values of sharpness.
+
+constexpr uint8_t kInnerThresh[8][kMaxLoopFilterValue + 1] = {
+    {1,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
+     16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
+     32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
+     48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63},
+    {1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 8, 8, 8, 8,
+     8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
+     8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8},
+    {1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
+     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+     7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7},
+    {1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
+     6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
+     6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6},
+    {1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
+     5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
+     5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5},
+    {1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4,
+     4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
+     4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4},
+    {1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
+     3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
+     3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3},
+    {1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+     2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+     2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}};
+
+constexpr uint8_t kOuterThresh[8][kMaxLoopFilterValue + 1] = {
+    {5,   7,   10,  13,  16,  19,  22,  25,  28,  31,  34,  37,  40,
+     43,  46,  49,  52,  55,  58,  61,  64,  67,  70,  73,  76,  79,
+     82,  85,  88,  91,  94,  97,  100, 103, 106, 109, 112, 115, 118,
+     121, 124, 127, 130, 133, 136, 139, 142, 145, 148, 151, 154, 157,
+     160, 163, 166, 169, 172, 175, 178, 181, 184, 187, 190, 193},
+    {5,   7,   9,   11,  14,  16,  19,  21,  24,  26,  29,  31,  34,
+     36,  39,  41,  44,  46,  48,  50,  52,  54,  56,  58,  60,  62,
+     64,  66,  68,  70,  72,  74,  76,  78,  80,  82,  84,  86,  88,
+     90,  92,  94,  96,  98,  100, 102, 104, 106, 108, 110, 112, 114,
+     116, 118, 120, 122, 124, 126, 128, 130, 132, 134, 136, 138},
+    {5,   7,   9,   11,  14,  16,  19,  21,  24,  26,  29,  31,  34,
+     36,  39,  41,  43,  45,  47,  49,  51,  53,  55,  57,  59,  61,
+     63,  65,  67,  69,  71,  73,  75,  77,  79,  81,  83,  85,  87,
+     89,  91,  93,  95,  97,  99,  101, 103, 105, 107, 109, 111, 113,
+     115, 117, 119, 121, 123, 125, 127, 129, 131, 133, 135, 137},
+    {5,   7,   9,   11,  14,  16,  19,  21,  24,  26,  29,  31,  34,
+     36,  38,  40,  42,  44,  46,  48,  50,  52,  54,  56,  58,  60,
+     62,  64,  66,  68,  70,  72,  74,  76,  78,  80,  82,  84,  86,
+     88,  90,  92,  94,  96,  98,  100, 102, 104, 106, 108, 110, 112,
+     114, 116, 118, 120, 122, 124, 126, 128, 130, 132, 134, 136},
+    {5,   7,   9,   11,  14,  16,  19,  21,  24,  26,  29,  31,  33,
+     35,  37,  39,  41,  43,  45,  47,  49,  51,  53,  55,  57,  59,
+     61,  63,  65,  67,  69,  71,  73,  75,  77,  79,  81,  83,  85,
+     87,  89,  91,  93,  95,  97,  99,  101, 103, 105, 107, 109, 111,
+     113, 115, 117, 119, 121, 123, 125, 127, 129, 131, 133, 135},
+    {5,   7,   9,   11,  13,  15,  17,  19,  22,  24,  26,  28,  31,
+     33,  35,  37,  40,  42,  44,  46,  48,  50,  52,  54,  56,  58,
+     60,  62,  64,  66,  68,  70,  72,  74,  76,  78,  80,  82,  84,
+     86,  88,  90,  92,  94,  96,  98,  100, 102, 104, 106, 108, 110,
+     112, 114, 116, 118, 120, 122, 124, 126, 128, 130, 132, 134},
+    {5,   7,   9,   11,  13,  15,  17,  19,  22,  24,  26,  28,  31,
+     33,  35,  37,  39,  41,  43,  45,  47,  49,  51,  53,  55,  57,
+     59,  61,  63,  65,  67,  69,  71,  73,  75,  77,  79,  81,  83,
+     85,  87,  89,  91,  93,  95,  97,  99,  101, 103, 105, 107, 109,
+     111, 113, 115, 117, 119, 121, 123, 125, 127, 129, 131, 133},
+    {5,   7,   9,   11,  13,  15,  17,  19,  22,  24,  26,  28,  30,
+     32,  34,  36,  38,  40,  42,  44,  46,  48,  50,  52,  54,  56,
+     58,  60,  62,  64,  66,  68,  70,  72,  74,  76,  78,  80,  82,
+     84,  86,  88,  90,  92,  94,  96,  98,  100, 102, 104, 106, 108,
+     110, 112, 114, 116, 118, 120, 122, 124, 126, 128, 130, 132}};
diff --git a/libgav1/src/post_filter/loop_restoration.cc b/libgav1/src/post_filter/loop_restoration.cc
new file mode 100644
index 0000000..17670b9
--- /dev/null
+++ b/libgav1/src/post_filter/loop_restoration.cc
@@ -0,0 +1,373 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "src/post_filter.h"
+#include "src/utils/blocking_counter.h"
+
+namespace libgav1 {
+namespace {
+
+template <typename Pixel>
+void CopyTwoRows(const Pixel* src, const ptrdiff_t src_stride, Pixel** dst,
+                 const ptrdiff_t dst_stride, const int width) {
+  for (int i = 0; i < kRestorationVerticalBorder; ++i) {
+    memcpy(*dst, src, sizeof(Pixel) * width);
+    src += src_stride;
+    *dst += dst_stride;
+  }
+}
+
+}  // namespace
+
+// static
+template <typename Pixel>
+void PostFilter::PrepareLoopRestorationBlock(
+    const Pixel* src_buffer, const ptrdiff_t src_stride,
+    const Pixel* deblock_buffer, const ptrdiff_t deblock_stride, Pixel* dst,
+    const ptrdiff_t dst_stride, const int width, const int height,
+    const bool frame_top_border, const bool frame_bottom_border) {
+  src_buffer -=
+      kRestorationVerticalBorder * src_stride + kRestorationHorizontalBorder;
+  deblock_buffer -= kRestorationHorizontalBorder;
+  int h = height;
+  // Top 2 rows.
+  if (frame_top_border) {
+    h += kRestorationVerticalBorder;
+  } else {
+    CopyTwoRows<Pixel>(deblock_buffer, deblock_stride, &dst, dst_stride,
+                       width + 2 * kRestorationHorizontalBorder);
+    src_buffer += kRestorationVerticalBorder * src_stride;
+    // If |frame_top_border| is true, then we are in the first superblock row,
+    // so in that case, do not increment |deblock_buffer| since we don't store
+    // anything from the first superblock row into |deblock_buffer|.
+    deblock_buffer += 4 * deblock_stride;
+  }
+  if (frame_bottom_border) h += kRestorationVerticalBorder;
+  // Main body.
+  do {
+    memcpy(dst, src_buffer,
+           sizeof(Pixel) * (width + 2 * kRestorationHorizontalBorder));
+    src_buffer += src_stride;
+    dst += dst_stride;
+  } while (--h != 0);
+  // Bottom 2 rows.
+  if (!frame_bottom_border) {
+    deblock_buffer += kRestorationVerticalBorder * deblock_stride;
+    CopyTwoRows<Pixel>(deblock_buffer, deblock_stride, &dst, dst_stride,
+                       width + 2 * kRestorationHorizontalBorder);
+  }
+}
+
+template void PostFilter::PrepareLoopRestorationBlock<uint8_t>(
+    const uint8_t* src_buffer, ptrdiff_t src_stride,
+    const uint8_t* deblock_buffer, ptrdiff_t deblock_stride, uint8_t* dst,
+    ptrdiff_t dst_stride, const int width, const int height,
+    const bool frame_top_border, const bool frame_bottom_border);
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+template void PostFilter::PrepareLoopRestorationBlock<uint16_t>(
+    const uint16_t* src_buffer, ptrdiff_t src_stride,
+    const uint16_t* deblock_buffer, ptrdiff_t deblock_stride, uint16_t* dst,
+    ptrdiff_t dst_stride, const int width, const int height,
+    const bool frame_top_border, const bool frame_bottom_border);
+#endif
+
+template <typename Pixel>
+void PostFilter::ApplyLoopRestorationForOneRowInWindow(
+    const Pixel* src_buffer, const Plane plane, const int plane_height,
+    const int plane_width, const int y, const int x, const int row,
+    const int unit_row, const int current_process_unit_height,
+    const int plane_unit_size, const int window_width,
+    Array2DView<Pixel>* const loop_restored_window) {
+  const int num_horizontal_units =
+      restoration_info_->num_horizontal_units(static_cast<Plane>(plane));
+  const ptrdiff_t src_stride = frame_buffer_.stride(plane) / sizeof(Pixel);
+  const RestorationUnitInfo* const restoration_info =
+      restoration_info_->loop_restoration_info(static_cast<Plane>(plane),
+                                               unit_row * num_horizontal_units);
+  int unit_column = x / plane_unit_size;
+  src_buffer += (y + row) * src_stride + x;
+  int column = 0;
+  do {
+    const int unit_x = x + column;
+    const int unit_y = y + row;
+    const int current_process_unit_width =
+        std::min(plane_unit_size, plane_width - unit_x);
+    const Pixel* src = src_buffer + column;
+    unit_column = std::min(unit_column, num_horizontal_units - 1);
+    if (restoration_info[unit_column].type == kLoopRestorationTypeNone) {
+      const ptrdiff_t dst_stride = loop_restored_window->columns();
+      Pixel* dst = &(*loop_restored_window)[row][column];
+      for (int k = 0; k < current_process_unit_height; ++k) {
+        if (DoCdef()) {
+          memmove(dst, src, current_process_unit_width * sizeof(Pixel));
+        } else {
+          memcpy(dst, src, current_process_unit_width * sizeof(Pixel));
+        }
+        src += src_stride;
+        dst += dst_stride;
+      }
+    } else {
+      const ptrdiff_t block_buffer_stride = kRestorationUnitWidthWithBorders;
+      // The SIMD implementation of wiener filter over-reads 15 -
+      // |kRestorationHorizontalBorder| bytes, and the SIMD implementation of
+      // self-guided filter over-reads up to 7 bytes which happens when
+      // |current_process_unit_width| equals |kRestorationUnitWidth| - 7, and
+      // the radius of the first pass in sfg is 0. So add 8 extra bytes at the
+      // end of block_buffer for 8 bit.
+      Pixel
+          block_buffer[kRestorationUnitHeightWithBorders * block_buffer_stride +
+                       ((sizeof(Pixel) == 1) ? 15 - kRestorationHorizontalBorder
+                                             : 0)];
+      RestorationBuffer restoration_buffer;
+      const Pixel* source;
+      ptrdiff_t source_stride;
+      if (DoCdef()) {
+        const int deblock_buffer_units = 64 >> subsampling_y_[plane];
+        const auto* const deblock_buffer =
+            reinterpret_cast<const Pixel*>(deblock_buffer_.data(plane));
+        assert(deblock_buffer != nullptr);
+        const ptrdiff_t deblock_buffer_stride =
+            deblock_buffer_.stride(plane) / sizeof(Pixel);
+        const int deblock_unit_y =
+            std::max(MultiplyBy4(Ceil(unit_y, deblock_buffer_units)) - 4, 0);
+        const Pixel* const deblock_unit_buffer =
+            deblock_buffer + deblock_unit_y * deblock_buffer_stride + unit_x;
+        PrepareLoopRestorationBlock<Pixel>(
+            src, src_stride, deblock_unit_buffer, deblock_buffer_stride,
+            block_buffer, block_buffer_stride, current_process_unit_width,
+            current_process_unit_height, unit_y == 0,
+            unit_y + current_process_unit_height >= plane_height);
+        source = block_buffer +
+                 kRestorationVerticalBorder * block_buffer_stride +
+                 kRestorationHorizontalBorder;
+        source_stride = kRestorationUnitWidthWithBorders;
+      } else {
+        source = src;
+        source_stride = src_stride;
+      }
+      const LoopRestorationType type = restoration_info[unit_column].type;
+      assert(type == kLoopRestorationTypeSgrProj ||
+             type == kLoopRestorationTypeWiener);
+      const dsp::LoopRestorationFunc restoration_func =
+          dsp_.loop_restorations[type - 2];
+      restoration_func(source, &(*loop_restored_window)[row][column],
+                       restoration_info[unit_column], source_stride,
+                       loop_restored_window->columns(),
+                       current_process_unit_width, current_process_unit_height,
+                       &restoration_buffer);
+    }
+    ++unit_column;
+    column += plane_unit_size;
+  } while (column < window_width);
+}
+
+template <typename Pixel>
+void PostFilter::ApplyLoopRestorationSingleThread(const int row4x4_start,
+                                                  const int sb4x4) {
+  assert(row4x4_start >= 0);
+  assert(DoRestoration());
+  for (int plane = 0; plane < planes_; ++plane) {
+    if (loop_restoration_.type[plane] == kLoopRestorationTypeNone) {
+      continue;
+    }
+    const ptrdiff_t stride = frame_buffer_.stride(plane) / sizeof(Pixel);
+    const int unit_height_offset =
+        kRestorationUnitOffset >> subsampling_y_[plane];
+    const int plane_height =
+        RightShiftWithRounding(height_, subsampling_y_[plane]);
+    const int plane_width =
+        RightShiftWithRounding(upscaled_width_, subsampling_x_[plane]);
+    const int num_vertical_units =
+        restoration_info_->num_vertical_units(static_cast<Plane>(plane));
+    const int plane_unit_size = loop_restoration_.unit_size[plane];
+    const int plane_process_unit_height =
+        kRestorationUnitHeight >> subsampling_y_[plane];
+    int y = (row4x4_start == 0)
+                ? 0
+                : (MultiplyBy4(row4x4_start) >> subsampling_y_[plane]) -
+                      unit_height_offset;
+    int expected_height = plane_process_unit_height -
+                          ((row4x4_start == 0) ? unit_height_offset : 0);
+    int current_process_unit_height;
+    for (int sb_y = 0; sb_y < sb4x4;
+         sb_y += 16, y += current_process_unit_height) {
+      if (y >= plane_height) break;
+      const int unit_row = std::min((y + unit_height_offset) / plane_unit_size,
+                                    num_vertical_units - 1);
+      current_process_unit_height = std::min(expected_height, plane_height - y);
+      expected_height = plane_process_unit_height;
+      Array2DView<Pixel> loop_restored_window(
+          current_process_unit_height, static_cast<int>(stride),
+          reinterpret_cast<Pixel*>(loop_restoration_buffer_[plane]) +
+              y * stride);
+      ApplyLoopRestorationForOneRowInWindow<Pixel>(
+          reinterpret_cast<Pixel*>(superres_buffer_[plane]),
+          static_cast<Plane>(plane), plane_height, plane_width, y, 0, 0,
+          unit_row, current_process_unit_height, plane_unit_size, plane_width,
+          &loop_restored_window);
+    }
+  }
+}
+
+// Multi-thread version of loop restoration, based on a moving window of size
+// |window_buffer_width_|x|window_buffer_height_|. Inside the moving window, we
+// create a filtering job for each row and each filtering job is submitted to
+// the thread pool. Each free thread takes one job from the thread pool and
+// completes filtering until all jobs are finished. This approach requires an
+// extra buffer (|threaded_window_buffer_|) to hold the filtering output, whose
+// size is the size of the window. It also needs block buffers (i.e.,
+// |block_buffer| in ApplyLoopRestorationForOneRowInWindow()) to store
+// intermediate results in loop restoration for each thread. After all units
+// inside the window are filtered, the output is written to the frame buffer.
+template <typename Pixel>
+void PostFilter::ApplyLoopRestorationThreaded() {
+  const int plane_process_unit_height[kMaxPlanes] = {
+      kRestorationUnitHeight, kRestorationUnitHeight >> subsampling_y_[kPlaneU],
+      kRestorationUnitHeight >> subsampling_y_[kPlaneV]};
+  Array2DView<Pixel> loop_restored_window;
+  if (!DoCdef()) {
+    loop_restored_window.Reset(
+        window_buffer_height_, window_buffer_width_,
+        reinterpret_cast<Pixel*>(threaded_window_buffer_));
+  }
+
+  for (int plane = kPlaneY; plane < planes_; ++plane) {
+    if (loop_restoration_.type[plane] == kLoopRestorationTypeNone) {
+      continue;
+    }
+
+    const int unit_height_offset =
+        kRestorationUnitOffset >> subsampling_y_[plane];
+    auto* const src_buffer = reinterpret_cast<Pixel*>(superres_buffer_[plane]);
+    const ptrdiff_t src_stride = frame_buffer_.stride(plane) / sizeof(Pixel);
+    const int plane_unit_size = loop_restoration_.unit_size[plane];
+    const int num_vertical_units =
+        restoration_info_->num_vertical_units(static_cast<Plane>(plane));
+    const int plane_width =
+        RightShiftWithRounding(upscaled_width_, subsampling_x_[plane]);
+    const int plane_height =
+        RightShiftWithRounding(height_, subsampling_y_[plane]);
+    PostFilter::ExtendFrame<Pixel>(
+        src_buffer, plane_width, plane_height, src_stride,
+        kRestorationHorizontalBorder, kRestorationHorizontalBorder,
+        kRestorationVerticalBorder, kRestorationVerticalBorder);
+
+    const int num_workers = thread_pool_->num_threads();
+    for (int y = 0; y < plane_height; y += window_buffer_height_) {
+      const int actual_window_height =
+          std::min(window_buffer_height_ - ((y == 0) ? unit_height_offset : 0),
+                   plane_height - y);
+      int vertical_units_per_window =
+          (actual_window_height + plane_process_unit_height[plane] - 1) /
+          plane_process_unit_height[plane];
+      if (y == 0) {
+        // The first row of loop restoration processing units is not 64x64, but
+        // 64x56 (|unit_height_offset| = 8 rows less than other restoration
+        // processing units). For u/v with subsampling, the size is halved. To
+        // compute the number of vertical units per window, we need to take a
+        // special handling for it.
+        const int height_without_first_unit =
+            actual_window_height -
+            std::min(actual_window_height,
+                     plane_process_unit_height[plane] - unit_height_offset);
+        vertical_units_per_window =
+            (height_without_first_unit + plane_process_unit_height[plane] - 1) /
+                plane_process_unit_height[plane] +
+            1;
+      }
+      const int jobs_for_threadpool =
+          vertical_units_per_window * num_workers / (num_workers + 1);
+      for (int x = 0; x < plane_width; x += window_buffer_width_) {
+        const int actual_window_width =
+            std::min(window_buffer_width_, plane_width - x);
+        assert(jobs_for_threadpool < vertical_units_per_window);
+        if (DoCdef()) {
+          loop_restored_window.Reset(
+              actual_window_height, static_cast<int>(src_stride),
+              reinterpret_cast<Pixel*>(loop_restoration_buffer_[plane]) +
+                  y * src_stride + x);
+        }
+        BlockingCounter pending_jobs(jobs_for_threadpool);
+        int job_count = 0;
+        int current_process_unit_height;
+        for (int row = 0; row < actual_window_height;
+             row += current_process_unit_height) {
+          const int unit_y = y + row;
+          const int expected_height = plane_process_unit_height[plane] -
+                                      ((unit_y == 0) ? unit_height_offset : 0);
+          current_process_unit_height =
+              std::min(expected_height, plane_height - unit_y);
+          const int unit_row =
+              std::min((unit_y + unit_height_offset) / plane_unit_size,
+                       num_vertical_units - 1);
+
+          if (job_count < jobs_for_threadpool) {
+            thread_pool_->Schedule(
+                [this, src_buffer, plane, plane_height, plane_width, y, x, row,
+                 unit_row, current_process_unit_height, plane_unit_size,
+                 actual_window_width, &loop_restored_window, &pending_jobs]() {
+                  ApplyLoopRestorationForOneRowInWindow<Pixel>(
+                      src_buffer, static_cast<Plane>(plane), plane_height,
+                      plane_width, y, x, row, unit_row,
+                      current_process_unit_height, plane_unit_size,
+                      actual_window_width, &loop_restored_window);
+                  pending_jobs.Decrement();
+                });
+          } else {
+            ApplyLoopRestorationForOneRowInWindow<Pixel>(
+                src_buffer, static_cast<Plane>(plane), plane_height,
+                plane_width, y, x, row, unit_row, current_process_unit_height,
+                plane_unit_size, actual_window_width, &loop_restored_window);
+          }
+          ++job_count;
+        }
+        // Wait for all jobs of current window to finish.
+        pending_jobs.Wait();
+        if (!DoCdef()) {
+          // Copy |threaded_window_buffer_| to output frame.
+          CopyPlane<Pixel>(
+              reinterpret_cast<const Pixel*>(threaded_window_buffer_),
+              window_buffer_width_, actual_window_width, actual_window_height,
+              reinterpret_cast<Pixel*>(loop_restoration_buffer_[plane]) +
+                  y * src_stride + x,
+              src_stride);
+        }
+      }
+      if (y == 0) y -= unit_height_offset;
+    }
+  }
+}
+
+void PostFilter::ApplyLoopRestoration(const int row4x4_start, const int sb4x4) {
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  if (bitdepth_ >= 10) {
+    ApplyLoopRestorationSingleThread<uint16_t>(row4x4_start, sb4x4);
+    return;
+  }
+#endif
+  ApplyLoopRestorationSingleThread<uint8_t>(row4x4_start, sb4x4);
+}
+
+void PostFilter::ApplyLoopRestoration() {
+  assert(threaded_window_buffer_ != nullptr);
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  if (bitdepth_ >= 10) {
+    ApplyLoopRestorationThreaded<uint16_t>();
+    return;
+  }
+#endif
+  ApplyLoopRestorationThreaded<uint8_t>();
+}
+
+}  // namespace libgav1
diff --git a/libgav1/src/post_filter/post_filter.cc b/libgav1/src/post_filter/post_filter.cc
new file mode 100644
index 0000000..6d5ef31
--- /dev/null
+++ b/libgav1/src/post_filter/post_filter.cc
@@ -0,0 +1,435 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/post_filter.h"
+
+#include <algorithm>
+#include <atomic>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/array_2d.h"
+#include "src/utils/constants.h"
+#include "src/utils/memory.h"
+#include "src/utils/types.h"
+
+namespace libgav1 {
+namespace {
+
+// Import all the constants in the anonymous namespace.
+#include "src/post_filter/deblock_thresholds.inc"
+
+// Row indices of deblocked pixels needed by loop restoration. This is used to
+// populate the |deblock_buffer_| when cdef is on. The first dimension is
+// subsampling_y.
+constexpr int kDeblockedRowsForLoopRestoration[2][4] = {{54, 55, 56, 57},
+                                                        {26, 27, 28, 29}};
+
+}  // namespace
+
+// The following example illustrates how ExtendFrame() extends a frame.
+// Suppose the frame width is 8 and height is 4, and left, right, top, and
+// bottom are all equal to 3.
+//
+// Before:
+//
+//       ABCDEFGH
+//       IJKLMNOP
+//       QRSTUVWX
+//       YZabcdef
+//
+// After:
+//
+//   AAA|ABCDEFGH|HHH  [3]
+//   AAA|ABCDEFGH|HHH
+//   AAA|ABCDEFGH|HHH
+//   ---+--------+---
+//   AAA|ABCDEFGH|HHH  [1]
+//   III|IJKLMNOP|PPP
+//   QQQ|QRSTUVWX|XXX
+//   YYY|YZabcdef|fff
+//   ---+--------+---
+//   YYY|YZabcdef|fff  [2]
+//   YYY|YZabcdef|fff
+//   YYY|YZabcdef|fff
+//
+// ExtendFrame() first extends the rows to the left and to the right[1]. Then
+// it copies the extended last row to the bottom borders[2]. Finally it copies
+// the extended first row to the top borders[3].
+// static
+template <typename Pixel>
+void PostFilter::ExtendFrame(Pixel* const frame_start, const int width,
+                             const int height, const ptrdiff_t stride,
+                             const int left, const int right, const int top,
+                             const int bottom) {
+  const Pixel* src = frame_start;
+  Pixel* dst = frame_start - left;
+  // Copy to left and right borders.
+  for (int y = 0; y < height; ++y) {
+    Memset(dst, src[0], left);
+    Memset(dst + left + width, src[width - 1], right);
+    src += stride;
+    dst += stride;
+  }
+  // Copy to bottom borders. For performance we copy |stride| pixels
+  // (including some padding pixels potentially) in each row, ending at the
+  // bottom right border pixel. In the diagram the asterisks indicate padding
+  // pixels.
+  //
+  // |<--- stride --->|
+  // **YYY|YZabcdef|fff <-- Copy from the extended last row.
+  // -----+--------+---
+  // **YYY|YZabcdef|fff
+  // **YYY|YZabcdef|fff
+  // **YYY|YZabcdef|fff <-- bottom right border pixel
+  assert(src == frame_start + height * stride);
+  dst = const_cast<Pixel*>(src) + width + right - stride;
+  src = dst - stride;
+  for (int y = 0; y < bottom; ++y) {
+    memcpy(dst, src, sizeof(Pixel) * stride);
+    dst += stride;
+  }
+  // Copy to top borders. For performance we copy |stride| pixels (including
+  // some padding pixels potentially) in each row, starting from the top left
+  // border pixel. In the diagram the asterisks indicate padding pixels.
+  //
+  // +-- top left border pixel
+  // |
+  // v
+  // AAA|ABCDEFGH|HHH**
+  // AAA|ABCDEFGH|HHH**
+  // AAA|ABCDEFGH|HHH**
+  // ---+--------+-----
+  // AAA|ABCDEFGH|HHH** <-- Copy from the extended first row.
+  // |<--- stride --->|
+  src = frame_start - left;
+  dst = frame_start - left - top * stride;
+  for (int y = 0; y < top; ++y) {
+    memcpy(dst, src, sizeof(Pixel) * stride);
+    dst += stride;
+  }
+}
+
+template void PostFilter::ExtendFrame<uint8_t>(uint8_t* const frame_start,
+                                               const int width,
+                                               const int height,
+                                               const ptrdiff_t stride,
+                                               const int left, const int right,
+                                               const int top, const int bottom);
+
+#if LIBGAV1_MAX_BITDEPTH >= 10
+template void PostFilter::ExtendFrame<uint16_t>(
+    uint16_t* const frame_start, const int width, const int height,
+    const ptrdiff_t stride, const int left, const int right, const int top,
+    const int bottom);
+#endif
+
+PostFilter::PostFilter(const ObuFrameHeader& frame_header,
+                       const ObuSequenceHeader& sequence_header,
+                       FrameScratchBuffer* const frame_scratch_buffer,
+                       YuvBuffer* const frame_buffer, const dsp::Dsp* dsp,
+                       int do_post_filter_mask)
+    : frame_header_(frame_header),
+      loop_restoration_(frame_header.loop_restoration),
+      dsp_(*dsp),
+      // Deblocking filter always uses 64x64 as step size.
+      num_64x64_blocks_per_row_(DivideBy64(frame_header.width + 63)),
+      upscaled_width_(frame_header.upscaled_width),
+      width_(frame_header.width),
+      height_(frame_header.height),
+      bitdepth_(sequence_header.color_config.bitdepth),
+      subsampling_x_{0, sequence_header.color_config.subsampling_x,
+                     sequence_header.color_config.subsampling_x},
+      subsampling_y_{0, sequence_header.color_config.subsampling_y,
+                     sequence_header.color_config.subsampling_y},
+      planes_(sequence_header.color_config.is_monochrome ? kMaxPlanesMonochrome
+                                                         : kMaxPlanes),
+      pixel_size_(static_cast<int>((bitdepth_ == 8) ? sizeof(uint8_t)
+                                                    : sizeof(uint16_t))),
+      inner_thresh_(kInnerThresh[frame_header.loop_filter.sharpness]),
+      outer_thresh_(kOuterThresh[frame_header.loop_filter.sharpness]),
+      needs_chroma_deblock_(frame_header.loop_filter.level[kPlaneU + 1] != 0 ||
+                            frame_header.loop_filter.level[kPlaneV + 1] != 0),
+      cdef_index_(frame_scratch_buffer->cdef_index),
+      inter_transform_sizes_(frame_scratch_buffer->inter_transform_sizes),
+      threaded_window_buffer_(
+          frame_scratch_buffer->threaded_window_buffer.get()),
+      restoration_info_(&frame_scratch_buffer->loop_restoration_info),
+      superres_line_buffer_(frame_scratch_buffer->superres_line_buffer.get()),
+      block_parameters_(frame_scratch_buffer->block_parameters_holder),
+      frame_buffer_(*frame_buffer),
+      deblock_buffer_(frame_scratch_buffer->deblock_buffer),
+      do_post_filter_mask_(do_post_filter_mask),
+      thread_pool_(
+          frame_scratch_buffer->threading_strategy.post_filter_thread_pool()),
+      window_buffer_width_(GetWindowBufferWidth(thread_pool_, frame_header)),
+      window_buffer_height_(GetWindowBufferHeight(thread_pool_, frame_header)) {
+  const int8_t zero_delta_lf[kFrameLfCount] = {};
+  ComputeDeblockFilterLevels(zero_delta_lf, deblock_filter_levels_);
+  if (DoSuperRes()) {
+    for (int plane = 0; plane < planes_; ++plane) {
+      const int downscaled_width =
+          RightShiftWithRounding(width_, subsampling_x_[plane]);
+      const int upscaled_width =
+          RightShiftWithRounding(upscaled_width_, subsampling_x_[plane]);
+      const int superres_width = downscaled_width << kSuperResScaleBits;
+      super_res_info_[plane].step =
+          (superres_width + upscaled_width / 2) / upscaled_width;
+      const int error =
+          super_res_info_[plane].step * upscaled_width - superres_width;
+      super_res_info_[plane].initial_subpixel_x =
+          ((-((upscaled_width - downscaled_width) << (kSuperResScaleBits - 1)) +
+            DivideBy2(upscaled_width)) /
+               upscaled_width +
+           (1 << (kSuperResExtraBits - 1)) - error / 2) &
+          kSuperResScaleMask;
+      super_res_info_[plane].upscaled_width = upscaled_width;
+    }
+  }
+  for (int plane = 0; plane < planes_; ++plane) {
+    loop_restoration_buffer_[plane] = frame_buffer_.data(plane);
+    cdef_buffer_[plane] = frame_buffer_.data(plane);
+    superres_buffer_[plane] = frame_buffer_.data(plane);
+    source_buffer_[plane] = frame_buffer_.data(plane);
+  }
+  // In single threaded mode, we apply SuperRes without making a copy of the
+  // input row by writing the output to one row to the top (we refer to this
+  // process as "in place superres" in our code).
+  const bool in_place_superres = DoSuperRes() && thread_pool_ == nullptr;
+  if (DoCdef() || DoRestoration() || in_place_superres) {
+    for (int plane = 0; plane < planes_; ++plane) {
+      int horizontal_shift = 0;
+      int vertical_shift = 0;
+      if (DoRestoration() &&
+          loop_restoration_.type[plane] != kLoopRestorationTypeNone) {
+        horizontal_shift += frame_buffer_.alignment();
+        if (!DoCdef()) {
+          vertical_shift += kRestorationVerticalBorder;
+        }
+        superres_buffer_[plane] +=
+            vertical_shift * frame_buffer_.stride(plane) +
+            horizontal_shift * pixel_size_;
+      }
+      if (in_place_superres) {
+        vertical_shift += kSuperResVerticalBorder;
+      }
+      cdef_buffer_[plane] += vertical_shift * frame_buffer_.stride(plane) +
+                             horizontal_shift * pixel_size_;
+      if (DoCdef()) {
+        horizontal_shift += frame_buffer_.alignment();
+        vertical_shift += kCdefBorder;
+      }
+      assert(horizontal_shift <= frame_buffer_.right_border(plane));
+      assert(vertical_shift <= frame_buffer_.bottom_border(plane));
+      source_buffer_[plane] += vertical_shift * frame_buffer_.stride(plane) +
+                               horizontal_shift * pixel_size_;
+    }
+  }
+}
+
+void PostFilter::ExtendFrameBoundary(uint8_t* const frame_start,
+                                     const int width, const int height,
+                                     const ptrdiff_t stride, const int left,
+                                     const int right, const int top,
+                                     const int bottom) const {
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  if (bitdepth_ >= 10) {
+    ExtendFrame<uint16_t>(reinterpret_cast<uint16_t*>(frame_start), width,
+                          height, stride / sizeof(uint16_t), left, right, top,
+                          bottom);
+    return;
+  }
+#endif
+  ExtendFrame<uint8_t>(frame_start, width, height, stride, left, right, top,
+                       bottom);
+}
+
+void PostFilter::ExtendBordersForReferenceFrame() {
+  if (frame_header_.refresh_frame_flags == 0) return;
+  for (int plane = kPlaneY; plane < planes_; ++plane) {
+    const int plane_width =
+        RightShiftWithRounding(upscaled_width_, subsampling_x_[plane]);
+    const int plane_height =
+        RightShiftWithRounding(height_, subsampling_y_[plane]);
+    assert(frame_buffer_.left_border(plane) >= kMinLeftBorderPixels &&
+           frame_buffer_.right_border(plane) >= kMinRightBorderPixels &&
+           frame_buffer_.top_border(plane) >= kMinTopBorderPixels &&
+           frame_buffer_.bottom_border(plane) >= kMinBottomBorderPixels);
+    // plane subsampling_x_ left_border
+    //   Y        N/A         64, 48
+    //  U,V        0          64, 48
+    //  U,V        1          32, 16
+    assert(frame_buffer_.left_border(plane) >= 16);
+    // The |left| argument to ExtendFrameBoundary() must be at least
+    // kMinLeftBorderPixels (13) for warp.
+    static_assert(16 >= kMinLeftBorderPixels, "");
+    ExtendFrameBoundary(
+        frame_buffer_.data(plane), plane_width, plane_height,
+        frame_buffer_.stride(plane), frame_buffer_.left_border(plane),
+        frame_buffer_.right_border(plane), frame_buffer_.top_border(plane),
+        frame_buffer_.bottom_border(plane));
+  }
+}
+
+void PostFilter::CopyDeblockedPixels(Plane plane, int row4x4) {
+  const ptrdiff_t src_stride = frame_buffer_.stride(plane);
+  const uint8_t* const src =
+      GetSourceBuffer(static_cast<Plane>(plane), row4x4, 0);
+  const ptrdiff_t dst_stride = deblock_buffer_.stride(plane);
+  const int row_offset = DivideBy4(row4x4);
+  uint8_t* dst = deblock_buffer_.data(plane) + dst_stride * row_offset;
+  const int num_pixels = SubsampledValue(MultiplyBy4(frame_header_.columns4x4),
+                                         subsampling_x_[plane]);
+  int last_valid_row = -1;
+  const int plane_height =
+      SubsampledValue(frame_header_.height, subsampling_y_[plane]);
+  for (int i = 0; i < 4; ++i) {
+    int row = kDeblockedRowsForLoopRestoration[subsampling_y_[plane]][i];
+    const int absolute_row =
+        (MultiplyBy4(row4x4) >> subsampling_y_[plane]) + row;
+    if (absolute_row >= plane_height) {
+      if (last_valid_row == -1) {
+        // We have run out of rows and there no valid row to copy. This will not
+        // be used by loop restoration, so we can simply break here. However,
+        // MSAN does not know that this is never used (since we sometimes apply
+        // superres to this row as well). So zero it out in case of MSAN.
+#if LIBGAV1_MSAN
+        if (DoSuperRes()) {
+          memset(dst, 0, num_pixels * pixel_size_);
+          dst += dst_stride;
+          continue;
+        }
+#endif
+        break;
+      }
+      // If we run out of rows, copy the last valid row (mimics the bottom
+      // border extension).
+      row = last_valid_row;
+    }
+    memcpy(dst, src + src_stride * row, num_pixels * pixel_size_);
+    last_valid_row = row;
+    dst += dst_stride;
+  }
+}
+
+void PostFilter::CopyBordersForOneSuperBlockRow(int row4x4, int sb4x4,
+                                                bool for_loop_restoration) {
+  // Number of rows to be subtracted from the start position described by
+  // row4x4. We always lag by 8 rows (to account for in-loop post filters).
+  const int row_offset = (row4x4 == 0) ? 0 : 8;
+  // Number of rows to be subtracted from the height described by sb4x4.
+  const int height_offset = (row4x4 == 0) ? 8 : 0;
+  // If cdef is off, then loop restoration needs 2 extra rows for the bottom
+  // border in each plane.
+  const int extra_rows = (for_loop_restoration && !DoCdef()) ? 2 : 0;
+  for (int plane = 0; plane < planes_; ++plane) {
+    const int plane_width =
+        RightShiftWithRounding(upscaled_width_, subsampling_x_[plane]);
+    const int plane_height =
+        RightShiftWithRounding(height_, subsampling_y_[plane]);
+    const int row = (MultiplyBy4(row4x4) - row_offset) >> subsampling_y_[plane];
+    assert(row >= 0);
+    if (row >= plane_height) break;
+    const int num_rows =
+        std::min(RightShiftWithRounding(MultiplyBy4(sb4x4) - height_offset,
+                                        subsampling_y_[plane]) +
+                     extra_rows,
+                 plane_height - row);
+    // We only need to track the progress of the Y plane since the progress of
+    // the U and V planes will be inferred from the progress of the Y plane.
+    if (!for_loop_restoration && plane == kPlaneY) {
+      progress_row_ = row + num_rows;
+    }
+    const bool copy_bottom = row + num_rows == plane_height;
+    const int stride = frame_buffer_.stride(plane);
+    uint8_t* const start = (for_loop_restoration ? superres_buffer_[plane]
+                                                 : frame_buffer_.data(plane)) +
+                           row * stride;
+    const int left_border = for_loop_restoration
+                                ? kRestorationHorizontalBorder
+                                : frame_buffer_.left_border(plane);
+    const int right_border = for_loop_restoration
+                                 ? kRestorationHorizontalBorder
+                                 : frame_buffer_.right_border(plane);
+    const int top_border =
+        (row == 0) ? (for_loop_restoration ? kRestorationVerticalBorder
+                                           : frame_buffer_.top_border(plane))
+                   : 0;
+    const int bottom_border =
+        copy_bottom
+            ? (for_loop_restoration ? kRestorationVerticalBorder
+                                    : frame_buffer_.bottom_border(plane))
+            : 0;
+    ExtendFrameBoundary(start, plane_width, num_rows, stride, left_border,
+                        right_border, top_border, bottom_border);
+  }
+}
+
+void PostFilter::ApplyFilteringThreaded() {
+  if (DoDeblock()) ApplyDeblockFilterThreaded();
+  if (DoCdef() && DoRestoration()) {
+    for (int row4x4 = 0; row4x4 < frame_header_.rows4x4;
+         row4x4 += kNum4x4InLoopFilterUnit) {
+      SetupDeblockBuffer(row4x4, kNum4x4InLoopFilterUnit);
+    }
+  }
+  if (DoCdef()) ApplyCdef();
+  if (DoSuperRes()) ApplySuperResThreaded();
+  if (DoRestoration()) ApplyLoopRestoration();
+  ExtendBordersForReferenceFrame();
+}
+
+int PostFilter::ApplyFilteringForOneSuperBlockRow(int row4x4, int sb4x4,
+                                                  bool is_last_row,
+                                                  bool do_deblock) {
+  if (row4x4 < 0) return -1;
+  if (DoDeblock() && do_deblock) {
+    ApplyDeblockFilterForOneSuperBlockRow(row4x4, sb4x4);
+  }
+  if (DoRestoration() && DoCdef()) {
+    SetupDeblockBuffer(row4x4, sb4x4);
+  }
+  if (DoCdef()) {
+    ApplyCdefForOneSuperBlockRow(row4x4, sb4x4, is_last_row);
+  }
+  if (DoSuperRes()) {
+    ApplySuperResForOneSuperBlockRow(row4x4, sb4x4, is_last_row);
+  }
+  if (DoRestoration()) {
+    CopyBordersForOneSuperBlockRow(row4x4, sb4x4, true);
+    ApplyLoopRestoration(row4x4, sb4x4);
+    if (is_last_row) {
+      // Loop restoration operates with a lag of 8 rows. So make sure to cover
+      // all the rows of the last superblock row.
+      CopyBordersForOneSuperBlockRow(row4x4 + sb4x4, 16, true);
+      ApplyLoopRestoration(row4x4 + sb4x4, 16);
+    }
+  }
+  if (frame_header_.refresh_frame_flags != 0 && DoBorderExtensionInLoop()) {
+    CopyBordersForOneSuperBlockRow(row4x4, sb4x4, false);
+    if (is_last_row) {
+      CopyBordersForOneSuperBlockRow(row4x4 + sb4x4, 16, false);
+    }
+  }
+  if (is_last_row && !DoBorderExtensionInLoop()) {
+    ExtendBordersForReferenceFrame();
+  }
+  return is_last_row ? height_ : progress_row_;
+}
+
+}  // namespace libgav1
diff --git a/libgav1/src/post_filter/super_res.cc b/libgav1/src/post_filter/super_res.cc
new file mode 100644
index 0000000..f6594f4
--- /dev/null
+++ b/libgav1/src/post_filter/super_res.cc
@@ -0,0 +1,232 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "src/post_filter.h"
+#include "src/utils/blocking_counter.h"
+
+namespace libgav1 {
+namespace {
+
+template <typename Pixel>
+void ExtendLine(uint8_t* const line_start, const int width, const int left,
+                const int right) {
+  auto* const start = reinterpret_cast<Pixel*>(line_start);
+  const Pixel* src = start;
+  Pixel* dst = start - left;
+  // Copy to left and right borders.
+  Memset(dst, src[0], left);
+  Memset(dst + (left + width), src[width - 1], right);
+}
+
+}  // namespace
+
+template <bool in_place>
+void PostFilter::ApplySuperRes(const std::array<uint8_t*, kMaxPlanes>& buffers,
+                               const std::array<int, kMaxPlanes>& strides,
+                               const std::array<int, kMaxPlanes>& rows,
+                               size_t line_buffer_offset) {
+  // Only used when |in_place| == false.
+  uint8_t* const line_buffer_start = superres_line_buffer_ +
+                                     line_buffer_offset +
+                                     kSuperResHorizontalBorder * pixel_size_;
+  for (int plane = kPlaneY; plane < planes_; ++plane) {
+    const int8_t subsampling_x = subsampling_x_[plane];
+    const int plane_width =
+        MultiplyBy4(frame_header_.columns4x4) >> subsampling_x;
+    uint8_t* input = buffers[plane];
+    const uint32_t input_stride = strides[plane];
+#if LIBGAV1_MAX_BITDEPTH >= 10
+    if (bitdepth_ >= 10) {
+      for (int y = 0; y < rows[plane]; ++y, input += input_stride) {
+        if (!in_place) {
+          memcpy(line_buffer_start, input, plane_width * sizeof(uint16_t));
+        }
+        ExtendLine<uint16_t>(in_place ? input : line_buffer_start, plane_width,
+                             kSuperResHorizontalBorder,
+                             kSuperResHorizontalBorder);
+        dsp_.super_res_row(in_place ? input : line_buffer_start,
+                           super_res_info_[plane].upscaled_width,
+                           super_res_info_[plane].initial_subpixel_x,
+                           super_res_info_[plane].step,
+                           input - (in_place ? input_stride : 0));
+      }
+      continue;
+    }
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+    for (int y = 0; y < rows[plane]; ++y, input += input_stride) {
+      if (!in_place) {
+        memcpy(line_buffer_start, input, plane_width);
+      }
+      ExtendLine<uint8_t>(in_place ? input : line_buffer_start, plane_width,
+                          kSuperResHorizontalBorder, kSuperResHorizontalBorder);
+      dsp_.super_res_row(in_place ? input : line_buffer_start,
+                         super_res_info_[plane].upscaled_width,
+                         super_res_info_[plane].initial_subpixel_x,
+                         super_res_info_[plane].step,
+                         input - (in_place ? input_stride : 0));
+    }
+  }
+}
+
+// Used by post_filter_test.cc.
+template void PostFilter::ApplySuperRes<false>(
+    const std::array<uint8_t*, kMaxPlanes>& buffers,
+    const std::array<int, kMaxPlanes>& strides,
+    const std::array<int, kMaxPlanes>& rows, size_t line_buffer_offset);
+
+void PostFilter::ApplySuperResForOneSuperBlockRow(int row4x4_start, int sb4x4,
+                                                  bool is_last_row) {
+  assert(row4x4_start >= 0);
+  assert(DoSuperRes());
+  // If not doing cdef, then LR needs two rows of border with superres applied.
+  const int num_rows_extra = (DoCdef() || !DoRestoration()) ? 0 : 2;
+  std::array<uint8_t*, kMaxPlanes> buffers;
+  std::array<int, kMaxPlanes> strides;
+  std::array<int, kMaxPlanes> rows;
+  // Apply superres for the last 8-num_rows_extra rows of the previous
+  // superblock.
+  if (row4x4_start > 0) {
+    const int row4x4 = row4x4_start - 2;
+    for (int plane = 0; plane < planes_; ++plane) {
+      const int row =
+          (MultiplyBy4(row4x4) >> subsampling_y_[plane]) + num_rows_extra;
+      const ptrdiff_t row_offset = row * frame_buffer_.stride(plane);
+      buffers[plane] = cdef_buffer_[plane] + row_offset;
+      strides[plane] = frame_buffer_.stride(plane);
+      // Note that the |num_rows_extra| subtraction is done after the value is
+      // subsampled since we always need to work on |num_rows_extra| extra rows
+      // irrespective of the plane subsampling.
+      rows[plane] = (8 >> subsampling_y_[plane]) - num_rows_extra;
+    }
+    ApplySuperRes<true>(buffers, strides, rows, /*line_buffer_offset=*/0);
+  }
+  // Apply superres for the current superblock row (except for the last
+  // 8-num_rows_extra rows).
+  const int num_rows4x4 =
+      std::min(sb4x4, frame_header_.rows4x4 - row4x4_start) -
+      (is_last_row ? 0 : 2);
+  for (int plane = 0; plane < planes_; ++plane) {
+    const ptrdiff_t row_offset =
+        (MultiplyBy4(row4x4_start) >> subsampling_y_[plane]) *
+        frame_buffer_.stride(plane);
+    buffers[plane] = cdef_buffer_[plane] + row_offset;
+    strides[plane] = frame_buffer_.stride(plane);
+    // Note that the |num_rows_extra| subtraction is done after the value is
+    // subsampled since we always need to work on |num_rows_extra| extra rows
+    // irrespective of the plane subsampling.
+    rows[plane] = (MultiplyBy4(num_rows4x4) >> subsampling_y_[plane]) +
+                  (is_last_row ? 0 : num_rows_extra);
+  }
+  ApplySuperRes<true>(buffers, strides, rows, /*line_buffer_offset=*/0);
+}
+
+void PostFilter::ApplySuperResThreaded() {
+  const int num_threads = thread_pool_->num_threads() + 1;
+  // The number of rows4x4 that will be processed by each thread in the thread
+  // pool (other than the current thread).
+  const int thread_pool_rows4x4 = frame_header_.rows4x4 / num_threads;
+  // For the current thread, we round up to process all the remaining rows so
+  // that the current thread's job will potentially run the longest.
+  const int current_thread_rows4x4 =
+      frame_header_.rows4x4 - (thread_pool_rows4x4 * (num_threads - 1));
+  // The size of the line buffer required by each thread. In the multi-threaded
+  // case we are guaranteed to have a line buffer which can store |num_threads|
+  // rows at the same time.
+  const size_t line_buffer_size =
+      (MultiplyBy4(frame_header_.columns4x4) +
+       MultiplyBy2(kSuperResHorizontalBorder) + kSuperResHorizontalPadding) *
+      pixel_size_;
+  size_t line_buffer_offset = 0;
+  BlockingCounter pending_workers(num_threads - 1);
+  for (int i = 0, row4x4_start = 0; i < num_threads; ++i,
+           row4x4_start += thread_pool_rows4x4,
+           line_buffer_offset += line_buffer_size) {
+    std::array<uint8_t*, kMaxPlanes> buffers;
+    std::array<int, kMaxPlanes> strides;
+    std::array<int, kMaxPlanes> rows;
+    for (int plane = 0; plane < planes_; ++plane) {
+      strides[plane] = frame_buffer_.stride(plane);
+      buffers[plane] =
+          GetBufferOffset(cdef_buffer_[plane], strides[plane],
+                          static_cast<Plane>(plane), row4x4_start, 0);
+      if (i < num_threads - 1) {
+        rows[plane] = MultiplyBy4(thread_pool_rows4x4) >> subsampling_y_[plane];
+      } else {
+        rows[plane] =
+            MultiplyBy4(current_thread_rows4x4) >> subsampling_y_[plane];
+      }
+    }
+    if (i < num_threads - 1) {
+      thread_pool_->Schedule([this, buffers, strides, rows, line_buffer_offset,
+                              &pending_workers]() {
+        ApplySuperRes<false>(buffers, strides, rows, line_buffer_offset);
+        pending_workers.Decrement();
+      });
+    } else {
+      ApplySuperRes<false>(buffers, strides, rows, line_buffer_offset);
+    }
+  }
+  // Wait for the threadpool jobs to finish.
+  pending_workers.Wait();
+}
+
+// This function lives in this file so that it has access to ExtendLine<>.
+void PostFilter::SetupDeblockBuffer(int row4x4_start, int sb4x4) {
+  assert(row4x4_start >= 0);
+  assert(DoCdef());
+  assert(DoRestoration());
+  for (int sb_y = 0; sb_y < sb4x4; sb_y += 16) {
+    const int row4x4 = row4x4_start + sb_y;
+    for (int plane = 0; plane < planes_; ++plane) {
+      CopyDeblockedPixels(static_cast<Plane>(plane), row4x4);
+    }
+    const int row_offset_start = DivideBy4(row4x4);
+    if (DoSuperRes()) {
+      std::array<uint8_t*, kMaxPlanes> buffers = {
+          deblock_buffer_.data(kPlaneY) +
+              row_offset_start * deblock_buffer_.stride(kPlaneY),
+          deblock_buffer_.data(kPlaneU) +
+              row_offset_start * deblock_buffer_.stride(kPlaneU),
+          deblock_buffer_.data(kPlaneV) +
+              row_offset_start * deblock_buffer_.stride(kPlaneV)};
+      std::array<int, kMaxPlanes> strides = {deblock_buffer_.stride(kPlaneY),
+                                             deblock_buffer_.stride(kPlaneU),
+                                             deblock_buffer_.stride(kPlaneV)};
+      std::array<int, kMaxPlanes> rows = {4, 4, 4};
+      ApplySuperRes<false>(buffers, strides, rows,
+                           /*line_buffer_offset=*/0);
+    }
+    // Extend the left and right boundaries needed for loop restoration.
+    for (int plane = 0; plane < planes_; ++plane) {
+      uint8_t* src = deblock_buffer_.data(plane) +
+                     row_offset_start * deblock_buffer_.stride(plane);
+      const int plane_width =
+          RightShiftWithRounding(upscaled_width_, subsampling_x_[plane]);
+      for (int i = 0; i < 4; ++i) {
+#if LIBGAV1_MAX_BITDEPTH >= 10
+        if (bitdepth_ >= 10) {
+          ExtendLine<uint16_t>(src, plane_width, kRestorationHorizontalBorder,
+                               kRestorationHorizontalBorder);
+        } else  // NOLINT.
+#endif
+        {
+          ExtendLine<uint8_t>(src, plane_width, kRestorationHorizontalBorder,
+                              kRestorationHorizontalBorder);
+        }
+        src += deblock_buffer_.stride(plane);
+      }
+    }
+  }
+}
+
+}  // namespace libgav1
diff --git a/libgav1/src/prediction_mask.cc b/libgav1/src/prediction_mask.cc
index 1ca2d6e..ab4d849 100644
--- a/libgav1/src/prediction_mask.cc
+++ b/libgav1/src/prediction_mask.cc
@@ -15,6 +15,7 @@
 #include "src/prediction_mask.h"
 
 #include <algorithm>
+#include <array>
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
@@ -23,6 +24,7 @@
 #include <memory>
 
 #include "src/utils/array_2d.h"
+#include "src/utils/bit_mask_set.h"
 #include "src/utils/common.h"
 #include "src/utils/constants.h"
 #include "src/utils/logging.h"
@@ -91,28 +93,34 @@
                                                {kWedgeOblique117, 2, 4},
                                                {kWedgeOblique117, 6, 4}}};
 
-constexpr uint8_t kWedgeFlipSignLookup[9][16] = {
-    {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1},  // kBlock8x8
-    {1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1},  // kBlock8x16
-    {1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1},  // kBlock8x32
-    {1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1},  // kBlock16x8
-    {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1},  // kBlock16x16
-    {1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1},  // kBlock16x32
-    {1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1},  // kBlock32x8
-    {1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1},  // kBlock32x16
-    {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1},  // kBlock32x32
+constexpr BitMaskSet kWedgeFlipSignMasks[9] = {
+    BitMaskSet(0xBBFF),  // kBlock8x8
+    BitMaskSet(0xBBEF),  // kBlock8x16
+    BitMaskSet(0xBAEF),  // kBlock8x32
+    BitMaskSet(0xBBEF),  // kBlock16x8
+    BitMaskSet(0xBBFF),  // kBlock16x16
+    BitMaskSet(0xBBEF),  // kBlock16x32
+    BitMaskSet(0xABEF),  // kBlock32x8
+    BitMaskSet(0xBBEF),  // kBlock32x16
+    BitMaskSet(0xBBFF)   // kBlock32x32
 };
 
-constexpr uint8_t kWedgeMasterObliqueOdd[kWedgeMaskMasterSize] = {
+// This table (and the one below) contains a few leading zeros and trailing 64s
+// to avoid some additional memcpys where it is actually used.
+constexpr uint8_t kWedgeMasterObliqueOdd[kWedgeMaskMasterSize * 3 / 2] = {
     0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
-    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  2,  6,  18,
-    37, 53, 60, 63, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
+    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
+    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  2,  6,  18, 37,
+    53, 60, 63, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
+    64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
     64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64};
 
-constexpr uint8_t kWedgeMasterObliqueEven[kWedgeMaskMasterSize] = {
+constexpr uint8_t kWedgeMasterObliqueEven[kWedgeMaskMasterSize * 3 / 2] = {
+    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
     0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
     0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  4,  11, 27,
     46, 58, 62, 63, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
+    64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
     64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64};
 
 constexpr uint8_t kWedgeMasterVertical[kWedgeMaskMasterSize] = {
@@ -121,15 +129,6 @@
     43, 57, 62, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
     64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64};
 
-constexpr uint8_t kInterIntraWeights[kMaxSuperBlockSizeInPixels] = {
-    60, 58, 56, 54, 52, 50, 48, 47, 45, 44, 42, 41, 39, 38, 37, 35, 34, 33, 32,
-    31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 22, 21, 20, 19, 19, 18, 18, 17, 16,
-    16, 15, 15, 14, 14, 13, 13, 12, 12, 12, 11, 11, 10, 10, 10, 9,  9,  9,  8,
-    8,  8,  8,  7,  7,  7,  7,  6,  6,  6,  6,  6,  5,  5,  5,  5,  5,  4,  4,
-    4,  4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  3,  3,  3,  3,  2,  2,  2,  2,
-    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  1,  1,  1,
-    1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1};
-
 int BlockShape(BlockSize block_size) {
   const int width = kNum4x4BlocksWide[block_size];
   const int height = kNum4x4BlocksHigh[block_size];
@@ -152,76 +151,35 @@
 
 }  // namespace
 
-void GenerateWedgeMask(uint8_t* const wedge_master_mask_data,
-                       uint8_t* const wedge_masks_data) {
+bool GenerateWedgeMask(WedgeMaskArray* const wedge_masks) {
   // Generate master masks.
-  Array2DView<uint8_t> master_mask(6, kMaxMaskBlockSize,
-                                   wedge_master_mask_data);
-  uint8_t* master_mask_row = master_mask[kWedgeVertical];
-  const int stride = kWedgeMaskMasterSize;
+  uint8_t master_mask[6][kWedgeMaskMasterSize][kWedgeMaskMasterSize];
   for (int y = 0; y < kWedgeMaskMasterSize; ++y) {
-    memcpy(master_mask_row, kWedgeMasterVertical, kWedgeMaskMasterSize);
-    master_mask_row += stride;
+    memcpy(master_mask[kWedgeVertical][y], kWedgeMasterVertical,
+           kWedgeMaskMasterSize);
   }
 
-  AlignedUniquePtr<uint8_t> wedge_master_oblique_even =
-      MakeAlignedUniquePtr<uint8_t>(
-          16, kWedgeMaskMasterSize + DivideBy2(kWedgeMaskMasterSize));
-  AlignedUniquePtr<uint8_t> wedge_master_oblique_odd =
-      MakeAlignedUniquePtr<uint8_t>(
-          16, kWedgeMaskMasterSize + DivideBy2(kWedgeMaskMasterSize));
-  if (wedge_master_oblique_even == nullptr ||
-      wedge_master_oblique_odd == nullptr) {
-    LIBGAV1_DLOG(ERROR, "Failed to allocate memory for master mask.");
-    return;
-  }
-  const int offset = DivideBy4(kWedgeMaskMasterSize);
-  memset(wedge_master_oblique_even.get(), 0, offset);
-  memcpy(wedge_master_oblique_even.get() + offset, kWedgeMasterObliqueEven,
-         kWedgeMaskMasterSize);
-  memset(wedge_master_oblique_even.get() + offset + kWedgeMaskMasterSize, 64,
-         offset);
-  memset(wedge_master_oblique_odd.get(), 0, offset - 1);
-  memcpy(wedge_master_oblique_odd.get() + offset - 1, kWedgeMasterObliqueOdd,
-         kWedgeMaskMasterSize);
-  memset(wedge_master_oblique_odd.get() + offset + kWedgeMaskMasterSize - 1, 64,
-         offset + 1);
-  master_mask_row = master_mask[kWedgeOblique63];
   for (int y = 0, shift = 0; y < kWedgeMaskMasterSize; y += 2, ++shift) {
-    memcpy(master_mask_row, wedge_master_oblique_even.get() + shift,
+    memcpy(master_mask[kWedgeOblique63][y], kWedgeMasterObliqueEven + shift,
            kWedgeMaskMasterSize);
-    master_mask_row += stride;
-    memcpy(master_mask_row, wedge_master_oblique_odd.get() + shift,
+    memcpy(master_mask[kWedgeOblique63][y + 1], kWedgeMasterObliqueOdd + shift,
            kWedgeMaskMasterSize);
-    master_mask_row += stride;
   }
 
-  uint8_t* const master_mask_horizontal = master_mask[kWedgeHorizontal];
-  uint8_t* master_mask_vertical = master_mask[kWedgeVertical];
-  uint8_t* const master_mask_oblique_27 = master_mask[kWedgeOblique27];
-  uint8_t* master_mask_oblique_63 = master_mask[kWedgeOblique63];
-  uint8_t* master_mask_oblique_117 = master_mask[kWedgeOblique117];
-  uint8_t* const master_mask_oblique_153 = master_mask[kWedgeOblique153];
   for (int y = 0; y < kWedgeMaskMasterSize; ++y) {
     for (int x = 0; x < kWedgeMaskMasterSize; ++x) {
-      const uint8_t mask_value = master_mask_oblique_63[x];
-      master_mask_horizontal[x * stride + y] = master_mask_vertical[x];
-      master_mask_oblique_27[x * stride + y] = mask_value;
-      master_mask_oblique_117[kWedgeMaskMasterSize - 1 - x] = 64 - mask_value;
-      master_mask_oblique_153[(kWedgeMaskMasterSize - 1 - x) * stride + y] =
+      const uint8_t mask_value = master_mask[kWedgeOblique63][y][x];
+      master_mask[kWedgeHorizontal][x][y] = master_mask[kWedgeVertical][y][x];
+      master_mask[kWedgeOblique27][x][y] = mask_value;
+      master_mask[kWedgeOblique117][y][kWedgeMaskMasterSize - 1 - x] =
+          64 - mask_value;
+      master_mask[kWedgeOblique153][(kWedgeMaskMasterSize - 1 - x)][y] =
           64 - mask_value;
     }
-    master_mask_vertical += stride;
-    master_mask_oblique_63 += stride;
-    master_mask_oblique_117 += stride;
   }
 
   // Generate wedge masks.
-  const int wedge_mask_stride_1 = kMaxMaskBlockSize;
-  const int wedge_mask_stride_2 = wedge_mask_stride_1 * 16;
-  const int wedge_mask_stride_3 = wedge_mask_stride_2 * 2;
   int block_size_index = 0;
-  int wedge_masks_offset = 0;
   for (int size = kBlock8x8; size <= kBlock32x32; ++size) {
     if (!kIsWedgeCompoundModeAllowed.Contains(size)) continue;
 
@@ -233,116 +191,46 @@
     assert(height <= 32);
 
     const auto block_size = static_cast<BlockSize>(size);
-    for (int index = 0; index < kWedgeDirectionTypes; ++index) {
-      const uint8_t direction = GetWedgeDirection(block_size, index);
+    for (int wedge_index = 0; wedge_index < kWedgeDirectionTypes;
+         ++wedge_index) {
+      const uint8_t direction = GetWedgeDirection(block_size, wedge_index);
       const uint8_t offset_x =
           DivideBy2(kWedgeMaskMasterSize) -
-          ((GetWedgeOffsetX(block_size, index) * width) >> 3);
+          ((GetWedgeOffsetX(block_size, wedge_index) * width) >> 3);
       const uint8_t offset_y =
           DivideBy2(kWedgeMaskMasterSize) -
-          ((GetWedgeOffsetY(block_size, index) * height) >> 3);
-      const uint8_t flip_sign = kWedgeFlipSignLookup[block_size_index][index];
+          ((GetWedgeOffsetY(block_size, wedge_index) * height) >> 3);
 
-      const int offset_1 = block_size_index * wedge_mask_stride_3 +
-                           flip_sign * wedge_mask_stride_2 +
-                           index * wedge_mask_stride_1;
-      const int offset_2 = block_size_index * wedge_mask_stride_3 +
-                           (1 - flip_sign) * wedge_mask_stride_2 +
-                           index * wedge_mask_stride_1;
+      // Allocate the 2d array.
+      for (int flip_sign = 0; flip_sign < 2; ++flip_sign) {
+        if (!((*wedge_masks)[block_size_index][flip_sign][wedge_index].Reset(
+                height, width, /*zero_initialize=*/false))) {
+          LIBGAV1_DLOG(ERROR, "Failed to allocate memory for wedge masks.");
+          return false;
+        }
+      }
 
-      uint8_t* wedge_masks_row = wedge_masks_data + offset_1;
-      uint8_t* wedge_masks_row_flip = wedge_masks_data + offset_2;
-      master_mask_row = &master_mask[direction][offset_y * stride + offset_x];
+      const auto flip_sign = static_cast<uint8_t>(
+          kWedgeFlipSignMasks[block_size_index].Contains(wedge_index));
+      uint8_t* wedge_masks_row =
+          (*wedge_masks)[block_size_index][flip_sign][wedge_index][0];
+      uint8_t* wedge_masks_row_flip =
+          (*wedge_masks)[block_size_index][1 - flip_sign][wedge_index][0];
+      uint8_t* master_mask_row = &master_mask[direction][offset_y][offset_x];
       for (int y = 0; y < height; ++y) {
         memcpy(wedge_masks_row, master_mask_row, width);
         for (int x = 0; x < width; ++x) {
-          // TODO(chengchen): sign flip may not be needed.
-          // Only need to return 64 - mask_value, when get mask.
           wedge_masks_row_flip[x] = 64 - wedge_masks_row[x];
         }
-        wedge_masks_row += stride;
-        wedge_masks_row_flip += stride;
-        master_mask_row += stride;
+        wedge_masks_row += width;
+        wedge_masks_row_flip += width;
+        master_mask_row += kWedgeMaskMasterSize;
       }
-      wedge_masks_offset += width * height;
     }
 
     block_size_index++;
   }
-}
-
-void GenerateWeightMask(const uint16_t* prediction_1, const ptrdiff_t stride_1,
-                        const uint16_t* prediction_2, const ptrdiff_t stride_2,
-                        const bool mask_is_inverse, const int width,
-                        const int height, const int bitdepth, uint8_t* mask,
-                        const ptrdiff_t mask_stride) {
-#if LIBGAV1_MAX_BITDEPTH == 12
-  const int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
-#else
-  constexpr int inter_post_round_bits = 4;
-#endif
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
-      const int rounding_bits = bitdepth - 8 + inter_post_round_bits;
-      const int difference = RightShiftWithRounding(
-          std::abs(prediction_1[x] - prediction_2[x]), rounding_bits);
-      const auto mask_value =
-          static_cast<uint8_t>(std::min(DivideBy16(difference) + 38, 64));
-      mask[x] = mask_is_inverse ? 64 - mask_value : mask_value;
-    }
-    prediction_1 += stride_1;
-    prediction_2 += stride_2;
-    mask += mask_stride;
-  }
-}
-
-void GenerateInterIntraMask(const int mode, const int width, const int height,
-                            uint8_t* const mask, const ptrdiff_t mask_stride) {
-  const int scale = kMaxSuperBlockSizeInPixels / std::max(width, height);
-  uint8_t* mask_row = mask;
-  const uint8_t* inter_intra_weight = kInterIntraWeights;
-  // TODO(chengchen): Reorder mode types if we have stats that which modes are
-  // used often.
-  if (mode == kInterIntraModeVertical) {
-    for (int y = 0; y < height; ++y) {
-      memset(mask_row, *inter_intra_weight, width);
-      mask_row += mask_stride;
-      inter_intra_weight += scale;
-    }
-  } else if (mode == kInterIntraModeHorizontal) {
-    for (int x = 0; x < width; ++x) {
-      mask_row[x] = *inter_intra_weight;
-      inter_intra_weight += scale;
-    }
-    mask_row += mask_stride;
-    for (int y = 1; y < height; ++y) {
-      memcpy(mask_row, mask, width);
-      mask_row += mask_stride;
-    }
-  } else if (mode == kInterIntraModeSmooth) {
-    uint8_t weight_row[64];
-    const int size = std::min(width, height);
-    for (int x = 0; x < width; ++x) {
-      weight_row[x] = *inter_intra_weight;
-      if (x < size) inter_intra_weight += scale;
-    }
-    int y;
-    for (y = 0; y < std::min(width, height); ++y) {
-      memcpy(mask_row, weight_row, y);
-      memset(mask_row + y, weight_row[y], width - y);
-      mask_row += mask_stride;
-    }
-    for (; y < height; ++y) {
-      memcpy(mask_row, weight_row, width);
-      mask_row += mask_stride;
-    }
-  } else {
-    assert(mode == kInterIntraModeDc);
-    for (int y = 0; y < height; ++y) {
-      memset(mask_row, 32, width);
-      mask_row += mask_stride;
-    }
-  }
+  return true;
 }
 
 }  // namespace libgav1
diff --git a/libgav1/src/prediction_mask.h b/libgav1/src/prediction_mask.h
index d324e00..0134a0d 100644
--- a/libgav1/src/prediction_mask.h
+++ b/libgav1/src/prediction_mask.h
@@ -33,20 +33,9 @@
 
 // This function generates wedge masks. It should be called only once for the
 // decoder. If the video is key frame only, we don't have to call this
-// function.
+// function. Returns true on success, false on allocation failure.
 // 7.11.3.11.
-void GenerateWedgeMask(uint8_t* wedge_master_mask_data,
-                       uint8_t* wedge_masks_data);
-
-// 7.11.3.12.
-void GenerateWeightMask(const uint16_t* prediction_1, ptrdiff_t stride_1,
-                        const uint16_t* prediction_2, ptrdiff_t stride_2,
-                        bool mask_is_inverse, int width, int height,
-                        int bitdepth, uint8_t* mask, ptrdiff_t mask_stride);
-
-// 7.11.3.13.
-void GenerateInterIntraMask(int mode, int width, int height, uint8_t* mask,
-                            ptrdiff_t mask_stride);
+bool GenerateWedgeMask(WedgeMaskArray* wedge_masks);
 
 }  // namespace libgav1
 #endif  // LIBGAV1_SRC_PREDICTION_MASK_H_
diff --git a/libgav1/src/quantizer.cc b/libgav1/src/quantizer.cc
index 68a3ad4..b26024d 100644
--- a/libgav1/src/quantizer.cc
+++ b/libgav1/src/quantizer.cc
@@ -141,9 +141,6 @@
 };
 // clang-format on
 
-// Converts bitdepth 8, 10, and 12 to array index 0, 1, and 2, respectively.
-inline int BitdepthToArrayIndex(int bitdepth) { return (bitdepth - 8) >> 1; }
-
 }  // namespace
 
 int GetQIndex(const Segmentation& segmentation, int index, int base_qindex) {
diff --git a/libgav1/src/quantizer.h b/libgav1/src/quantizer.h
index ba98f81..e555115 100644
--- a/libgav1/src/quantizer.h
+++ b/libgav1/src/quantizer.h
@@ -21,25 +21,10 @@
 
 #include "src/utils/constants.h"
 #include "src/utils/segmentation.h"
+#include "src/utils/types.h"
 
 namespace libgav1 {
 
-// Stores the quantization parameters of Section 5.9.12.
-struct QuantizerParameters {
-  // base_index is in the range [0, 255].
-  uint8_t base_index;
-  int8_t delta_dc[kMaxPlanes];
-  // delta_ac[kPlaneY] is always 0.
-  int8_t delta_ac[kMaxPlanes];
-  bool use_matrix;
-  // The |matrix_level| array is used only when |use_matrix| is true.
-  // matrix_level[plane] specifies the level in the quantizer matrix that
-  // should be used for decoding |plane|. The quantizer matrix has 15 levels,
-  // from 0 to 14. The range of matrix_level[plane] is [0, 15]. If
-  // matrix_level[plane] is 15, the quantizer matrix is not used.
-  int8_t matrix_level[kMaxPlanes];
-};
-
 // Implements the dequantization functions of Section 7.12.2.
 class Quantizer {
  public:
diff --git a/libgav1/src/quantizer_tables.inc b/libgav1/src/quantizer_tables.inc
index 9054652..b5a89a8 100644
--- a/libgav1/src/quantizer_tables.inc
+++ b/libgav1/src/quantizer_tables.inc
@@ -16,14 +16,14 @@
 // definitions from the quantizer functions.
 
 // Quantizer matrix is used only when level < 15.
-const int kNumQuantizerLevelsForQuantizerMatrix = 15;
-const int kQuantizerMatrixSize = 3344;
+constexpr int kNumQuantizerLevelsForQuantizerMatrix = 15;
+constexpr int kQuantizerMatrixSize = 3344;
 
-const uint16_t kQuantizerMatrixOffset[kNumTransformSizes] = {
+constexpr uint16_t kQuantizerMatrixOffset[kNumTransformSizes] = {
     0,    1360, 2704, 1392, 16,  1424, 2832, 2768, 1552, 80,
     1680, 1680, 3088, 2192, 336, 336,  2192, 336,  336};
 
-const uint8_t kQuantizerMatrix
+constexpr uint8_t kQuantizerMatrix
     [kNumQuantizerLevelsForQuantizerMatrix][kNumPlaneTypes]
     [kQuantizerMatrixSize] = {
         // Quantizer level 0.
@@ -43,20 +43,20 @@
              31, 32, 32, 33, 34, 41, 44, 54, 59, 72, 75, 83, 90, 97, 104, 112,
              31, 32, 33, 35, 36, 42, 45, 54, 59, 71, 74, 81, 86, 93, 100, 107,
              34, 33, 35, 39, 42, 47, 51, 58, 63, 74, 76, 81, 84, 90, 97, 105,
-             36, 34, 36, 42, 48, 54, 57, 64, 68, 79, 81, 88, 91, 96, 102,
-             105, 44, 41, 42, 47, 54, 63, 67, 75, 79, 90, 92, 95, 100, 102, 109,
-             112, 48, 44, 45, 51, 57, 67, 71, 80, 85, 96, 99, 107, 108, 111,
-             117, 120, 59, 54, 54, 58, 64, 75, 80, 92, 98, 110, 113, 115, 116,
-             122, 125, 130, 65, 59, 59, 63, 68, 79, 85, 98, 105, 118, 121, 127,
-             130, 134, 135, 140, 80, 72, 71, 74, 79, 90, 96, 110, 118, 134, 137,
-             140, 143, 144, 146, 152, 83, 75, 74, 76, 81, 92, 99, 113, 121, 137,
-             140, 151, 152, 155, 158, 165, 91, 83, 81, 81, 88, 95, 107, 115,
-             127, 140, 151, 159, 166, 169, 173, 179, 97, 90, 86, 84, 91, 100,
-             108, 116, 130, 143, 152, 166, 174, 182, 189, 193, 104, 97, 93, 90,
-             96, 102, 111, 122, 134, 144, 155, 169, 182, 191, 200, 210, 111,
-             104, 100, 97, 102, 109, 117, 125, 135, 146, 158, 173, 189, 200,
-             210, 220, 119, 112, 107, 105, 105, 112, 120, 130, 140, 152, 165,
-             179, 193, 210, 220, 231,
+             36, 34, 36, 42, 48, 54, 57, 64, 68, 79, 81, 88, 91, 96, 102, 105,
+             44, 41, 42, 47, 54, 63, 67, 75, 79, 90, 92, 95, 100, 102, 109, 112,
+             48, 44, 45, 51, 57, 67, 71, 80, 85, 96, 99, 107, 108, 111, 117,
+             120, 59, 54, 54, 58, 64, 75, 80, 92, 98, 110, 113, 115, 116, 122,
+             125, 130, 65, 59, 59, 63, 68, 79, 85, 98, 105, 118, 121, 127, 130,
+             134, 135, 140, 80, 72, 71, 74, 79, 90, 96, 110, 118, 134, 137, 140,
+             143, 144, 146, 152, 83, 75, 74, 76, 81, 92, 99, 113, 121, 137, 140,
+             151, 152, 155, 158, 165, 91, 83, 81, 81, 88, 95, 107, 115, 127,
+             140, 151, 159, 166, 169, 173, 179, 97, 90, 86, 84, 91, 100, 108,
+             116, 130, 143, 152, 166, 174, 182, 189, 193, 104, 97, 93, 90, 96,
+             102, 111, 122, 134, 144, 155, 169, 182, 191, 200, 210, 111, 104,
+             100, 97, 102, 109, 117, 125, 135, 146, 158, 173, 189, 200, 210,
+             220, 119, 112, 107, 105, 105, 112, 120, 130, 140, 152, 165, 179,
+             193, 210, 220, 231,
              // Size 32x32
              32, 31, 31, 31, 31, 32, 34, 35, 36, 39, 44, 46, 48, 54, 59, 62, 65,
              71, 80, 81, 83, 88, 91, 94, 97, 101, 104, 107, 111, 115, 119, 123,
diff --git a/libgav1/src/reconstruction.cc b/libgav1/src/reconstruction.cc
index 7772878..97de9f0 100644
--- a/libgav1/src/reconstruction.cc
+++ b/libgav1/src/reconstruction.cc
@@ -54,7 +54,7 @@
 void Reconstruct(const dsp::Dsp& dsp, TransformType tx_type,
                  TransformSize tx_size, bool lossless, Residual* const buffer,
                  int start_x, int start_y, Array2DView<Pixel>* frame,
-                 int16_t non_zero_coeff_count) {
+                 int non_zero_coeff_count) {
   static_assert(sizeof(Residual) == 2 || sizeof(Residual) == 4, "");
   const int tx_width_log2 = kTransformWidthLog2[tx_size];
   const int tx_height_log2 = kTransformHeightLog2[tx_size];
@@ -87,13 +87,13 @@
 template void Reconstruct(const dsp::Dsp& dsp, TransformType tx_type,
                           TransformSize tx_size, bool lossless, int16_t* buffer,
                           int start_x, int start_y, Array2DView<uint8_t>* frame,
-                          int16_t non_zero_coeff_count);
+                          int non_zero_coeff_count);
 #if LIBGAV1_MAX_BITDEPTH >= 10
 template void Reconstruct(const dsp::Dsp& dsp, TransformType tx_type,
                           TransformSize tx_size, bool lossless, int32_t* buffer,
                           int start_x, int start_y,
                           Array2DView<uint16_t>* frame,
-                          int16_t non_zero_coeff_count);
+                          int non_zero_coeff_count);
 #endif
 
 }  // namespace libgav1
diff --git a/libgav1/src/reconstruction.h b/libgav1/src/reconstruction.h
index 3ef0381..6d5b115 100644
--- a/libgav1/src/reconstruction.h
+++ b/libgav1/src/reconstruction.h
@@ -35,19 +35,19 @@
 void Reconstruct(const dsp::Dsp& dsp, TransformType tx_type,
                  TransformSize tx_size, bool lossless, Residual* buffer,
                  int start_x, int start_y, Array2DView<Pixel>* frame,
-                 int16_t non_zero_coeff_count);
+                 int non_zero_coeff_count);
 
 extern template void Reconstruct(const dsp::Dsp& dsp, TransformType tx_type,
                                  TransformSize tx_size, bool lossless,
                                  int16_t* buffer, int start_x, int start_y,
                                  Array2DView<uint8_t>* frame,
-                                 int16_t non_zero_coeff_count);
+                                 int non_zero_coeff_count);
 #if LIBGAV1_MAX_BITDEPTH >= 10
 extern template void Reconstruct(const dsp::Dsp& dsp, TransformType tx_type,
                                  TransformSize tx_size, bool lossless,
                                  int32_t* buffer, int start_x, int start_y,
                                  Array2DView<uint16_t>* frame,
-                                 int16_t non_zero_coeff_count);
+                                 int non_zero_coeff_count);
 #endif
 
 }  // namespace libgav1
diff --git a/libgav1/src/residual_buffer_pool.h b/libgav1/src/residual_buffer_pool.h
index 497579d..f7bc75d 100644
--- a/libgav1/src/residual_buffer_pool.h
+++ b/libgav1/src/residual_buffer_pool.h
@@ -52,7 +52,7 @@
   }
 
   // Adds the |non_zero_coeff_count| and the |tx_type| to the back of the queue.
-  void Push(int16_t non_zero_coeff_count, TransformType tx_type) {
+  void Push(int non_zero_coeff_count, TransformType tx_type) {
     assert(back_ < max_size_);
     non_zero_coeff_count_[back_] = non_zero_coeff_count;
     tx_type_[back_++] = tx_type;
diff --git a/libgav1/src/scan_tables.inc b/libgav1/src/scan_tables.inc
index a535feb..f7c9231 100644
--- a/libgav1/src/scan_tables.inc
+++ b/libgav1/src/scan_tables.inc
@@ -14,58 +14,58 @@
 
 // This file contains all the scan order tables.
 
-const uint16_t kDefaultScan4x4[16] = {0, 1,  4,  8,  5, 2,  3,  6,
-                                      9, 12, 13, 10, 7, 11, 14, 15};
+constexpr uint16_t kDefaultScan4x4[16] = {0, 1,  4,  8,  5, 2,  3,  6,
+                                          9, 12, 13, 10, 7, 11, 14, 15};
 
-const uint16_t kColumnScan4x4[16] = {0, 4, 8,  12, 1, 5, 9,  13,
-                                     2, 6, 10, 14, 3, 7, 11, 15};
+constexpr uint16_t kColumnScan4x4[16] = {0, 4, 8,  12, 1, 5, 9,  13,
+                                         2, 6, 10, 14, 3, 7, 11, 15};
 
-const uint16_t kRowScan4x4[16] = {0, 1, 2,  3,  4,  5,  6,  7,
-                                  8, 9, 10, 11, 12, 13, 14, 15};
+constexpr uint16_t kRowScan4x4[16] = {0, 1, 2,  3,  4,  5,  6,  7,
+                                      8, 9, 10, 11, 12, 13, 14, 15};
 
-const uint16_t kDefaultScan4x8[32] = {
+constexpr uint16_t kDefaultScan4x8[32] = {
     0,  1,  4,  2,  5,  8,  3,  6,  9,  12, 7,  10, 13, 16, 11, 14,
     17, 20, 15, 18, 21, 24, 19, 22, 25, 28, 23, 26, 29, 27, 30, 31};
 
-const uint16_t kColumnScan4x8[32] = {0,  4,  8,  12, 16, 20, 24, 28, 1,  5,  9,
-                                     13, 17, 21, 25, 29, 2,  6,  10, 14, 18, 22,
-                                     26, 30, 3,  7,  11, 15, 19, 23, 27, 31};
+constexpr uint16_t kColumnScan4x8[32] = {
+    0, 4, 8,  12, 16, 20, 24, 28, 1, 5, 9,  13, 17, 21, 25, 29,
+    2, 6, 10, 14, 18, 22, 26, 30, 3, 7, 11, 15, 19, 23, 27, 31};
 
-const uint16_t kRowScan4x8[32] = {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10,
-                                  11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
-                                  22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
+constexpr uint16_t kRowScan4x8[32] = {
+    0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
+    16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
 
-const uint16_t kDefaultScan8x4[32] = {
+constexpr uint16_t kDefaultScan8x4[32] = {
     0,  8, 1,  16, 9,  2, 24, 17, 10, 3, 25, 18, 11, 4,  26, 19,
     12, 5, 27, 20, 13, 6, 28, 21, 14, 7, 29, 22, 15, 30, 23, 31};
 
-const uint16_t kColumnScan8x4[32] = {0,  8,  16, 24, 1,  9,  17, 25, 2,  10, 18,
-                                     26, 3,  11, 19, 27, 4,  12, 20, 28, 5,  13,
-                                     21, 29, 6,  14, 22, 30, 7,  15, 23, 31};
+constexpr uint16_t kColumnScan8x4[32] = {
+    0, 8,  16, 24, 1, 9,  17, 25, 2, 10, 18, 26, 3, 11, 19, 27,
+    4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31};
 
-const uint16_t kRowScan8x4[32] = {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10,
-                                  11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
-                                  22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
+constexpr uint16_t kRowScan8x4[32] = {
+    0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
+    16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
 
-const uint16_t kDefaultScan8x8[64] = {
+constexpr uint16_t kDefaultScan8x8[64] = {
     0,  1,  8,  16, 9,  2,  3,  10, 17, 24, 32, 25, 18, 11, 4,  5,
     12, 19, 26, 33, 40, 48, 41, 34, 27, 20, 13, 6,  7,  14, 21, 28,
     35, 42, 49, 56, 57, 50, 43, 36, 29, 22, 15, 23, 30, 37, 44, 51,
     58, 59, 52, 45, 38, 31, 39, 46, 53, 60, 61, 54, 47, 55, 62, 63};
 
-const uint16_t kColumnScan8x8[64] = {
+constexpr uint16_t kColumnScan8x8[64] = {
     0, 8,  16, 24, 32, 40, 48, 56, 1, 9,  17, 25, 33, 41, 49, 57,
     2, 10, 18, 26, 34, 42, 50, 58, 3, 11, 19, 27, 35, 43, 51, 59,
     4, 12, 20, 28, 36, 44, 52, 60, 5, 13, 21, 29, 37, 45, 53, 61,
     6, 14, 22, 30, 38, 46, 54, 62, 7, 15, 23, 31, 39, 47, 55, 63};
 
-const uint16_t kRowScan8x8[64] = {
+constexpr uint16_t kRowScan8x8[64] = {
     0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
     16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
     32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
     48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
 
-const uint16_t kDefaultScan8x16[128] = {
+constexpr uint16_t kDefaultScan8x16[128] = {
     0,   1,   8,   2,   9,   16,  3,   10,  17,  24,  4,   11,  18,  25,  32,
     5,   12,  19,  26,  33,  40,  6,   13,  20,  27,  34,  41,  48,  7,   14,
     21,  28,  35,  42,  49,  56,  15,  22,  29,  36,  43,  50,  57,  64,  23,
@@ -76,7 +76,7 @@
     114, 121, 87,  94,  101, 108, 115, 122, 95,  102, 109, 116, 123, 103, 110,
     117, 124, 111, 118, 125, 119, 126, 127};
 
-const uint16_t kColumnScan8x16[128] = {
+constexpr uint16_t kColumnScan8x16[128] = {
     0, 8,  16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96,  104, 112, 120,
     1, 9,  17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97,  105, 113, 121,
     2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98,  106, 114, 122,
@@ -86,7 +86,7 @@
     6, 14, 22, 30, 38, 46, 54, 62, 70, 78, 86, 94, 102, 110, 118, 126,
     7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, 119, 127};
 
-const uint16_t kRowScan8x16[128] = {
+constexpr uint16_t kRowScan8x16[128] = {
     0,   1,   2,   3,   4,   5,   6,   7,   8,   9,   10,  11,  12,  13,  14,
     15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,
     30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,
@@ -97,7 +97,7 @@
     105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
     120, 121, 122, 123, 124, 125, 126, 127};
 
-const uint16_t kDefaultScan16x8[128] = {
+constexpr uint16_t kDefaultScan16x8[128] = {
     0,  16,  1,   32, 17,  2,   48,  33,  18, 3,  64,  49,  34,  19,  4,   80,
     65, 50,  35,  20, 5,   96,  81,  66,  51, 36, 21,  6,   112, 97,  82,  67,
     52, 37,  22,  7,  113, 98,  83,  68,  53, 38, 23,  8,   114, 99,  84,  69,
@@ -107,7 +107,7 @@
     60, 45,  30,  15, 121, 106, 91,  76,  61, 46, 31,  122, 107, 92,  77,  62,
     47, 123, 108, 93, 78,  63,  124, 109, 94, 79, 125, 110, 95,  126, 111, 127};
 
-const uint16_t kColumnScan16x8[128] = {
+constexpr uint16_t kColumnScan16x8[128] = {
     0,  16, 32, 48, 64, 80, 96,  112, 1,  17, 33, 49, 65, 81, 97,  113,
     2,  18, 34, 50, 66, 82, 98,  114, 3,  19, 35, 51, 67, 83, 99,  115,
     4,  20, 36, 52, 68, 84, 100, 116, 5,  21, 37, 53, 69, 85, 101, 117,
@@ -117,7 +117,7 @@
     12, 28, 44, 60, 76, 92, 108, 124, 13, 29, 45, 61, 77, 93, 109, 125,
     14, 30, 46, 62, 78, 94, 110, 126, 15, 31, 47, 63, 79, 95, 111, 127};
 
-const uint16_t kRowScan16x8[128] = {
+constexpr uint16_t kRowScan16x8[128] = {
     0,   1,   2,   3,   4,   5,   6,   7,   8,   9,   10,  11,  12,  13,  14,
     15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,
     30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,
@@ -128,7 +128,7 @@
     105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
     120, 121, 122, 123, 124, 125, 126, 127};
 
-const uint16_t kDefaultScan16x16[256] = {
+constexpr uint16_t kDefaultScan16x16[256] = {
     0,   1,   16,  32,  17,  2,   3,   18,  33,  48,  64,  49,  34,  19,  4,
     5,   20,  35,  50,  65,  80,  96,  81,  66,  51,  36,  21,  6,   7,   22,
     37,  52,  67,  82,  97,  112, 128, 113, 98,  83,  68,  53,  38,  23,  8,
@@ -148,7 +148,7 @@
     250, 251, 236, 221, 206, 191, 207, 222, 237, 252, 253, 238, 223, 239, 254,
     255};
 
-const uint16_t kColumnScan16x16[256] = {
+constexpr uint16_t kColumnScan16x16[256] = {
     0,  16, 32, 48, 64, 80, 96,  112, 128, 144, 160, 176, 192, 208, 224, 240,
     1,  17, 33, 49, 65, 81, 97,  113, 129, 145, 161, 177, 193, 209, 225, 241,
     2,  18, 34, 50, 66, 82, 98,  114, 130, 146, 162, 178, 194, 210, 226, 242,
@@ -166,7 +166,7 @@
     14, 30, 46, 62, 78, 94, 110, 126, 142, 158, 174, 190, 206, 222, 238, 254,
     15, 31, 47, 63, 79, 95, 111, 127, 143, 159, 175, 191, 207, 223, 239, 255};
 
-const uint16_t kRowScan16x16[256] = {
+constexpr uint16_t kRowScan16x16[256] = {
     0,   1,   2,   3,   4,   5,   6,   7,   8,   9,   10,  11,  12,  13,  14,
     15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,
     30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,
@@ -186,7 +186,7 @@
     240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254,
     255};
 
-const uint16_t kDefaultScan16x32[512] = {
+constexpr uint16_t kDefaultScan16x32[512] = {
     0,   1,   16,  2,   17,  32,  3,   18,  33,  48,  4,   19,  34,  49,  64,
     5,   20,  35,  50,  65,  80,  6,   21,  36,  51,  66,  81,  96,  7,   22,
     37,  52,  67,  82,  97,  112, 8,   23,  38,  53,  68,  83,  98,  113, 128,
@@ -223,7 +223,7 @@
     491, 506, 447, 462, 477, 492, 507, 463, 478, 493, 508, 479, 494, 509, 495,
     510, 511};
 
-const uint16_t kDefaultScan32x16[512] = {
+constexpr uint16_t kDefaultScan32x16[512] = {
     0,   32,  1,   64,  33,  2,   96,  65,  34,  3,   128, 97,  66,  35,  4,
     160, 129, 98,  67,  36,  5,   192, 161, 130, 99,  68,  37,  6,   224, 193,
     162, 131, 100, 69,  38,  7,   256, 225, 194, 163, 132, 101, 70,  39,  8,
@@ -260,7 +260,7 @@
     382, 351, 507, 476, 445, 414, 383, 508, 477, 446, 415, 509, 478, 447, 510,
     479, 511};
 
-const uint16_t kDefaultScan32x32[1024] = {
+constexpr uint16_t kDefaultScan32x32[1024] = {
     0,    1,    32,   64,   33,   2,   3,    34,   65,   96,   128,  97,  66,
     35,   4,    5,    36,   67,   98,  129,  160,  192,  161,  130,  99,  68,
     37,   6,    7,    38,   69,   100, 131,  162,  193,  224,  256,  225, 194,
@@ -341,43 +341,43 @@
     862,  831,  863,  894,  925,  956, 987,  1018, 1019, 988,  957,  926, 895,
     927,  958,  989,  1020, 1021, 990, 959,  991,  1022, 1023};
 
-const uint16_t kDefaultScan4x16[64] = {
+constexpr uint16_t kDefaultScan4x16[64] = {
     0,  1,  4,  2,  5,  8,  3,  6,  9,  12, 7,  10, 13, 16, 11, 14,
     17, 20, 15, 18, 21, 24, 19, 22, 25, 28, 23, 26, 29, 32, 27, 30,
     33, 36, 31, 34, 37, 40, 35, 38, 41, 44, 39, 42, 45, 48, 43, 46,
     49, 52, 47, 50, 53, 56, 51, 54, 57, 60, 55, 58, 61, 59, 62, 63};
 
-const uint16_t kColumnScan4x16[64] = {
+constexpr uint16_t kColumnScan4x16[64] = {
     0, 4, 8,  12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60,
     1, 5, 9,  13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61,
     2, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 58, 62,
     3, 7, 11, 15, 19, 23, 27, 31, 35, 39, 43, 47, 51, 55, 59, 63};
 
-const uint16_t kRowScan4x16[64] = {
+constexpr uint16_t kRowScan4x16[64] = {
     0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
     16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
     32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
     48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
 
-const uint16_t kDefaultScan16x4[64] = {
+constexpr uint16_t kDefaultScan16x4[64] = {
     0,  16, 1,  32, 17, 2,  48, 33, 18, 3,  49, 34, 19, 4,  50, 35,
     20, 5,  51, 36, 21, 6,  52, 37, 22, 7,  53, 38, 23, 8,  54, 39,
     24, 9,  55, 40, 25, 10, 56, 41, 26, 11, 57, 42, 27, 12, 58, 43,
     28, 13, 59, 44, 29, 14, 60, 45, 30, 15, 61, 46, 31, 62, 47, 63};
 
-const uint16_t kColumnScan16x4[64] = {
+constexpr uint16_t kColumnScan16x4[64] = {
     0,  16, 32, 48, 1,  17, 33, 49, 2,  18, 34, 50, 3,  19, 35, 51,
     4,  20, 36, 52, 5,  21, 37, 53, 6,  22, 38, 54, 7,  23, 39, 55,
     8,  24, 40, 56, 9,  25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59,
     12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63};
 
-const uint16_t kRowScan16x4[64] = {
+constexpr uint16_t kRowScan16x4[64] = {
     0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
     16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
     32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
     48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
 
-const uint16_t kDefaultScan8x32[256] = {
+constexpr uint16_t kDefaultScan8x32[256] = {
     0,   1,   8,   2,   9,   16,  3,   10,  17,  24,  4,   11,  18,  25,  32,
     5,   12,  19,  26,  33,  40,  6,   13,  20,  27,  34,  41,  48,  7,   14,
     21,  28,  35,  42,  49,  56,  15,  22,  29,  36,  43,  50,  57,  64,  23,
@@ -397,7 +397,7 @@
     250, 223, 230, 237, 244, 251, 231, 238, 245, 252, 239, 246, 253, 247, 254,
     255};
 
-const uint16_t kDefaultScan32x8[256] = {
+constexpr uint16_t kDefaultScan32x8[256] = {
     0,   32,  1,   64,  33,  2,   96,  65,  34,  3,   128, 97,  66,  35,  4,
     160, 129, 98,  67,  36,  5,   192, 161, 130, 99,  68,  37,  6,   224, 193,
     162, 131, 100, 69,  38,  7,   225, 194, 163, 132, 101, 70,  39,  8,   226,
diff --git a/libgav1/src/status_code.cc b/libgav1/src/status_code.cc
new file mode 100644
index 0000000..34def08
--- /dev/null
+++ b/libgav1/src/status_code.cc
@@ -0,0 +1,57 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/gav1/status_code.h"
+
+extern "C" {
+
+const char* Libgav1GetErrorString(Libgav1StatusCode status) {
+  switch (status) {
+    case kLibgav1StatusOk:
+      return "Success.";
+    case kLibgav1StatusUnknownError:
+      return "Unknown error.";
+    case kLibgav1StatusInvalidArgument:
+      return "Invalid function argument.";
+    case kLibgav1StatusOutOfMemory:
+      return "Memory allocation failure.";
+    case kLibgav1StatusResourceExhausted:
+      return "Ran out of a resource (other than memory).";
+    case kLibgav1StatusNotInitialized:
+      return "The object is not initialized.";
+    case kLibgav1StatusAlready:
+      return "An operation that can only be performed once has already been "
+             "performed.";
+    case kLibgav1StatusUnimplemented:
+      return "Not implemented.";
+    case kLibgav1StatusInternalError:
+      return "Internal error in libgav1.";
+    case kLibgav1StatusBitstreamError:
+      return "The bitstream is not encoded correctly or violates a bitstream "
+             "conformance requirement.";
+    case kLibgav1StatusTryAgain:
+      return "The operation is not allowed at the moment. Try again later.";
+    case kLibgav1StatusNothingToDequeue:
+      return "There are no enqueued frames, so there is nothing to dequeue. "
+             "Try enqueuing a frame before trying to dequeue again.";
+    // This switch statement does not have a default case. This way the compiler
+    // will warn if we neglect to update this function after adding a new value
+    // to the Libgav1StatusCode enum type.
+    case kLibgav1StatusReservedForFutureExpansionUseDefaultInSwitchInstead_:
+      break;
+  }
+  return "Unrecognized status code.";
+}
+
+}  // extern "C"
diff --git a/libgav1/src/status_code.h b/libgav1/src/status_code.h
deleted file mode 100644
index dad0548..0000000
--- a/libgav1/src/status_code.h
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Copyright 2019 The libgav1 Authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBGAV1_SRC_STATUS_CODE_H_
-#define LIBGAV1_SRC_STATUS_CODE_H_
-
-// All the declarations in this file are part of the public ABI. This file may
-// be included by both C and C++ files.
-
-// The Libgav1StatusCode enum type: A libgav1 function may return
-// Libgav1StatusCode to indicate success or the reason for failure.
-typedef enum {
-  // Success.
-  kLibgav1StatusOk = 0,
-
-  // An unknown error. Used as the default error status if error detail is not
-  // available.
-  kLibgav1StatusUnknownError = -1,
-
-  // An invalid function argument.
-  kLibgav1StatusInvalidArgument = -2,
-
-  // Memory allocation failure.
-  kLibgav1StatusOutOfMemory = -3,
-
-  // Ran out of a resource (other than memory).
-  kLibgav1StatusResourceExhausted = -4,
-
-  // The object is not initialized.
-  kLibgav1StatusNotInitialized = -5,
-
-  // An operation that can only be performed once has already been performed.
-  kLibgav1StatusAlready = -6,
-
-  // Not implemented, or not supported.
-  kLibgav1StatusUnimplemented = -7,
-
-  // An internal error in libgav1. Usually this indicates a programming error.
-  kLibgav1StatusInternalError = -8,
-
-  // The bitstream is not encoded correctly or violates a bitstream conformance
-  // requirement.
-  kLibgav1StatusBitstreamError = -9,
-
-  // An extra enumerator to prevent people from writing code that fails to
-  // compile when a new status code is added.
-  //
-  // Do not reference this enumerator. In particular, if you write code that
-  // switches on Libgav1StatusCode, add a default: case instead of a case that
-  // mentions this enumerator.
-  //
-  // Do not depend on the value (currently -1000) listed here. It may change in
-  // the future.
-  kLibgav1StatusReservedForFutureExpansionUseDefaultInSwitchInstead_ = -1000
-} Libgav1StatusCode;
-
-#if defined(__cplusplus)
-// Declare type aliases for C++.
-namespace libgav1 {
-
-using StatusCode = Libgav1StatusCode;
-
-}  // namespace libgav1
-#endif  // defined(__cplusplus)
-
-#endif  // LIBGAV1_SRC_STATUS_CODE_H_
diff --git a/libgav1/src/symbol_decoder_context.cc b/libgav1/src/symbol_decoder_context.cc
index 557002c..159f25c 100644
--- a/libgav1/src/symbol_decoder_context.cc
+++ b/libgav1/src/symbol_decoder_context.cc
@@ -98,7 +98,8 @@
 
 }  // namespace
 
-#define CDF_COPY(source, destination) \
+#define CDF_COPY(source, destination)                       \
+  static_assert(sizeof(source) == sizeof(destination), ""); \
   memcpy(destination, source, sizeof(source))
 
 void SymbolDecoderContext::Initialize(int base_quantizer_index) {
@@ -191,7 +192,7 @@
 
 #undef CDF_COPY
 
-// These macros set the last element in the inner-most dimesion of the array to
+// These macros set the last element in the inner-most dimension of the array to
 // zero.
 #define RESET_COUNTER_1D(array)                              \
   do {                                                       \
diff --git a/libgav1/src/symbol_decoder_context_cdfs.inc b/libgav1/src/symbol_decoder_context_cdfs.inc
index da18e6a..7f8f2c2 100644
--- a/libgav1/src/symbol_decoder_context_cdfs.inc
+++ b/libgav1/src/symbol_decoder_context_cdfs.inc
@@ -15,7 +15,7 @@
 // This file is just a convenience to separate out all the CDF constant
 // definitions from the symbol decoder context functions.
 
-const uint16_t kDefaultPartitionCdf
+constexpr uint16_t kDefaultPartitionCdf
     [kBlockWidthCount][kPartitionContexts][kMaxPartitionTypes + 1] = {
         // width 8
         {{13636, 7258, 2376, 0, 0},
@@ -43,30 +43,32 @@
          {27339, 26092, 25646, 741, 541, 237, 186, 0, 0},
          {32057, 31802, 31596, 320, 230, 151, 104, 0, 0}}};
 
-const uint16_t kDefaultSegmentIdCdf[kSegmentIdContexts][kMaxSegments + 1] = {
-    {27146, 24875, 16675, 14535, 4959, 4395, 235, 0, 0},
-    {18494, 14538, 10211, 7833, 2788, 1917, 424, 0, 0},
-    {5241, 4281, 4045, 3878, 371, 121, 89, 0, 0}};
+constexpr uint16_t kDefaultSegmentIdCdf[kSegmentIdContexts][kMaxSegments + 1] =
+    {{27146, 24875, 16675, 14535, 4959, 4395, 235, 0, 0},
+     {18494, 14538, 10211, 7833, 2788, 1917, 424, 0, 0},
+     {5241, 4281, 4045, 3878, 371, 121, 89, 0, 0}};
 
-const uint16_t kDefaultUsePredictedSegmentIdCdf[kUsePredictedSegmentIdContexts]
-                                               [kBooleanFieldCdfSize] = {
-                                                   {16384, 0, 0},
-                                                   {16384, 0, 0},
-                                                   {16384, 0, 0}};
+constexpr uint16_t
+    kDefaultUsePredictedSegmentIdCdf[kUsePredictedSegmentIdContexts]
+                                    [kBooleanFieldCdfSize] = {{16384, 0, 0},
+                                                              {16384, 0, 0},
+                                                              {16384, 0, 0}};
 
-const uint16_t kDefaultSkipCdf[kSkipContexts][kBooleanFieldCdfSize] = {
+constexpr uint16_t kDefaultSkipCdf[kSkipContexts][kBooleanFieldCdfSize] = {
     {1097, 0, 0}, {16253, 0, 0}, {28192, 0, 0}};
 
-const uint16_t kDefaultSkipModeCdf[kSkipModeContexts][kBooleanFieldCdfSize] = {
-    {147, 0, 0}, {12060, 0, 0}, {24641, 0, 0}};
+constexpr uint16_t
+    kDefaultSkipModeCdf[kSkipModeContexts][kBooleanFieldCdfSize] = {
+        {147, 0, 0}, {12060, 0, 0}, {24641, 0, 0}};
 
 // This constant is also used for DeltaLf and DeltaLfMulti.
-const uint16_t kDefaultDeltaQCdf[kDeltaSymbolCount + 1] = {4608, 648, 91, 0,
-                                                            0};
+constexpr uint16_t kDefaultDeltaQCdf[kDeltaSymbolCount + 1] = {4608, 648, 91, 0,
+                                                               0};
 
-const uint16_t kDefaultIntraBlockCopyCdf[kBooleanFieldCdfSize] = {2237, 0, 0};
+constexpr uint16_t kDefaultIntraBlockCopyCdf[kBooleanFieldCdfSize] = {2237, 0,
+                                                                      0};
 
-const uint16_t
+constexpr uint16_t
     kDefaultIntraFrameYModeCdf[kIntraModeContexts][kIntraModeContexts]
                               [kIntraPredictionModesY + 1] = {
                                   {{17180, 15741, 13430, 12550, 12086, 11658,
@@ -120,17 +122,18 @@
                                    {25150, 24480, 22909, 22259, 17382, 14111,
                                     9865, 3992, 3588, 1413, 966, 175, 0, 0}}};
 
-const uint16_t kDefaultYModeCdf[kYModeContexts][kIntraPredictionModesY + 1] = {
-    {9967, 9279, 8475, 8012, 7167, 6645, 6162, 5350, 4823, 3540, 3083, 2419, 0,
-     0},
-    {14095, 12923, 10137, 9450, 8818, 8119, 7241, 5404, 4616, 3067, 2784, 1916,
-     0, 0},
-    {12998, 11789, 9372, 8829, 8527, 8114, 7632, 5695, 4938, 3408, 3038, 2109,
-     0, 0},
-    {12613, 11467, 9930, 9590, 9507, 9235, 9065, 7964, 7416, 6193, 5752, 4719,
-     0, 0}};
+constexpr uint16_t
+    kDefaultYModeCdf[kYModeContexts][kIntraPredictionModesY + 1] = {
+        {9967, 9279, 8475, 8012, 7167, 6645, 6162, 5350, 4823, 3540, 3083, 2419,
+         0, 0},
+        {14095, 12923, 10137, 9450, 8818, 8119, 7241, 5404, 4616, 3067, 2784,
+         1916, 0, 0},
+        {12998, 11789, 9372, 8829, 8527, 8114, 7632, 5695, 4938, 3408, 3038,
+         2109, 0, 0},
+        {12613, 11467, 9930, 9590, 9507, 9235, 9065, 7964, 7416, 6193, 5752,
+         4719, 0, 0}};
 
-const uint16_t
+constexpr uint16_t
     kDefaultAngleDeltaCdf[kDirectionalIntraModes][kAngleDeltaSymbolCount + 1] =
         {{30588, 27736, 25201, 9992, 5779, 2551, 0, 0},
          {30467, 27160, 23967, 9281, 5794, 2438, 0, 0},
@@ -141,7 +144,7 @@
          {30528, 21672, 17315, 12427, 10207, 3851, 0, 0},
          {29163, 22340, 20309, 15092, 11524, 2113, 0, 0}};
 
-const uint16_t
+constexpr uint16_t
     kDefaultUVModeCdf[kBooleanSymbolCount][kIntraPredictionModesY]
                      [kIntraPredictionModesUV + 1] = {
                          // CFL not allowed.
@@ -157,8 +160,8 @@
                            4525, 1667, 1024, 405, 0, 0},
                           {20943, 19179, 19091, 19048, 17720, 3555, 3467, 3310,
                            3057, 1607, 1327, 218, 0, 0},
-                          {18593, 18369, 16160, 15947, 15050, 14993, 4217,
-                           2568, 2523, 931, 426, 101, 0, 0},
+                          {18593, 18369, 16160, 15947, 15050, 14993, 4217, 2568,
+                           2523, 931, 426, 101, 0, 0},
                           {19883, 19730, 17790, 17178, 17095, 17020, 16592,
                            3640, 3501, 2125, 807, 307, 0, 0},
                           {20742, 19107, 18894, 17463, 17278, 17042, 16773,
@@ -199,11 +202,11 @@
                           {29624, 27681, 25386, 25264, 25175, 25078, 24967,
                            24704, 24536, 23520, 22893, 22247, 3720, 0, 0}}};
 
-const uint16_t kDefaultCflAlphaSignsCdf[kCflAlphaSignsSymbolCount + 1] = {
+constexpr uint16_t kDefaultCflAlphaSignsCdf[kCflAlphaSignsSymbolCount + 1] = {
     31350, 30645, 19428, 14363, 5796, 4425, 474, 0, 0};
 
-const uint16_t kDefaultCflAlphaCdf[kCflAlphaContexts][kCflAlphaSymbolCount +
-                                                      1] = {
+constexpr uint16_t kDefaultCflAlphaCdf[kCflAlphaContexts][kCflAlphaSymbolCount +
+                                                          1] = {
     {25131, 12049, 1367, 287, 111, 80, 76, 72, 68, 64, 60, 56, 52, 48, 44, 0,
      0},
     {18403, 9165, 4633, 1600, 601, 373, 281, 195, 148, 121, 100, 96, 92, 88, 84,
@@ -216,32 +219,34 @@
     {18030, 11090, 6989, 4867, 3744, 2466, 1788, 925, 624, 355, 248, 174, 146,
      112, 108, 0, 0}};
 
-const uint16_t kDefaultUseFilterIntraCdf[kMaxBlockSizes][kBooleanFieldCdfSize] =
-    {{28147, 0, 0}, {26025, 0, 0}, {19998, 0, 0}, {26875, 0, 0}, {24902, 0, 0},
-     {20217, 0, 0}, {12539, 0, 0}, {22400, 0, 0}, {23374, 0, 0}, {20360, 0, 0},
-     {18467, 0, 0}, {16384, 0, 0}, {14667, 0, 0}, {20012, 0, 0}, {10425, 0, 0},
-     {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0},
-     {16384, 0, 0}, {16384, 0, 0}};
+constexpr uint16_t
+    kDefaultUseFilterIntraCdf[kMaxBlockSizes][kBooleanFieldCdfSize] = {
+        {28147, 0, 0}, {26025, 0, 0}, {19998, 0, 0}, {26875, 0, 0},
+        {24902, 0, 0}, {20217, 0, 0}, {12539, 0, 0}, {22400, 0, 0},
+        {23374, 0, 0}, {20360, 0, 0}, {18467, 0, 0}, {16384, 0, 0},
+        {14667, 0, 0}, {20012, 0, 0}, {10425, 0, 0}, {16384, 0, 0},
+        {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0},
+        {16384, 0, 0}, {16384, 0, 0}};
 
-const uint16_t kDefaultFilterIntraModeCdf[kNumFilterIntraPredictors + 1] = {
+constexpr uint16_t kDefaultFilterIntraModeCdf[kNumFilterIntraPredictors + 1] = {
     23819, 19992, 15557, 3210, 0, 0};
 
-const uint16_t
+constexpr uint16_t
     kDefaultTxDepthCdf[4][kTxDepthContexts][kMaxTxDepthSymbolCount + 1] = {
         {{12800, 0, 0}, {12800, 0, 0}, {8448, 0, 0}},
         {{20496, 2596, 0, 0}, {20496, 2596, 0, 0}, {14091, 1920, 0, 0}},
         {{19782, 17588, 0, 0}, {19782, 17588, 0, 0}, {8466, 7166, 0, 0}},
         {{26986, 21293, 0, 0}, {26986, 21293, 0, 0}, {15965, 10009, 0, 0}}};
 
-const uint16_t kDefaultTxSplitCdf[kTxSplitContexts][kBooleanFieldCdfSize] = {
-    {4187, 0, 0},  {8922, 0, 0},  {11921, 0, 0}, {8453, 0, 0},  {14572, 0, 0},
-    {20635, 0, 0}, {13977, 0, 0}, {21881, 0, 0}, {21763, 0, 0}, {5589, 0, 0},
-    {12764, 0, 0}, {21487, 0, 0}, {6219, 0, 0},  {13460, 0, 0}, {18544, 0, 0},
-    {4753, 0, 0},  {11222, 0, 0}, {18368, 0, 0}, {4603, 0, 0},  {10367, 0, 0},
-    {16680, 0, 0}};
+constexpr uint16_t kDefaultTxSplitCdf[kTxSplitContexts][kBooleanFieldCdfSize] =
+    {{4187, 0, 0},  {8922, 0, 0},  {11921, 0, 0}, {8453, 0, 0},  {14572, 0, 0},
+     {20635, 0, 0}, {13977, 0, 0}, {21881, 0, 0}, {21763, 0, 0}, {5589, 0, 0},
+     {12764, 0, 0}, {21487, 0, 0}, {6219, 0, 0},  {13460, 0, 0}, {18544, 0, 0},
+     {4753, 0, 0},  {11222, 0, 0}, {18368, 0, 0}, {4603, 0, 0},  {10367, 0, 0},
+     {16680, 0, 0}};
 
 /* clang-format off */
-const uint16_t kDefaultAllZeroCdf[kCoefficientQuantizerContexts]
+constexpr uint16_t kDefaultAllZeroCdf[kCoefficientQuantizerContexts]
                                  [kNumSquareTransformSizes][kAllZeroContexts]
                                  [kBooleanFieldCdfSize] = {
   {
@@ -315,7 +320,7 @@
 };
 /* clang-format on */
 
-const uint16_t
+constexpr uint16_t
     kDefaultInterTxTypeCdf[3][kNumExtendedTransformSizes][kNumTransformTypes +
                                                           1] = {
         {{28310, 27208, 25073, 23059, 19438, 17979, 15231, 12502, 11264, 9920,
@@ -334,7 +339,7 @@
          {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
         {{16384, 0, 0}, {28601, 0, 0}, {30770, 0, 0}, {32020, 0, 0}}};
 
-const uint16_t kDefaultIntraTxTypeCdf
+constexpr uint16_t kDefaultIntraTxTypeCdf
     [2][kNumExtendedTransformSizes][kIntraPredictionModesY]
     [kNumTransformTypes + 1] = {
         {{{31233, 24733, 23307, 20017, 9301, 4943, 0, 0},
@@ -403,25 +408,26 @@
           {32685, 27153, 20767, 15540, 0, 0},
           {30800, 27212, 20745, 14221, 0, 0}}}};
 
-const uint16_t kDefaultEobPt16Cdf[kCoefficientQuantizerContexts][kNumPlaneTypes]
-                                 [kEobPtContexts][kEobPt16SymbolCount + 1] = {
-                                     {{{31928, 31729, 30788, 27873, 0, 0},
-                                       {32398, 32097, 30885, 28297, 0, 0}},
-                                      {{29521, 27818, 23080, 18205, 0, 0},
-                                       {30864, 29414, 25005, 18121, 0, 0}}},
-                                     {{{30643, 30217, 27603, 23822, 0, 0},
-                                       {32255, 32003, 30909, 26429, 0, 0}},
-                                      {{25131, 23270, 18509, 13660, 0, 0},
-                                       {30271, 28672, 23902, 15775, 0, 0}}},
-                                     {{{28752, 27871, 23887, 17800, 0, 0},
-                                       {32052, 31663, 30122, 22712, 0, 0}},
-                                      {{21629, 19498, 14527, 9202, 0, 0},
-                                       {29576, 27736, 22471, 13013, 0, 0}}},
-                                     {{{26060, 23810, 18022, 10635, 0, 0},
-                                       {31546, 30694, 27985, 17358, 0, 0}},
-                                      {{13193, 11002, 6724, 3059, 0, 0},
-                                       {25471, 22001, 13495, 4574, 0, 0}}}};
-const uint16_t
+constexpr uint16_t kDefaultEobPt16Cdf[kCoefficientQuantizerContexts]
+                                     [kNumPlaneTypes][kEobPtContexts]
+                                     [kEobPt16SymbolCount + 1] = {
+                                         {{{31928, 31729, 30788, 27873, 0, 0},
+                                           {32398, 32097, 30885, 28297, 0, 0}},
+                                          {{29521, 27818, 23080, 18205, 0, 0},
+                                           {30864, 29414, 25005, 18121, 0, 0}}},
+                                         {{{30643, 30217, 27603, 23822, 0, 0},
+                                           {32255, 32003, 30909, 26429, 0, 0}},
+                                          {{25131, 23270, 18509, 13660, 0, 0},
+                                           {30271, 28672, 23902, 15775, 0, 0}}},
+                                         {{{28752, 27871, 23887, 17800, 0, 0},
+                                           {32052, 31663, 30122, 22712, 0, 0}},
+                                          {{21629, 19498, 14527, 9202, 0, 0},
+                                           {29576, 27736, 22471, 13013, 0, 0}}},
+                                         {{{26060, 23810, 18022, 10635, 0, 0},
+                                           {31546, 30694, 27985, 17358, 0, 0}},
+                                          {{13193, 11002, 6724, 3059, 0, 0},
+                                           {25471, 22001, 13495, 4574, 0, 0}}}};
+constexpr uint16_t
     kDefaultEobPt32Cdf[kCoefficientQuantizerContexts][kNumPlaneTypes]
                       [kEobPtContexts][kEobPt32SymbolCount + 1] = {
                           {{{32368, 32248, 31791, 30666, 26226, 0, 0},
@@ -440,7 +446,7 @@
                             {31612, 31066, 29093, 23494, 12229, 0, 0}},
                            {{10682, 8486, 5758, 2998, 1025, 0, 0},
                             {25069, 21871, 11877, 5842, 1140, 0, 0}}}};
-const uint16_t
+constexpr uint16_t
     kDefaultEobPt64Cdf[kCoefficientQuantizerContexts][kNumPlaneTypes]
                       [kEobPtContexts][kEobPt64SymbolCount + 1] = {
                           {{{32439, 32270, 31667, 30984, 29503, 25010, 0, 0},
@@ -459,7 +465,7 @@
                             {31479, 30448, 28797, 24842, 18615, 8477, 0, 0}},
                            {{8556, 7060, 4500, 2733, 1461, 719, 0, 0},
                             {24042, 20390, 13359, 6318, 2730, 306, 0, 0}}}};
-const uint16_t kDefaultEobPt128Cdf
+constexpr uint16_t kDefaultEobPt128Cdf
     [kCoefficientQuantizerContexts][kNumPlaneTypes][kEobPtContexts]
     [kEobPt128SymbolCount + 1] = {
         {{{32549, 32286, 31628, 30677, 29088, 26740, 20182, 0, 0},
@@ -479,7 +485,7 @@
          {{8455, 6706, 4383, 2661, 1551, 870, 423, 0, 0},
           {23603, 19486, 11618, 2482, 874, 197, 56, 0, 0}}}};
 
-const uint16_t kDefaultEobPt256Cdf
+constexpr uint16_t kDefaultEobPt256Cdf
     [kCoefficientQuantizerContexts][kNumPlaneTypes][kEobPtContexts]
     [kEobPt256SymbolCount + 1] = {
         {{{32458, 32184, 30881, 29179, 26600, 24157, 21416, 17116, 0, 0},
@@ -499,7 +505,7 @@
          {{9658, 8171, 5628, 3874, 2601, 1841, 1376, 674, 0, 0},
           {22770, 15107, 7590, 4671, 1460, 730, 365, 73, 0, 0}}}};
 
-const uint16_t kDefaultEobPt512Cdf
+constexpr uint16_t kDefaultEobPt512Cdf
     [kCoefficientQuantizerContexts][kNumPlaneTypes][kEobPt512SymbolCount + 1] =
         {{{32127, 31785, 29061, 27338, 22534, 17810, 13980, 9356, 6707, 0, 0},
           {27673, 26322, 22772, 19414, 16751, 14782, 11849, 6639, 3628, 0, 0}},
@@ -510,7 +516,7 @@
          {{26841, 24959, 21845, 18171, 13329, 8633, 4312, 1626, 708, 0, 0},
           {11675, 9725, 7026, 5110, 3671, 3052, 2695, 1948, 812, 0, 0}}};
 
-const uint16_t
+constexpr uint16_t
     kDefaultEobPt1024Cdf[kCoefficientQuantizerContexts][kNumPlaneTypes]
                         [kEobPt1024SymbolCount + 1] = {
                             {{32375, 32347, 32017, 31145, 29608, 26416, 19423,
@@ -531,7 +537,7 @@
                               2961, 198, 0, 0}}};
 
 /* clang-format off */
-const uint16_t kDefaultEobExtraCdf[kCoefficientQuantizerContexts]
+constexpr uint16_t kDefaultEobExtraCdf[kCoefficientQuantizerContexts]
                                   [kNumSquareTransformSizes][kNumPlaneTypes]
                                   [kEobExtraContexts][kBooleanFieldCdfSize] = {
   {
@@ -704,7 +710,7 @@
   }
 };
 
-const uint16_t kDefaultCoeffBaseEobCdf[kCoefficientQuantizerContexts]
+constexpr uint16_t kDefaultCoeffBaseEobCdf[kCoefficientQuantizerContexts]
                                       [kNumSquareTransformSizes][kNumPlaneTypes]
                                       [kCoeffBaseEobContexts]
                                       [kCoeffBaseEobSymbolCount + 1] = {
@@ -839,7 +845,7 @@
 };
 /* clang-format on */
 
-const uint16_t kDefaultCoeffBaseCdf
+constexpr uint16_t kDefaultCoeffBaseCdf
     [kCoefficientQuantizerContexts][kNumSquareTransformSizes][kNumPlaneTypes]
     [kCoeffBaseContexts][kCoeffBaseSymbolCount + 1] = {
         {{{{28734, 23838, 20041, 0, 0}, {14686, 3027, 891, 0, 0},
@@ -1683,7 +1689,7 @@
            {24576, 16384, 8192, 0, 0}, {24576, 16384, 8192, 0, 0},
            {24576, 16384, 8192, 0, 0}, {24576, 16384, 8192, 0, 0}}}}};
 
-const uint16_t kDefaultCoeffBaseRangeCdf
+constexpr uint16_t kDefaultCoeffBaseRangeCdf
     [kCoefficientQuantizerContexts][kNumSquareTransformSizes][kNumPlaneTypes]
     [kCoeffBaseRangeContexts][kCoeffBaseRangeSymbolCount + 1] = {
         {{{{18470, 12050, 8594, 0, 0},  {20232, 13167, 8979, 0, 0},
@@ -2128,7 +2134,7 @@
            {24576, 16384, 8192, 0, 0}}}}};
 
 /* clang-format off */
-const uint16_t kDefaultDcSignCdf[kCoefficientQuantizerContexts][kNumPlaneTypes]
+constexpr uint16_t kDefaultDcSignCdf[kCoefficientQuantizerContexts][kNumPlaneTypes]
                                 [kDcSignContexts][kBooleanFieldCdfSize] = {
   {{{16768, 0, 0}, {19712, 0, 0}, {13952, 0, 0}}, {{17536, 0, 0}, {19840, 0, 0},
     {15488, 0, 0}}},
@@ -2140,14 +2146,14 @@
     {15488, 0, 0}}}
 };
 /* clang-format on */
-const uint16_t kDefaultRestorationTypeCdf[kCoeffBaseRangeSymbolCount + 1] = {
-    23355, 10187, 0, 0};
+constexpr uint16_t kDefaultRestorationTypeCdf[kRestorationTypeSymbolCount + 1] =
+    {23355, 10187, 0, 0};
 
-const uint16_t kDefaultUseWienerCdf[kBooleanFieldCdfSize] = {21198, 0, 0};
+constexpr uint16_t kDefaultUseWienerCdf[kBooleanFieldCdfSize] = {21198, 0, 0};
 
-const uint16_t kDefaultUseSgrProjCdf[kBooleanFieldCdfSize] = {15913, 0, 0};
+constexpr uint16_t kDefaultUseSgrProjCdf[kBooleanFieldCdfSize] = {15913, 0, 0};
 
-const uint16_t
+constexpr uint16_t
     kDefaultHasPaletteYCdf[kPaletteBlockSizeContexts][kPaletteYModeContexts]
                           [kBooleanFieldCdfSize] = {
                               {{1092, 0, 0}, {29349, 0, 0}, {31507, 0, 0}},
@@ -2158,7 +2164,7 @@
                               {{503, 0, 0}, {28753, 0, 0}, {31247, 0, 0}},
                               {{318, 0, 0}, {24822, 0, 0}, {32639, 0, 0}}};
 
-const uint16_t
+constexpr uint16_t
     kDefaultPaletteYSizeCdf[kPaletteBlockSizeContexts]
                            [kPaletteSizeSymbolCount + 1] = {
                                {24816, 19768, 14619, 11290, 7241, 3527, 0, 0},
@@ -2169,11 +2175,11 @@
                                {23057, 17880, 15845, 11716, 7107, 4893, 0, 0},
                                {17828, 11971, 11090, 8582, 5735, 3769, 0, 0}};
 
-const uint16_t kDefaultHasPaletteUVCdf[kPaletteUVModeContexts]
-                                      [kBooleanFieldCdfSize] = {{307, 0, 0},
-                                                                {11280, 0, 0}};
+constexpr uint16_t kDefaultHasPaletteUVCdf[kPaletteUVModeContexts]
+                                          [kBooleanFieldCdfSize] = {
+                                              {307, 0, 0}, {11280, 0, 0}};
 
-const uint16_t
+constexpr uint16_t
     kDefaultPaletteUVSizeCdf[kPaletteBlockSizeContexts]
                             [kPaletteSizeSymbolCount + 1] = {
                                 {24055, 12789, 5640, 3159, 1437, 496, 0, 0},
@@ -2185,7 +2191,7 @@
                                 {31499, 27333, 22335, 13805, 11068, 6903, 0,
                                  0}};
 
-const uint16_t kDefaultPaletteColorIndexCdf
+constexpr uint16_t kDefaultPaletteColorIndexCdf
     [kNumPlaneTypes][kPaletteSizeSymbolCount][kPaletteColorIndexContexts]
     [kPaletteColorIndexSymbolCount + 1] = {
         {{{4058, 0, 0},
@@ -2259,26 +2265,26 @@
           {14803, 12684, 10536, 8794, 6494, 4366, 2378, 0, 0},
           {1578, 1439, 1252, 1089, 943, 742, 446, 0, 0}}}};
 
-const uint16_t kDefaultIsInterCdf[kIsInterContexts][kBooleanFieldCdfSize] = {
-    {31962, 0, 0}, {16106, 0, 0}, {12582, 0, 0}, {6230, 0, 0}};
+constexpr uint16_t kDefaultIsInterCdf[kIsInterContexts][kBooleanFieldCdfSize] =
+    {{31962, 0, 0}, {16106, 0, 0}, {12582, 0, 0}, {6230, 0, 0}};
 
-const uint16_t kDefaultUseCompoundReferenceCdf[kUseCompoundReferenceContexts]
-                                              [kBooleanFieldCdfSize] = {
-                                                  {5940, 0, 0},
-                                                  {8733, 0, 0},
-                                                  {20737, 0, 0},
-                                                  {22128, 0, 0},
-                                                  {29867, 0, 0}};
+constexpr uint16_t
+    kDefaultUseCompoundReferenceCdf[kUseCompoundReferenceContexts]
+                                   [kBooleanFieldCdfSize] = {{5940, 0, 0},
+                                                             {8733, 0, 0},
+                                                             {20737, 0, 0},
+                                                             {22128, 0, 0},
+                                                             {29867, 0, 0}};
 
-const uint16_t kDefaultCompoundReferenceTypeCdf[kCompoundReferenceTypeContexts]
-                                               [kBooleanFieldCdfSize] = {
-                                                   {31570, 0, 0},
-                                                   {30698, 0, 0},
-                                                   {23602, 0, 0},
-                                                   {25269, 0, 0},
-                                                   {10293, 0, 0}};
+constexpr uint16_t
+    kDefaultCompoundReferenceTypeCdf[kCompoundReferenceTypeContexts]
+                                    [kBooleanFieldCdfSize] = {{31570, 0, 0},
+                                                              {30698, 0, 0},
+                                                              {23602, 0, 0},
+                                                              {25269, 0, 0},
+                                                              {10293, 0, 0}};
 
-const uint16_t kDefaultCompoundReferenceCdf
+constexpr uint16_t kDefaultCompoundReferenceCdf
     [kNumCompoundReferenceTypes][kReferenceContexts][3][kBooleanFieldCdfSize] =
         {{{{27484, 0, 0}, {28903, 0, 0}, {29640, 0, 0}},
           {{9616, 0, 0}, {18595, 0, 0}, {17498, 0, 0}},
@@ -2287,7 +2293,7 @@
           {{12877, 0, 0}, {10327, 0, 0}, {17608, 0, 0}},
           {{2037, 0, 0}, {1709, 0, 0}, {5224, 0, 0}}}};
 
-const uint16_t
+constexpr uint16_t
     kDefaultCompoundBackwardReferenceCdf[kReferenceContexts][2]
                                         [kBooleanFieldCdfSize] = {
                                             {{30533, 0, 0}, {31345, 0, 0}},
@@ -2295,7 +2301,7 @@
                                             {{2162, 0, 0}, {2279, 0, 0}}};
 
 /* clang-format off */
-const uint16_t kDefaultSingleReferenceCdf[kReferenceContexts][6]
+constexpr uint16_t kDefaultSingleReferenceCdf[kReferenceContexts][6]
                                          [kBooleanFieldCdfSize] = {
   {{27871, 0, 0}, {31213, 0, 0}, {28532, 0, 0}, {24118, 0, 0}, {31864, 0, 0},
    {31324, 0, 0}},
@@ -2305,7 +2311,7 @@
    {2464, 0, 0}}};
 /* clang-format on */
 
-const uint16_t kDefaultCompoundPredictionModeCdf
+constexpr uint16_t kDefaultCompoundPredictionModeCdf
     [kCompoundPredictionModeContexts][kNumCompoundInterPredictionModes + 1] = {
         {25008, 18945, 16960, 15127, 13612, 12102, 5877, 0, 0},
         {22038, 13316, 11623, 10019, 8729, 7637, 4044, 0, 0},
@@ -2316,35 +2322,35 @@
         {15643, 8495, 6954, 5276, 4554, 4064, 2176, 0, 0},
         {19722, 9554, 8263, 6826, 5333, 4326, 3438, 0, 0}};
 
-const uint16_t kDefaultNewMvCdf[kNewMvContexts][kBooleanFieldCdfSize] = {
+constexpr uint16_t kDefaultNewMvCdf[kNewMvContexts][kBooleanFieldCdfSize] = {
     {8733, 0, 0},  {16138, 0, 0}, {17429, 0, 0},
     {24382, 0, 0}, {20546, 0, 0}, {28092, 0, 0}};
 
-const uint16_t kDefaultZeroMvCdf[kZeroMvContexts][kBooleanFieldCdfSize] = {
+constexpr uint16_t kDefaultZeroMvCdf[kZeroMvContexts][kBooleanFieldCdfSize] = {
     {30593, 0, 0}, {31714, 0, 0}};
 
-const uint16_t
+constexpr uint16_t
     kDefaultReferenceMvCdf[kReferenceMvContexts][kBooleanFieldCdfSize] = {
         {8794, 0, 0}, {8580, 0, 0}, {14920, 0, 0},
         {4146, 0, 0}, {8456, 0, 0}, {12845, 0, 0}};
 
 // This is called drl_mode in the spec where DRL stands for Dynamic Reference
 // List.
-const uint16_t
+constexpr uint16_t
     kDefaultRefMvIndexCdf[kRefMvIndexContexts][kBooleanFieldCdfSize] = {
         {19664, 0, 0}, {8208, 0, 0}, {13823, 0, 0}};
 
-const uint16_t
+constexpr uint16_t
     kDefaultIsInterIntraCdf[kInterIntraContexts][kBooleanFieldCdfSize] = {
         {5881, 0, 0}, {5171, 0, 0}, {2531, 0, 0}};
 
-const uint16_t kDefaultInterIntraModeCdf[kInterIntraContexts]
-                                        [kNumInterIntraModes + 1] = {
-                                            {30893, 21686, 5436, 0, 0},
-                                            {30295, 22772, 6380, 0, 0},
-                                            {28530, 21231, 6842, 0, 0}};
+constexpr uint16_t kDefaultInterIntraModeCdf[kInterIntraContexts]
+                                            [kNumInterIntraModes + 1] = {
+                                                {30893, 21686, 5436, 0, 0},
+                                                {30295, 22772, 6380, 0, 0},
+                                                {28530, 21231, 6842, 0, 0}};
 
-const uint16_t
+constexpr uint16_t
     kDefaultIsWedgeInterIntraCdf[kMaxBlockSizes][kBooleanFieldCdfSize] = {
         {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0},
         {12732, 0, 0}, {7811, 0, 0},  {16384, 0, 0}, {16384, 0, 0},
@@ -2353,7 +2359,7 @@
         {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0},
         {16384, 0, 0}, {16384, 0, 0}};
 
-const uint16_t
+constexpr uint16_t
     kDefaultWedgeIndexCdf[kMaxBlockSizes][kWedgeIndexSymbolCount + 1] = {
         {30720, 28672, 26624, 24576, 22528, 20480, 18432, 16384, 14336, 12288,
          10240, 8192, 6144, 4096, 2048, 0, 0},
@@ -2400,40 +2406,38 @@
         {30720, 28672, 26624, 24576, 22528, 20480, 18432, 16384, 14336, 12288,
          10240, 8192, 6144, 4096, 2048, 0, 0}};
 
-const uint16_t kDefaultUseObmcCdf[kMaxBlockSizes][kBooleanFieldCdfSize] = {
+constexpr uint16_t kDefaultUseObmcCdf[kMaxBlockSizes][kBooleanFieldCdfSize] = {
     {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0}, {22331, 0, 0},
     {23397, 0, 0}, {9104, 0, 0},  {16384, 0, 0}, {23467, 0, 0}, {15336, 0, 0},
     {18345, 0, 0}, {8760, 0, 0},  {11867, 0, 0}, {17626, 0, 0}, {6951, 0, 0},
     {9945, 0, 0},  {5889, 0, 0},  {10685, 0, 0}, {2640, 0, 0},  {1754, 0, 0},
     {1208, 0, 0},  {130, 0, 0}};
 
-const uint16_t kDefaultMotionModeCdf[kMaxBlockSizes][kNumMotionModes + 1] = {
-    {21845, 10923, 0, 0}, {21845, 10923, 0, 0}, {21845, 10923, 0, 0},
-    {21845, 10923, 0, 0}, {25117, 8008, 0, 0},  {28030, 8003, 0, 0},
-    {3969, 1378, 0, 0},   {21845, 10923, 0, 0}, {27377, 7240, 0, 0},
-    {13349, 5958, 0, 0},  {27645, 9162, 0, 0},  {3795, 1174, 0, 0},
-    {6337, 1994, 0, 0},   {21162, 8460, 0, 0},  {6508, 3652, 0, 0},
-    {12408, 4706, 0, 0},  {3026, 1565, 0, 0},   {11089, 5938, 0, 0},
-    {3252, 2067, 0, 0},   {3870, 2371, 0, 0},   {1890, 1433, 0, 0},
-    {261, 210, 0, 0}};
+constexpr uint16_t kDefaultMotionModeCdf[kMaxBlockSizes][kNumMotionModes + 1] =
+    {{21845, 10923, 0, 0}, {21845, 10923, 0, 0}, {21845, 10923, 0, 0},
+     {21845, 10923, 0, 0}, {25117, 8008, 0, 0},  {28030, 8003, 0, 0},
+     {3969, 1378, 0, 0},   {21845, 10923, 0, 0}, {27377, 7240, 0, 0},
+     {13349, 5958, 0, 0},  {27645, 9162, 0, 0},  {3795, 1174, 0, 0},
+     {6337, 1994, 0, 0},   {21162, 8460, 0, 0},  {6508, 3652, 0, 0},
+     {12408, 4706, 0, 0},  {3026, 1565, 0, 0},   {11089, 5938, 0, 0},
+     {3252, 2067, 0, 0},   {3870, 2371, 0, 0},   {1890, 1433, 0, 0},
+     {261, 210, 0, 0}};
 
-const uint16_t
+constexpr uint16_t
     kDefaultIsExplicitCompoundTypeCdf[kIsExplicitCompoundTypeContexts]
                                      [kBooleanFieldCdfSize] = {
                                          {6161, 0, 0},  {9877, 0, 0},
                                          {13928, 0, 0}, {8174, 0, 0},
                                          {12834, 0, 0}, {10094, 0, 0}};
 
-const uint16_t kDefaultIsCompoundTypeAverageCdf[kIsCompoundTypeAverageContexts]
-                                               [kBooleanFieldCdfSize] = {
-                                                   {14524, 0, 0},
-                                                   {19903, 0, 0},
-                                                   {25715, 0, 0},
-                                                   {19509, 0, 0},
-                                                   {23434, 0, 0},
-                                                   {28124, 0, 0}};
+constexpr uint16_t
+    kDefaultIsCompoundTypeAverageCdf[kIsCompoundTypeAverageContexts]
+                                    [kBooleanFieldCdfSize] = {
+                                        {14524, 0, 0}, {19903, 0, 0},
+                                        {25715, 0, 0}, {19509, 0, 0},
+                                        {23434, 0, 0}, {28124, 0, 0}};
 
-const uint16_t kDefaultCompoundTypeCdf
+constexpr uint16_t kDefaultCompoundTypeCdf
     [kMaxBlockSizes][kNumExplicitCompoundPredictionTypes + 1] = {
         {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0},
         {9337, 0, 0},  {19597, 0, 0}, {20948, 0, 0}, {16384, 0, 0},
@@ -2442,7 +2446,7 @@
         {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0}, {16384, 0, 0},
         {16384, 0, 0}, {16384, 0, 0}};
 
-const uint16_t kDefaultInterpolationFilterCdf
+constexpr uint16_t kDefaultInterpolationFilterCdf
     [kInterpolationFilterContexts][kNumExplicitInterpolationFilters + 1] = {
         {833, 48, 0, 0},      {27200, 49, 0, 0},    {32346, 29830, 0, 0},
         {4524, 160, 0, 0},    {1562, 815, 0, 0},    {27906, 647, 0, 0},
@@ -2451,29 +2455,30 @@
         {1746, 759, 0, 0},    {29805, 675, 0, 0},   {32167, 31825, 0, 0},
         {17799, 11370, 0, 0}};
 
-const uint16_t kDefaultMvJointCdf[kNumMvJointTypes + 1] = {28672, 21504, 13440,
-                                                           0, 0};
+constexpr uint16_t kDefaultMvJointCdf[kNumMvJointTypes + 1] = {28672, 21504,
+                                                               13440, 0, 0};
 
-const uint16_t kDefaultMvSignCdf[kBooleanFieldCdfSize] = {16384, 0, 0};
+constexpr uint16_t kDefaultMvSignCdf[kBooleanFieldCdfSize] = {16384, 0, 0};
 
-const uint16_t kDefaultMvClassCdf[kMvClassSymbolCount + 1] = {
+constexpr uint16_t kDefaultMvClassCdf[kMvClassSymbolCount + 1] = {
     4096, 1792, 910, 448, 217, 112, 28, 11, 6, 1, 0};
 
-const uint16_t kDefaultMvClass0BitCdf[kBooleanFieldCdfSize] = {5120, 0, 0};
+constexpr uint16_t kDefaultMvClass0BitCdf[kBooleanFieldCdfSize] = {5120, 0, 0};
 
-const uint16_t kDefaultMvClass0FractionCdf[kBooleanSymbolCount]
-                                          [kMvFractionSymbolCount + 1] = {
-                                              {16384, 8192, 6144, 0, 0},
-                                              {20480, 11520, 8640, 0, 0}};
+constexpr uint16_t kDefaultMvClass0FractionCdf[kBooleanSymbolCount]
+                                              [kMvFractionSymbolCount + 1] = {
+                                                  {16384, 8192, 6144, 0, 0},
+                                                  {20480, 11520, 8640, 0, 0}};
 
-const uint16_t kDefaultMvClass0HighPrecisionCdf[kBooleanFieldCdfSize] = {12288,
-                                                                         0, 0};
+constexpr uint16_t kDefaultMvClass0HighPrecisionCdf[kBooleanFieldCdfSize] = {
+    12288, 0, 0};
 
-const uint16_t kDefaultMvBitCdf[kMvBitSymbolCount][kBooleanFieldCdfSize] = {
+constexpr uint16_t kDefaultMvBitCdf[kMvBitSymbolCount][kBooleanFieldCdfSize] = {
     {15360, 0, 0}, {14848, 0, 0}, {13824, 0, 0}, {12288, 0, 0}, {10240, 0, 0},
     {8192, 0, 0},  {4096, 0, 0},  {2816, 0, 0},  {2816, 0, 0},  {2048, 0, 0}};
 
-const uint16_t kDefaultMvFractionCdf[kMvFractionSymbolCount + 1] = {
+constexpr uint16_t kDefaultMvFractionCdf[kMvFractionSymbolCount + 1] = {
     24576, 15360, 11520, 0, 0};
 
-const uint16_t kDefaultMvHighPrecisionCdf[kBooleanFieldCdfSize] = {16384, 0, 0};
+constexpr uint16_t kDefaultMvHighPrecisionCdf[kBooleanFieldCdfSize] = {16384, 0,
+                                                                       0};
diff --git a/libgav1/src/threading_strategy.cc b/libgav1/src/threading_strategy.cc
index 5e98da8..2864c34 100644
--- a/libgav1/src/threading_strategy.cc
+++ b/libgav1/src/threading_strategy.cc
@@ -16,15 +16,60 @@
 
 #include <algorithm>
 #include <cassert>
+#include <memory>
 
+#include "src/frame_scratch_buffer.h"
 #include "src/utils/constants.h"
 #include "src/utils/logging.h"
+#include "src/utils/vector.h"
 
 namespace libgav1 {
+namespace {
+
+#if !defined(LIBGAV1_FRAME_PARALLEL_THRESHOLD_MULTIPLIER)
+constexpr int kFrameParallelThresholdMultiplier = 4;
+#else
+constexpr int kFrameParallelThresholdMultiplier =
+    LIBGAV1_FRAME_PARALLEL_THRESHOLD_MULTIPLIER;
+#endif
+
+// Computes the number of frame threads to be used based on the following
+// heuristic:
+//   * If |thread_count| == 1, return 0.
+//   * If |thread_count| <= |tile_count| * 4, return 0.
+//   * Otherwise, return the largest value of i which satisfies the following
+//     condition: i + i * tile_columns <= thread_count. This ensures that there
+//     are at least |tile_columns| worker threads for each frame thread.
+//   * This function will never return 1 or a value > |thread_count|.
+//
+//  This heuristic is based empirical performance data. The in-frame threading
+//  model (combination of tile multithreading, superblock row multithreading and
+//  post filter multithreading) performs better than the frame parallel model
+//  until we reach the threshold of |thread_count| > |tile_count| *
+//  kFrameParallelThresholdMultiplier.
+//
+//  It is a function of |tile_count| since tile threading and superblock row
+//  multithreading will scale only as a factor of |tile_count|. The threshold 4
+//  is arrived at based on empirical data. The general idea is that superblock
+//  row multithreading plateaus at 4 * |tile_count| because in most practical
+//  cases there aren't more than that many superblock rows and columns available
+//  to work on in parallel.
+int ComputeFrameThreadCount(int thread_count, int tile_count,
+                            int tile_columns) {
+  assert(thread_count > 0);
+  if (thread_count == 1) return 0;
+  return (thread_count <= tile_count * kFrameParallelThresholdMultiplier)
+             ? 0
+             : std::max(2, thread_count / (1 + tile_columns));
+}
+
+}  // namespace
 
 bool ThreadingStrategy::Reset(const ObuFrameHeader& frame_header,
                               int thread_count) {
   assert(thread_count > 0);
+  frame_parallel_ = false;
+
   if (thread_count == 1) {
     thread_pool_.reset(nullptr);
     tile_thread_count_ = 0;
@@ -34,7 +79,7 @@
 
   // We do work in the current thread, so it is sufficient to create
   // |thread_count|-1 threads in the threadpool.
-  thread_count = std::min(thread_count - 1, static_cast<int>(kMaxThreads));
+  thread_count = std::min(thread_count, static_cast<int>(kMaxThreads)) - 1;
 
   if (thread_pool_ == nullptr || thread_pool_->num_threads() != thread_count) {
     thread_pool_ = ThreadPool::Create("libgav1", thread_count);
@@ -87,7 +132,7 @@
     thread_count -= 2;
     if (thread_count <= 0) break;
   }
-#else   // !defined(__ANDROID__)
+#else  // !defined(__ANDROID__)
   // Assign the remaining threads to each Tile.
   for (int i = 0; i < tile_count; ++i) {
     const int count = thread_count / tile_count +
@@ -103,4 +148,75 @@
   return true;
 }
 
+bool ThreadingStrategy::Reset(int thread_count) {
+  assert(thread_count > 0);
+  frame_parallel_ = true;
+
+  // In frame parallel mode, we simply access the underlying |thread_pool_|
+  // directly. So ensure all the other threadpool getter functions return
+  // nullptr. Also, superblock row multithreading is always disabled in frame
+  // parallel mode.
+  tile_thread_count_ = 0;
+  max_tile_index_for_row_threads_ = 0;
+
+  if (thread_pool_ == nullptr || thread_pool_->num_threads() != thread_count) {
+    thread_pool_ = ThreadPool::Create("libgav1-fp", thread_count);
+    if (thread_pool_ == nullptr) {
+      LIBGAV1_DLOG(ERROR, "Failed to create a thread pool with %d threads.",
+                   thread_count);
+      return false;
+    }
+  }
+  return true;
+}
+
+bool InitializeThreadPoolsForFrameParallel(
+    int thread_count, int tile_count, int tile_columns,
+    std::unique_ptr<ThreadPool>* const frame_thread_pool,
+    FrameScratchBufferPool* const frame_scratch_buffer_pool) {
+  assert(*frame_thread_pool == nullptr);
+  thread_count = std::min(thread_count, static_cast<int>(kMaxThreads));
+  const int frame_threads =
+      ComputeFrameThreadCount(thread_count, tile_count, tile_columns);
+  if (frame_threads == 0) return true;
+  *frame_thread_pool = ThreadPool::Create(frame_threads);
+  if (*frame_thread_pool == nullptr) {
+    LIBGAV1_DLOG(ERROR, "Failed to create frame thread pool with %d threads.",
+                 frame_threads);
+    return false;
+  }
+  int remaining_threads = thread_count - frame_threads;
+  if (remaining_threads == 0) return true;
+  int threads_per_frame = remaining_threads / frame_threads;
+  const int extra_threads = remaining_threads % frame_threads;
+  Vector<std::unique_ptr<FrameScratchBuffer>> frame_scratch_buffers;
+  if (!frame_scratch_buffers.reserve(frame_threads)) return false;
+  // Create the tile thread pools.
+  for (int i = 0; i < frame_threads && remaining_threads > 0; ++i) {
+    std::unique_ptr<FrameScratchBuffer> frame_scratch_buffer =
+        frame_scratch_buffer_pool->Get();
+    if (frame_scratch_buffer == nullptr) {
+      return false;
+    }
+    // If the number of tile threads cannot be divided equally amongst all the
+    // frame threads, assign one extra thread to the first |extra_threads| frame
+    // threads.
+    const int current_frame_thread_count =
+        threads_per_frame + static_cast<int>(i < extra_threads);
+    if (!frame_scratch_buffer->threading_strategy.Reset(
+            current_frame_thread_count)) {
+      return false;
+    }
+    remaining_threads -= current_frame_thread_count;
+    frame_scratch_buffers.push_back_unchecked(std::move(frame_scratch_buffer));
+  }
+  // We release the frame scratch buffers in reverse order so that the extra
+  // threads are allocated to buffers in the top of the stack.
+  for (int i = static_cast<int>(frame_scratch_buffers.size()) - 1; i >= 0;
+       --i) {
+    frame_scratch_buffer_pool->Release(std::move(frame_scratch_buffers[i]));
+  }
+  return true;
+}
+
 }  // namespace libgav1
diff --git a/libgav1/src/threading_strategy.h b/libgav1/src/threading_strategy.h
index 3f354f0..84b3589 100644
--- a/libgav1/src/threading_strategy.h
+++ b/libgav1/src/threading_strategy.h
@@ -25,6 +25,8 @@
 
 namespace libgav1 {
 
+class FrameScratchBufferPool;
+
 // This class allocates and manages the worker threads among thread pools used
 // for multi-threaded decoding.
 class ThreadingStrategy {
@@ -36,18 +38,28 @@
   ThreadingStrategy& operator=(const ThreadingStrategy&) = delete;
 
   // Creates or re-allocates the thread pools based on the |frame_header| and
-  // |thread_count|. This function is idempotent if the |frame_header| and
-  // |thread_count| doesn't change between calls (it will only create new
-  // threads on the first call and do nothing on the subsequent calls). This
-  // function also starts the worker threads whenever it creates new thread
-  // pools.
+  // |thread_count|. This function is used only in non frame-parallel mode. This
+  // function is idempotent if the |frame_header| and |thread_count| don't
+  // change between calls (it will only create new threads on the first call and
+  // do nothing on the subsequent calls). This function also starts the worker
+  // threads whenever it creates new thread pools.
   // The following strategy is used to allocate threads:
   //   * One thread is allocated for decoding each Tile.
   //   * Any remaining threads are allocated for superblock row multi-threading
   //     within each of the tile in a round robin fashion.
+  // Note: During the lifetime of a ThreadingStrategy object, only one of the
+  // Reset() variants will be used.
   LIBGAV1_MUST_USE_RESULT bool Reset(const ObuFrameHeader& frame_header,
                                      int thread_count);
 
+  // Creates or re-allocates a thread pool with |thread_count| threads. This
+  // function is used only in frame parallel mode. This function is idempotent
+  // if the |thread_count| doesn't change between calls (it will only create new
+  // threads on the first call and do nothing on the subsequent calls).
+  // Note: During the lifetime of a ThreadingStrategy object, only one of the
+  // Reset() variants will be used.
+  LIBGAV1_MUST_USE_RESULT bool Reset(int thread_count);
+
   // Returns a pointer to the ThreadPool that is to be used for Tile
   // multi-threading.
   ThreadPool* tile_thread_pool() const {
@@ -56,8 +68,14 @@
 
   int tile_thread_count() const { return tile_thread_count_; }
 
+  // Returns a pointer to the underlying ThreadPool.
+  // Note: Valid only when |frame_parallel_| is true. This is used for
+  // facilitating in-frame multi-threading in that case.
+  ThreadPool* thread_pool() const { return thread_pool_.get(); }
+
   // Returns a pointer to the ThreadPool that is to be used within the Tile at
   // index |tile_index| for superblock row multi-threading.
+  // Note: Valid only when |frame_parallel_| is false.
   ThreadPool* row_thread_pool(int tile_index) const {
     return tile_index < max_tile_index_for_row_threads_ ? thread_pool_.get()
                                                         : nullptr;
@@ -65,14 +83,49 @@
 
   // Returns a pointer to the ThreadPool that is to be used for post filter
   // multi-threading.
-  ThreadPool* post_filter_thread_pool() const { return thread_pool_.get(); }
+  // Note: Valid only when |frame_parallel_| is false.
+  ThreadPool* post_filter_thread_pool() const {
+    return frame_parallel_ ? nullptr : thread_pool_.get();
+  }
+
+  // Returns a pointer to the ThreadPool that is to be used for film grain
+  // synthesis and blending.
+  // Note: Valid only when |frame_parallel_| is false.
+  ThreadPool* film_grain_thread_pool() const { return thread_pool_.get(); }
 
  private:
   std::unique_ptr<ThreadPool> thread_pool_;
-  int tile_thread_count_;
-  int max_tile_index_for_row_threads_;
+  int tile_thread_count_ = 0;
+  int max_tile_index_for_row_threads_ = 0;
+  bool frame_parallel_ = false;
 };
 
+// Initializes the |frame_thread_pool| and the necessary worker threadpools (the
+// threading_strategy objects in each of the frame scratch buffer in
+// |frame_scratch_buffer_pool|) as follows:
+//  * frame_threads = ComputeFrameThreadCount();
+//  * For more details on how frame_threads is computed, see the function
+//    comment in ComputeFrameThreadCount().
+//  * |frame_thread_pool| is created with |frame_threads| threads.
+//  * divide the remaining number of threads into each frame thread and
+//    initialize a frame_scratch_buffer.threading_strategy for each frame
+//    thread.
+//  When this function is called, |frame_scratch_buffer_pool| must be empty. If
+//  this function returns true, it means the initialization was successful and
+//  one of the following is true:
+//    * |frame_thread_pool| has been successfully initialized and
+//      |frame_scratch_buffer_pool| has been successfully populated with
+//      |frame_threads| buffers to be used by each frame thread. The total
+//      number of threads that this function creates will always be equal to
+//      |thread_count|.
+//    * |frame_thread_pool| is nullptr. |frame_scratch_buffer_pool| is not
+//      modified. This means that frame threading will not be used and the
+//      decoder will continue to operate normally in non frame parallel mode.
+LIBGAV1_MUST_USE_RESULT bool InitializeThreadPoolsForFrameParallel(
+    int thread_count, int tile_count, int tile_columns,
+    std::unique_ptr<ThreadPool>* frame_thread_pool,
+    FrameScratchBufferPool* frame_scratch_buffer_pool);
+
 }  // namespace libgav1
 
 #endif  // LIBGAV1_SRC_THREADING_STRATEGY_H_
diff --git a/libgav1/src/tile.h b/libgav1/src/tile.h
index 2533452..065ef70 100644
--- a/libgav1/src/tile.h
+++ b/libgav1/src/tile.h
@@ -28,17 +28,18 @@
 #include <vector>
 
 #include "src/buffer_pool.h"
-#include "src/decoder_scratch_buffer.h"
+#include "src/decoder_state.h"
 #include "src/dsp/common.h"
 #include "src/dsp/constants.h"
 #include "src/dsp/dsp.h"
-#include "src/loop_filter_mask.h"
+#include "src/frame_scratch_buffer.h"
 #include "src/loop_restoration_info.h"
 #include "src/obu_parser.h"
 #include "src/post_filter.h"
 #include "src/quantizer.h"
 #include "src/residual_buffer_pool.h"
 #include "src/symbol_decoder_context.h"
+#include "src/tile_scratch_buffer.h"
 #include "src/utils/array_2d.h"
 #include "src/utils/block_parameters_holder.h"
 #include "src/utils/blocking_counter.h"
@@ -67,24 +68,24 @@
 
 class Tile : public Allocable {
  public:
-  Tile(int tile_number, const uint8_t* data, size_t size,
-       const ObuSequenceHeader& sequence_header,
-       const ObuFrameHeader& frame_header, RefCountedBuffer* current_frame,
-       const std::array<bool, kNumReferenceFrameTypes>&
-           reference_frame_sign_bias,
-       const std::array<RefCountedBufferPtr, kNumReferenceFrameTypes>&
-           reference_frames,
-       Array2D<TemporalMotionVector>* motion_field_mv,
-       const std::array<uint8_t, kNumReferenceFrameTypes>& reference_order_hint,
-       const std::array<uint8_t, kWedgeMaskSize>& wedge_masks,
-       const SymbolDecoderContext& symbol_decoder_context,
-       SymbolDecoderContext* saved_symbol_decoder_context,
-       const SegmentationMap* prev_segment_ids, PostFilter* post_filter,
-       BlockParametersHolder* block_parameters, Array2D<int16_t>* cdef_index,
-       Array2D<TransformSize>* inter_transform_sizes, const dsp::Dsp* dsp,
-       ThreadPool* thread_pool, ResidualBufferPool* residual_buffer_pool,
-       DecoderScratchBufferPool* decoder_scratch_buffer_pool,
-       BlockingCounterWithStatus* pending_tiles);
+  static std::unique_ptr<Tile> Create(
+      int tile_number, const uint8_t* const data, size_t size,
+      const ObuSequenceHeader& sequence_header,
+      const ObuFrameHeader& frame_header, RefCountedBuffer* const current_frame,
+      const DecoderState& state, FrameScratchBuffer* const frame_scratch_buffer,
+      const WedgeMaskArray& wedge_masks,
+      SymbolDecoderContext* const saved_symbol_decoder_context,
+      const SegmentationMap* prev_segment_ids, PostFilter* const post_filter,
+      const dsp::Dsp* const dsp, ThreadPool* const thread_pool,
+      BlockingCounterWithStatus* const pending_tiles, bool frame_parallel,
+      bool use_intra_prediction_buffer) {
+    std::unique_ptr<Tile> tile(new (std::nothrow) Tile(
+        tile_number, data, size, sequence_header, frame_header, current_frame,
+        state, frame_scratch_buffer, wedge_masks, saved_symbol_decoder_context,
+        prev_segment_ids, post_filter, dsp, thread_pool, pending_tiles,
+        frame_parallel, use_intra_prediction_buffer));
+    return (tile != nullptr && tile->Init()) ? std::move(tile) : nullptr;
+  }
 
   // Move only.
   Tile(Tile&& tile) noexcept;
@@ -94,11 +95,80 @@
 
   struct Block;  // Defined after this class.
 
-  bool Decode(bool is_main_thread);  // 5.11.2.
+  // Parses the entire tile.
+  bool Parse();
+  // Decodes the entire tile. |superblock_row_progress| and
+  // |superblock_row_progress_condvar| are arrays of size equal to the number of
+  // superblock rows in the frame. Increments |superblock_row_progress[i]| after
+  // each superblock row at index |i| is decoded. If the count reaches the
+  // number of tile columns, then it notifies
+  // |superblock_row_progress_condvar[i]|.
+  bool Decode(std::mutex* mutex, int* superblock_row_progress,
+              std::condition_variable* superblock_row_progress_condvar);
+  // Parses and decodes the entire tile. Depending on the configuration of this
+  // Tile, this function may do multithreaded decoding.
+  bool ParseAndDecode();  // 5.11.2.
+  // Processes all the columns of the superblock row at |row4x4| that are within
+  // this Tile. If |save_symbol_decoder_context| is true, then
+  // SaveSymbolDecoderContext() is invoked for the last superblock row.
+  template <ProcessingMode processing_mode, bool save_symbol_decoder_context>
+  bool ProcessSuperBlockRow(int row4x4, TileScratchBuffer* scratch_buffer);
+
   const ObuSequenceHeader& sequence_header() const { return sequence_header_; }
   const ObuFrameHeader& frame_header() const { return frame_header_; }
   const RefCountedBuffer& current_frame() const { return current_frame_; }
-  bool IsInside(int row4x4, int column4x4) const;  // 5.11.51.
+  const TemporalMotionField& motion_field() const { return motion_field_; }
+  const std::array<bool, kNumReferenceFrameTypes>& reference_frame_sign_bias()
+      const {
+    return reference_frame_sign_bias_;
+  }
+
+  bool IsRow4x4Inside(int row4x4) const {
+    return row4x4 >= row4x4_start_ && row4x4 < row4x4_end_;
+  }
+
+  // 5.11.51.
+  bool IsInside(int row4x4, int column4x4) const {
+    return IsRow4x4Inside(row4x4) && column4x4 >= column4x4_start_ &&
+           column4x4 < column4x4_end_;
+  }
+
+  bool IsLeftInside(int column4x4) const {
+    // We use "larger than" as the condition. Don't pass in the left column
+    // offset column4x4 - 1.
+    assert(column4x4 <= column4x4_end_);
+    return column4x4 > column4x4_start_;
+  }
+
+  bool IsTopInside(int row4x4) const {
+    // We use "larger than" as the condition. Don't pass in the top row offset
+    // row4x4 - 1.
+    assert(row4x4 <= row4x4_end_);
+    return row4x4 > row4x4_start_;
+  }
+
+  bool IsTopLeftInside(int row4x4, int column4x4) const {
+    // We use "larger than" as the condition. Don't pass in the top row offset
+    // row4x4 - 1 or the left column offset column4x4 - 1.
+    assert(row4x4 <= row4x4_end_);
+    assert(column4x4 <= column4x4_end_);
+    return row4x4 > row4x4_start_ && column4x4 > column4x4_start_;
+  }
+
+  bool IsBottomRightInside(int row4x4, int column4x4) const {
+    assert(row4x4 >= row4x4_start_);
+    assert(column4x4 >= column4x4_start_);
+    return row4x4 < row4x4_end_ && column4x4 < column4x4_end_;
+  }
+
+  BlockParameters** BlockParametersAddress(int row4x4, int column4x4) const {
+    return block_parameters_holder_.Address(row4x4, column4x4);
+  }
+
+  int BlockParametersStride() const {
+    return block_parameters_holder_.columns4x4();
+  }
+
   // Returns true if Parameters() can be called with |row| and |column| as
   // inputs, false otherwise.
   bool HasParameters(int row, int column) const {
@@ -107,11 +177,26 @@
   const BlockParameters& Parameters(int row, int column) const {
     return *block_parameters_holder_.Find(row, column);
   }
+
   int number() const { return number_; }
   int superblock_rows() const { return superblock_rows_; }
   int superblock_columns() const { return superblock_columns_; }
+  int row4x4_start() const { return row4x4_start_; }
+  int column4x4_start() const { return column4x4_start_; }
+  int column4x4_end() const { return column4x4_end_; }
 
  private:
+  Tile(int tile_number, const uint8_t* data, size_t size,
+       const ObuSequenceHeader& sequence_header,
+       const ObuFrameHeader& frame_header, RefCountedBuffer* current_frame,
+       const DecoderState& state, FrameScratchBuffer* frame_scratch_buffer,
+       const WedgeMaskArray& wedge_masks,
+       SymbolDecoderContext* saved_symbol_decoder_context,
+       const SegmentationMap* prev_segment_ids, PostFilter* post_filter,
+       const dsp::Dsp* dsp, ThreadPool* thread_pool,
+       BlockingCounterWithStatus* pending_tiles, bool frame_parallel,
+       bool use_intra_prediction_buffer);
+
   // Stores the transform tree state when reading variable size transform trees
   // and when applying the transform tree. When applying the transform tree,
   // |depth| is not used.
@@ -163,13 +248,18 @@
   //    every transform block.
   using ResidualPtr = uint8_t*;
 
-  // Performs member initializations that may fail. Called by Decode().
+  // Performs member initializations that may fail. Helper function used by
+  // Create().
   LIBGAV1_MUST_USE_RESULT bool Init();
 
+  // Saves the symbol decoder context of this tile into
+  // |saved_symbol_decoder_context_| if necessary.
+  void SaveSymbolDecoderContext();
+
   // Entry point for multi-threaded decoding. This function performs the same
-  // functionality as Decode(). The current thread does the "parse" step while
-  // the worker threads do the "decode" step.
-  bool ThreadedDecode();
+  // functionality as ParseAndDecode(). The current thread does the "parse" step
+  // while the worker threads do the "decode" step.
+  bool ThreadedParseAndDecode();
 
   // Returns whether or not the prerequisites for decoding the superblock at
   // |row_index| and |column_index| are satisfied. |threading_.mutex| must be
@@ -186,6 +276,12 @@
   // the right of this superblock (if it is allowed).
   void DecodeSuperBlock(int row_index, int column_index, int block_width4x4);
 
+  // If |use_intra_prediction_buffer_| is true, then this function copies the
+  // last row of the superblockrow starting at |row4x4| into the
+  // |intra_prediction_buffer_| (which may be used by the intra prediction
+  // process for the next superblock row).
+  void PopulateIntraPredictionBuffer(int row4x4);
+
   uint16_t* GetPartitionCdf(int row4x4, int column4x4, BlockSize block_size);
   bool ReadPartition(int row4x4, int column4x4, BlockSize block_size,
                      bool has_rows, bool has_columns, Partition* partition);
@@ -194,39 +290,30 @@
   // the blocks in the right order.
   bool ProcessPartition(
       int row4x4_start, int column4x4_start, ParameterTree* root,
-      DecoderScratchBuffer* scratch_buffer,
+      TileScratchBuffer* scratch_buffer,
       ResidualPtr* residual);  // Iterative implementation of 5.11.4.
   bool ProcessBlock(int row4x4, int column4x4, BlockSize block_size,
-                    ParameterTree* tree, DecoderScratchBuffer* scratch_buffer,
+                    ParameterTree* tree, TileScratchBuffer* scratch_buffer,
                     ResidualPtr* residual);   // 5.11.5.
   void ResetCdef(int row4x4, int column4x4);  // 5.11.55.
 
   // This function is used to decode a superblock when the parsing has already
   // been done for that superblock.
-  bool DecodeSuperBlock(ParameterTree* tree,
-                        DecoderScratchBuffer* scratch_buffer,
+  bool DecodeSuperBlock(ParameterTree* tree, TileScratchBuffer* scratch_buffer,
                         ResidualPtr* residual);
   // Helper function used by DecodeSuperBlock(). Note that the decode_block()
   // function in the spec is equivalent to ProcessBlock() in the code.
-  bool DecodeBlock(ParameterTree* tree, DecoderScratchBuffer* scratch_buffer,
+  bool DecodeBlock(ParameterTree* tree, TileScratchBuffer* scratch_buffer,
                    ResidualPtr* residual);
 
-  void ClearBlockDecoded(DecoderScratchBuffer* scratch_buffer, int row4x4,
+  void ClearBlockDecoded(TileScratchBuffer* scratch_buffer, int row4x4,
                          int column4x4);  // 5.11.3.
   bool ProcessSuperBlock(int row4x4, int column4x4, int block_width4x4,
-                         DecoderScratchBuffer* scratch_buffer,
+                         TileScratchBuffer* scratch_buffer,
                          ProcessingMode mode);
   void ResetLoopRestorationParams();
   void ReadLoopRestorationCoefficients(int row4x4, int column4x4,
                                        BlockSize block_size);  // 5.11.57.
-  // Build bit masks for vertical edges followed by horizontal edges.
-  // Traverse through each transform edge in the current coding block, and
-  // determine if a 4x4 edge needs filtering. If filtering is needed, determine
-  // filter length. Set corresponding bit mask to 1.
-  void BuildBitMask(int row4x4, int column4x4, BlockSize block_size);
-  void BuildBitMaskHelper(int row4x4, int column4x4, BlockSize block_size,
-                          bool is_vertical_block_border,
-                          bool is_horizontal_block_border);
 
   // Helper functions for DecodeBlock.
   bool ReadSegmentId(const Block& block);       // 5.11.9.
@@ -261,19 +348,7 @@
   void ReadIsInter(const Block& block);                        // 5.11.20.
   bool ReadIntraBlockModeInfo(const Block& block,
                               bool intra_y_mode);  // 5.11.22.
-  int GetUseCompoundReferenceContext(const Block& block);
   CompoundReferenceType ReadCompoundReferenceType(const Block& block);
-  // Calculates count0 by calling block.CountReferences() on the frame types
-  // from type0_start to type0_end, inclusive, and summing the results.
-  // Calculates count1 by calling block.CountReferences() on the frame types
-  // from type1_start to type1_end, inclusive, and summing the results.
-  // Compares count0 with count1 and returns 0, 1 or 2.
-  //
-  // See count_refs and ref_count_ctx in 8.3.2.
-  int GetReferenceContext(const Block& block, ReferenceFrameType type0_start,
-                          ReferenceFrameType type0_end,
-                          ReferenceFrameType type1_start,
-                          ReferenceFrameType type1_end) const;
   template <bool is_single, bool is_backward, int index>
   uint16_t* GetReferenceCdf(const Block& block, CompoundReferenceType type =
                                                     kNumCompoundReferenceTypes);
@@ -293,7 +368,8 @@
   bool DecodeInterModeInfo(const Block& block);                // 5.11.18.
   bool DecodeModeInfo(const Block& block);                     // 5.11.6.
   bool IsMvValid(const Block& block, bool is_compound) const;  // 6.10.25.
-  bool AssignMv(const Block& block, bool is_compound);         // 5.11.26.
+  bool AssignInterMv(const Block& block, bool is_compound);    // 5.11.26.
+  bool AssignIntraMv(const Block& block);                      // 5.11.26.
   int GetTopTransformWidth(const Block& block, int row4x4, int column4x4,
                            bool ignore_skip);
   int GetLeftTransformHeight(const Block& block, int row4x4, int column4x4,
@@ -303,7 +379,7 @@
   void ReadVariableTransformTree(const Block& block, int row4x4, int column4x4,
                                  TransformSize tx_size);
   void DecodeTransformSize(const Block& block);  // 5.11.16.
-  void ComputePrediction(const Block& block);    // 5.11.33.
+  bool ComputePrediction(const Block& block);    // 5.11.33.
   // |x4| and |y4| are the column and row positions of the 4x4 block. |w4| and
   // |h4| are the width and height in 4x4 units of |tx_size|.
   int GetTransformAllZeroContext(const Block& block, Plane plane,
@@ -316,37 +392,39 @@
                                      int block_y);  // 5.11.40.
   void ReadTransformType(const Block& block, int x4, int y4,
                          TransformSize tx_size);  // 5.11.47.
-  int GetCoeffBaseContextEob(TransformSize tx_size, int index);
-  int GetCoeffBaseContext2D(const int32_t* quantized_buffer,
-                            TransformSize tx_size, int adjusted_tx_width_log2,
-                            uint16_t pos);
-  int GetCoeffBaseContextHorizontal(const int32_t* quantized_buffer,
-                                    TransformSize tx_size,
-                                    int adjusted_tx_width_log2, uint16_t pos);
-  int GetCoeffBaseContextVertical(const int32_t* quantized_buffer,
-                                  TransformSize tx_size,
-                                  int adjusted_tx_width_log2, uint16_t pos);
-  int GetCoeffBaseRangeContext2D(const int32_t* quantized_buffer,
-                                 int adjusted_tx_width_log2, int pos);
-  int GetCoeffBaseRangeContextHorizontal(const int32_t* quantized_buffer,
-                                         int adjusted_tx_width_log2, int pos);
-  int GetCoeffBaseRangeContextVertical(const int32_t* quantized_buffer,
-                                       int adjusted_tx_width_log2, int pos);
+  template <typename ResidualType>
+  void ReadCoeffBase2D(
+      const uint16_t* scan, PlaneType plane_type, TransformSize tx_size,
+      int clamped_tx_size_context, int adjusted_tx_width_log2, int eob,
+      uint16_t coeff_base_cdf[kCoeffBaseContexts][kCoeffBaseSymbolCount + 1],
+      ResidualType* quantized_buffer);
+  template <typename ResidualType>
+  void ReadCoeffBaseHorizontal(
+      const uint16_t* scan, PlaneType plane_type, TransformSize tx_size,
+      int clamped_tx_size_context, int adjusted_tx_width_log2, int eob,
+      uint16_t coeff_base_cdf[kCoeffBaseContexts][kCoeffBaseSymbolCount + 1],
+      ResidualType* quantized_buffer);
+  template <typename ResidualType>
+  void ReadCoeffBaseVertical(
+      const uint16_t* scan, PlaneType plane_type, TransformSize tx_size,
+      int clamped_tx_size_context, int adjusted_tx_width_log2, int eob,
+      uint16_t coeff_base_cdf[kCoeffBaseContexts][kCoeffBaseSymbolCount + 1],
+      ResidualType* quantized_buffer);
   int GetDcSignContext(int x4, int y4, int w4, int h4, Plane plane);
   void SetEntropyContexts(int x4, int y4, int w4, int h4, Plane plane,
                           uint8_t coefficient_level, int8_t dc_category);
   void InterIntraPrediction(
-      uint16_t* prediction[2], ptrdiff_t prediction_stride,
-      const uint8_t* prediction_mask, ptrdiff_t prediction_mask_stride,
+      uint16_t* prediction_0, const uint8_t* prediction_mask,
+      ptrdiff_t prediction_mask_stride,
       const PredictionParameters& prediction_parameters, int prediction_width,
       int prediction_height, int subsampling_x, int subsampling_y,
       uint8_t* dest,
       ptrdiff_t dest_stride);  // Part of section 7.11.3.1 in the spec.
   void CompoundInterPrediction(
-      const Block& block, ptrdiff_t prediction_stride,
+      const Block& block, const uint8_t* prediction_mask,
       ptrdiff_t prediction_mask_stride, int prediction_width,
-      int prediction_height, Plane plane, int subsampling_x, int subsampling_y,
-      int bitdepth, int candidate_row, int candidate_column, uint8_t* dest,
+      int prediction_height, int subsampling_x, int subsampling_y,
+      int candidate_row, int candidate_column, uint8_t* dest,
       ptrdiff_t dest_stride);  // Part of section 7.11.3.1 in the spec.
   GlobalMotion* GetWarpParams(const Block& block, Plane plane,
                               int prediction_width, int prediction_height,
@@ -356,7 +434,7 @@
                               GlobalMotion* global_motion_params,
                               GlobalMotion* local_warp_params)
       const;  // Part of section 7.11.3.1 in the spec.
-  void InterPrediction(const Block& block, Plane plane, int x, int y,
+  bool InterPrediction(const Block& block, Plane plane, int x, int y,
                        int prediction_width, int prediction_height,
                        int candidate_row, int candidate_column,
                        bool* is_local_valid,
@@ -364,48 +442,46 @@
   void ScaleMotionVector(const MotionVector& mv, Plane plane,
                          int reference_frame_index, int x, int y, int* start_x,
                          int* start_y, int* step_x, int* step_y);  // 7.11.3.3.
-  bool GetReferenceBlockPosition(int reference_frame_index, bool is_scaled,
-                                 int width, int height, int ref_start_x,
-                                 int ref_last_x, int ref_start_y,
-                                 int ref_last_y, int start_x, int start_y,
-                                 int step_x, int step_y, int left_border,
-                                 int right_border, int top_border,
-                                 int bottom_border, int* ref_block_start_x,
-                                 int* ref_block_start_y, int* ref_block_end_x,
-                                 int* ref_block_end_y);
+  // If the method returns false, the caller only uses the output parameters
+  // *ref_block_start_x and *ref_block_start_y. If the method returns true, the
+  // caller uses all three output parameters.
+  static bool GetReferenceBlockPosition(
+      int reference_frame_index, bool is_scaled, int width, int height,
+      int ref_start_x, int ref_last_x, int ref_start_y, int ref_last_y,
+      int start_x, int start_y, int step_x, int step_y, int left_border,
+      int right_border, int top_border, int bottom_border,
+      int* ref_block_start_x, int* ref_block_start_y, int* ref_block_end_x);
+
   template <typename Pixel>
   void BuildConvolveBlock(Plane plane, int reference_frame_index,
                           bool is_scaled, int height, int ref_start_x,
                           int ref_last_x, int ref_start_y, int ref_last_y,
                           int step_y, int ref_block_start_x,
                           int ref_block_end_x, int ref_block_start_y,
-                          uint8_t* block_buffer, ptrdiff_t block_stride);
-  void BlockInterPrediction(const Block& block, Plane plane,
+                          uint8_t* block_buffer,
+                          ptrdiff_t convolve_buffer_stride,
+                          ptrdiff_t block_extended_width);
+  bool BlockInterPrediction(const Block& block, Plane plane,
                             int reference_frame_index, const MotionVector& mv,
                             int x, int y, int width, int height,
                             int candidate_row, int candidate_column,
-                            uint16_t* prediction, ptrdiff_t prediction_stride,
-                            int round_bits, bool is_compound,
+                            uint16_t* prediction, bool is_compound,
                             bool is_inter_intra, uint8_t* dest,
                             ptrdiff_t dest_stride);  // 7.11.3.4.
-  void BlockWarpProcess(const Block& block, Plane plane, int index,
+  bool BlockWarpProcess(const Block& block, Plane plane, int index,
                         int block_start_x, int block_start_y, int width,
-                        int height, ptrdiff_t prediction_stride,
-                        GlobalMotion* warp_params, int round_bits,
-                        bool is_compound, bool is_inter_intra, uint8_t* dest,
+                        int height, GlobalMotion* warp_params, bool is_compound,
+                        bool is_inter_intra, uint8_t* dest,
                         ptrdiff_t dest_stride);  // 7.11.3.5.
-  void ObmcBlockPrediction(const Block& block, const MotionVector& mv,
+  bool ObmcBlockPrediction(const Block& block, const MotionVector& mv,
                            Plane plane, int reference_frame_index, int width,
                            int height, int x, int y, int candidate_row,
                            int candidate_column,
-                           ObmcDirection blending_direction, int round_bits);
-  void ObmcPrediction(const Block& block, Plane plane, int width, int height,
-                      int round_bits);  // 7.11.3.9.
-  void DistanceWeightedPrediction(uint16_t* prediction_0,
-                                  ptrdiff_t prediction_stride_0,
-                                  uint16_t* prediction_1,
-                                  ptrdiff_t prediction_stride_1, int width,
-                                  int height, int candidate_row,
+                           ObmcDirection blending_direction);
+  bool ObmcPrediction(const Block& block, Plane plane, int width,
+                      int height);  // 7.11.3.9.
+  void DistanceWeightedPrediction(void* prediction_0, void* prediction_1,
+                                  int width, int height, int candidate_row,
                                   int candidate_column, uint8_t* dest,
                                   ptrdiff_t dest_stride);  // 7.11.3.15.
   // This function specializes the parsing of DC coefficient by removing some of
@@ -414,22 +490,21 @@
   // parameter that is populated when |is_dc_coefficient| is true.
   // |coefficient_level| is an output parameter which accumulates the
   // coefficient level.
-  template <bool is_dc_coefficient>
-  bool ReadSignAndApplyDequantization(
-      const Block& block, int32_t* quantized_buffer, const uint16_t* scan,
-      int i, int adjusted_tx_width_log2, int tx_width, int q_value,
-      const uint8_t* quantizer_matrix, int shift, int min_value, int max_value,
-      uint16_t* dc_sign_cdf, int8_t* dc_category,
-      int* coefficient_level);  // Part of 5.11.39.
+  template <typename ResidualType, bool is_dc_coefficient>
+  LIBGAV1_ALWAYS_INLINE bool ReadSignAndApplyDequantization(
+      const uint16_t* scan, int i, int q_value, const uint8_t* quantizer_matrix,
+      int shift, int max_value, uint16_t* dc_sign_cdf, int8_t* dc_category,
+      int* coefficient_level,
+      ResidualType* residual_buffer);  // Part of 5.11.39.
   int ReadCoeffBaseRange(int clamped_tx_size_context, int cdf_context,
                          int plane_type);  // Part of 5.11.39.
   // Returns the number of non-zero coefficients that were read. |tx_type| is an
   // output parameter that stores the computed transform type for the plane
   // whose coefficients were read. Returns -1 on failure.
-  int16_t ReadTransformCoefficients(const Block& block, Plane plane,
-                                    int start_x, int start_y,
-                                    TransformSize tx_size,
-                                    TransformType* tx_type);  // 5.11.39.
+  template <typename ResidualType>
+  int ReadTransformCoefficients(const Block& block, Plane plane, int start_x,
+                                int start_y, TransformSize tx_size,
+                                TransformType* tx_type);  // 5.11.39.
   bool TransformBlock(const Block& block, Plane plane, int base_x, int base_y,
                       TransformSize tx_size, int x, int y,
                       ProcessingMode mode);  // 5.11.35.
@@ -439,7 +514,7 @@
   void ReconstructBlock(const Block& block, Plane plane, int start_x,
                         int start_y, TransformSize tx_size,
                         TransformType tx_type,
-                        int16_t non_zero_coeff_count);     // Part of 7.12.3.
+                        int non_zero_coeff_count);         // Part of 7.12.3.
   bool Residual(const Block& block, ProcessingMode mode);  // 5.11.34.
   // part of 5.11.5 (reset_block_context() in the spec).
   void ResetEntropyContext(const Block& block);
@@ -460,8 +535,9 @@
                              Plane plane) const;  // 7.11.2.8.
   template <typename Pixel>
   void DirectionalPrediction(const Block& block, Plane plane, int x, int y,
-                             bool has_left, bool has_top, int prediction_angle,
-                             int width, int height, int max_x, int max_y,
+                             bool has_left, bool has_top, bool needs_left,
+                             bool needs_top, int prediction_angle, int width,
+                             int height, int max_x, int max_y,
                              TransformSize tx_size, Pixel* top_row,
                              Pixel* left_column);  // 7.11.2.4.
   template <typename Pixel>
@@ -500,8 +576,8 @@
   }
 
   const int number_;
-  int row_;
-  int column_;
+  const int row_;
+  const int column_;
   const uint8_t* const data_;
   size_t size_;
   int row4x4_start_;
@@ -513,6 +589,8 @@
   bool read_deltas_;
   const int8_t subsampling_x_[kMaxPlanes];
   const int8_t subsampling_y_[kMaxPlanes];
+  int deblock_row_limit_[kMaxPlanes];
+  int deblock_column_limit_[kMaxPlanes];
 
   // The dimensions (in order) are: segment_id, level_index (based on plane and
   // direction), reference_frame and mode_id.
@@ -553,13 +631,12 @@
   std::array<Array2D<int8_t>, 2> dc_categories_;
   const ObuSequenceHeader& sequence_header_;
   const ObuFrameHeader& frame_header_;
-  RefCountedBuffer& current_frame_;
   const std::array<bool, kNumReferenceFrameTypes>& reference_frame_sign_bias_;
   const std::array<RefCountedBufferPtr, kNumReferenceFrameTypes>&
       reference_frames_;
-  Array2D<TemporalMotionVector>* const motion_field_mv_;
+  TemporalMotionField& motion_field_;
   const std::array<uint8_t, kNumReferenceFrameTypes>& reference_order_hint_;
-  const std::array<uint8_t, kWedgeMaskSize>& wedge_masks_;
+  const WedgeMaskArray& wedge_masks_;
   DaalaBitReader reader_;
   SymbolDecoderContext symbol_decoder_context_;
   SymbolDecoderContext* const saved_symbol_decoder_context_;
@@ -584,9 +661,11 @@
   //   2) In Reconstruct(), this buffer is used as the input to the row
   //   transform process.
   // The size of this buffer would be:
-  //    For |residual_buffer_|: 4096 * |residual_size_|. Where 4096 =
-  //        64x64 which is the maximum transform size. This memory is allocated
-  //        and owned by the Tile class.
+  //    For |residual_buffer_|: (4096 + 32 * |kResidualPaddingVertical|) *
+  //        |residual_size_|. Where 4096 = 64x64 which is the maximum transform
+  //        size, and 32 * |kResidualPaddingVertical| is the padding to avoid
+  //        bottom boundary checks when parsing quantized coefficients. This
+  //        memory is allocated and owned by the Tile class.
   //    For |residual_buffer_threaded_|: See the comment below. This memory is
   //        not allocated or owned by the Tile class.
   AlignedUniquePtr<uint8_t> residual_buffer_;
@@ -605,7 +684,22 @@
   // use_128x128_superblock ? 3 : 5. This is the allowed range of reference for
   // the top rows for intrabc.
   const int intra_block_copy_lag_;
+
+  // In the Tile class, we use the "current_frame" in two ways:
+  //   1) To write the decoded output into (using the |buffer_| view).
+  //   2) To read the pixels for intra block copy (using the |current_frame_|
+  //      reference).
+  //
+  // When intra block copy is off, |buffer_| and |current_frame_| may or may not
+  // point to the same plane pointers. But it is okay since |current_frame_| is
+  // never used in this case.
+  //
+  // When intra block copy is on, |buffer_| and |current_frame_| always point to
+  // the same plane pointers (since post filtering is disabled). So the usage in
+  // both case 1 and case 2 remain valid.
   Array2DView<uint8_t> buffer_[kMaxPlanes];
+  RefCountedBuffer& current_frame_;
+
   Array2D<int16_t>& cdef_index_;
   Array2D<TransformSize>& inter_transform_sizes_;
   std::array<RestorationUnitInfo, kMaxPlanes> reference_unit_info_;
@@ -616,7 +710,7 @@
   ThreadPool* const thread_pool_;
   ThreadingParameters threading_;
   ResidualBufferPool* const residual_buffer_pool_;
-  DecoderScratchBufferPool* const decoder_scratch_buffer_pool_;
+  TileScratchBufferPool* const tile_scratch_buffer_pool_;
   BlockingCounterWithStatus* const pending_tiles_;
   bool split_parse_and_decode_;
   // This is used only when |split_parse_and_decode_| is false.
@@ -629,86 +723,99 @@
   int8_t delta_lf_[kFrameLfCount];
   // True if all the values in |delta_lf_| are zero. False otherwise.
   bool delta_lf_all_zero_;
-  bool build_bit_mask_when_parsing_;
+  const bool frame_parallel_;
+  const bool use_intra_prediction_buffer_;
+  // Buffer used to store the unfiltered pixels that are necessary for decoding
+  // the next superblock row (for the intra prediction process). Used only if
+  // |use_intra_prediction_buffer_| is true. The |frame_scratch_buffer| contains
+  // one row buffer for each tile row. This tile will have to use the buffer
+  // corresponding to this tile's row.
+  IntraPredictionBuffer* const intra_prediction_buffer_;
+  // Stores the progress of the reference frames. This will be used to avoid
+  // unnecessary calls into RefCountedBuffer::WaitUntil().
+  std::array<int, kNumReferenceFrameTypes> reference_frame_progress_cache_;
 };
 
 struct Tile::Block {
-  Block(const Tile& tile, int row4x4, int column4x4, BlockSize size,
-        DecoderScratchBuffer* const scratch_buffer, ResidualPtr* residual,
-        BlockParameters* const parameters)
+  Block(const Tile& tile, BlockSize size, int row4x4, int column4x4,
+        TileScratchBuffer* const scratch_buffer, ResidualPtr* residual)
       : tile(tile),
+        size(size),
         row4x4(row4x4),
         column4x4(column4x4),
-        size(size),
-        left_available(tile.IsInside(row4x4, column4x4 - 1)),
-        top_available(tile.IsInside(row4x4 - 1, column4x4)),
-        residual_size{kPlaneResidualSize[size][0][0],
-                      kPlaneResidualSize[size][tile.subsampling_x_[kPlaneU]]
-                                        [tile.subsampling_y_[kPlaneU]]},
-        bp_top(top_available
-                   ? tile.block_parameters_holder_.Find(row4x4 - 1, column4x4)
-                   : nullptr),
-        bp_left(left_available
-                    ? tile.block_parameters_holder_.Find(row4x4, column4x4 - 1)
-                    : nullptr),
-        bp(parameters),
+        width(kBlockWidthPixels[size]),
+        height(kBlockHeightPixels[size]),
+        width4x4(width >> 2),
+        height4x4(height >> 2),
         scratch_buffer(scratch_buffer),
         residual(residual) {
     assert(size != kBlockInvalid);
-    assert(residual_size[kPlaneTypeY] != kBlockInvalid);
+    residual_size[kPlaneY] = kPlaneResidualSize[size][0][0];
+    residual_size[kPlaneU] = residual_size[kPlaneV] =
+        kPlaneResidualSize[size][tile.subsampling_x_[kPlaneU]]
+                          [tile.subsampling_y_[kPlaneU]];
+    assert(residual_size[kPlaneY] != kBlockInvalid);
     if (tile.PlaneCount() > 1) {
-      assert(residual_size[kPlaneTypeUV] != kBlockInvalid);
+      assert(residual_size[kPlaneU] != kBlockInvalid);
+    }
+    if ((row4x4 & 1) == 0 &&
+        (tile.sequence_header_.color_config.subsampling_y & height4x4) == 1) {
+      has_chroma = false;
+    } else if ((column4x4 & 1) == 0 &&
+               (tile.sequence_header_.color_config.subsampling_x & width4x4) ==
+                   1) {
+      has_chroma = false;
+    } else {
+      has_chroma = !tile.sequence_header_.color_config.is_monochrome;
+    }
+    top_available[kPlaneY] = tile.IsTopInside(row4x4);
+    left_available[kPlaneY] = tile.IsLeftInside(column4x4);
+    if (has_chroma) {
+      // top_available[kPlaneU] and top_available[kPlaneV] are valid only if
+      // has_chroma is true.
+      // The next 3 lines are equivalent to:
+      // top_available[kPlaneU] = top_available[kPlaneV] =
+      //     top_available[kPlaneY] &&
+      //     ((tile.sequence_header_.color_config.subsampling_y & height4x4) ==
+      //     0 || tile.IsTopInside(row4x4 - 1));
+      top_available[kPlaneU] = top_available[kPlaneV] = tile.IsTopInside(
+          row4x4 -
+          (tile.sequence_header_.color_config.subsampling_y & height4x4));
+      // left_available[kPlaneU] and left_available[kPlaneV] are valid only if
+      // has_chroma is true.
+      // The next 3 lines are equivalent to:
+      // left_available[kPlaneU] = left_available[kPlaneV] =
+      //     left_available[kPlaneY] &&
+      //     ((tile.sequence_header_.color_config.subsampling_x & width4x4) == 0
+      //      || tile.IsLeftInside(column4x4 - 1));
+      left_available[kPlaneU] = left_available[kPlaneV] = tile.IsLeftInside(
+          column4x4 -
+          (tile.sequence_header_.color_config.subsampling_x & width4x4));
+    }
+    const ptrdiff_t stride = tile.BlockParametersStride();
+    BlockParameters** const bps =
+        tile.BlockParametersAddress(row4x4, column4x4);
+    bp = *bps;
+    // bp_top is valid only if top_available[kPlaneY] is true.
+    if (top_available[kPlaneY]) {
+      bp_top = *(bps - stride);
+    }
+    // bp_left is valid only if left_available[kPlaneY] is true.
+    if (left_available[kPlaneY]) {
+      bp_left = *(bps - 1);
     }
   }
 
-  bool HasChroma() const {
-    if (kNum4x4BlocksHigh[size] == 1 &&
-        tile.sequence_header_.color_config.subsampling_y != 0 &&
-        (row4x4 & 1) == 0) {
-      return false;
-    }
-    if (kNum4x4BlocksWide[size] == 1 &&
-        tile.sequence_header_.color_config.subsampling_x != 0 &&
-        (column4x4 & 1) == 0) {
-      return false;
-    }
-    return !tile.sequence_header_.color_config.is_monochrome;
-  }
-
-  bool TopAvailableChroma() const {
-    if (!HasChroma()) return false;
-    if ((tile.sequence_header_.color_config.subsampling_y &
-         kNum4x4BlocksHigh[size]) == 1) {
-      return tile.IsInside(row4x4 - 2, column4x4);
-    }
-    return top_available;
-  }
-
-  bool LeftAvailableChroma() const {
-    if (!HasChroma()) return false;
-    if ((tile.sequence_header_.color_config.subsampling_x &
-         kNum4x4BlocksWide[size]) == 1) {
-      return tile.IsInside(row4x4, column4x4 - 2);
-    }
-    return left_available;
-  }
+  bool HasChroma() const { return has_chroma; }
 
   // These return values of these group of functions are valid only if the
   // corresponding top_available or left_available is true.
   ReferenceFrameType TopReference(int index) const {
-    const ReferenceFrameType default_type =
-        (index == 0) ? kReferenceFrameIntra : kReferenceFrameNone;
-    return top_available
-               ? tile.Parameters(row4x4 - 1, column4x4).reference_frame[index]
-               : default_type;
+    return bp_top->reference_frame[index];
   }
 
   ReferenceFrameType LeftReference(int index) const {
-    const ReferenceFrameType default_type =
-        (index == 0) ? kReferenceFrameIntra : kReferenceFrameNone;
-    return left_available
-               ? tile.Parameters(row4x4, column4x4 - 1).reference_frame[index]
-               : default_type;
+    return bp_left->reference_frame[index];
   }
 
   bool IsTopIntra() const { return TopReference(0) <= kReferenceFrameIntra; }
@@ -718,10 +825,14 @@
   bool IsLeftSingle() const { return LeftReference(1) <= kReferenceFrameIntra; }
 
   int CountReferences(ReferenceFrameType type) const {
-    return static_cast<int>(TopReference(0) == type) +
-           static_cast<int>(TopReference(1) == type) +
-           static_cast<int>(LeftReference(0) == type) +
-           static_cast<int>(LeftReference(1) == type);
+    return static_cast<int>(top_available[kPlaneY] &&
+                            bp_top->reference_frame[0] == type) +
+           static_cast<int>(top_available[kPlaneY] &&
+                            bp_top->reference_frame[1] == type) +
+           static_cast<int>(left_available[kPlaneY] &&
+                            bp_left->reference_frame[0] == type) +
+           static_cast<int>(left_available[kPlaneY] &&
+                            bp_left->reference_frame[1] == type);
   }
 
   // 7.10.3.
@@ -729,45 +840,61 @@
   // returns true indicating that the block has neighbors that are suitable for
   // use by overlapped motion compensation.
   bool HasOverlappableCandidates() const {
-    if (top_available) {
-      for (int x = column4x4; x < std::min(tile.frame_header_.columns4x4,
-                                           column4x4 + kNum4x4BlocksWide[size]);
-           x += 2) {
-        if (tile.Parameters(row4x4 - 1, x | 1).reference_frame[0] >
-            kReferenceFrameIntra) {
+    const ptrdiff_t stride = tile.BlockParametersStride();
+    BlockParameters** const bps = tile.BlockParametersAddress(0, 0);
+    if (top_available[kPlaneY]) {
+      BlockParameters** bps_top = bps + (row4x4 - 1) * stride + (column4x4 | 1);
+      const int columns = std::min(tile.frame_header_.columns4x4 - column4x4,
+                                   static_cast<int>(width4x4));
+      BlockParameters** const bps_top_end = bps_top + columns;
+      do {
+        if ((*bps_top)->reference_frame[0] > kReferenceFrameIntra) {
           return true;
         }
-      }
+        bps_top += 2;
+      } while (bps_top < bps_top_end);
     }
-    if (left_available) {
-      for (int y = row4x4; y < std::min(tile.frame_header_.rows4x4,
-                                        row4x4 + kNum4x4BlocksHigh[size]);
-           y += 2) {
-        if (tile.Parameters(y | 1, column4x4 - 1).reference_frame[0] >
-            kReferenceFrameIntra) {
+    if (left_available[kPlaneY]) {
+      BlockParameters** bps_left = bps + (row4x4 | 1) * stride + column4x4 - 1;
+      const int rows = std::min(tile.frame_header_.rows4x4 - row4x4,
+                                static_cast<int>(height4x4));
+      BlockParameters** const bps_left_end = bps_left + rows * stride;
+      do {
+        if ((*bps_left)->reference_frame[0] > kReferenceFrameIntra) {
           return true;
         }
-      }
+        bps_left += 2 * stride;
+      } while (bps_left < bps_left_end);
     }
     return false;
   }
 
-  const BlockParameters& parameters() const { return *bp; }
-
   const Tile& tile;
+  bool has_chroma;
+  const BlockSize size;
+  bool top_available[kMaxPlanes];
+  bool left_available[kMaxPlanes];
+  BlockSize residual_size[kMaxPlanes];
   const int row4x4;
   const int column4x4;
-  const BlockSize size;
-  const bool left_available;
-  const bool top_available;
-  const BlockSize residual_size[kNumPlaneTypes];
-  BlockParameters* const bp_top;
-  BlockParameters* const bp_left;
-  BlockParameters* const bp;
-  DecoderScratchBuffer* const scratch_buffer;
+  const int width;
+  const int height;
+  const int width4x4;
+  const int height4x4;
+  const BlockParameters* bp_top;
+  const BlockParameters* bp_left;
+  BlockParameters* bp;
+  TileScratchBuffer* const scratch_buffer;
   ResidualPtr* const residual;
 };
 
+extern template bool
+Tile::ProcessSuperBlockRow<kProcessingModeDecodeOnly, false>(
+    int row4x4, TileScratchBuffer* scratch_buffer);
+extern template bool
+Tile::ProcessSuperBlockRow<kProcessingModeParseAndDecode, true>(
+    int row4x4, TileScratchBuffer* scratch_buffer);
+
 }  // namespace libgav1
 
 #endif  // LIBGAV1_SRC_TILE_H_
diff --git a/libgav1/src/tile/bitstream/mode_info.cc b/libgav1/src/tile/bitstream/mode_info.cc
index a50a195..d73ebed 100644
--- a/libgav1/src/tile/bitstream/mode_info.cc
+++ b/libgav1/src/tile/bitstream/mode_info.cc
@@ -109,29 +109,21 @@
 }
 
 // This is called DrlCtxStack in section 7.10.2.14 of the spec.
-int GetRefMvIndexContext(
-    const CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize], int count,
-    int index) {
-  if (index + 1 >= count ||
-      (ref_mv_stack[index].weight >= kExtraWeightForNearestMvs &&
-       ref_mv_stack[index + 1].weight >= kExtraWeightForNearestMvs)) {
+// In the spec, the weights of all the nearest mvs are incremented by a bonus
+// weight which is larger than any natural weight, and the weights of the mvs
+// are compared with this bonus weight to determine their contexts. We replace
+// this procedure by introducing |nearest_mv_count| in PredictionParameters,
+// which records the count of the nearest mvs. Since all the nearest mvs are in
+// the beginning of the mv stack, the |index| of a mv in the mv stack can be
+// compared with |nearest_mv_count| to get that mv's context.
+int GetRefMvIndexContext(int nearest_mv_count, int index) {
+  if (index + 1 < nearest_mv_count) {
     return 0;
   }
-  if (ref_mv_stack[index].weight >= kExtraWeightForNearestMvs &&
-      ref_mv_stack[index + 1].weight < kExtraWeightForNearestMvs) {
+  if (index + 1 == nearest_mv_count) {
     return 1;
   }
-  if (ref_mv_stack[index].weight < kExtraWeightForNearestMvs &&
-      ref_mv_stack[index + 1].weight < kExtraWeightForNearestMvs) {
-    return 2;
-  }
-  return 0;
-}
-
-// Returns true if the either the width or the height of the block is equal to
-// four.
-bool IsBlockDimension4(BlockSize size) {
-  return size < kBlock8x8 || size == kBlock16x4;
+  return 2;
 }
 
 // Returns true if both the width and height of the block is less than 64.
@@ -139,21 +131,73 @@
   return size <= kBlock32x32 && size != kBlock16x64;
 }
 
+int GetUseCompoundReferenceContext(const Tile::Block& block) {
+  if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
+    if (block.IsTopSingle() && block.IsLeftSingle()) {
+      return static_cast<int>(IsBackwardReference(block.TopReference(0))) ^
+             static_cast<int>(IsBackwardReference(block.LeftReference(0)));
+    }
+    if (block.IsTopSingle()) {
+      return 2 + static_cast<int>(IsBackwardReference(block.TopReference(0)) ||
+                                  block.IsTopIntra());
+    }
+    if (block.IsLeftSingle()) {
+      return 2 + static_cast<int>(IsBackwardReference(block.LeftReference(0)) ||
+                                  block.IsLeftIntra());
+    }
+    return 4;
+  }
+  if (block.top_available[kPlaneY]) {
+    return block.IsTopSingle()
+               ? static_cast<int>(IsBackwardReference(block.TopReference(0)))
+               : 3;
+  }
+  if (block.left_available[kPlaneY]) {
+    return block.IsLeftSingle()
+               ? static_cast<int>(IsBackwardReference(block.LeftReference(0)))
+               : 3;
+  }
+  return 1;
+}
+
+// Calculates count0 by calling block.CountReferences() on the frame types from
+// type0_start to type0_end, inclusive, and summing the results.
+// Calculates count1 by calling block.CountReferences() on the frame types from
+// type1_start to type1_end, inclusive, and summing the results.
+// Compares count0 with count1 and returns 0, 1 or 2.
+//
+// See count_refs and ref_count_ctx in 8.3.2.
+int GetReferenceContext(const Tile::Block& block,
+                        ReferenceFrameType type0_start,
+                        ReferenceFrameType type0_end,
+                        ReferenceFrameType type1_start,
+                        ReferenceFrameType type1_end) {
+  int count0 = 0;
+  int count1 = 0;
+  for (int type = type0_start; type <= type0_end; ++type) {
+    count0 += block.CountReferences(static_cast<ReferenceFrameType>(type));
+  }
+  for (int type = type1_start; type <= type1_end; ++type) {
+    count1 += block.CountReferences(static_cast<ReferenceFrameType>(type));
+  }
+  return (count0 < count1) ? 0 : (count0 == count1 ? 1 : 2);
+}
+
 }  // namespace
 
 bool Tile::ReadSegmentId(const Block& block) {
   int top_left = -1;
-  if (block.top_available && block.left_available) {
+  if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
     top_left =
         block_parameters_holder_.Find(block.row4x4 - 1, block.column4x4 - 1)
             ->segment_id;
   }
   int top = -1;
-  if (block.top_available) {
+  if (block.top_available[kPlaneY]) {
     top = block.bp_top->segment_id;
   }
   int left = -1;
-  if (block.left_available) {
+  if (block.left_available[kPlaneY]) {
     left = block.bp_left->segment_id;
   }
   int pred;
@@ -215,10 +259,10 @@
     return;
   }
   int context = 0;
-  if (block.top_available && block.bp_top->skip) {
+  if (block.top_available[kPlaneY] && block.bp_top->skip) {
     ++context;
   }
-  if (block.left_available && block.bp_left->skip) {
+  if (block.left_available[kPlaneY] && block.bp_left->skip) {
     ++context;
   }
   uint16_t* const skip_cdf = symbol_decoder_context_.skip_cdf[context];
@@ -239,8 +283,11 @@
     return;
   }
   const int context =
-      (block.left_available ? static_cast<int>(block.bp_left->skip_mode) : 0) +
-      (block.top_available ? static_cast<int>(block.bp_top->skip_mode) : 0);
+      (block.left_available[kPlaneY]
+           ? static_cast<int>(block.bp_left->skip_mode)
+           : 0) +
+      (block.top_available[kPlaneY] ? static_cast<int>(block.bp_top->skip_mode)
+                                    : 0);
   bp.skip_mode =
       reader_.ReadSymbol(symbol_decoder_context_.skip_mode_cdf[context]);
 }
@@ -259,11 +306,12 @@
   const int column = DivideBy16(column4x4);
   if (cdef_index_[row][column] == -1) {
     cdef_index_[row][column] =
-        static_cast<int16_t>(reader_.ReadLiteral(frame_header_.cdef.bits));
-    const int width4x4 = kNum4x4BlocksWide[block.size];
-    const int height4x4 = kNum4x4BlocksHigh[block.size];
-    for (int i = row4x4; i < row4x4 + height4x4; i += cdef_size4x4) {
-      for (int j = column4x4; j < column4x4 + width4x4; j += cdef_size4x4) {
+        frame_header_.cdef.bits > 0
+            ? static_cast<int16_t>(reader_.ReadLiteral(frame_header_.cdef.bits))
+            : 0;
+    for (int i = row4x4; i < row4x4 + block.height4x4; i += cdef_size4x4) {
+      for (int j = column4x4; j < column4x4 + block.width4x4;
+           j += cdef_size4x4) {
         cdef_index_[DivideBy16(i)][DivideBy16(j)] = cdef_index_[row][column];
       }
     }
@@ -337,9 +385,10 @@
   uint16_t* cdf;
   if (intra_y_mode) {
     const PredictionMode top_mode =
-        block.top_available ? block.bp_top->y_mode : kPredictionModeDc;
-    const PredictionMode left_mode =
-        block.left_available ? block.bp_left->y_mode : kPredictionModeDc;
+        block.top_available[kPlaneY] ? block.bp_top->y_mode : kPredictionModeDc;
+    const PredictionMode left_mode = block.left_available[kPlaneY]
+                                         ? block.bp_left->y_mode
+                                         : kPredictionModeDc;
     const int top_context = kIntraYModeContext[top_mode];
     const int left_context = kIntraYModeContext[left_mode];
     cdf = symbol_decoder_context_
@@ -398,17 +447,20 @@
   BlockParameters& bp = *block.bp;
   bool chroma_from_luma_allowed;
   if (frame_header_.segmentation.lossless[bp.segment_id]) {
-    chroma_from_luma_allowed = block.residual_size[kPlaneTypeUV] == kBlock4x4;
+    chroma_from_luma_allowed = block.residual_size[kPlaneU] == kBlock4x4;
   } else {
     chroma_from_luma_allowed = IsBlockDimensionLessThan64(block.size);
   }
   uint16_t* const cdf =
       symbol_decoder_context_
           .uv_mode_cdf[static_cast<int>(chroma_from_luma_allowed)][bp.y_mode];
-  const int symbol_count =
-      kIntraPredictionModesUV - static_cast<int>(!chroma_from_luma_allowed);
-  bp.uv_mode =
-      static_cast<PredictionMode>(reader_.ReadSymbol(cdf, symbol_count));
+  if (chroma_from_luma_allowed) {
+    bp.uv_mode = static_cast<PredictionMode>(
+        reader_.ReadSymbol<kIntraPredictionModesUV>(cdf));
+  } else {
+    bp.uv_mode = static_cast<PredictionMode>(
+        reader_.ReadSymbol<kIntraPredictionModesUV - 1>(cdf));
+  }
 }
 
 int Tile::ReadMotionVectorComponent(const Block& block, const int component) {
@@ -463,11 +515,11 @@
                          static_cast<int>(kNumMvJointTypes)));
   if (mv_joint == kMvJointTypeHorizontalZeroVerticalNonZero ||
       mv_joint == kMvJointTypeNonZero) {
-    bp.mv[index].mv[0] = ReadMotionVectorComponent(block, 0);
+    bp.mv.mv[index].mv[0] = ReadMotionVectorComponent(block, 0);
   }
   if (mv_joint == kMvJointTypeHorizontalNonZeroVerticalZero ||
       mv_joint == kMvJointTypeNonZero) {
-    bp.mv[index].mv[1] = ReadMotionVectorComponent(block, 1);
+    bp.mv.mv[index].mv[1] = ReadMotionVectorComponent(block, 1);
   }
 }
 
@@ -529,11 +581,9 @@
     bp.palette_mode_info.size[kPlaneTypeUV] = 0;
     bp.interpolation_filter[0] = kInterpolationFilterBilinear;
     bp.interpolation_filter[1] = kInterpolationFilterBilinear;
-    FindMvStack(block, /*is_compound=*/false, reference_frame_sign_bias_,
-                *motion_field_mv_, prediction_parameters.ref_mv_stack,
-                &prediction_parameters.ref_mv_count, /*contexts=*/nullptr,
-                prediction_parameters.global_mv);
-    return AssignMv(block, /*is_compound=*/false);
+    MvContexts dummy_mode_contexts;
+    FindMvStack(block, /*is_compound=*/false, &dummy_mode_contexts);
+    return AssignIntraMv(block);
   }
   bp.is_inter = false;
   return ReadIntraBlockModeInfo(block, /*intra_y_mode=*/true);
@@ -545,9 +595,9 @@
   if (prev_segment_ids_ == nullptr) return 0;
 
   const int x_limit = std::min(frame_header_.columns4x4 - block.column4x4,
-                               static_cast<int>(kNum4x4BlocksWide[block.size]));
+                               static_cast<int>(block.width4x4));
   const int y_limit = std::min(frame_header_.rows4x4 - block.row4x4,
-                               static_cast<int>(kNum4x4BlocksHigh[block.size]));
+                               static_cast<int>(block.height4x4));
   int8_t id = 7;
   for (int y = 0; y < y_limit; ++y) {
     for (int x = 0; x < x_limit; ++x) {
@@ -580,10 +630,10 @@
   }
   if (frame_header_.segmentation.temporal_update) {
     const int context =
-        (block.left_available
+        (block.left_available[kPlaneY]
              ? static_cast<int>(block.bp_left->use_predicted_segment_id)
              : 0) +
-        (block.top_available
+        (block.top_available[kPlaneY]
              ? static_cast<int>(block.bp_top->use_predicted_segment_id)
              : 0);
     bp.use_predicted_segment_id = reader_.ReadSymbol(
@@ -616,13 +666,14 @@
     return;
   }
   int context = 0;
-  if (block.top_available && block.left_available) {
+  if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
     context = (block.IsTopIntra() && block.IsLeftIntra())
                   ? 3
                   : static_cast<int>(block.IsTopIntra() || block.IsLeftIntra());
-  } else if (block.top_available || block.left_available) {
-    context = 2 * static_cast<int>(block.top_available ? block.IsTopIntra()
-                                                       : block.IsLeftIntra());
+  } else if (block.top_available[kPlaneY] || block.left_available[kPlaneY]) {
+    context = 2 * static_cast<int>(block.top_available[kPlaneY]
+                                       ? block.IsTopIntra()
+                                       : block.IsLeftIntra());
   }
   bp.is_inter =
       reader_.ReadSymbol(symbol_decoder_context_.is_inter_cdf[context]);
@@ -646,41 +697,12 @@
   return true;
 }
 
-int Tile::GetUseCompoundReferenceContext(const Block& block) {
-  if (block.top_available && block.left_available) {
-    if (block.IsTopSingle() && block.IsLeftSingle()) {
-      return static_cast<int>(IsBackwardReference(block.TopReference(0))) ^
-             static_cast<int>(IsBackwardReference(block.LeftReference(0)));
-    }
-    if (block.IsTopSingle()) {
-      return 2 + static_cast<int>(IsBackwardReference(block.TopReference(0)) ||
-                                  block.IsTopIntra());
-    }
-    if (block.IsLeftSingle()) {
-      return 2 + static_cast<int>(IsBackwardReference(block.LeftReference(0)) ||
-                                  block.IsLeftIntra());
-    }
-    return 4;
-  }
-  if (block.top_available) {
-    return block.IsTopSingle()
-               ? static_cast<int>(IsBackwardReference(block.TopReference(0)))
-               : 3;
-  }
-  if (block.left_available) {
-    return block.IsLeftSingle()
-               ? static_cast<int>(IsBackwardReference(block.LeftReference(0)))
-               : 3;
-  }
-  return 1;
-}
-
 CompoundReferenceType Tile::ReadCompoundReferenceType(const Block& block) {
   // compound and inter.
-  const bool top_comp_inter =
-      block.top_available && !block.IsTopIntra() && !block.IsTopSingle();
-  const bool left_comp_inter =
-      block.left_available && !block.IsLeftIntra() && !block.IsLeftSingle();
+  const bool top_comp_inter = block.top_available[kPlaneY] &&
+                              !block.IsTopIntra() && !block.IsTopSingle();
+  const bool left_comp_inter = block.left_available[kPlaneY] &&
+                               !block.IsLeftIntra() && !block.IsLeftSingle();
   // unidirectional compound.
   const bool top_uni_comp =
       top_comp_inter && IsSameDirectionReferencePair(block.TopReference(0),
@@ -689,8 +711,8 @@
       left_comp_inter && IsSameDirectionReferencePair(block.LeftReference(0),
                                                       block.LeftReference(1));
   int context;
-  if (block.top_available && !block.IsTopIntra() && block.left_available &&
-      !block.IsLeftIntra()) {
+  if (block.top_available[kPlaneY] && !block.IsTopIntra() &&
+      block.left_available[kPlaneY] && !block.IsLeftIntra()) {
     const int same_direction = static_cast<int>(IsSameDirectionReferencePair(
         block.TopReference(0), block.LeftReference(0)));
     if (!top_comp_inter && !left_comp_inter) {
@@ -710,7 +732,7 @@
                           (block.LeftReference(0) == kReferenceFrameBackward));
       }
     }
-  } else if (block.top_available && block.left_available) {
+  } else if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
     if (top_comp_inter) {
       context = 1 + MultiplyBy2(static_cast<int>(top_uni_comp));
     } else if (left_comp_inter) {
@@ -729,22 +751,6 @@
       symbol_decoder_context_.compound_reference_type_cdf[context]));
 }
 
-int Tile::GetReferenceContext(const Block& block,
-                              ReferenceFrameType type0_start,
-                              ReferenceFrameType type0_end,
-                              ReferenceFrameType type1_start,
-                              ReferenceFrameType type1_end) const {
-  int count0 = 0;
-  int count1 = 0;
-  for (int type = type0_start; type <= type0_end; ++type) {
-    count0 += block.CountReferences(static_cast<ReferenceFrameType>(type));
-  }
-  for (int type = type1_start; type <= type1_end; ++type) {
-    count1 += block.CountReferences(static_cast<ReferenceFrameType>(type));
-  }
-  return (count0 < count1) ? 0 : (count0 == count1 ? 1 : 2);
-}
-
 template <bool is_single, bool is_backward, int index>
 uint16_t* Tile::GetReferenceCdf(
     const Block& block,
@@ -826,11 +832,9 @@
     bp.reference_frame[1] = kReferenceFrameNone;
     return;
   }
-  const int block_width4x4 = kNum4x4BlocksWide[block.size];
-  const int block_height4x4 = kNum4x4BlocksHigh[block.size];
   const bool use_compound_reference =
       frame_header_.reference_mode_select &&
-      std::min(block_width4x4, block_height4x4) >= 2 &&
+      std::min(block.width4x4, block.height4x4) >= 2 &&
       reader_.ReadSymbol(symbol_decoder_context_.use_compound_reference_cdf
                              [GetUseCompoundReferenceContext(block)]);
   if (use_compound_reference) {
@@ -982,12 +986,11 @@
       static_cast<int>(kPredictionModeHasNearMvMask.Contains(bp.y_mode));
   prediction_parameters.ref_mv_index = start;
   for (int i = start; i < start + 2; ++i) {
-    if (prediction_parameters.ref_mv_count <= i + 1) continue;
+    if (prediction_parameters.ref_mv_count <= i + 1) break;
     // drl_mode in the spec.
     const bool ref_mv_index_bit = reader_.ReadSymbol(
         symbol_decoder_context_.ref_mv_index_cdf[GetRefMvIndexContext(
-            prediction_parameters.ref_mv_stack,
-            prediction_parameters.ref_mv_count, i)]);
+            prediction_parameters.nearest_mv_count, i)]);
     prediction_parameters.ref_mv_index = i + static_cast<int>(ref_mv_index_bit);
     if (!ref_mv_index_bit) return;
   }
@@ -1084,14 +1087,14 @@
 
 uint16_t* Tile::GetIsExplicitCompoundTypeCdf(const Block& block) {
   int context = 0;
-  if (block.top_available) {
+  if (block.top_available[kPlaneY]) {
     if (!block.IsTopSingle()) {
       context += static_cast<int>(block.bp_top->is_explicit_compound_type);
     } else if (block.TopReference(0) == kReferenceFrameAlternate) {
       context += 3;
     }
   }
-  if (block.left_available) {
+  if (block.left_available[kPlaneY]) {
     if (!block.IsLeftSingle()) {
       context += static_cast<int>(block.bp_left->is_explicit_compound_type);
     } else if (block.LeftReference(0) == kReferenceFrameAlternate) {
@@ -1104,23 +1107,20 @@
 
 uint16_t* Tile::GetIsCompoundTypeAverageCdf(const Block& block) {
   const BlockParameters& bp = *block.bp;
-  const int forward = std::abs(GetRelativeDistance(
-      current_frame_.order_hint(bp.reference_frame[0]),
-      frame_header_.order_hint, sequence_header_.enable_order_hint,
-      sequence_header_.order_hint_bits));
-  const int backward = std::abs(GetRelativeDistance(
-      current_frame_.order_hint(bp.reference_frame[1]),
-      frame_header_.order_hint, sequence_header_.enable_order_hint,
-      sequence_header_.order_hint_bits));
+  const ReferenceInfo& reference_info = *current_frame_.reference_info();
+  const int forward =
+      std::abs(reference_info.relative_distance_from[bp.reference_frame[0]]);
+  const int backward =
+      std::abs(reference_info.relative_distance_from[bp.reference_frame[1]]);
   int context = (forward == backward) ? 3 : 0;
-  if (block.top_available) {
+  if (block.top_available[kPlaneY]) {
     if (!block.IsTopSingle()) {
       context += static_cast<int>(block.bp_top->is_compound_type_average);
     } else if (block.TopReference(0) == kReferenceFrameAlternate) {
       ++context;
     }
   }
-  if (block.left_available) {
+  if (block.left_available[kPlaneY]) {
     if (!block.IsLeftSingle()) {
       context += static_cast<int>(block.bp_left->is_compound_type_average);
     } else if (block.LeftReference(0) == kReferenceFrameAlternate) {
@@ -1200,14 +1200,14 @@
                 MultiplyBy4(static_cast<int>(bp.reference_frame[1] >
                                              kReferenceFrameIntra));
   int top_type = kNumExplicitInterpolationFilters;
-  if (block.top_available) {
+  if (block.top_available[kPlaneY]) {
     if (block.bp_top->reference_frame[0] == bp.reference_frame[0] ||
         block.bp_top->reference_frame[1] == bp.reference_frame[0]) {
       top_type = block.bp_top->interpolation_filter[direction];
     }
   }
   int left_type = kNumExplicitInterpolationFilters;
-  if (block.left_available) {
+  if (block.left_available[kPlaneY]) {
     if (block.bp_left->reference_frame[0] == bp.reference_frame[0] ||
         block.bp_left->reference_frame[1] == bp.reference_frame[0]) {
       left_type = block.bp_left->interpolation_filter[direction];
@@ -1273,16 +1273,11 @@
   bp.palette_mode_info.size[kPlaneTypeUV] = 0;
   ReadReferenceFrames(block);
   const bool is_compound = bp.reference_frame[1] > kReferenceFrameIntra;
-  PredictionParameters& prediction_parameters =
-      *block.bp->prediction_parameters;
   MvContexts mode_contexts;
-  FindMvStack(block, is_compound, reference_frame_sign_bias_, *motion_field_mv_,
-              prediction_parameters.ref_mv_stack,
-              &prediction_parameters.ref_mv_count, &mode_contexts,
-              prediction_parameters.global_mv);
+  FindMvStack(block, is_compound, &mode_contexts);
   ReadInterPredictionModeY(block, mode_contexts);
   ReadRefMvIndex(block);
-  if (!AssignMv(block, is_compound)) return false;
+  if (!AssignInterMv(block, is_compound)) return false;
   ReadInterIntraMode(block, is_compound);
   ReadMotionMode(block, is_compound);
   ReadCompoundType(block, is_compound);
diff --git a/libgav1/src/tile/bitstream/palette.cc b/libgav1/src/tile/bitstream/palette.cc
index ee75839..674d210 100644
--- a/libgav1/src/tile/bitstream/palette.cc
+++ b/libgav1/src/tile/bitstream/palette.cc
@@ -34,21 +34,21 @@
 int Tile::GetPaletteCache(const Block& block, PlaneType plane_type,
                           uint16_t* const cache) {
   const int top_size =
-      (block.top_available && Mod64(MultiplyBy4(block.row4x4)) != 0)
+      (block.top_available[kPlaneY] && Mod64(MultiplyBy4(block.row4x4)) != 0)
           ? block.bp_top->palette_mode_info.size[plane_type]
           : 0;
-  const int left_size = block.left_available
+  const int left_size = block.left_available[kPlaneY]
                             ? block.bp_left->palette_mode_info.size[plane_type]
                             : 0;
   if (left_size == 0 && top_size == 0) return 0;
   // Merge the left and top colors in sorted order and store them in |cache|.
   uint16_t dummy[1];
-  uint16_t* top = (top_size > 0)
-                      ? block.bp_top->palette_mode_info.color[plane_type]
+  const uint16_t* top = (top_size > 0)
+                            ? block.bp_top->palette_mode_info.color[plane_type]
+                            : dummy;
+  const uint16_t* left =
+      (left_size > 0) ? block.bp_left->palette_mode_info.color[plane_type]
                       : dummy;
-  uint16_t* left = (left_size > 0)
-                       ? block.bp_left->palette_mode_info.color[plane_type]
-                       : dummy;
   std::merge(top, top + top_size, left, left + left_size, cache);
   // Deduplicate the entries in |cache| and return the number of unique
   // entries.
@@ -140,10 +140,10 @@
       k4x4WidthLog2[block.size] + k4x4HeightLog2[block.size] - 2;
   if (bp.y_mode == kPredictionModeDc) {
     const int context =
-        static_cast<int>(block.top_available &&
+        static_cast<int>(block.top_available[kPlaneY] &&
                          block.bp_top->palette_mode_info.size[kPlaneTypeY] >
                              0) +
-        static_cast<int>(block.left_available &&
+        static_cast<int>(block.left_available[kPlaneY] &&
                          block.bp_left->palette_mode_info.size[kPlaneTypeY] >
                              0);
     const bool has_palette_y = reader_.ReadSymbol(
@@ -252,8 +252,8 @@
        ++plane_type) {
     const int palette_size = palette_mode_info.size[plane_type];
     if (palette_size == 0) continue;
-    int block_height = kBlockHeightPixels[block.size];
-    int block_width = kBlockWidthPixels[block.size];
+    int block_height = block.height;
+    int block_width = block.width;
     int screen_height = std::min(
         block_height, MultiplyBy4(frame_header_.rows4x4 - block.row4x4));
     int screen_width = std::min(
diff --git a/libgav1/src/tile/bitstream/partition.cc b/libgav1/src/tile/bitstream/partition.cc
index 84e8cf0..60899a2 100644
--- a/libgav1/src/tile/bitstream/partition.cc
+++ b/libgav1/src/tile/bitstream/partition.cc
@@ -26,53 +26,57 @@
 namespace libgav1 {
 namespace {
 
-uint16_t InverseCdfProbability(uint16_t probability) {
-  return kCdfMaxProbability - probability;
-}
-
-uint16_t CdfElementProbability(const uint16_t* const cdf, uint8_t element) {
-  return (element > 0 ? cdf[element - 1] : uint16_t{kCdfMaxProbability}) -
-         cdf[element];
-}
-
-void PartitionCdfGatherHorizontalAlike(const uint16_t* const partition_cdf,
-                                       BlockSize block_size,
-                                       uint16_t* const cdf) {
-  cdf[0] = kCdfMaxProbability;
-  cdf[0] -= CdfElementProbability(partition_cdf, kPartitionHorizontal);
-  cdf[0] -= CdfElementProbability(partition_cdf, kPartitionSplit);
-  cdf[0] -=
-      CdfElementProbability(partition_cdf, kPartitionHorizontalWithTopSplit);
-  cdf[0] -=
-      CdfElementProbability(partition_cdf, kPartitionHorizontalWithBottomSplit);
-  cdf[0] -=
-      CdfElementProbability(partition_cdf, kPartitionVerticalWithLeftSplit);
+uint16_t PartitionCdfGatherHorizontalAlike(const uint16_t* const partition_cdf,
+                                           BlockSize block_size) {
+  // The spec computes the cdf value using the following formula (not writing
+  // partition_cdf[] and using short forms for partition names for clarity):
+  //   cdf = None - H + V - S + S - HTS + HTS - HBS + HBS - VLS;
+  //   if (block_size != 128x128) {
+  //     cdf += VRS - H4;
+  //   }
+  // After canceling out the repeated terms with opposite signs, we have:
+  //   cdf = None - H + V - VLS;
+  //   if (block_size != 128x128) {
+  //     cdf += VRS - H4;
+  //   }
+  uint16_t cdf = partition_cdf[kPartitionNone] -
+                 partition_cdf[kPartitionHorizontal] +
+                 partition_cdf[kPartitionVertical] -
+                 partition_cdf[kPartitionVerticalWithLeftSplit];
   if (block_size != kBlock128x128) {
-    cdf[0] -= CdfElementProbability(partition_cdf, kPartitionHorizontal4);
+    cdf += partition_cdf[kPartitionVerticalWithRightSplit] -
+           partition_cdf[kPartitionHorizontal4];
   }
-  cdf[0] = InverseCdfProbability(cdf[0]);
-  cdf[1] = 0;
-  cdf[2] = 0;
+  return cdf;
 }
 
-void PartitionCdfGatherVerticalAlike(const uint16_t* const partition_cdf,
-                                     BlockSize block_size,
-                                     uint16_t* const cdf) {
-  cdf[0] = kCdfMaxProbability;
-  cdf[0] -= CdfElementProbability(partition_cdf, kPartitionVertical);
-  cdf[0] -= CdfElementProbability(partition_cdf, kPartitionSplit);
-  cdf[0] -=
-      CdfElementProbability(partition_cdf, kPartitionVerticalWithLeftSplit);
-  cdf[0] -=
-      CdfElementProbability(partition_cdf, kPartitionVerticalWithRightSplit);
-  cdf[0] -=
-      CdfElementProbability(partition_cdf, kPartitionHorizontalWithTopSplit);
+uint16_t PartitionCdfGatherVerticalAlike(const uint16_t* const partition_cdf,
+                                         BlockSize block_size) {
+  // The spec computes the cdf value using the following formula (not writing
+  // partition_cdf[] and using short forms for partition names for clarity):
+  //   cdf = H - V + V - S + HBS - VLS + VLS - VRS + S - HTS;
+  //   if (block_size != 128x128) {
+  //     cdf += H4 - V4;
+  //   }
+  // V4 is always zero. So, after canceling out the repeated terms with opposite
+  // signs, we have:
+  //   cdf = H + HBS - VRS - HTS;
+  //   if (block_size != 128x128) {
+  //     cdf += H4;
+  //   }
+  // VRS is zero for 128x128 blocks. So, further simplifying we have:
+  //   cdf = H + HBS - HTS;
+  //   if (block_size != 128x128) {
+  //     cdf += H4 - VRS;
+  //   }
+  uint16_t cdf = partition_cdf[kPartitionHorizontal] +
+                 partition_cdf[kPartitionHorizontalWithBottomSplit] -
+                 partition_cdf[kPartitionHorizontalWithTopSplit];
   if (block_size != kBlock128x128) {
-    cdf[0] -= CdfElementProbability(partition_cdf, kPartitionVertical4);
+    cdf += partition_cdf[kPartitionHorizontal4] -
+           partition_cdf[kPartitionVerticalWithRightSplit];
   }
-  cdf[0] = InverseCdfProbability(cdf[0]);
-  cdf[1] = 0;
-  cdf[2] = 0;
+  return cdf;
 }
 
 }  // namespace
@@ -81,13 +85,13 @@
                                 BlockSize block_size) {
   const int block_size_log2 = k4x4WidthLog2[block_size];
   int top = 0;
-  if (IsInside(row4x4 - 1, column4x4)) {
+  if (IsTopInside(row4x4)) {
     top = static_cast<int>(
         k4x4WidthLog2[block_parameters_holder_.Find(row4x4 - 1, column4x4)
                           ->size] < block_size_log2);
   }
   int left = 0;
-  if (IsInside(row4x4, column4x4 - 1)) {
+  if (IsLeftInside(column4x4)) {
     left = static_cast<int>(
         k4x4HeightLog2[block_parameters_holder_.Find(row4x4, column4x4 - 1)
                            ->size] < block_size_log2);
@@ -116,17 +120,25 @@
     const int bsize_log2 = k4x4WidthLog2[block_size];
     // The partition block size should be 8x8 or above.
     assert(bsize_log2 > 0);
-    const int cdf_size = SymbolDecoderContext::PartitionCdfSize(bsize_log2);
-    *partition =
-        static_cast<Partition>(reader_.ReadSymbol(partition_cdf, cdf_size));
+    if (bsize_log2 == 1) {
+      *partition = static_cast<Partition>(
+          reader_.ReadSymbol<kPartitionSplit + 1>(partition_cdf));
+    } else if (bsize_log2 == 5) {
+      *partition = static_cast<Partition>(
+          reader_.ReadSymbol<kPartitionVerticalWithRightSplit + 1>(
+              partition_cdf));
+    } else {
+      *partition = static_cast<Partition>(
+          reader_.ReadSymbol<kMaxPartitionTypes>(partition_cdf));
+    }
   } else if (has_columns) {
-    uint16_t cdf[3];
-    PartitionCdfGatherVerticalAlike(partition_cdf, block_size, cdf);
+    uint16_t cdf[3] = {
+        PartitionCdfGatherVerticalAlike(partition_cdf, block_size), 0, 0};
     *partition = reader_.ReadSymbolWithoutCdfUpdate(cdf) ? kPartitionSplit
                                                          : kPartitionHorizontal;
   } else {
-    uint16_t cdf[3];
-    PartitionCdfGatherHorizontalAlike(partition_cdf, block_size, cdf);
+    uint16_t cdf[3] = {
+        PartitionCdfGatherHorizontalAlike(partition_cdf, block_size), 0, 0};
     *partition = reader_.ReadSymbolWithoutCdfUpdate(cdf) ? kPartitionSplit
                                                          : kPartitionVertical;
   }
diff --git a/libgav1/src/tile/bitstream/transform_size.cc b/libgav1/src/tile/bitstream/transform_size.cc
index 1d95fca..c5ee757 100644
--- a/libgav1/src/tile/bitstream/transform_size.cc
+++ b/libgav1/src/tile/bitstream/transform_size.cc
@@ -70,7 +70,7 @@
 int Tile::GetTopTransformWidth(const Block& block, int row4x4, int column4x4,
                                bool ignore_skip) {
   if (row4x4 == block.row4x4) {
-    if (!block.top_available) return 64;
+    if (!block.top_available[kPlaneY]) return 64;
     const BlockParameters& bp_top =
         *block_parameters_holder_.Find(row4x4 - 1, column4x4);
     if ((ignore_skip || bp_top.skip) && bp_top.is_inter) {
@@ -83,7 +83,7 @@
 int Tile::GetLeftTransformHeight(const Block& block, int row4x4, int column4x4,
                                  bool ignore_skip) {
   if (column4x4 == block.column4x4) {
-    if (!block.left_available) return 64;
+    if (!block.left_available[kPlaneY]) return 64;
     const BlockParameters& bp_left =
         *block_parameters_holder_.Find(row4x4, column4x4 - 1);
     if ((ignore_skip || bp_left.skip) && bp_left.is_inter) {
@@ -107,11 +107,11 @@
   const int max_tx_width = kTransformWidth[max_rect_tx_size];
   const int max_tx_height = kTransformHeight[max_rect_tx_size];
   const int top_width =
-      block.top_available
+      block.top_available[kPlaneY]
           ? GetTopTransformWidth(block, block.row4x4, block.column4x4, true)
           : 0;
   const int left_height =
-      block.left_available
+      block.left_available[kPlaneY]
           ? GetLeftTransformHeight(block, block.row4x4, block.column4x4, true)
           : 0;
   const auto context = static_cast<int>(top_width >= max_tx_width) +
@@ -130,8 +130,7 @@
 
 void Tile::ReadVariableTransformTree(const Block& block, int row4x4,
                                      int column4x4, TransformSize tx_size) {
-  const uint8_t pixels =
-      std::max(kBlockWidthPixels[block.size], kBlockHeightPixels[block.size]);
+  const uint8_t pixels = std::max(block.width, block.height);
   const TransformSize max_tx_size = GetSquareTransformSize(pixels);
   const int context_delta = (kNumSquareTransformSizes - 1 -
                              TransformSizeToSquareTransformIndex(max_tx_size)) *
@@ -142,7 +141,7 @@
   Stack<TransformTreeNode, 7> stack;
   stack.Push(TransformTreeNode(column4x4, row4x4, tx_size, 0));
 
-  while (!stack.Empty()) {
+  do {
     TransformTreeNode node = stack.Pop();
     const int tx_width4x4 = kTransformWidth4x4[node.tx_size];
     const int tx_height4x4 = kTransformHeight4x4[node.tx_size];
@@ -190,12 +189,10 @@
     }
     block_parameters_holder_.Find(node.y, node.x)->transform_size =
         node.tx_size;
-  }
+  } while (!stack.Empty());
 }
 
 void Tile::DecodeTransformSize(const Block& block) {
-  const int block_width4x4 = kNum4x4BlocksWide[block.size];
-  const int block_height4x4 = kNum4x4BlocksHigh[block.size];
   BlockParameters& bp = *block.bp;
   if (frame_header_.tx_mode == kTxModeSelect && block.size > kBlock4x4 &&
       bp.is_inter && !bp.skip &&
@@ -203,19 +200,19 @@
     const TransformSize max_tx_size = kMaxTransformSizeRectangle[block.size];
     const int tx_width4x4 = kTransformWidth4x4[max_tx_size];
     const int tx_height4x4 = kTransformHeight4x4[max_tx_size];
-    for (int row = block.row4x4; row < block.row4x4 + block_height4x4;
+    for (int row = block.row4x4; row < block.row4x4 + block.height4x4;
          row += tx_height4x4) {
       for (int column = block.column4x4;
-           column < block.column4x4 + block_width4x4; column += tx_width4x4) {
+           column < block.column4x4 + block.width4x4; column += tx_width4x4) {
         ReadVariableTransformTree(block, row, column, max_tx_size);
       }
     }
   } else {
     bp.transform_size = ReadFixedTransformSize(block);
-    for (int row = block.row4x4; row < block.row4x4 + block_height4x4; ++row) {
+    for (int row = block.row4x4; row < block.row4x4 + block.height4x4; ++row) {
       static_assert(sizeof(TransformSize) == 1, "");
       memset(&inter_transform_sizes_[row][block.column4x4], bp.transform_size,
-             block_width4x4);
+             block.width4x4);
     }
   }
 }
diff --git a/libgav1/src/tile/prediction.cc b/libgav1/src/tile/prediction.cc
index 5232b71..a234a19 100644
--- a/libgav1/src/tile/prediction.cc
+++ b/libgav1/src/tile/prediction.cc
@@ -42,34 +42,21 @@
 namespace libgav1 {
 namespace {
 
-constexpr int kObmcBufferSize = 4096;  // 64x64
+// Import all the constants in the anonymous namespace.
+#include "src/inter_intra_masks.inc"
+
 constexpr int kAngleStep = 3;
 constexpr int kPredictionModeToAngle[kIntraPredictionModesUV] = {
     0, 90, 180, 45, 135, 113, 157, 203, 67, 0, 0, 0, 0};
 
-enum : uint8_t {
-  kNeedsLeft = 1,
-  kNeedsTop = 2,
-};
-
-// The values for directional and dc modes are not used since the left/top
-// requirement for those modes depend on the prediction angle and the type of dc
-// mode.
-constexpr BitMaskSet kPredictionModeNeedsMask[kIntraPredictionModesY] = {
-    BitMaskSet(0),                      // kPredictionModeDc
-    BitMaskSet(kNeedsTop),              // kPredictionModeVertical
-    BitMaskSet(kNeedsLeft),             // kPredictionModeHorizontal
-    BitMaskSet(kNeedsTop),              // kPredictionModeD45
-    BitMaskSet(kNeedsLeft, kNeedsTop),  // kPredictionModeD135
-    BitMaskSet(kNeedsLeft, kNeedsTop),  // kPredictionModeD113
-    BitMaskSet(kNeedsLeft, kNeedsTop),  // kPredictionModeD157
-    BitMaskSet(kNeedsLeft),             // kPredictionModeD203
-    BitMaskSet(kNeedsTop),              // kPredictionModeD67
-    BitMaskSet(kNeedsLeft, kNeedsTop),  // kPredictionModeSmooth
-    BitMaskSet(kNeedsLeft, kNeedsTop),  // kPredictionModeSmoothVertical
-    BitMaskSet(kNeedsLeft, kNeedsTop),  // kPredictionModeSmoothHorizontal
-    BitMaskSet(kNeedsLeft, kNeedsTop)   // kPredictionModePaeth
-};
+// The following modes need both the left_column and top_row for intra
+// prediction. For directional modes left/top requirement is inferred based on
+// the prediction angle. For Dc modes, left/top requirement is inferred based on
+// whether or not left/top is available.
+constexpr BitMaskSet kNeedsLeftAndTop(kPredictionModeSmooth,
+                                      kPredictionModeSmoothHorizontal,
+                                      kPredictionModeSmoothVertical,
+                                      kPredictionModePaeth);
 
 int16_t GetDirectionalIntraPredictorDerivative(const int angle) {
   assert(angle >= 3);
@@ -93,6 +80,13 @@
          static_cast<int>(block_size >= kBlock32x8);
 }
 
+// Maps a dimension of 4, 8, 16 and 32 to indices 0, 1, 2 and 3 respectively.
+int GetInterIntraMaskLookupIndex(int dimension) {
+  assert(dimension == 4 || dimension == 8 || dimension == 16 ||
+         dimension == 32);
+  return FloorLog2(dimension) - 2;
+}
+
 // 7.11.2.9.
 int GetIntraEdgeFilterStrength(int width, int height, int filter_type,
                                int delta) {
@@ -134,7 +128,9 @@
 bool DoIntraEdgeUpsampling(int width, int height, int filter_type, int delta) {
   const int sum = width + height;
   delta = std::abs(delta);
-  if (delta == 0 || delta >= 40) return false;
+  // This function should not be called when the prediction angle is 90 or 180.
+  assert(delta != 0);
+  if (delta >= 40) return false;
   return (filter_type == 1) ? sum <= 8 : sum <= 16;
 }
 
@@ -167,26 +163,6 @@
   }
 }
 
-template <int bitdepth, typename Pixel>
-void ClipPrediction(const uint16_t* prediction,
-                    const ptrdiff_t prediction_stride, const int width,
-                    const int height, uint8_t* clipped_prediction,
-                    ptrdiff_t clipped_prediction_stride) {
-  // An offset to cancel offsets used in compound predictor generation that
-  // make intermediate computations non negative.
-  const int single_round_offset = (1 << bitdepth) + (1 << (bitdepth - 1));
-  auto* clipped_pred = reinterpret_cast<Pixel*>(clipped_prediction);
-  clipped_prediction_stride /= sizeof(Pixel);
-  for (int y = 0; y < height; ++y) {
-    for (int x = 0; x < width; ++x) {
-      clipped_pred[x] = static_cast<Pixel>(
-          Clip3(prediction[x] - single_round_offset, 0, (1 << bitdepth) - 1));
-    }
-    prediction += prediction_stride;
-    clipped_pred += clipped_prediction_stride;
-  }
-}
-
 dsp::IntraPredictor GetIntraPredictor(PredictionMode mode, bool has_left,
                                       bool has_top) {
   if (mode == kPredictionModeDc) {
@@ -215,17 +191,6 @@
   }
 }
 
-// 7.11.3.2. Note InterRoundBits0 is derived in the dsp layer.
-int GetInterRoundingBits(const bool is_compound, const int bitdepth) {
-  if (is_compound) return 7;
-#if LIBGAV1_MAX_BITDEPTH == 12
-  if (bitdepth == 12) return 9;
-#else
-  static_cast<void>(bitdepth);
-#endif
-  return 11;
-}
-
 uint8_t* GetStartPoint(Array2DView<uint8_t>* const buffer, const int plane,
                        const int x, const int y, const int bitdepth) {
 #if LIBGAV1_MAX_BITDEPTH >= 10
@@ -244,30 +209,14 @@
   return (start + step * offset) >> kScaleSubPixelBits;
 }
 
-dsp::MaskBlendFunc GetMaskBlendFunc(const dsp::Dsp& dsp,
-                                    InterIntraMode inter_intra_mode,
+dsp::MaskBlendFunc GetMaskBlendFunc(const dsp::Dsp& dsp, bool is_inter_intra,
                                     bool is_wedge_inter_intra,
                                     int subsampling_x, int subsampling_y) {
-  const int is_inter_intra =
-      static_cast<int>(inter_intra_mode != kNumInterIntraModes);
-  return (is_inter_intra == 1 && !is_wedge_inter_intra)
-             ? dsp.mask_blend[0][is_inter_intra]
+  return (is_inter_intra && !is_wedge_inter_intra)
+             ? dsp.mask_blend[0][/*is_inter_intra=*/true]
              : dsp.mask_blend[subsampling_x + subsampling_y][is_inter_intra];
 }
 
-void PopulatePredictionMaskFromWedgeMask(const uint8_t* wedge_mask,
-                                         int wedge_mask_stride,
-                                         int prediction_width,
-                                         int prediction_height,
-                                         uint8_t* prediction_mask,
-                                         int prediction_mask_stride) {
-  for (int y = 0; y < prediction_height; ++y) {
-    memcpy(prediction_mask, wedge_mask, prediction_width);
-    prediction_mask += prediction_mask_stride;
-    wedge_mask += wedge_mask_stride;
-  }
-}
-
 }  // namespace
 
 template <typename Pixel>
@@ -281,8 +230,15 @@
   const int y_shift = subsampling_y_[plane];
   const int max_x = (MultiplyBy4(frame_header_.columns4x4) >> x_shift) - 1;
   const int max_y = (MultiplyBy4(frame_header_.rows4x4) >> y_shift) - 1;
-  alignas(kMaxAlignment) Pixel top_row_data[160] = {};
-  alignas(kMaxAlignment) Pixel left_column_data[160] = {};
+  // For performance reasons, do not initialize the following two buffers.
+  alignas(kMaxAlignment) Pixel top_row_data[160];
+  alignas(kMaxAlignment) Pixel left_column_data[160];
+#if LIBGAV1_MSAN
+  if (IsDirectionalMode(mode)) {
+    memset(top_row_data, 0, sizeof(top_row_data));
+    memset(left_column_data, 0, sizeof(left_column_data));
+  }
+#endif
   // Some predictors use |top_row_data| and |left_column_data| with a negative
   // offset to access pixels to the top-left of the current block. So have some
   // space before the arrays to allow populating those without having to move
@@ -302,71 +258,114 @@
                 prediction_parameters.angle_delta[GetPlaneType(plane)] *
                     kAngleStep
           : 0;
-  const bool needs_top = use_filter_intra ||
-                         kPredictionModeNeedsMask[mode].Contains(kNeedsTop) ||
-                         (is_directional_mode && prediction_angle < 180) ||
-                         (mode == kPredictionModeDc && has_top);
+  // Directional prediction requires buffers larger than the width or height.
+  const int top_size = is_directional_mode ? top_and_left_size : width;
+  const int left_size = is_directional_mode ? top_and_left_size : height;
+  const int top_right_size =
+      is_directional_mode ? (has_top_right ? 2 : 1) * width : width;
+  const int bottom_left_size =
+      is_directional_mode ? (has_bottom_left ? 2 : 1) * height : height;
+
   Array2DView<Pixel> buffer(buffer_[plane].rows(),
                             buffer_[plane].columns() / sizeof(Pixel),
                             reinterpret_cast<Pixel*>(&buffer_[plane][0][0]));
+  const bool needs_top = use_filter_intra || kNeedsLeftAndTop.Contains(mode) ||
+                         (is_directional_mode && prediction_angle < 180) ||
+                         (mode == kPredictionModeDc && has_top);
+  const bool needs_left = use_filter_intra || kNeedsLeftAndTop.Contains(mode) ||
+                          (is_directional_mode && prediction_angle > 90) ||
+                          (mode == kPredictionModeDc && has_left);
+
+  const Pixel* top_row_src = buffer[y - 1];
+
+  // Determine if we need to retrieve the top row from
+  // |intra_prediction_buffer_|.
+  if ((needs_top || needs_left) && use_intra_prediction_buffer_) {
+    // Superblock index of block.row4x4. block.row4x4 is always in luma
+    // dimension (no subsampling).
+    const int current_superblock_index =
+        block.row4x4 >> (sequence_header_.use_128x128_superblock ? 5 : 4);
+    // Superblock index of y - 1. y is in the plane dimension (chroma planes
+    // could be subsampled).
+    const int plane_shift = (sequence_header_.use_128x128_superblock ? 7 : 6) -
+                            subsampling_y_[plane];
+    const int top_row_superblock_index = (y - 1) >> plane_shift;
+    // If the superblock index of y - 1 is not that of the current superblock,
+    // then we will have to retrieve the top row from the
+    // |intra_prediction_buffer_|.
+    if (current_superblock_index != top_row_superblock_index) {
+      top_row_src = reinterpret_cast<const Pixel*>(
+          (*intra_prediction_buffer_)[plane].get());
+    }
+  }
+
   if (needs_top) {
     // Compute top_row.
-    top_row[-1] = (has_top || has_left)
-                      ? buffer[has_top ? y - 1 : y][has_left ? x - 1 : x]
-                      : (1 << (bitdepth - 1));
-    if (!has_top && has_left) {
-      Memset(top_row, buffer[y][x - 1], top_and_left_size);
-    } else if (!has_top && !has_left) {
-      Memset(top_row, (1 << (bitdepth - 1)) - 1, top_and_left_size);
+    if (has_top || has_left) {
+      const int left_index = has_left ? x - 1 : x;
+      top_row[-1] = has_top ? top_row_src[left_index] : buffer[y][left_index];
     } else {
-      const int top_limit =
-          std::min(max_x, x - 1 + ((has_top_right ? 2 : 1) * width));
-      for (int i = 0; i < top_and_left_size; ++i) {
-        top_row[i] = buffer[y - 1][std::min(top_limit, x + i)];
+      top_row[-1] = 1 << (bitdepth - 1);
+    }
+    if (!has_top && has_left) {
+      Memset(top_row, buffer[y][x - 1], top_size);
+    } else if (!has_top && !has_left) {
+      Memset(top_row, (1 << (bitdepth - 1)) - 1, top_size);
+    } else {
+      const int top_limit = std::min(max_x - x + 1, top_right_size);
+      memcpy(top_row, &top_row_src[x], top_limit * sizeof(Pixel));
+      // Even though it is safe to call Memset with a size of 0, accessing
+      // top_row_src[top_limit - x + 1] is not allowed when this condition is
+      // false.
+      if (top_size - top_limit > 0) {
+        Memset(top_row + top_limit, top_row_src[top_limit + x - 1],
+               top_size - top_limit);
       }
     }
   }
-  const bool needs_left = use_filter_intra ||
-                          kPredictionModeNeedsMask[mode].Contains(kNeedsLeft) ||
-                          (is_directional_mode && prediction_angle > 90) ||
-                          (mode == kPredictionModeDc && has_left);
   if (needs_left) {
     // Compute left_column.
-    left_column[-1] = (has_top || has_left)
-                          ? buffer[has_top ? y - 1 : y][has_left ? x - 1 : x]
-                          : (1 << (bitdepth - 1));
-    if (!has_left && has_top) {
-      Memset(left_column, buffer[y - 1][x], top_and_left_size);
-    } else if (!has_left && !has_top) {
-      Memset(left_column, (1 << (bitdepth - 1)) + 1, top_and_left_size);
+    if (has_top || has_left) {
+      const int left_index = has_left ? x - 1 : x;
+      left_column[-1] =
+          has_top ? top_row_src[left_index] : buffer[y][left_index];
     } else {
-      const int left_limit =
-          std::min(max_y, y - 1 + ((has_bottom_left ? 2 : 1) * height));
-      for (int i = 0; i < top_and_left_size; ++i) {
-        left_column[i] = buffer[std::min(left_limit, y + i)][x - 1];
+      left_column[-1] = 1 << (bitdepth - 1);
+    }
+    if (!has_left && has_top) {
+      Memset(left_column, top_row_src[x], left_size);
+    } else if (!has_left && !has_top) {
+      Memset(left_column, (1 << (bitdepth - 1)) + 1, left_size);
+    } else {
+      const int left_limit = std::min(max_y - y + 1, bottom_left_size);
+      for (int i = 0; i < left_limit; ++i) {
+        left_column[i] = buffer[y + i][x - 1];
+      }
+      // Even though it is safe to call Memset with a size of 0, accessing
+      // buffer[left_limit - y + 1][x - 1] is not allowed when this condition is
+      // false.
+      if (left_size - left_limit > 0) {
+        Memset(left_column + left_limit, buffer[left_limit + y - 1][x - 1],
+               left_size - left_limit);
       }
     }
   }
   Pixel* const dest = &buffer[y][x];
   const ptrdiff_t dest_stride = buffer_[plane].columns();
   if (use_filter_intra) {
-    dsp_.filter_intra_predictor(reinterpret_cast<uint8_t*>(dest), dest_stride,
-                                reinterpret_cast<uint8_t*>(top_row),
-                                reinterpret_cast<uint8_t*>(left_column),
+    dsp_.filter_intra_predictor(dest, dest_stride, top_row, left_column,
                                 prediction_parameters.filter_intra_mode, width,
                                 height);
   } else if (is_directional_mode) {
-    DirectionalPrediction(block, plane, x, y, has_left, has_top,
-                          prediction_angle, width, height, max_x, max_y,
-                          tx_size, top_row, left_column);
+    DirectionalPrediction(block, plane, x, y, has_left, has_top, needs_left,
+                          needs_top, prediction_angle, width, height, max_x,
+                          max_y, tx_size, top_row, left_column);
   } else {
     const dsp::IntraPredictor predictor =
         GetIntraPredictor(mode, has_left, has_top);
     assert(predictor != dsp::kNumIntraPredictors);
-    dsp_.intra_predictors[tx_size][predictor](
-        reinterpret_cast<uint8_t*>(dest), dest_stride,
-        reinterpret_cast<uint8_t*>(top_row),
-        reinterpret_cast<uint8_t*>(left_column));
+    dsp_.intra_predictors[tx_size][predictor](dest, dest_stride, top_row,
+                                              left_column);
   }
 }
 
@@ -404,8 +403,7 @@
 int Tile::GetIntraEdgeFilterType(const Block& block, Plane plane) const {
   const int subsampling_x = subsampling_x_[plane];
   const int subsampling_y = subsampling_y_[plane];
-  if ((plane == kPlaneY && block.top_available) ||
-      (plane != kPlaneY && block.TopAvailableChroma())) {
+  if (block.top_available[plane]) {
     const int row =
         block.row4x4 - 1 -
         static_cast<int>(subsampling_y != 0 && (block.row4x4 & 1) != 0);
@@ -414,8 +412,7 @@
         static_cast<int>(subsampling_x != 0 && (block.column4x4 & 1) == 0);
     if (IsSmoothPrediction(row, column, plane)) return 1;
   }
-  if ((plane == kPlaneY && block.left_available) ||
-      (plane != kPlaneY && block.LeftAvailableChroma())) {
+  if (block.left_available[plane]) {
     const int row = block.row4x4 + static_cast<int>(subsampling_y != 0 &&
                                                     (block.row4x4 & 1) == 0);
     const int column =
@@ -428,88 +425,87 @@
 
 template <typename Pixel>
 void Tile::DirectionalPrediction(const Block& block, Plane plane, int x, int y,
-                                 bool has_left, bool has_top,
-                                 int prediction_angle, int width, int height,
-                                 int max_x, int max_y, TransformSize tx_size,
-                                 Pixel* const top_row,
+                                 bool has_left, bool has_top, bool needs_left,
+                                 bool needs_top, int prediction_angle,
+                                 int width, int height, int max_x, int max_y,
+                                 TransformSize tx_size, Pixel* const top_row,
                                  Pixel* const left_column) {
+  Array2DView<Pixel> buffer(buffer_[plane].rows(),
+                            buffer_[plane].columns() / sizeof(Pixel),
+                            reinterpret_cast<Pixel*>(&buffer_[plane][0][0]));
+  Pixel* const dest = &buffer[y][x];
+  const ptrdiff_t stride = buffer_[plane].columns();
+  if (prediction_angle == 90) {
+    dsp_.intra_predictors[tx_size][dsp::kIntraPredictorVertical](
+        dest, stride, top_row, left_column);
+    return;
+  }
+  if (prediction_angle == 180) {
+    dsp_.intra_predictors[tx_size][dsp::kIntraPredictorHorizontal](
+        dest, stride, top_row, left_column);
+    return;
+  }
+
   bool upsampled_top = false;
   bool upsampled_left = false;
   if (sequence_header_.enable_intra_edge_filter) {
     const int filter_type = GetIntraEdgeFilterType(block, plane);
-    if (prediction_angle != 90 && prediction_angle != 180) {
-      if (prediction_angle > 90 && prediction_angle < 180 &&
-          (width + height) >= 24) {
-        // 7.11.2.7.
-        left_column[-1] = top_row[-1] = RightShiftWithRounding(
-            left_column[0] * 5 + top_row[-1] * 6 + top_row[0] * 5, 4);
+    if (prediction_angle > 90 && prediction_angle < 180 &&
+        (width + height) >= 24) {
+      // 7.11.2.7.
+      left_column[-1] = top_row[-1] = RightShiftWithRounding(
+          left_column[0] * 5 + top_row[-1] * 6 + top_row[0] * 5, 4);
+    }
+    if (has_top && needs_top) {
+      const int strength = GetIntraEdgeFilterStrength(
+          width, height, filter_type, prediction_angle - 90);
+      if (strength > 0) {
+        const int num_pixels = std::min(width, max_x - x + 1) +
+                               ((prediction_angle < 90) ? height : 0) + 1;
+        dsp_.intra_edge_filter(top_row - 1, num_pixels, strength);
       }
-      if (has_top) {
-        const int strength = GetIntraEdgeFilterStrength(
-            width, height, filter_type, prediction_angle - 90);
-        if (strength > 0) {
-          const int num_pixels = std::min(width, max_x - x + 1) +
-                                 ((prediction_angle < 90) ? height : 0) + 1;
-          dsp_.intra_edge_filter(top_row - 1, num_pixels, strength);
-        }
-      }
-      if (has_left) {
-        const int strength = GetIntraEdgeFilterStrength(
-            width, height, filter_type, prediction_angle - 180);
-        if (strength > 0) {
-          const int num_pixels = std::min(height, max_y - y + 1) +
-                                 ((prediction_angle > 180) ? width : 0) + 1;
-          dsp_.intra_edge_filter(left_column - 1, num_pixels, strength);
-        }
+    }
+    if (has_left && needs_left) {
+      const int strength = GetIntraEdgeFilterStrength(
+          width, height, filter_type, prediction_angle - 180);
+      if (strength > 0) {
+        const int num_pixels = std::min(height, max_y - y + 1) +
+                               ((prediction_angle > 180) ? width : 0) + 1;
+        dsp_.intra_edge_filter(left_column - 1, num_pixels, strength);
       }
     }
     upsampled_top = DoIntraEdgeUpsampling(width, height, filter_type,
                                           prediction_angle - 90);
-    if (upsampled_top) {
+    if (upsampled_top && needs_top) {
       const int num_pixels = width + ((prediction_angle < 90) ? height : 0);
       dsp_.intra_edge_upsampler(top_row, num_pixels);
     }
     upsampled_left = DoIntraEdgeUpsampling(width, height, filter_type,
                                            prediction_angle - 180);
-    if (upsampled_left) {
+    if (upsampled_left && needs_left) {
       const int num_pixels = height + ((prediction_angle > 180) ? width : 0);
       dsp_.intra_edge_upsampler(left_column, num_pixels);
     }
   }
-  Array2DView<Pixel> buffer(buffer_[plane].rows(),
-                            buffer_[plane].columns() / sizeof(Pixel),
-                            reinterpret_cast<Pixel*>(&buffer_[plane][0][0]));
-  auto* const dest = reinterpret_cast<uint8_t* const>(&buffer[y][x]);
-  const ptrdiff_t stride = buffer_[plane].columns();
-  if (prediction_angle == 90) {
-    dsp_.intra_predictors[tx_size][dsp::kIntraPredictorVertical](
-        dest, stride, reinterpret_cast<uint8_t*>(top_row),
-        reinterpret_cast<uint8_t*>(left_column));
-  } else if (prediction_angle == 180) {
-    dsp_.intra_predictors[tx_size][dsp::kIntraPredictorHorizontal](
-        dest, stride, reinterpret_cast<uint8_t*>(top_row),
-        reinterpret_cast<uint8_t*>(left_column));
-  } else if (prediction_angle < 90) {
+
+  if (prediction_angle < 90) {
     const int dx = GetDirectionalIntraPredictorDerivative(prediction_angle);
-    dsp_.directional_intra_predictor_zone1(dest, stride,
-                                           reinterpret_cast<uint8_t*>(top_row),
-                                           width, height, dx, upsampled_top);
+    dsp_.directional_intra_predictor_zone1(dest, stride, top_row, width, height,
+                                           dx, upsampled_top);
   } else if (prediction_angle < 180) {
     const int dx =
         GetDirectionalIntraPredictorDerivative(180 - prediction_angle);
     const int dy =
         GetDirectionalIntraPredictorDerivative(prediction_angle - 90);
-    dsp_.directional_intra_predictor_zone2(
-        dest, stride, reinterpret_cast<uint8_t*>(top_row),
-        reinterpret_cast<uint8_t*>(left_column), width, height, dx, dy,
-        upsampled_top, upsampled_left);
+    dsp_.directional_intra_predictor_zone2(dest, stride, top_row, left_column,
+                                           width, height, dx, dy, upsampled_top,
+                                           upsampled_left);
   } else {
     assert(prediction_angle < 270);
     const int dy =
         GetDirectionalIntraPredictorDerivative(270 - prediction_angle);
-    dsp_.directional_intra_predictor_zone3(
-        dest, stride, reinterpret_cast<uint8_t*>(left_column), width, height,
-        dy, upsampled_left);
+    dsp_.directional_intra_predictor_zone3(dest, stride, left_column, width,
+                                           height, dy, upsampled_left);
   }
 }
 
@@ -588,8 +584,7 @@
 #endif
 
 void Tile::InterIntraPrediction(
-    uint16_t* prediction[2], const ptrdiff_t prediction_stride,
-    const uint8_t* const prediction_mask,
+    uint16_t* const prediction_0, const uint8_t* const prediction_mask,
     const ptrdiff_t prediction_mask_stride,
     const PredictionParameters& prediction_parameters,
     const int prediction_width, const int prediction_height,
@@ -602,81 +597,74 @@
              kCompoundPredictionTypeWedge);
   // The first buffer of InterIntra is from inter prediction.
   // The second buffer is from intra prediction.
-  ptrdiff_t intra_stride;
-  const int bitdepth = sequence_header_.color_config.bitdepth;
-  if (bitdepth == 8) {
-    // Both the input predictors must be of type uint16_t. For bitdepth ==
-    // 8, |buffer_| is uint8_t and hence a copy has to be made. For higher
-    // bitdepths, the |buffer_| itself can act as an uint16_t buffer so no
-    // copy is necessary.
-    uint8_t* dest_ptr = dest;
-    Array2DView<uint16_t> intra_prediction(
-        kMaxSuperBlockSizeInPixels, kMaxSuperBlockSizeInPixels, prediction[1]);
-    for (int r = 0; r < prediction_height; ++r) {
-      for (int c = 0; c < prediction_width; ++c) {
-        intra_prediction[r][c] = dest_ptr[c];
-      }
-      dest_ptr += dest_stride;
-    }
-    intra_stride = kMaxSuperBlockSizeInPixels;
-  } else {
-    prediction[1] = reinterpret_cast<uint16_t*>(dest);
-    intra_stride = dest_stride / sizeof(uint16_t);
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  if (sequence_header_.color_config.bitdepth > 8) {
+    GetMaskBlendFunc(dsp_, /*is_inter_intra=*/true,
+                     prediction_parameters.is_wedge_inter_intra, subsampling_x,
+                     subsampling_y)(
+        prediction_0, reinterpret_cast<uint16_t*>(dest),
+        dest_stride / sizeof(uint16_t), prediction_mask, prediction_mask_stride,
+        prediction_width, prediction_height, dest, dest_stride);
+    return;
   }
-  GetMaskBlendFunc(dsp_, prediction_parameters.inter_intra_mode,
-                   prediction_parameters.is_wedge_inter_intra, subsampling_x,
-                   subsampling_y)(prediction[0], prediction_stride,
-                                  prediction[1], intra_stride, prediction_mask,
-                                  prediction_mask_stride, prediction_width,
-                                  prediction_height, dest, dest_stride);
+#endif
+  const int function_index = prediction_parameters.is_wedge_inter_intra
+                                 ? subsampling_x + subsampling_y
+                                 : 0;
+  // |is_inter_intra| prediction values are stored in a Pixel buffer but it is
+  // currently declared as a uint16_t buffer.
+  // TODO(johannkoenig): convert the prediction buffer to a uint8_t buffer and
+  // remove the reinterpret_cast.
+  dsp_.inter_intra_mask_blend_8bpp[function_index](
+      reinterpret_cast<uint8_t*>(prediction_0), dest, dest_stride,
+      prediction_mask, prediction_mask_stride, prediction_width,
+      prediction_height);
 }
 
 void Tile::CompoundInterPrediction(
-    const Block& block, const ptrdiff_t prediction_stride,
+    const Block& block, const uint8_t* const prediction_mask,
     const ptrdiff_t prediction_mask_stride, const int prediction_width,
-    const int prediction_height, const Plane plane, const int subsampling_x,
-    const int subsampling_y, const int bitdepth, const int candidate_row,
+    const int prediction_height, const int subsampling_x,
+    const int subsampling_y, const int candidate_row,
     const int candidate_column, uint8_t* dest, const ptrdiff_t dest_stride) {
   const PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
-  uint16_t* prediction[2] = {block.scratch_buffer->prediction_buffer[0],
-                             block.scratch_buffer->prediction_buffer[1]};
+
+  void* prediction[2];
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  const int bitdepth = sequence_header_.color_config.bitdepth;
+  if (bitdepth > 8) {
+    prediction[0] = block.scratch_buffer->prediction_buffer[0];
+    prediction[1] = block.scratch_buffer->prediction_buffer[1];
+  } else {
+#endif
+    prediction[0] = block.scratch_buffer->compound_prediction_buffer_8bpp[0];
+    prediction[1] = block.scratch_buffer->compound_prediction_buffer_8bpp[1];
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  }
+#endif
+
   switch (prediction_parameters.compound_prediction_type) {
     case kCompoundPredictionTypeWedge:
-      GetMaskBlendFunc(dsp_, prediction_parameters.inter_intra_mode,
-                       prediction_parameters.is_wedge_inter_intra,
-                       subsampling_x, subsampling_y)(
-          prediction[0], prediction_stride, prediction[1], prediction_stride,
-          block.scratch_buffer->prediction_mask, prediction_mask_stride,
-          prediction_width, prediction_height, dest, dest_stride);
-      break;
     case kCompoundPredictionTypeDiffWeighted:
-      if (plane == kPlaneY) {
-        GenerateWeightMask(
-            prediction[0], prediction_stride, prediction[1], prediction_stride,
-            prediction_parameters.mask_is_inverse, prediction_width,
-            prediction_height, bitdepth, block.scratch_buffer->prediction_mask,
-            prediction_mask_stride);
-      }
-      GetMaskBlendFunc(dsp_, prediction_parameters.inter_intra_mode,
+      GetMaskBlendFunc(dsp_, /*is_inter_intra=*/false,
                        prediction_parameters.is_wedge_inter_intra,
                        subsampling_x, subsampling_y)(
-          prediction[0], prediction_stride, prediction[1], prediction_stride,
-          block.scratch_buffer->prediction_mask, prediction_mask_stride,
-          prediction_width, prediction_height, dest, dest_stride);
+          prediction[0], prediction[1],
+          /*prediction_stride=*/prediction_width, prediction_mask,
+          prediction_mask_stride, prediction_width, prediction_height, dest,
+          dest_stride);
       break;
     case kCompoundPredictionTypeDistance:
-      DistanceWeightedPrediction(
-          prediction[0], prediction_stride, prediction[1], prediction_stride,
-          prediction_width, prediction_height, candidate_row, candidate_column,
-          dest, dest_stride);
+      DistanceWeightedPrediction(prediction[0], prediction[1], prediction_width,
+                                 prediction_height, candidate_row,
+                                 candidate_column, dest, dest_stride);
       break;
     default:
       assert(prediction_parameters.compound_prediction_type ==
              kCompoundPredictionTypeAverage);
-      dsp_.average_blend(prediction[0], prediction_stride, prediction[1],
-                         prediction_stride, prediction_width, prediction_height,
-                         dest, dest_stride);
+      dsp_.average_blend(prediction[0], prediction[1], prediction_width,
+                         prediction_height, dest, dest_stride);
       break;
   }
 }
@@ -698,7 +686,7 @@
         WarpEstimation(
             prediction_parameters.num_warp_samples, DivideBy4(prediction_width),
             DivideBy4(prediction_height), block.row4x4, block.column4x4,
-            block.bp->mv[0], prediction_parameters.warp_estimate_candidates,
+            block.bp->mv.mv[0], prediction_parameters.warp_estimate_candidates,
             local_warp_params) &&
         SetupShear(local_warp_params);
   }
@@ -712,7 +700,7 @@
             ? global_motion_params->type
             : kNumGlobalMotionTransformationTypes;
     const bool is_global_valid =
-        IsGlobalMvBlock(block.bp->y_mode, global_motion_type, block.size) &&
+        IsGlobalMvBlock(block.bp->is_global_mv_block, global_motion_type) &&
         SetupShear(global_motion_params);
     // Valid global motion type implies reference type can't be intra.
     assert(!is_global_valid || reference_type != kReferenceFrameIntra);
@@ -721,7 +709,7 @@
   return nullptr;
 }
 
-void Tile::InterPrediction(const Block& block, const Plane plane, const int x,
+bool Tile::InterPrediction(const Block& block, const Plane plane, const int x,
                            const int y, const int prediction_width,
                            const int prediction_height, int candidate_row,
                            int candidate_column, bool* const is_local_valid,
@@ -732,16 +720,13 @@
       *block_parameters_holder_.Find(candidate_row, candidate_column);
   const bool is_compound =
       bp_reference.reference_frame[1] > kReferenceFrameIntra;
-  const bool is_inter_intra =
-      bp.is_inter && bp.reference_frame[1] == kReferenceFrameIntra;
-  const ptrdiff_t prediction_stride = prediction_width;
+  assert(bp.is_inter);
+  const bool is_inter_intra = bp.reference_frame[1] == kReferenceFrameIntra;
 
   const PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
   uint8_t* const dest = GetStartPoint(buffer_, plane, x, y, bitdepth);
   const ptrdiff_t dest_stride = buffer_[plane].columns();  // In bytes.
-  const int round_bits =
-      GetInterRoundingBits(is_compound, sequence_header_.color_config.bitdepth);
   for (int index = 0; index < 1 + static_cast<int>(is_compound); ++index) {
     const ReferenceFrameType reference_type =
         bp_reference.reference_frame[index];
@@ -752,119 +737,126 @@
                       prediction_parameters, reference_type, is_local_valid,
                       &global_motion_params, local_warp_params);
     if (warp_params != nullptr) {
-      BlockWarpProcess(block, plane, index, x, y, prediction_width,
-                       prediction_height, prediction_stride, warp_params,
-                       round_bits, is_compound, is_inter_intra, dest,
-                       dest_stride);
+      if (!BlockWarpProcess(block, plane, index, x, y, prediction_width,
+                            prediction_height, warp_params, is_compound,
+                            is_inter_intra, dest, dest_stride)) {
+        return false;
+      }
     } else {
       const int reference_index =
           prediction_parameters.use_intra_block_copy
               ? -1
               : frame_header_.reference_frame_index[reference_type -
                                                     kReferenceFrameLast];
-      BlockInterPrediction(
-          block, plane, reference_index, bp_reference.mv[index], x, y,
-          prediction_width, prediction_height, candidate_row, candidate_column,
-          block.scratch_buffer->prediction_buffer[index], prediction_stride,
-          round_bits, is_compound, is_inter_intra, dest, dest_stride);
+      if (!BlockInterPrediction(
+              block, plane, reference_index, bp_reference.mv.mv[index], x, y,
+              prediction_width, prediction_height, candidate_row,
+              candidate_column, block.scratch_buffer->prediction_buffer[index],
+              is_compound, is_inter_intra, dest, dest_stride)) {
+        return false;
+      }
     }
   }
 
-  const ptrdiff_t prediction_mask_stride = kMaxSuperBlockSizeInPixels;
   const int subsampling_x = subsampling_x_[plane];
   const int subsampling_y = subsampling_y_[plane];
+  ptrdiff_t prediction_mask_stride = 0;
+  const uint8_t* prediction_mask = nullptr;
   if (prediction_parameters.compound_prediction_type ==
-          kCompoundPredictionTypeWedge &&
-      plane == kPlaneY) {
-    // Wedge masks are generated only once per decoder. We only need to
-    // populate wedge masks to prediction_mask_.
-    const int wedge_mask_stride_1 = kMaxMaskBlockSize;
-    const int wedge_mask_stride_2 = wedge_mask_stride_1 * 16;
-    const int wedge_mask_stride_3 = wedge_mask_stride_2 * 2;
-    const int block_size_index = GetWedgeBlockSizeIndex(block.size);
-    assert(block_size_index >= 0);
-    const int offset = block_size_index * wedge_mask_stride_3 +
-                       prediction_parameters.wedge_sign * wedge_mask_stride_2 +
-                       prediction_parameters.wedge_index * wedge_mask_stride_1;
-    PopulatePredictionMaskFromWedgeMask(
-        &wedge_masks_[offset], kWedgeMaskMasterSize, prediction_width,
-        prediction_height, block.scratch_buffer->prediction_mask,
-        kMaxSuperBlockSizeInPixels);
+      kCompoundPredictionTypeWedge) {
+    const Array2D<uint8_t>& wedge_mask =
+        wedge_masks_[GetWedgeBlockSizeIndex(block.size)]
+                    [prediction_parameters.wedge_sign]
+                    [prediction_parameters.wedge_index];
+    prediction_mask = wedge_mask[0];
+    prediction_mask_stride = wedge_mask.columns();
   } else if (prediction_parameters.compound_prediction_type ==
              kCompoundPredictionTypeIntra) {
-    GenerateInterIntraMask(prediction_parameters.inter_intra_mode,
-                           prediction_width, prediction_height,
-                           block.scratch_buffer->prediction_mask,
-                           prediction_mask_stride);
+    // 7.11.3.13. The inter intra masks are precomputed and stored as a set of
+    // look up tables.
+    assert(prediction_parameters.inter_intra_mode < kNumInterIntraModes);
+    prediction_mask =
+        kInterIntraMasks[prediction_parameters.inter_intra_mode]
+                        [GetInterIntraMaskLookupIndex(prediction_width)]
+                        [GetInterIntraMaskLookupIndex(prediction_height)];
+    prediction_mask_stride = prediction_width;
+  } else if (prediction_parameters.compound_prediction_type ==
+             kCompoundPredictionTypeDiffWeighted) {
+    if (plane == kPlaneY) {
+      assert(prediction_width >= 8);
+      assert(prediction_height >= 8);
+      dsp_.weight_mask[FloorLog2(prediction_width) - 3]
+                      [FloorLog2(prediction_height) - 3]
+                      [static_cast<int>(prediction_parameters.mask_is_inverse)](
+                          block.scratch_buffer->prediction_buffer[0],
+                          block.scratch_buffer->prediction_buffer[1],
+                          block.scratch_buffer->weight_mask,
+                          kMaxSuperBlockSizeInPixels);
+    }
+    prediction_mask = block.scratch_buffer->weight_mask;
+    prediction_mask_stride = kMaxSuperBlockSizeInPixels;
   }
 
   if (is_compound) {
-    CompoundInterPrediction(block, prediction_stride, prediction_mask_stride,
-                            prediction_width, prediction_height, plane,
-                            subsampling_x, subsampling_y, bitdepth,
-                            candidate_row, candidate_column, dest, dest_stride);
-  } else {
-    if (prediction_parameters.motion_mode == kMotionModeObmc) {
-      // Obmc mode is allowed only for single reference (!is_compound).
-      ObmcPrediction(block, plane, prediction_width, prediction_height,
-                     round_bits);
-    } else if (is_inter_intra) {
-      // InterIntra and obmc must be mutually exclusive.
-      uint16_t* prediction_ptr[2] = {
-          block.scratch_buffer->prediction_buffer[0],
-          block.scratch_buffer->prediction_buffer[1]};
-      InterIntraPrediction(prediction_ptr, prediction_stride,
-                           block.scratch_buffer->prediction_mask,
-                           prediction_mask_stride, prediction_parameters,
-                           prediction_width, prediction_height, subsampling_x,
-                           subsampling_y, dest, dest_stride);
-    }
+    CompoundInterPrediction(block, prediction_mask, prediction_mask_stride,
+                            prediction_width, prediction_height, subsampling_x,
+                            subsampling_y, candidate_row, candidate_column,
+                            dest, dest_stride);
+  } else if (prediction_parameters.motion_mode == kMotionModeObmc) {
+    // Obmc mode is allowed only for single reference (!is_compound).
+    return ObmcPrediction(block, plane, prediction_width, prediction_height);
+  } else if (is_inter_intra) {
+    // InterIntra and obmc must be mutually exclusive.
+    InterIntraPrediction(
+        block.scratch_buffer->prediction_buffer[0], prediction_mask,
+        prediction_mask_stride, prediction_parameters, prediction_width,
+        prediction_height, subsampling_x, subsampling_y, dest, dest_stride);
   }
+  return true;
 }
 
-void Tile::ObmcBlockPrediction(const Block& block, const MotionVector& mv,
+bool Tile::ObmcBlockPrediction(const Block& block, const MotionVector& mv,
                                const Plane plane,
                                const int reference_frame_index, const int width,
                                const int height, const int x, const int y,
                                const int candidate_row,
                                const int candidate_column,
-                               const ObmcDirection blending_direction,
-                               const int round_bits) {
+                               const ObmcDirection blending_direction) {
   const int bitdepth = sequence_header_.color_config.bitdepth;
   // Obmc's prediction needs to be clipped before blending with above/left
   // prediction blocks.
-  uint8_t obmc_clipped_prediction[kObmcBufferSize
-#if LIBGAV1_MAX_BITDEPTH >= 10
-                                  * 2
-#endif
-  ];
-  const ptrdiff_t obmc_clipped_prediction_stride =
+  // Obmc prediction is used only when is_compound is false. So it is safe to
+  // use prediction_buffer[1] as a temporary buffer for the Obmc prediction.
+  static_assert(sizeof(block.scratch_buffer->prediction_buffer[1]) >=
+                    64 * 64 * sizeof(uint16_t),
+                "");
+  auto* const obmc_buffer =
+      reinterpret_cast<uint8_t*>(block.scratch_buffer->prediction_buffer[1]);
+  const ptrdiff_t obmc_buffer_stride =
       (bitdepth == 8) ? width : width * sizeof(uint16_t);
-  BlockInterPrediction(block, plane, reference_frame_index, mv, x, y, width,
-                       height, candidate_row, candidate_column, nullptr, width,
-                       round_bits, false, false, obmc_clipped_prediction,
-                       obmc_clipped_prediction_stride);
+  if (!BlockInterPrediction(block, plane, reference_frame_index, mv, x, y,
+                            width, height, candidate_row, candidate_column,
+                            nullptr, false, false, obmc_buffer,
+                            obmc_buffer_stride)) {
+    return false;
+  }
 
   uint8_t* const prediction = GetStartPoint(buffer_, plane, x, y, bitdepth);
   const ptrdiff_t prediction_stride = buffer_[plane].columns();
   dsp_.obmc_blend[blending_direction](prediction, prediction_stride, width,
-                                      height, obmc_clipped_prediction,
-                                      obmc_clipped_prediction_stride);
+                                      height, obmc_buffer, obmc_buffer_stride);
+  return true;
 }
 
-void Tile::ObmcPrediction(const Block& block, const Plane plane,
-                          const int width, const int height,
-                          const int round_bits) {
+bool Tile::ObmcPrediction(const Block& block, const Plane plane,
+                          const int width, const int height) {
   const int subsampling_x = subsampling_x_[plane];
   const int subsampling_y = subsampling_y_[plane];
-  const int num4x4_wide = kNum4x4BlocksWide[block.size];
-  const int num4x4_high = kNum4x4BlocksHigh[block.size];
-
-  if (block.top_available &&
-      !IsBlockSmallerThan8x8(block.residual_size[GetPlaneType(plane)])) {
+  if (block.top_available[kPlaneY] &&
+      !IsBlockSmallerThan8x8(block.residual_size[plane])) {
     const int num_limit = std::min(uint8_t{4}, k4x4WidthLog2[block.size]);
     const int column4x4_max =
-        std::min(block.column4x4 + num4x4_wide, frame_header_.columns4x4);
+        std::min(block.column4x4 + block.width4x4, frame_header_.columns4x4);
     const int candidate_row = block.row4x4 - 1;
     const int block_start_y = MultiplyBy4(block.row4x4) >> subsampling_y;
     int column4x4 = block.column4x4;
@@ -883,20 +875,21 @@
                                                 kReferenceFrameLast];
         const int prediction_width =
             std::min(width, MultiplyBy4(step) >> subsampling_x);
-        ObmcBlockPrediction(block, bp_top.mv[0], plane,
-                            candidate_reference_frame_index, prediction_width,
-                            prediction_height,
-                            MultiplyBy4(column4x4) >> subsampling_x,
-                            block_start_y, candidate_row, candidate_column,
-                            kObmcDirectionVertical, round_bits);
+        if (!ObmcBlockPrediction(
+                block, bp_top.mv.mv[0], plane, candidate_reference_frame_index,
+                prediction_width, prediction_height,
+                MultiplyBy4(column4x4) >> subsampling_x, block_start_y,
+                candidate_row, candidate_column, kObmcDirectionVertical)) {
+          return false;
+        }
       }
     }
   }
 
-  if (block.left_available) {
+  if (block.left_available[kPlaneY]) {
     const int num_limit = std::min(uint8_t{4}, k4x4HeightLog2[block.size]);
     const int row4x4_max =
-        std::min(block.row4x4 + num4x4_high, frame_header_.rows4x4);
+        std::min(block.row4x4 + block.height4x4, frame_header_.rows4x4);
     const int candidate_column = block.column4x4 - 1;
     int row4x4 = block.row4x4;
     const int block_start_x = MultiplyBy4(block.column4x4) >> subsampling_x;
@@ -915,43 +908,44 @@
                                                 kReferenceFrameLast];
         const int prediction_height =
             std::min(height, MultiplyBy4(step) >> subsampling_y);
-        ObmcBlockPrediction(
-            block, bp_left.mv[0], plane, candidate_reference_frame_index,
-            prediction_width, prediction_height, block_start_x,
-            MultiplyBy4(row4x4) >> subsampling_y, candidate_row,
-            candidate_column, kObmcDirectionHorizontal, round_bits);
+        if (!ObmcBlockPrediction(
+                block, bp_left.mv.mv[0], plane, candidate_reference_frame_index,
+                prediction_width, prediction_height, block_start_x,
+                MultiplyBy4(row4x4) >> subsampling_y, candidate_row,
+                candidate_column, kObmcDirectionHorizontal)) {
+          return false;
+        }
       }
     }
   }
+  return true;
 }
 
-void Tile::DistanceWeightedPrediction(
-    uint16_t* prediction_0, ptrdiff_t prediction_stride_0,
-    uint16_t* prediction_1, ptrdiff_t prediction_stride_1, const int width,
-    const int height, const int candidate_row, const int candidate_column,
-    uint8_t* dest, ptrdiff_t dest_stride) {
+void Tile::DistanceWeightedPrediction(void* prediction_0, void* prediction_1,
+                                      const int width, const int height,
+                                      const int candidate_row,
+                                      const int candidate_column, uint8_t* dest,
+                                      ptrdiff_t dest_stride) {
   int distance[2];
   int weight[2];
   for (int reference = 0; reference < 2; ++reference) {
     const BlockParameters& bp =
         *block_parameters_holder_.Find(candidate_row, candidate_column);
-    const int reference_hint =
-        current_frame_.order_hint(bp.reference_frame[reference]);
     // Note: distance[0] and distance[1] correspond to relative distance
     // between current frame and reference frame [1] and [0], respectively.
-    distance[1 - reference] = Clip3(
-        std::abs(GetRelativeDistance(reference_hint, frame_header_.order_hint,
-                                     sequence_header_.enable_order_hint,
-                                     sequence_header_.order_hint_bits)),
-        0, kMaxFrameDistance);
+    distance[1 - reference] = std::min(
+        std::abs(static_cast<int>(
+            current_frame_.reference_info()
+                ->relative_distance_from[bp.reference_frame[reference]])),
+        static_cast<int>(kMaxFrameDistance));
   }
   GetDistanceWeights(distance, weight);
 
-  dsp_.distance_weighted_blend(prediction_0, prediction_stride_0, prediction_1,
-                               prediction_stride_1, weight[0], weight[1], width,
-                               height, dest, dest_stride);
+  dsp_.distance_weighted_blend(prediction_0, prediction_1, weight[0], weight[1],
+                               width, height, dest, dest_stride);
 }
 
+// static.
 bool Tile::GetReferenceBlockPosition(
     const int reference_frame_index, const bool is_scaled, const int width,
     const int height, const int ref_start_x, const int ref_last_x,
@@ -959,7 +953,7 @@
     const int start_y, const int step_x, const int step_y,
     const int left_border, const int right_border, const int top_border,
     const int bottom_border, int* ref_block_start_x, int* ref_block_start_y,
-    int* ref_block_end_x, int* ref_block_end_y) {
+    int* ref_block_end_x) {
   *ref_block_start_x = GetPixelPositionFromHighScale(start_x, 0, 0);
   *ref_block_start_y = GetPixelPositionFromHighScale(start_y, 0, 0);
   if (reference_frame_index == -1) {
@@ -968,36 +962,35 @@
   *ref_block_start_x -= kConvolveBorderLeftTop;
   *ref_block_start_y -= kConvolveBorderLeftTop;
   *ref_block_end_x = GetPixelPositionFromHighScale(start_x, step_x, width - 1) +
-                     kConvolveBorderRightBottom;
-  *ref_block_end_y =
+                     kConvolveBorderRight;
+  int ref_block_end_y =
       GetPixelPositionFromHighScale(start_y, step_y, height - 1) +
-      kConvolveBorderRightBottom;
+      kConvolveBorderBottom;
   if (is_scaled) {
     const int block_height =
         (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
          kScaleSubPixelBits) +
         kSubPixelTaps;
-    *ref_block_end_y = *ref_block_start_y + block_height - 1;
+    ref_block_end_y = *ref_block_start_y + block_height - 1;
   }
   // Determines if we need to extend beyond the left/right/top/bottom border.
   return *ref_block_start_x < (ref_start_x - left_border) ||
          *ref_block_end_x > (ref_last_x + right_border) ||
          *ref_block_start_y < (ref_start_y - top_border) ||
-         *ref_block_end_y > (ref_last_y + bottom_border);
+         ref_block_end_y > (ref_last_y + bottom_border);
 }
 
 // Builds a block as the input for convolve, by copying the content of
 // reference frame (either a decoded reference frame, or current frame).
+// |block_extended_width| is the combined width of the block and its borders.
 template <typename Pixel>
-void Tile::BuildConvolveBlock(const Plane plane,
-                              const int reference_frame_index,
-                              const bool is_scaled, const int height,
-                              const int ref_start_x, const int ref_last_x,
-                              const int ref_start_y, const int ref_last_y,
-                              const int step_y, const int ref_block_start_x,
-                              const int ref_block_end_x,
-                              const int ref_block_start_y,
-                              uint8_t* block_buffer, ptrdiff_t block_stride) {
+void Tile::BuildConvolveBlock(
+    const Plane plane, const int reference_frame_index, const bool is_scaled,
+    const int height, const int ref_start_x, const int ref_last_x,
+    const int ref_start_y, const int ref_last_y, const int step_y,
+    const int ref_block_start_x, const int ref_block_end_x,
+    const int ref_block_start_y, uint8_t* block_buffer,
+    ptrdiff_t convolve_buffer_stride, ptrdiff_t block_extended_width) {
   const YuvBuffer* const reference_buffer =
       (reference_frame_index == -1)
           ? current_frame_.buffer()
@@ -1007,9 +1000,8 @@
       reference_buffer->stride(plane) / sizeof(Pixel),
       reinterpret_cast<const Pixel*>(reference_buffer->data(plane)));
   auto* const block_head = reinterpret_cast<Pixel*>(block_buffer);
-  block_stride /= sizeof(Pixel);
-  int block_height =
-      height + kConvolveBorderLeftTop + kConvolveBorderRightBottom;
+  convolve_buffer_stride /= sizeof(Pixel);
+  int block_height = height + kConvolveBorderLeftTop + kConvolveBorderBottom;
   if (is_scaled) {
     block_height = (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
                     kScaleSubPixelBits) +
@@ -1030,12 +1022,12 @@
     const int ref_x = out_of_left ? copy_start_x : copy_end_x;
     Pixel* buf_ptr = block_head;
     for (int y = 0, ref_y = copy_start_y; y < block_height; ++y) {
-      Memset(buf_ptr, reference_block[ref_y][ref_x], block_stride);
+      Memset(buf_ptr, reference_block[ref_y][ref_x], block_extended_width);
       if (ref_block_start_y + y >= ref_start_y &&
           ref_block_start_y + y < ref_last_y) {
         ++ref_y;
       }
-      buf_ptr += block_stride;
+      buf_ptr += convolve_buffer_stride;
     }
   } else {
     Pixel* buf_ptr = block_head;
@@ -1049,24 +1041,24 @@
       if (extend_right) {
         Memset(buf_ptr + left_width + block_width,
                reference_block[ref_y][copy_end_x],
-               block_stride - left_width - block_width);
+               block_extended_width - left_width - block_width);
       }
       if (ref_block_start_y + y >= ref_start_y &&
           ref_block_start_y + y < ref_last_y) {
         ++ref_y;
       }
-      buf_ptr += block_stride;
+      buf_ptr += convolve_buffer_stride;
     }
   }
 }
 
-void Tile::BlockInterPrediction(
+bool Tile::BlockInterPrediction(
     const Block& block, const Plane plane, const int reference_frame_index,
     const MotionVector& mv, const int x, const int y, const int width,
     const int height, const int candidate_row, const int candidate_column,
-    uint16_t* const prediction, const ptrdiff_t prediction_stride,
-    const int round_bits, const bool is_compound, const bool is_inter_intra,
-    uint8_t* const dest, const ptrdiff_t dest_stride) {
+    uint16_t* const prediction, const bool is_compound,
+    const bool is_inter_intra, uint8_t* const dest,
+    const ptrdiff_t dest_stride) {
   const BlockParameters& bp =
       *block_parameters_holder_.Find(candidate_row, candidate_column);
   int start_x;
@@ -1095,43 +1087,60 @@
           : reference_frames_[reference_frame_index]->frame_height();
   const int ref_start_x = 0;
   const int ref_last_x =
-      SubsampledValue(
-          reference_upscaled_width,
-          (plane == kPlaneY) ? 0 : reference_buffer->subsampling_x()) -
-      1;
+      SubsampledValue(reference_upscaled_width, subsampling_x) - 1;
   const int ref_start_y = 0;
-  const int ref_last_y =
-      SubsampledValue(
-          reference_height,
-          (plane == kPlaneY) ? 0 : reference_buffer->subsampling_y()) -
-      1;
+  const int ref_last_y = SubsampledValue(reference_height, subsampling_y) - 1;
 
   const bool is_scaled = (reference_frame_index != -1) &&
                          (frame_header_.width != reference_upscaled_width ||
                           frame_header_.height != reference_height);
   const int bitdepth = sequence_header_.color_config.bitdepth;
-  const size_t pixel_size =
-      (bitdepth == 8) ? sizeof(uint8_t) : sizeof(uint16_t);
+  const int pixel_size = (bitdepth == 8) ? sizeof(uint8_t) : sizeof(uint16_t);
   int ref_block_start_x;
   int ref_block_start_y;
   int ref_block_end_x;
-  int ref_block_end_y;
-  bool extend_block = GetReferenceBlockPosition(
+  const bool extend_block = GetReferenceBlockPosition(
       reference_frame_index, is_scaled, width, height, ref_start_x, ref_last_x,
       ref_start_y, ref_last_y, start_x, start_y, step_x, step_y,
       reference_buffer->left_border(plane),
       reference_buffer->right_border(plane),
       reference_buffer->top_border(plane),
       reference_buffer->bottom_border(plane), &ref_block_start_x,
-      &ref_block_start_y, &ref_block_end_x, &ref_block_end_y);
+      &ref_block_start_y, &ref_block_end_x);
+
+  // In frame parallel mode, ensure that the reference block has been decoded
+  // and available for referencing.
+  if (reference_frame_index != -1 && frame_parallel_) {
+    int reference_y_max;
+    if (is_scaled) {
+      // TODO(vigneshv): For now, we wait for the entire reference frame to be
+      // decoded if we are using scaled references. This will eventually be
+      // fixed.
+      reference_y_max = reference_height;
+    } else {
+      reference_y_max =
+          std::min(ref_block_start_y + height + kSubPixelTaps, ref_last_y);
+      // For U and V planes with subsampling, we need to multiply
+      // reference_y_max by 2 since we only track the progress of Y planes.
+      reference_y_max = LeftShift(reference_y_max, subsampling_y);
+    }
+    if (reference_frame_progress_cache_[reference_frame_index] <
+            reference_y_max &&
+        !reference_frames_[reference_frame_index]->WaitUntil(
+            reference_y_max,
+            &reference_frame_progress_cache_[reference_frame_index])) {
+      return false;
+    }
+  }
+
   const uint8_t* block_start = nullptr;
-  ptrdiff_t block_stride;
+  ptrdiff_t convolve_buffer_stride;
   if (!extend_block) {
     const YuvBuffer* const reference_buffer =
         (reference_frame_index == -1)
             ? current_frame_.buffer()
             : reference_frames_[reference_frame_index]->buffer();
-    block_stride = reference_buffer->stride(plane);
+    convolve_buffer_stride = reference_buffer->stride(plane);
     if (reference_frame_index == -1 || is_scaled) {
       block_start = reference_buffer->data(plane) +
                     ref_block_start_y * reference_buffer->stride(plane) +
@@ -1143,63 +1152,86 @@
                     (ref_block_start_x + kConvolveBorderLeftTop) * pixel_size;
     }
   } else {
-    // The reference block width can be at most 2 times as much as current
+    // The block width can be at most 2 times as much as current
     // block's width because of scaling.
-    block_stride =
-        (2 * width + kConvolveBorderLeftTop + kConvolveBorderRightBottom) *
-        pixel_size;
-    if (bitdepth == 8) {
-      BuildConvolveBlock<uint8_t>(
-          plane, reference_frame_index, is_scaled, height, ref_start_x,
-          ref_last_x, ref_start_y, ref_last_y, step_y, ref_block_start_x,
-          ref_block_end_x, ref_block_start_y,
-          block.scratch_buffer->convolve_block_buffer, block_stride);
+    auto block_extended_width = Align<ptrdiff_t>(
+        (2 * width + kConvolveBorderLeftTop + kConvolveBorderRight) *
+            pixel_size,
+        kMaxAlignment);
+    convolve_buffer_stride = block.scratch_buffer->convolve_block_buffer_stride;
 #if LIBGAV1_MAX_BITDEPTH >= 10
-    } else {
+    if (bitdepth > 8) {
       BuildConvolveBlock<uint16_t>(
           plane, reference_frame_index, is_scaled, height, ref_start_x,
           ref_last_x, ref_start_y, ref_last_y, step_y, ref_block_start_x,
           ref_block_end_x, ref_block_start_y,
-          block.scratch_buffer->convolve_block_buffer, block_stride);
+          block.scratch_buffer->convolve_block_buffer.get(),
+          convolve_buffer_stride, block_extended_width);
+    } else {
 #endif
+      BuildConvolveBlock<uint8_t>(
+          plane, reference_frame_index, is_scaled, height, ref_start_x,
+          ref_last_x, ref_start_y, ref_last_y, step_y, ref_block_start_x,
+          ref_block_end_x, ref_block_start_y,
+          block.scratch_buffer->convolve_block_buffer.get(),
+          convolve_buffer_stride, block_extended_width);
+#if LIBGAV1_MAX_BITDEPTH >= 10
     }
-    block_start = block.scratch_buffer->convolve_block_buffer +
+#endif
+    block_start = block.scratch_buffer->convolve_block_buffer.get() +
                   (is_scaled ? 0
-                             : kConvolveBorderLeftTop * block_stride +
+                             : kConvolveBorderLeftTop * convolve_buffer_stride +
                                    kConvolveBorderLeftTop * pixel_size);
   }
 
   const int has_horizontal_filter = static_cast<int>(
-      ((mv.mv[MotionVector::kColumn] * (1 << (1 - subsampling_x))) & 15) != 0);
+      ((mv.mv[MotionVector::kColumn] * (1 << (1 - subsampling_x))) &
+       kSubPixelMask) != 0);
   const int has_vertical_filter = static_cast<int>(
-      ((mv.mv[MotionVector::kRow] * (1 << (1 - subsampling_y))) & 15) != 0);
+      ((mv.mv[MotionVector::kRow] * (1 << (1 - subsampling_y))) &
+       kSubPixelMask) != 0);
   void* const output =
       (is_compound || is_inter_intra) ? prediction : static_cast<void*>(dest);
-  const ptrdiff_t output_stride =
-      (is_compound || is_inter_intra) ? prediction_stride : dest_stride;
-  assert(output != nullptr);
-  dsp::ConvolveFunc convolve_func =
-      is_scaled ? dsp_.convolve_scale[is_compound || is_inter_intra]
-                : dsp_.convolve[reference_frame_index == -1][is_compound]
-                               [has_vertical_filter][has_horizontal_filter];
-  assert(convolve_func != nullptr);
-  // TODO(b/127805357): Refactor is_inter_intra into single prediction.
-  if (is_inter_intra && !is_scaled) {
-    convolve_func = dsp_.convolve[0][1][1][1];
+  ptrdiff_t output_stride = (is_compound || is_inter_intra)
+                                ? /*prediction_stride=*/width
+                                : dest_stride;
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  // |is_inter_intra| calculations are written to the |prediction| buffer.
+  // Unlike the |is_compound| calculations the output is Pixel and not uint16_t.
+  // convolve_func() expects |output_stride| to be in bytes and not Pixels.
+  // |prediction_stride| is in units of uint16_t. Adjust |output_stride| to
+  // account for this.
+  if (is_inter_intra && sequence_header_.color_config.bitdepth > 8) {
+    output_stride *= 2;
   }
-  convolve_func(block_start, block_stride, horizontal_filter_index,
-                vertical_filter_index, round_bits, start_x, start_y, step_x,
-                step_y, width, height, output, output_stride);
+#endif
+  assert(output != nullptr);
+  if (is_scaled) {
+    dsp::ConvolveScaleFunc convolve_func = dsp_.convolve_scale[is_compound];
+    assert(convolve_func != nullptr);
+
+    convolve_func(block_start, convolve_buffer_stride, horizontal_filter_index,
+                  vertical_filter_index, start_x, start_y, step_x, step_y,
+                  width, height, output, output_stride);
+  } else {
+    dsp::ConvolveFunc convolve_func =
+        dsp_.convolve[reference_frame_index == -1][is_compound]
+                     [has_vertical_filter][has_horizontal_filter];
+    assert(convolve_func != nullptr);
+
+    convolve_func(block_start, convolve_buffer_stride, horizontal_filter_index,
+                  vertical_filter_index, start_x, start_y, width, height,
+                  output, output_stride);
+  }
+  return true;
 }
 
-void Tile::BlockWarpProcess(const Block& block, const Plane plane,
+bool Tile::BlockWarpProcess(const Block& block, const Plane plane,
                             const int index, const int block_start_x,
                             const int block_start_y, const int width,
-                            const int height, const ptrdiff_t prediction_stride,
-                            GlobalMotion* const warp_params,
-                            const int round_bits, const bool is_compound,
-                            const bool is_inter_intra, uint8_t* const dest,
-                            const ptrdiff_t dest_stride) {
+                            const int height, GlobalMotion* const warp_params,
+                            const bool is_compound, const bool is_inter_intra,
+                            uint8_t* const dest, const ptrdiff_t dest_stride) {
   assert(width >= 8 && height >= 8);
   const BlockParameters& bp = *block.bp;
   const int reference_frame_index =
@@ -1210,29 +1242,69 @@
   ptrdiff_t source_stride =
       reference_frames_[reference_frame_index]->buffer()->stride(plane);
   const int source_width =
-      reference_frames_[reference_frame_index]->buffer()->displayed_width(
-          plane);
+      reference_frames_[reference_frame_index]->buffer()->width(plane);
   const int source_height =
-      reference_frames_[reference_frame_index]->buffer()->displayed_height(
-          plane);
+      reference_frames_[reference_frame_index]->buffer()->height(plane);
   uint16_t* const prediction = block.scratch_buffer->prediction_buffer[index];
-  dsp_.warp(source, source_stride, source_width, source_height,
-            warp_params->params, subsampling_x_[plane], subsampling_y_[plane],
-            round_bits, block_start_x, block_start_y, width, height,
-            warp_params->alpha, warp_params->beta, warp_params->gamma,
-            warp_params->delta, prediction, prediction_stride);
-  if (!is_compound && !is_inter_intra) {
-    const int bitdepth = sequence_header_.color_config.bitdepth;
-    if (bitdepth == 8) {
-      ClipPrediction<8, uint8_t>(prediction, prediction_stride, width, height,
-                                 dest, dest_stride);
-#if LIBGAV1_MAX_BITDEPTH >= 10
-    } else {
-      ClipPrediction<10, uint16_t>(prediction, prediction_stride, width, height,
-                                   dest, dest_stride);
-#endif
+
+  // In frame parallel mode, ensure that the reference block has been decoded
+  // and available for referencing.
+  if (frame_parallel_) {
+    int reference_y_max = -1;
+    // Find out the maximum y-coordinate for warping.
+    for (int start_y = block_start_y; start_y < block_start_y + height;
+         start_y += 8) {
+      for (int start_x = block_start_x; start_x < block_start_x + width;
+           start_x += 8) {
+        const int src_x = (start_x + 4) << subsampling_x_[plane];
+        const int src_y = (start_y + 4) << subsampling_y_[plane];
+        const int dst_y = src_x * warp_params->params[4] +
+                          src_y * warp_params->params[5] +
+                          warp_params->params[1];
+        const int y4 = dst_y >> subsampling_y_[plane];
+        const int iy4 = y4 >> kWarpedModelPrecisionBits;
+        reference_y_max = std::max(iy4 + 8, reference_y_max);
+      }
+    }
+    // For U and V planes with subsampling, we need to multiply reference_y_max
+    // by 2 since we only track the progress of Y planes.
+    reference_y_max = LeftShift(reference_y_max, subsampling_y_[plane]);
+    if (reference_frame_progress_cache_[reference_frame_index] <
+            reference_y_max &&
+        !reference_frames_[reference_frame_index]->WaitUntil(
+            reference_y_max,
+            &reference_frame_progress_cache_[reference_frame_index])) {
+      return false;
     }
   }
+  if (is_compound) {
+    dsp_.warp_compound(source, source_stride, source_width, source_height,
+                       warp_params->params, subsampling_x_[plane],
+                       subsampling_y_[plane], block_start_x, block_start_y,
+                       width, height, warp_params->alpha, warp_params->beta,
+                       warp_params->gamma, warp_params->delta, prediction,
+                       /*prediction_stride=*/width);
+  } else {
+    void* const output = is_inter_intra ? static_cast<void*>(prediction) : dest;
+    ptrdiff_t output_stride =
+        is_inter_intra ? /*prediction_stride=*/width : dest_stride;
+#if LIBGAV1_MAX_BITDEPTH >= 10
+    // |is_inter_intra| calculations are written to the |prediction| buffer.
+    // Unlike the |is_compound| calculations the output is Pixel and not
+    // uint16_t. warp_clip() expects |output_stride| to be in bytes and not
+    // Pixels. |prediction_stride| is in units of uint16_t. Adjust
+    // |output_stride| to account for this.
+    if (is_inter_intra && sequence_header_.color_config.bitdepth > 8) {
+      output_stride *= 2;
+    }
+#endif
+    dsp_.warp(source, source_stride, source_width, source_height,
+              warp_params->params, subsampling_x_[plane], subsampling_y_[plane],
+              block_start_x, block_start_y, width, height, warp_params->alpha,
+              warp_params->beta, warp_params->gamma, warp_params->delta, output,
+              output_stride);
+  }
+  return true;
 }
 
 }  // namespace libgav1
diff --git a/libgav1/src/tile/tile.cc b/libgav1/src/tile/tile.cc
index 96c724f..f79158f 100644
--- a/libgav1/src/tile/tile.cc
+++ b/libgav1/src/tile/tile.cc
@@ -17,6 +17,7 @@
 #include <algorithm>
 #include <array>
 #include <cassert>
+#include <climits>
 #include <cstdlib>
 #include <cstring>
 #include <memory>
@@ -25,9 +26,12 @@
 #include <type_traits>
 #include <utility>
 
+#include "src/frame_scratch_buffer.h"
 #include "src/motion_vector.h"
 #include "src/reconstruction.h"
 #include "src/utils/bit_mask_set.h"
+#include "src/utils/common.h"
+#include "src/utils/constants.h"
 #include "src/utils/logging.h"
 #include "src/utils/segmentation.h"
 #include "src/utils/stack.h"
@@ -45,8 +49,6 @@
 // process is activated.
 constexpr int kQuantizerCoefficientBaseRange = 12;
 constexpr int kNumQuantizerBaseLevels = 2;
-constexpr int kQuantizerCoefficientBaseRangeContextClamp =
-    kQuantizerCoefficientBaseRange + kNumQuantizerBaseLevels + 1;
 constexpr int kCoeffBaseRangeMaxIterations =
     kQuantizerCoefficientBaseRange / (kCoeffBaseRangeSymbolCount - 1);
 constexpr int kEntropyContextLeft = 0;
@@ -99,6 +101,14 @@
         kPredictionModeDc, kPredictionModeVertical, kPredictionModeHorizontal,
         kPredictionModeD157, kPredictionModeDc};
 
+// Mask used to determine the index for mode_deltas lookup.
+constexpr BitMaskSet kPredictionModeDeltasMask(
+    kPredictionModeNearestMv, kPredictionModeNearMv, kPredictionModeNewMv,
+    kPredictionModeNearestNearestMv, kPredictionModeNearNearMv,
+    kPredictionModeNearestNewMv, kPredictionModeNewNearestMv,
+    kPredictionModeNearNewMv, kPredictionModeNewNearMv,
+    kPredictionModeNewNewMv);
+
 // This is computed as:
 // min(transform_width_log2, 5) + min(transform_height_log2, 5) - 4.
 constexpr uint8_t kEobMultiSizeLookup[kNumTransformSizes] = {
@@ -146,7 +156,10 @@
      {6, 21, 21, 21, 21}, {21, 21, 21, 21, 21}}};
 /* clang-format on */
 
-constexpr uint8_t kCoeffBasePositionContextOffset[3] = {26, 31, 36};
+// Extended the table size from 3 to 16 by repeating the last element to avoid
+// the clips to row or column indices.
+constexpr uint8_t kCoeffBasePositionContextOffset[16] = {
+    26, 31, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36};
 
 constexpr PredictionMode kInterIntraToIntraMode[kNumInterIntraModes] = {
     kPredictionModeDc, kPredictionModeVertical, kPredictionModeHorizontal,
@@ -235,7 +248,7 @@
 
 constexpr int8_t kSgrProjDefaultMultiplier[2] = {-32, 31};
 
-constexpr int8_t kWienerDefaultFilter[3] = {3, -7, 15};
+constexpr int8_t kWienerDefaultFilter[kNumWienerCoefficients] = {3, -7, 15};
 
 // Maps compound prediction modes into single modes. For e.g.
 // kPredictionModeNearestNewMv will map to kPredictionModeNearestMv for index 0
@@ -264,31 +277,8 @@
 // log2(dqDenom) in section 7.12.3 of the spec. We use the log2 value because
 // dqDenom is always a power of two and hence right shift can be used instead of
 // division.
-constexpr BitMaskSet kQuantizationShift2Mask(kTransformSize32x64,
-                                             kTransformSize64x32,
-                                             kTransformSize64x64);
-constexpr BitMaskSet kQuantizationShift1Mask(kTransformSize16x32,
-                                             kTransformSize16x64,
-                                             kTransformSize32x16,
-                                             kTransformSize32x32,
-                                             kTransformSize64x16);
-int GetQuantizationShift(TransformSize tx_size) {
-  if (kQuantizationShift2Mask.Contains(tx_size)) {
-    return 2;
-  }
-  if (kQuantizationShift1Mask.Contains(tx_size)) {
-    return 1;
-  }
-  return 0;
-}
-
-// Input: 1d array index |index|, which indexes into a 2d array of width
-//     1 << |tx_width_log2|.
-// Output: 1d array index which indexes into a 2d array of width
-//     (1 << |tx_width_log2|) + kQuantizedCoefficientBufferPadding.
-int PaddedIndex(int index, int tx_width_log2) {
-  return index + MultiplyBy4(index >> tx_width_log2);
-}
+constexpr uint8_t kQuantizationShift[kNumTransformSizes] = {
+    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 2, 1, 2, 2};
 
 // Returns the minimum of |length| or |max|-|start|. This is used to clamp array
 // indices when accessing arrays whose bound is equal to |max|.
@@ -296,40 +286,151 @@
   return std::min(length, max - start);
 }
 
+template <typename T>
+void SetBlockValues(int rows, int columns, T value, T* dst, ptrdiff_t stride) {
+  // Specialize all columns cases (values in kTransformWidth4x4[]) for better
+  // performance.
+  switch (columns) {
+    case 1:
+      MemSetBlock<T>(rows, 1, value, dst, stride);
+      break;
+    case 2:
+      MemSetBlock<T>(rows, 2, value, dst, stride);
+      break;
+    case 4:
+      MemSetBlock<T>(rows, 4, value, dst, stride);
+      break;
+    case 8:
+      MemSetBlock<T>(rows, 8, value, dst, stride);
+      break;
+    default:
+      assert(columns == 16);
+      MemSetBlock<T>(rows, 16, value, dst, stride);
+      break;
+  }
+}
+
 void SetTransformType(const Tile::Block& block, int x4, int y4, int w4, int h4,
                       TransformType tx_type,
                       TransformType transform_types[32][32]) {
   const int y_offset = y4 - block.row4x4;
   const int x_offset = x4 - block.column4x4;
-  static_assert(sizeof(transform_types[0][0]) == 1, "");
-  for (int i = 0; i < h4; ++i) {
-    memset(&transform_types[y_offset + i][x_offset], tx_type, w4);
-  }
+  TransformType* const dst = &transform_types[y_offset][x_offset];
+  SetBlockValues<TransformType>(h4, w4, tx_type, dst, 32);
+}
+
+void StoreMotionFieldMvs(ReferenceFrameType reference_frame_to_store,
+                         const MotionVector& mv_to_store, ptrdiff_t stride,
+                         int rows, int columns,
+                         ReferenceFrameType* reference_frame_row_start,
+                         MotionVector* mv) {
+  static_assert(sizeof(*reference_frame_row_start) == sizeof(int8_t), "");
+  do {
+    // Don't switch the following two memory setting functions.
+    // Some ARM CPUs are quite sensitive to the order.
+    memset(reference_frame_row_start, reference_frame_to_store, columns);
+    std::fill(mv, mv + columns, mv_to_store);
+    reference_frame_row_start += stride;
+    mv += stride;
+  } while (--rows != 0);
+}
+
+// Inverse transform process assumes that the quantized coefficients are stored
+// as a virtual 2d array of size |tx_width| x tx_height. If transform width is
+// 64, then this assumption is broken because the scan order used for populating
+// the coefficients for such transforms is the same as the one used for
+// corresponding transform with width 32 (e.g. the scan order used for 64x16 is
+// the same as the one used for 32x16). So we must restore the coefficients to
+// their correct positions and clean the positions they occupied.
+template <typename ResidualType>
+void MoveCoefficientsForTxWidth64(int clamped_tx_height, int tx_width,
+                                  ResidualType* residual) {
+  if (tx_width != 64) return;
+  const int rows = clamped_tx_height - 2;
+  auto* src = residual + 32 * rows;
+  residual += 64 * rows;
+  // Process 2 rows in each loop in reverse order to avoid overwrite.
+  int x = rows >> 1;
+  do {
+    // The 2 rows can be processed in order.
+    memcpy(residual, src, 32 * sizeof(src[0]));
+    memcpy(residual + 64, src + 32, 32 * sizeof(src[0]));
+    memset(src + 32, 0, 32 * sizeof(src[0]));
+    src -= 64;
+    residual -= 128;
+  } while (--x);
+  // Process the second row. The first row is already correct.
+  memcpy(residual + 64, src + 32, 32 * sizeof(src[0]));
+  memset(src + 32, 0, 32 * sizeof(src[0]));
+}
+
+void GetClampParameters(const Tile::Block& block, int min[2], int max[2]) {
+  // 7.10.2.14 (part 1). (also contains implementations of 5.11.53
+  // and 5.11.54).
+  constexpr int kMvBorder4x4 = 4;
+  const int row_border = kMvBorder4x4 + block.height4x4;
+  const int column_border = kMvBorder4x4 + block.width4x4;
+  const int macroblocks_to_top_edge = -block.row4x4;
+  const int macroblocks_to_bottom_edge =
+      block.tile.frame_header().rows4x4 - block.height4x4 - block.row4x4;
+  const int macroblocks_to_left_edge = -block.column4x4;
+  const int macroblocks_to_right_edge =
+      block.tile.frame_header().columns4x4 - block.width4x4 - block.column4x4;
+  min[0] = MultiplyBy32(macroblocks_to_top_edge - row_border);
+  min[1] = MultiplyBy32(macroblocks_to_left_edge - column_border);
+  max[0] = MultiplyBy32(macroblocks_to_bottom_edge + row_border);
+  max[1] = MultiplyBy32(macroblocks_to_right_edge + column_border);
+}
+
+// Section 8.3.2 in the spec, under coeff_base_eob.
+int GetCoeffBaseContextEob(TransformSize tx_size, int index) {
+  if (index == 0) return 0;
+  const TransformSize adjusted_tx_size = kAdjustedTransformSize[tx_size];
+  const int tx_width_log2 = kTransformWidthLog2[adjusted_tx_size];
+  const int tx_height = kTransformHeight[adjusted_tx_size];
+  if (index <= DivideBy8(tx_height << tx_width_log2)) return 1;
+  if (index <= DivideBy4(tx_height << tx_width_log2)) return 2;
+  return 3;
+}
+
+// Section 8.3.2 in the spec, under coeff_br. Optimized for end of block based
+// on the fact that {0, 1}, {1, 0}, {1, 1}, {0, 2} and {2, 0} will all be 0 in
+// the end of block case.
+int GetCoeffBaseRangeContextEob(int adjusted_tx_width_log2, int pos,
+                                TransformClass tx_class) {
+  if (pos == 0) return 0;
+  const int tx_width = 1 << adjusted_tx_width_log2;
+  const int row = pos >> adjusted_tx_width_log2;
+  const int column = pos & (tx_width - 1);
+  // This return statement is equivalent to:
+  // return ((tx_class == kTransformClass2D && (row | column) < 2) ||
+  //         (tx_class == kTransformClassHorizontal && column == 0) ||
+  //         (tx_class == kTransformClassVertical && row == 0))
+  //            ? 7
+  //            : 14;
+  return 14 >> ((static_cast<int>(tx_class == kTransformClass2D) &
+                 static_cast<int>((row | column) < 2)) |
+                (tx_class & static_cast<int>(column == 0)) |
+                ((tx_class >> 1) & static_cast<int>(row == 0)));
 }
 
 }  // namespace
 
-Tile::Tile(
-    int tile_number, const uint8_t* const data, size_t size,
-    const ObuSequenceHeader& sequence_header,
-    const ObuFrameHeader& frame_header, RefCountedBuffer* const current_frame,
-    const std::array<bool, kNumReferenceFrameTypes>& reference_frame_sign_bias,
-    const std::array<RefCountedBufferPtr, kNumReferenceFrameTypes>&
-        reference_frames,
-    Array2D<TemporalMotionVector>* const motion_field_mv,
-    const std::array<uint8_t, kNumReferenceFrameTypes>& reference_order_hint,
-    const std::array<uint8_t, kWedgeMaskSize>& wedge_masks,
-    const SymbolDecoderContext& symbol_decoder_context,
-    SymbolDecoderContext* const saved_symbol_decoder_context,
-    const SegmentationMap* prev_segment_ids, PostFilter* const post_filter,
-    BlockParametersHolder* const block_parameters_holder,
-    Array2D<int16_t>* const cdef_index,
-    Array2D<TransformSize>* const inter_transform_sizes,
-    const dsp::Dsp* const dsp, ThreadPool* const thread_pool,
-    ResidualBufferPool* const residual_buffer_pool,
-    DecoderScratchBufferPool* const decoder_scratch_buffer_pool,
-    BlockingCounterWithStatus* const pending_tiles)
+Tile::Tile(int tile_number, const uint8_t* const data, size_t size,
+           const ObuSequenceHeader& sequence_header,
+           const ObuFrameHeader& frame_header,
+           RefCountedBuffer* const current_frame, const DecoderState& state,
+           FrameScratchBuffer* const frame_scratch_buffer,
+           const WedgeMaskArray& wedge_masks,
+           SymbolDecoderContext* const saved_symbol_decoder_context,
+           const SegmentationMap* prev_segment_ids,
+           PostFilter* const post_filter, const dsp::Dsp* const dsp,
+           ThreadPool* const thread_pool,
+           BlockingCounterWithStatus* const pending_tiles, bool frame_parallel,
+           bool use_intra_prediction_buffer)
     : number_(tile_number),
+      row_(number_ / frame_header.tile_info.tile_columns),
+      column_(number_ % frame_header.tile_info.tile_columns),
       data_(data),
       size_(size),
       read_deltas_(false),
@@ -340,19 +441,18 @@
       current_quantizer_index_(frame_header.quantizer.base_index),
       sequence_header_(sequence_header),
       frame_header_(frame_header),
-      current_frame_(*current_frame),
-      reference_frame_sign_bias_(reference_frame_sign_bias),
-      reference_frames_(reference_frames),
-      motion_field_mv_(motion_field_mv),
-      reference_order_hint_(reference_order_hint),
+      reference_frame_sign_bias_(state.reference_frame_sign_bias),
+      reference_frames_(state.reference_frame),
+      motion_field_(frame_scratch_buffer->motion_field),
+      reference_order_hint_(state.reference_order_hint),
       wedge_masks_(wedge_masks),
       reader_(data_, size_, frame_header_.enable_cdf_update),
-      symbol_decoder_context_(symbol_decoder_context),
+      symbol_decoder_context_(frame_scratch_buffer->symbol_decoder_context),
       saved_symbol_decoder_context_(saved_symbol_decoder_context),
       prev_segment_ids_(prev_segment_ids),
       dsp_(*dsp),
       post_filter_(*post_filter),
-      block_parameters_holder_(*block_parameters_holder),
+      block_parameters_holder_(frame_scratch_buffer->block_parameters_holder),
       quantizer_(sequence_header_.color_config.bitdepth,
                  &frame_header_.quantizer),
       residual_size_((sequence_header_.color_config.bitdepth == 8)
@@ -362,15 +462,20 @@
           frame_header_.allow_intrabc
               ? (sequence_header_.use_128x128_superblock ? 3 : 5)
               : 1),
-      cdef_index_(*cdef_index),
-      inter_transform_sizes_(*inter_transform_sizes),
+      current_frame_(*current_frame),
+      cdef_index_(frame_scratch_buffer->cdef_index),
+      inter_transform_sizes_(frame_scratch_buffer->inter_transform_sizes),
       thread_pool_(thread_pool),
-      residual_buffer_pool_(residual_buffer_pool),
-      decoder_scratch_buffer_pool_(decoder_scratch_buffer_pool),
+      residual_buffer_pool_(frame_scratch_buffer->residual_buffer_pool.get()),
+      tile_scratch_buffer_pool_(
+          &frame_scratch_buffer->tile_scratch_buffer_pool),
       pending_tiles_(pending_tiles),
-      build_bit_mask_when_parsing_(false) {
-  row_ = number_ / frame_header.tile_info.tile_columns;
-  column_ = number_ % frame_header.tile_info.tile_columns;
+      frame_parallel_(frame_parallel),
+      use_intra_prediction_buffer_(use_intra_prediction_buffer),
+      intra_prediction_buffer_(
+          use_intra_prediction_buffer_
+              ? &frame_scratch_buffer->intra_prediction_buffers.get()[row_]
+              : nullptr) {
   row4x4_start_ = frame_header.tile_info.tile_row_start[row_];
   row4x4_end_ = frame_header.tile_info.tile_row_start[row_ + 1];
   column4x4_start_ = frame_header.tile_info.tile_column_start[column_];
@@ -382,16 +487,71 @@
   superblock_columns_ =
       (column4x4_end_ - column4x4_start_ + block_width4x4 - 1) >>
       block_width4x4_log2;
-  // Enable multi-threading within a tile only if there are at least as many
-  // superblock columns as |intra_block_copy_lag_|.
-  split_parse_and_decode_ =
-      thread_pool_ != nullptr && superblock_columns_ > intra_block_copy_lag_;
+  // If |split_parse_and_decode_| is true, we do the necessary setup for
+  // splitting the parsing and the decoding steps. This is done in the following
+  // two cases:
+  //  1) If there is multi-threading within a tile (this is done if
+  //     |thread_pool_| is not nullptr and if there are at least as many
+  //     superblock columns as |intra_block_copy_lag_|).
+  //  2) If |frame_parallel| is true.
+  split_parse_and_decode_ = (thread_pool_ != nullptr &&
+                             superblock_columns_ > intra_block_copy_lag_) ||
+                            frame_parallel;
+  if (frame_parallel_) {
+    reference_frame_progress_cache_.fill(INT_MIN);
+  }
   memset(delta_lf_, 0, sizeof(delta_lf_));
   delta_lf_all_zero_ = true;
-  YuvBuffer* const buffer = current_frame->buffer();
+  const YuvBuffer& buffer = post_filter_.frame_buffer();
   for (int plane = 0; plane < PlaneCount(); ++plane) {
-    buffer_[plane].Reset(buffer->height(plane) + buffer->bottom_border(plane),
-                         buffer->stride(plane), buffer->data(plane));
+    // Verify that the borders are big enough for Reconstruct(). max_tx_length
+    // is the maximum value of tx_width and tx_height for the plane.
+    const int max_tx_length = (plane == kPlaneY) ? 64 : 32;
+    // Reconstruct() may overwrite on the right. Since the right border of a
+    // row is followed in memory by the left border of the next row, the
+    // number of extra pixels to the right of a row is at least the sum of the
+    // left and right borders.
+    //
+    // Note: This assertion actually checks the sum of the left and right
+    // borders of post_filter_.GetUnfilteredBuffer(), which is a horizontally
+    // and vertically shifted version of |buffer|. Since the sum of the left and
+    // right borders is not changed by the shift, we can just check the sum of
+    // the left and right borders of |buffer|.
+    assert(buffer.left_border(plane) + buffer.right_border(plane) >=
+           max_tx_length - 1);
+    // Reconstruct() may overwrite on the bottom. We need an extra border row
+    // on the bottom because we need the left border of that row.
+    //
+    // Note: This assertion checks the bottom border of
+    // post_filter_.GetUnfilteredBuffer(). So we need to calculate the vertical
+    // shift that the PostFilter constructor applied to |buffer| and reduce the
+    // bottom border by that amount.
+#ifndef NDEBUG
+    const int vertical_shift = static_cast<int>(
+        (post_filter_.GetUnfilteredBuffer(plane) - buffer.data(plane)) /
+        buffer.stride(plane));
+    const int bottom_border = buffer.bottom_border(plane) - vertical_shift;
+    assert(bottom_border >= max_tx_length);
+#endif
+    // In AV1, a transform block of height H starts at a y coordinate that is
+    // a multiple of H. If a transform block at the bottom of the frame has
+    // height H, then Reconstruct() will write up to the row with index
+    // Align(buffer.height(plane), H) - 1. Therefore the maximum number of
+    // rows Reconstruct() may write to is
+    // Align(buffer.height(plane), max_tx_length).
+    buffer_[plane].Reset(Align(buffer.height(plane), max_tx_length),
+                         buffer.stride(plane),
+                         post_filter_.GetUnfilteredBuffer(plane));
+    const int plane_height =
+        RightShiftWithRounding(frame_header_.height, subsampling_y_[plane]);
+    deblock_row_limit_[plane] =
+        std::min(frame_header_.rows4x4, DivideBy4(plane_height + 3)
+                                            << subsampling_y_[plane]);
+    const int plane_width =
+        RightShiftWithRounding(frame_header_.width, subsampling_x_[plane]);
+    deblock_column_limit_[plane] =
+        std::min(frame_header_.columns4x4, DivideBy4(plane_width + 3)
+                                               << subsampling_x_[plane]);
   }
 }
 
@@ -418,7 +578,10 @@
       return false;
     }
   } else {
-    residual_buffer_ = MakeAlignedUniquePtr<uint8_t>(32, 4096 * residual_size_);
+    // Add 32 * |kResidualPaddingVertical| padding to avoid bottom boundary
+    // checks when parsing quantized coefficients.
+    residual_buffer_ = MakeAlignedUniquePtr<uint8_t>(
+        32, (4096 + 32 * kResidualPaddingVertical) * residual_size_);
     if (residual_buffer_ == nullptr) {
       LIBGAV1_DLOG(ERROR, "Allocation of residual_buffer_ failed.");
       return false;
@@ -429,62 +592,165 @@
       return false;
     }
   }
+  if (frame_header_.use_ref_frame_mvs) {
+    assert(sequence_header_.enable_order_hint);
+    SetupMotionField(frame_header_, current_frame_, reference_frames_,
+                     row4x4_start_, row4x4_end_, column4x4_start_,
+                     column4x4_end_, &motion_field_);
+  }
+  ResetLoopRestorationParams();
   return true;
 }
 
-bool Tile::Decode(bool is_main_thread) {
-  if (!Init()) {
-    pending_tiles_->Decrement(false);
-    return false;
-  }
-  if (frame_header_.use_ref_frame_mvs) {
-    SetupMotionField(sequence_header_, frame_header_, current_frame_,
-                     reference_frames_, motion_field_mv_, row4x4_start_,
-                     row4x4_end_, column4x4_start_, column4x4_end_);
-  }
-  ResetLoopRestorationParams();
-  // If this is the main thread, we build the loop filter bit masks when parsing
-  // so that it happens in the current thread. This ensures that the main thread
-  // does as much work as possible.
-  build_bit_mask_when_parsing_ = is_main_thread;
-  if (split_parse_and_decode_) {
-    if (!ThreadedDecode()) return false;
-  } else {
-    const int block_width4x4 = kNum4x4BlocksWide[SuperBlockSize()];
-    std::unique_ptr<DecoderScratchBuffer> scratch_buffer =
-        decoder_scratch_buffer_pool_->Get();
-    if (scratch_buffer == nullptr) {
-      pending_tiles_->Decrement(false);
-      LIBGAV1_DLOG(ERROR, "Failed to get scratch buffer.");
+template <ProcessingMode processing_mode, bool save_symbol_decoder_context>
+bool Tile::ProcessSuperBlockRow(int row4x4,
+                                TileScratchBuffer* const scratch_buffer) {
+  if (row4x4 < row4x4_start_ || row4x4 >= row4x4_end_) return true;
+  assert(scratch_buffer != nullptr);
+  const int block_width4x4 = kNum4x4BlocksWide[SuperBlockSize()];
+  for (int column4x4 = column4x4_start_; column4x4 < column4x4_end_;
+       column4x4 += block_width4x4) {
+    if (!ProcessSuperBlock(row4x4, column4x4, block_width4x4, scratch_buffer,
+                           processing_mode)) {
+      LIBGAV1_DLOG(ERROR, "Error decoding super block row: %d column: %d",
+                   row4x4, column4x4);
       return false;
     }
-    for (int row4x4 = row4x4_start_; row4x4 < row4x4_end_;
-         row4x4 += block_width4x4) {
-      for (int column4x4 = column4x4_start_; column4x4 < column4x4_end_;
-           column4x4 += block_width4x4) {
-        if (!ProcessSuperBlock(row4x4, column4x4, block_width4x4,
-                               scratch_buffer.get(),
-                               kProcessingModeParseAndDecode)) {
-          pending_tiles_->Decrement(false);
-          LIBGAV1_DLOG(ERROR, "Error decoding super block row: %d column: %d",
-                       row4x4, column4x4);
-          return false;
-        }
-      }
-    }
-    decoder_scratch_buffer_pool_->Release(std::move(scratch_buffer));
   }
+  if (save_symbol_decoder_context && row4x4 + block_width4x4 >= row4x4_end_) {
+    SaveSymbolDecoderContext();
+  }
+  if (processing_mode == kProcessingModeDecodeOnly ||
+      processing_mode == kProcessingModeParseAndDecode) {
+    PopulateIntraPredictionBuffer(row4x4);
+  }
+  return true;
+}
+
+// Used in frame parallel mode. The symbol decoder context need not be saved in
+// this case since it was done when parsing was complete.
+template bool Tile::ProcessSuperBlockRow<kProcessingModeDecodeOnly, false>(
+    int row4x4, TileScratchBuffer* scratch_buffer);
+// Used in non frame parallel mode.
+template bool Tile::ProcessSuperBlockRow<kProcessingModeParseAndDecode, true>(
+    int row4x4, TileScratchBuffer* scratch_buffer);
+
+void Tile::SaveSymbolDecoderContext() {
   if (frame_header_.enable_frame_end_update_cdf &&
       number_ == frame_header_.tile_info.context_update_id) {
     *saved_symbol_decoder_context_ = symbol_decoder_context_;
   }
-  if (!split_parse_and_decode_) {
-    pending_tiles_->Decrement(true);
+}
+
+bool Tile::ParseAndDecode() {
+  // If this is the main thread, we build the loop filter bit masks when parsing
+  // so that it happens in the current thread. This ensures that the main thread
+  // does as much work as possible.
+  if (split_parse_and_decode_) {
+    if (!ThreadedParseAndDecode()) return false;
+    SaveSymbolDecoderContext();
+    return true;
   }
+  std::unique_ptr<TileScratchBuffer> scratch_buffer =
+      tile_scratch_buffer_pool_->Get();
+  if (scratch_buffer == nullptr) {
+    pending_tiles_->Decrement(false);
+    LIBGAV1_DLOG(ERROR, "Failed to get scratch buffer.");
+    return false;
+  }
+  const int block_width4x4 = kNum4x4BlocksWide[SuperBlockSize()];
+  for (int row4x4 = row4x4_start_; row4x4 < row4x4_end_;
+       row4x4 += block_width4x4) {
+    if (!ProcessSuperBlockRow<kProcessingModeParseAndDecode, true>(
+            row4x4, scratch_buffer.get())) {
+      pending_tiles_->Decrement(false);
+      return false;
+    }
+  }
+  tile_scratch_buffer_pool_->Release(std::move(scratch_buffer));
+  pending_tiles_->Decrement(true);
   return true;
 }
 
-bool Tile::ThreadedDecode() {
+bool Tile::Parse() {
+  const int block_width4x4 = kNum4x4BlocksWide[SuperBlockSize()];
+  std::unique_ptr<TileScratchBuffer> scratch_buffer =
+      tile_scratch_buffer_pool_->Get();
+  if (scratch_buffer == nullptr) {
+    LIBGAV1_DLOG(ERROR, "Failed to get scratch buffer.");
+    return false;
+  }
+  for (int row4x4 = row4x4_start_; row4x4 < row4x4_end_;
+       row4x4 += block_width4x4) {
+    if (!ProcessSuperBlockRow<kProcessingModeParseOnly, false>(
+            row4x4, scratch_buffer.get())) {
+      return false;
+    }
+  }
+  tile_scratch_buffer_pool_->Release(std::move(scratch_buffer));
+  SaveSymbolDecoderContext();
+  return true;
+}
+
+bool Tile::Decode(
+    std::mutex* const mutex, int* const superblock_row_progress,
+    std::condition_variable* const superblock_row_progress_condvar) {
+  const int block_width4x4 = sequence_header_.use_128x128_superblock ? 32 : 16;
+  const int block_width4x4_log2 =
+      sequence_header_.use_128x128_superblock ? 5 : 4;
+  std::unique_ptr<TileScratchBuffer> scratch_buffer =
+      tile_scratch_buffer_pool_->Get();
+  if (scratch_buffer == nullptr) {
+    LIBGAV1_DLOG(ERROR, "Failed to get scratch buffer.");
+    return false;
+  }
+  for (int row4x4 = row4x4_start_, index = row4x4_start_ >> block_width4x4_log2;
+       row4x4 < row4x4_end_; row4x4 += block_width4x4, ++index) {
+    if (!ProcessSuperBlockRow<kProcessingModeDecodeOnly, false>(
+            row4x4, scratch_buffer.get())) {
+      return false;
+    }
+    if (post_filter_.DoDeblock()) {
+      // Apply vertical deblock filtering for all the columns in this tile
+      // except for the first 64 columns.
+      post_filter_.ApplyDeblockFilter(
+          kLoopFilterTypeVertical, row4x4,
+          column4x4_start_ + kNum4x4InLoopFilterUnit, column4x4_end_,
+          block_width4x4);
+      // If this is the first superblock row of the tile, then we cannot apply
+      // horizontal deblocking here since we don't know if the top row is
+      // available. So it will be done by the calling thread in that case.
+      if (row4x4 != row4x4_start_) {
+        // Apply horizontal deblock filtering for all the columns in this tile
+        // except for the first and the last 64 columns.
+        // Note about the last tile of each row: For the last tile,
+        // column4x4_end may not be a multiple of 16. In that case it is still
+        // okay to simply subtract 16 since ApplyDeblockFilter() will only do
+        // the filters in increments of 64 columns (or 32 columns for chroma
+        // with subsampling).
+        post_filter_.ApplyDeblockFilter(
+            kLoopFilterTypeHorizontal, row4x4,
+            column4x4_start_ + kNum4x4InLoopFilterUnit,
+            column4x4_end_ - kNum4x4InLoopFilterUnit, block_width4x4);
+      }
+    }
+    bool notify;
+    {
+      std::unique_lock<std::mutex> lock(*mutex);
+      notify = ++superblock_row_progress[index] ==
+               frame_header_.tile_info.tile_columns;
+    }
+    if (notify) {
+      // We are done decoding this superblock row. Notify the post filtering
+      // thread.
+      superblock_row_progress_condvar[index].notify_one();
+    }
+  }
+  tile_scratch_buffer_pool_->Release(std::move(scratch_buffer));
+  return true;
+}
+
+bool Tile::ThreadedParseAndDecode() {
   {
     std::lock_guard<std::mutex> lock(threading_.mutex);
     if (!threading_.sb_state.Reset(superblock_rows_, superblock_columns_)) {
@@ -499,8 +765,8 @@
   const int block_width4x4 = kNum4x4BlocksWide[SuperBlockSize()];
 
   // Begin parsing.
-  std::unique_ptr<DecoderScratchBuffer> scratch_buffer =
-      decoder_scratch_buffer_pool_->Get();
+  std::unique_ptr<TileScratchBuffer> scratch_buffer =
+      tile_scratch_buffer_pool_->Get();
   if (scratch_buffer == nullptr) {
     pending_tiles_->Decrement(false);
     LIBGAV1_DLOG(ERROR, "Failed to get scratch buffer.");
@@ -535,7 +801,7 @@
     std::lock_guard<std::mutex> lock(threading_.mutex);
     if (threading_.abort) break;
   }
-  decoder_scratch_buffer_pool_->Release(std::move(scratch_buffer));
+  tile_scratch_buffer_pool_->Release(std::move(scratch_buffer));
 
   // We are done parsing. We can return here since the calling thread will make
   // sure that it waits for all the superblocks to be decoded.
@@ -593,13 +859,13 @@
                             int block_width4x4) {
   const int row4x4 = row4x4_start_ + (row_index * block_width4x4);
   const int column4x4 = column4x4_start_ + (column_index * block_width4x4);
-  std::unique_ptr<DecoderScratchBuffer> scratch_buffer =
-      decoder_scratch_buffer_pool_->Get();
+  std::unique_ptr<TileScratchBuffer> scratch_buffer =
+      tile_scratch_buffer_pool_->Get();
   bool ok = scratch_buffer != nullptr;
   if (ok) {
     ok = ProcessSuperBlock(row4x4, column4x4, block_width4x4,
                            scratch_buffer.get(), kProcessingModeDecodeOnly);
-    decoder_scratch_buffer_pool_->Release(std::move(scratch_buffer));
+    tile_scratch_buffer_pool_->Release(std::move(scratch_buffer));
   }
   std::unique_lock<std::mutex> lock(threading_.mutex);
   if (ok) {
@@ -647,9 +913,38 @@
   }
 }
 
-bool Tile::IsInside(int row4x4, int column4x4) const {
-  return row4x4 >= row4x4_start_ && row4x4 < row4x4_end_ &&
-         column4x4 >= column4x4_start_ && column4x4 < column4x4_end_;
+void Tile::PopulateIntraPredictionBuffer(int row4x4) {
+  const int block_width4x4 = kNum4x4BlocksWide[SuperBlockSize()];
+  if (!use_intra_prediction_buffer_ || row4x4 + block_width4x4 >= row4x4_end_) {
+    return;
+  }
+  const size_t pixel_size =
+      (sequence_header_.color_config.bitdepth == 8 ? sizeof(uint8_t)
+                                                   : sizeof(uint16_t));
+  for (int plane = 0; plane < PlaneCount(); ++plane) {
+    const int row_to_copy =
+        (MultiplyBy4(row4x4 + block_width4x4) >> subsampling_y_[plane]) - 1;
+    const size_t pixels_to_copy =
+        (MultiplyBy4(column4x4_end_ - column4x4_start_) >>
+         subsampling_x_[plane]) *
+        pixel_size;
+    const size_t column_start =
+        MultiplyBy4(column4x4_start_) >> subsampling_x_[plane];
+    void* start;
+#if LIBGAV1_MAX_BITDEPTH >= 10
+    if (sequence_header_.color_config.bitdepth > 8) {
+      Array2DView<uint16_t> buffer(
+          buffer_[plane].rows(), buffer_[plane].columns() / sizeof(uint16_t),
+          reinterpret_cast<uint16_t*>(&buffer_[plane][0][0]));
+      start = &buffer[row_to_copy][column_start];
+    } else  // NOLINT
+#endif
+    {
+      start = &buffer_[plane][row_to_copy][column_start];
+    }
+    memcpy((*intra_prediction_buffer_)[plane].get() + column_start * pixel_size,
+           start, pixels_to_copy);
+  }
 }
 
 int Tile::GetTransformAllZeroContext(const Block& block, Plane plane,
@@ -660,7 +955,7 @@
 
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
-  const BlockSize plane_size = block.residual_size[GetPlaneType(plane)];
+  const BlockSize plane_size = block.residual_size[plane];
   const int block_width = kBlockWidthPixels[plane_size];
   const int block_height = kBlockHeightPixels[plane_size];
 
@@ -785,150 +1080,167 @@
                    kTransformHeight4x4[tx_size], tx_type, transform_types_);
 }
 
-// Section 8.3.2 in the spec, under coeff_base_eob.
-int Tile::GetCoeffBaseContextEob(TransformSize tx_size, int index) {
-  if (index == 0) return 0;
-  const TransformSize adjusted_tx_size = kAdjustedTransformSize[tx_size];
-  const int tx_width_log2 = kTransformWidthLog2[adjusted_tx_size];
-  const int tx_height = kTransformHeight[adjusted_tx_size];
-  if (index <= DivideBy8(tx_height << tx_width_log2)) return 1;
-  if (index <= DivideBy4(tx_height << tx_width_log2)) return 2;
-  return 3;
-}
-
-// Section 8.3.2 in the spec, under coeff_base.
-int Tile::GetCoeffBaseContext2D(const int32_t* const quantized_buffer,
-                                TransformSize tx_size,
-                                int adjusted_tx_width_log2, uint16_t pos) {
-  if (pos == 0) return 0;
+// Section 8.3.2 in the spec, under coeff_base and coeff_br.
+// Bottom boundary checks are avoided by the padded rows.
+// For a coefficient near the right boundary, the two right neighbors and the
+// one bottom-right neighbor may be out of boundary. We don't check the right
+// boundary for them, because the out of boundary neighbors project to positions
+// above the diagonal line which goes through the current coefficient and these
+// positions are still all 0s according to the diagonal scan order.
+template <typename ResidualType>
+void Tile::ReadCoeffBase2D(
+    const uint16_t* scan, PlaneType plane_type, TransformSize tx_size,
+    int clamped_tx_size_context, int adjusted_tx_width_log2, int eob,
+    uint16_t coeff_base_cdf[kCoeffBaseContexts][kCoeffBaseSymbolCount + 1],
+    ResidualType* const quantized_buffer) {
   const int tx_width = 1 << adjusted_tx_width_log2;
-  const int padded_tx_width = tx_width + kQuantizedCoefficientBufferPadding;
-  const int32_t* const quantized =
-      &quantized_buffer[PaddedIndex(pos, adjusted_tx_width_log2)];
-  const int context = std::min(
-      4, DivideBy2(1 + (std::min(quantized[1], 3) +                    // {0, 1}
-                        std::min(quantized[padded_tx_width], 3) +      // {1, 0}
-                        std::min(quantized[padded_tx_width + 1], 3) +  // {1, 1}
-                        std::min(quantized[2], 3) +                    // {0, 2}
-                        std::min(quantized[MultiplyBy2(padded_tx_width)],
-                                 3))));  // {2, 0}
-  const int row = pos >> adjusted_tx_width_log2;
-  const int column = pos & (tx_width - 1);
-  return context + kCoeffBaseContextOffset[tx_size][std::min(row, 4)]
-                                          [std::min(column, 4)];
+  int i = eob - 2;
+  do {
+    constexpr auto threshold = static_cast<ResidualType>(3);
+    const uint16_t pos = scan[i];
+    const int row = pos >> adjusted_tx_width_log2;
+    const int column = pos & (tx_width - 1);
+    auto* const quantized = &quantized_buffer[pos];
+    int context;
+    if (pos == 0) {
+      context = 0;
+    } else {
+      context = std::min(
+          4, DivideBy2(
+                 1 + (std::min(quantized[1], threshold) +             // {0, 1}
+                      std::min(quantized[tx_width], threshold) +      // {1, 0}
+                      std::min(quantized[tx_width + 1], threshold) +  // {1, 1}
+                      std::min(quantized[2], threshold) +             // {0, 2}
+                      std::min(quantized[MultiplyBy2(tx_width)],
+                               threshold))));  // {2, 0}
+      context += kCoeffBaseContextOffset[tx_size][std::min(row, 4)]
+                                        [std::min(column, 4)];
+    }
+    int level =
+        reader_.ReadSymbol<kCoeffBaseSymbolCount>(coeff_base_cdf[context]);
+    if (level > kNumQuantizerBaseLevels) {
+      // No need to clip quantized values to COEFF_BASE_RANGE + NUM_BASE_LEVELS
+      // + 1, because we clip the overall output to 6 and the unclipped
+      // quantized values will always result in an output of greater than 6.
+      context = std::min(6, DivideBy2(1 + quantized[1] +          // {0, 1}
+                                      quantized[tx_width] +       // {1, 0}
+                                      quantized[tx_width + 1]));  // {1, 1}
+      if (pos != 0) {
+        context += 14 >> static_cast<int>((row | column) < 2);
+      }
+      level += ReadCoeffBaseRange(clamped_tx_size_context, context, plane_type);
+    }
+    quantized[0] = level;
+  } while (--i >= 0);
 }
 
-// Section 8.3.2 in the spec, under coeff_base.
-int Tile::GetCoeffBaseContextHorizontal(const int32_t* const quantized_buffer,
-                                        TransformSize /*tx_size*/,
-                                        int adjusted_tx_width_log2,
-                                        uint16_t pos) {
+// Section 8.3.2 in the spec, under coeff_base and coeff_br.
+// Bottom boundary checks are avoided by the padded rows.
+// For a coefficient near the right boundary, the four right neighbors may be
+// out of boundary. We don't do the boundary check for the first three right
+// neighbors, because even for the transform blocks with smallest width 4, the
+// first three out of boundary neighbors project to positions left of the
+// current coefficient and these positions are still all 0s according to the
+// column scan order. However, when transform block width is 4 and the current
+// coefficient is on the right boundary, its fourth right neighbor projects to
+// the under position on the same column, which could be nonzero. Therefore, we
+// must skip the fourth right neighbor. To make it simple, for any coefficient,
+// we always do the boundary check for its fourth right neighbor.
+template <typename ResidualType>
+void Tile::ReadCoeffBaseHorizontal(
+    const uint16_t* scan, PlaneType plane_type, TransformSize /*tx_size*/,
+    int clamped_tx_size_context, int adjusted_tx_width_log2, int eob,
+    uint16_t coeff_base_cdf[kCoeffBaseContexts][kCoeffBaseSymbolCount + 1],
+    ResidualType* const quantized_buffer) {
   const int tx_width = 1 << adjusted_tx_width_log2;
-  const int padded_tx_width = tx_width + kQuantizedCoefficientBufferPadding;
-  const int32_t* const quantized =
-      &quantized_buffer[PaddedIndex(pos, adjusted_tx_width_log2)];
-  const int context = std::min(
-      4, DivideBy2(1 + (std::min(quantized[1], 3) +                // {0, 1}
-                        std::min(quantized[padded_tx_width], 3) +  // {1, 0}
-                        std::min(quantized[2], 3) +                // {0, 2}
-                        std::min(quantized[3], 3) +                // {0, 3}
-                        std::min(quantized[4], 3))));              // {0, 4}
-  const int index = pos & (tx_width - 1);
-  return context + kCoeffBasePositionContextOffset[std::min(index, 2)];
+  int i = eob - 2;
+  do {
+    constexpr auto threshold = static_cast<ResidualType>(3);
+    const uint16_t pos = scan[i];
+    const int column = pos & (tx_width - 1);
+    auto* const quantized = &quantized_buffer[pos];
+    int context = std::min(
+        4,
+        DivideBy2(1 +
+                  (std::min(quantized[1], threshold) +         // {0, 1}
+                   std::min(quantized[tx_width], threshold) +  // {1, 0}
+                   std::min(quantized[2], threshold) +         // {0, 2}
+                   std::min(quantized[3], threshold) +         // {0, 3}
+                   std::min(quantized[4],
+                            static_cast<ResidualType>(
+                                (column + 4 < tx_width) ? 3 : 0)))));  // {0, 4}
+    context += kCoeffBasePositionContextOffset[column];
+    int level =
+        reader_.ReadSymbol<kCoeffBaseSymbolCount>(coeff_base_cdf[context]);
+    if (level > kNumQuantizerBaseLevels) {
+      // No need to clip quantized values to COEFF_BASE_RANGE + NUM_BASE_LEVELS
+      // + 1, because we clip the overall output to 6 and the unclipped
+      // quantized values will always result in an output of greater than 6.
+      context = std::min(6, DivideBy2(1 + quantized[1] +     // {0, 1}
+                                      quantized[tx_width] +  // {1, 0}
+                                      quantized[2]));        // {0, 2}
+      if (pos != 0) {
+        context += 14 >> static_cast<int>(column == 0);
+      }
+      level += ReadCoeffBaseRange(clamped_tx_size_context, context, plane_type);
+    }
+    quantized[0] = level;
+  } while (--i >= 0);
 }
 
-// Section 8.3.2 in the spec, under coeff_base.
-int Tile::GetCoeffBaseContextVertical(const int32_t* const quantized_buffer,
-                                      TransformSize /*tx_size*/,
-                                      int adjusted_tx_width_log2,
-                                      uint16_t pos) {
+// Section 8.3.2 in the spec, under coeff_base and coeff_br.
+// Bottom boundary checks are avoided by the padded rows.
+// Right boundary check is performed explicitly.
+template <typename ResidualType>
+void Tile::ReadCoeffBaseVertical(
+    const uint16_t* scan, PlaneType plane_type, TransformSize /*tx_size*/,
+    int clamped_tx_size_context, int adjusted_tx_width_log2, int eob,
+    uint16_t coeff_base_cdf[kCoeffBaseContexts][kCoeffBaseSymbolCount + 1],
+    ResidualType* const quantized_buffer) {
   const int tx_width = 1 << adjusted_tx_width_log2;
-  const int padded_tx_width = tx_width + kQuantizedCoefficientBufferPadding;
-  const int32_t* const quantized =
-      &quantized_buffer[PaddedIndex(pos, adjusted_tx_width_log2)];
-  const int context = std::min(
-      4, DivideBy2(1 + (std::min(quantized[1], 3) +                // {0, 1}
-                        std::min(quantized[padded_tx_width], 3) +  // {1, 0}
-                        std::min(quantized[MultiplyBy2(padded_tx_width)],
-                                 3) +                                  // {2, 0}
-                        std::min(quantized[padded_tx_width * 3], 3) +  // {3, 0}
-                        std::min(quantized[MultiplyBy4(padded_tx_width)],
-                                 3))));  // {4, 0}
-
-  const int index = pos >> adjusted_tx_width_log2;
-  return context + kCoeffBasePositionContextOffset[std::min(index, 2)];
-}
-
-// Section 8.3.2 in the spec, under coeff_br.
-int Tile::GetCoeffBaseRangeContext2D(const int32_t* const quantized_buffer,
-                                     int adjusted_tx_width_log2, int pos) {
-  const uint8_t tx_width = 1 << adjusted_tx_width_log2;
-  const int padded_tx_width = tx_width + kQuantizedCoefficientBufferPadding;
-  const int32_t* const quantized =
-      &quantized_buffer[PaddedIndex(pos, adjusted_tx_width_log2)];
-  const int context = std::min(
-      6, DivideBy2(
-             1 +
-             std::min(quantized[1],
-                      kQuantizerCoefficientBaseRangeContextClamp) +  // {0, 1}
-             std::min(quantized[padded_tx_width],
-                      kQuantizerCoefficientBaseRangeContextClamp) +  // {1, 0}
-             std::min(quantized[padded_tx_width + 1],
-                      kQuantizerCoefficientBaseRangeContextClamp)));  // {1, 1}
-  if (pos == 0) return context;
-  const int row = pos >> adjusted_tx_width_log2;
-  const int column = pos & (tx_width - 1);
-  return context + (((row | column) < 2) ? 7 : 14);
-}
-
-// Section 8.3.2 in the spec, under coeff_br.
-int Tile::GetCoeffBaseRangeContextHorizontal(
-    const int32_t* const quantized_buffer, int adjusted_tx_width_log2,
-    int pos) {
-  const uint8_t tx_width = 1 << adjusted_tx_width_log2;
-  const int padded_tx_width = tx_width + kQuantizedCoefficientBufferPadding;
-  const int32_t* const quantized =
-      &quantized_buffer[PaddedIndex(pos, adjusted_tx_width_log2)];
-  const int context = std::min(
-      6, DivideBy2(
-             1 +
-             std::min(quantized[1],
-                      kQuantizerCoefficientBaseRangeContextClamp) +  // {0, 1}
-             std::min(quantized[padded_tx_width],
-                      kQuantizerCoefficientBaseRangeContextClamp) +  // {1, 0}
-             std::min(quantized[2],
-                      kQuantizerCoefficientBaseRangeContextClamp)));  // {0, 2}
-  if (pos == 0) return context;
-  const int column = pos & (tx_width - 1);
-  return context + ((column == 0) ? 7 : 14);
-}
-
-// Section 8.3.2 in the spec, under coeff_br.
-int Tile::GetCoeffBaseRangeContextVertical(
-    const int32_t* const quantized_buffer, int adjusted_tx_width_log2,
-    int pos) {
-  const uint8_t tx_width = 1 << adjusted_tx_width_log2;
-  const int padded_tx_width = tx_width + kQuantizedCoefficientBufferPadding;
-  const int32_t* const quantized =
-      &quantized_buffer[PaddedIndex(pos, adjusted_tx_width_log2)];
-  const int context = std::min(
-      6, DivideBy2(
-             1 +
-             std::min(quantized[1],
-                      kQuantizerCoefficientBaseRangeContextClamp) +  // {0, 1}
-             std::min(quantized[padded_tx_width],
-                      kQuantizerCoefficientBaseRangeContextClamp) +  // {1, 0}
-             std::min(quantized[MultiplyBy2(padded_tx_width)],
-                      kQuantizerCoefficientBaseRangeContextClamp)));  // {2, 0}
-  if (pos == 0) return context;
-  const int row = pos >> adjusted_tx_width_log2;
-  return context + ((row == 0) ? 7 : 14);
+  int i = eob - 2;
+  do {
+    constexpr auto threshold = static_cast<ResidualType>(3);
+    const uint16_t pos = scan[i];
+    const int row = pos >> adjusted_tx_width_log2;
+    const int column = pos & (tx_width - 1);
+    auto* const quantized = &quantized_buffer[pos];
+    const int quantized_column1 = (column + 1 < tx_width) ? quantized[1] : 0;
+    int context =
+        std::min(4, DivideBy2(1 + (std::min(quantized_column1, 3) +  // {0, 1}
+                                   std::min(quantized[tx_width],
+                                            threshold) +  // {1, 0}
+                                   std::min(quantized[MultiplyBy2(tx_width)],
+                                            threshold) +  // {2, 0}
+                                   std::min(quantized[tx_width * 3],
+                                            threshold) +  // {3, 0}
+                                   std::min(quantized[MultiplyBy4(tx_width)],
+                                            threshold))));  // {4, 0}
+    context += kCoeffBasePositionContextOffset[row];
+    int level =
+        reader_.ReadSymbol<kCoeffBaseSymbolCount>(coeff_base_cdf[context]);
+    if (level > kNumQuantizerBaseLevels) {
+      // No need to clip quantized values to COEFF_BASE_RANGE + NUM_BASE_LEVELS
+      // + 1, because we clip the overall output to 6 and the unclipped
+      // quantized values will always result in an output of greater than 6.
+      int context =
+          std::min(6, DivideBy2(1 + quantized_column1 +              // {0, 1}
+                                quantized[tx_width] +                // {1, 0}
+                                quantized[MultiplyBy2(tx_width)]));  // {2, 0}
+      if (pos != 0) {
+        context += 14 >> static_cast<int>(row == 0);
+      }
+      level += ReadCoeffBaseRange(clamped_tx_size_context, context, plane_type);
+    }
+    quantized[0] = level;
+  } while (--i >= 0);
 }
 
 int Tile::GetDcSignContext(int x4, int y4, int w4, int h4, Plane plane) {
   const int max_x4x4 = frame_header_.columns4x4 >> subsampling_x_[plane];
   const int8_t* dc_categories = &dc_categories_[kEntropyContextTop][plane][x4];
-  int dc_sign = std::accumulate(
+  // Set dc_sign to 8-bit long so that std::accumulate() saves sign extension.
+  int8_t dc_sign = std::accumulate(
       dc_categories, dc_categories + GetNumElements(w4, x4, max_x4x4), 0);
   const int max_y4x4 = frame_header_.rows4x4 >> subsampling_y_[plane];
   dc_categories = &dc_categories_[kEntropyContextLeft][plane][y4];
@@ -938,6 +1250,8 @@
   //   if (dc_sign < 0) return 1;
   //   if (dc_sign > 0) return 2;
   //   return 0;
+  // And it is better than:
+  //   return static_cast<int>(dc_sign != 0) + static_cast<int>(dc_sign > 0);
   return static_cast<int>(dc_sign < 0) +
          MultiplyBy2(static_cast<int>(dc_sign > 0));
 }
@@ -1020,23 +1334,21 @@
   }
 }
 
-template <bool is_dc_coefficient>
+template <typename ResidualType, bool is_dc_coefficient>
 bool Tile::ReadSignAndApplyDequantization(
-    const Block& block, int32_t* const quantized_buffer,
-    const uint16_t* const scan, int i, int adjusted_tx_width_log2, int tx_width,
-    int q_value, const uint8_t* const quantizer_matrix, int shift,
-    int min_value, int max_value, uint16_t* const dc_sign_cdf,
-    int8_t* const dc_category, int* const coefficient_level) {
-  int pos = is_dc_coefficient ? 0 : scan[i];
-  const int pos_index =
-      is_dc_coefficient ? 0 : PaddedIndex(pos, adjusted_tx_width_log2);
-  // If quantized_buffer[pos_index] is zero, then the rest of the function has
-  // no effect.
-  if (quantized_buffer[pos_index] == 0) return true;
-  const bool sign = is_dc_coefficient ? reader_.ReadSymbol(dc_sign_cdf)
-                                      : static_cast<bool>(reader_.ReadBit());
-  if (quantized_buffer[pos_index] >
-      kNumQuantizerBaseLevels + kQuantizerCoefficientBaseRange) {
+    const uint16_t* const scan, int i, int q_value,
+    const uint8_t* const quantizer_matrix, int shift, int max_value,
+    uint16_t* const dc_sign_cdf, int8_t* const dc_category,
+    int* const coefficient_level, ResidualType* residual_buffer) {
+  const int pos = is_dc_coefficient ? 0 : scan[i];
+  // If residual_buffer[pos] is zero, then the rest of the function has no
+  // effect.
+  int level = residual_buffer[pos];
+  if (level == 0) return true;
+  const int sign = is_dc_coefficient
+                       ? static_cast<int>(reader_.ReadSymbol(dc_sign_cdf))
+                       : reader_.ReadBit();
+  if (level > kNumQuantizerBaseLevels + kQuantizerCoefficientBaseRange) {
     int length = 0;
     bool golomb_length_bit = false;
     do {
@@ -1051,13 +1363,13 @@
     for (int i = length - 2; i >= 0; --i) {
       x = (x << 1) | reader_.ReadBit();
     }
-    quantized_buffer[pos_index] += x - 1;
+    level += x - 1;
   }
-  if (is_dc_coefficient && quantized_buffer[0] > 0) {
-    *dc_category = sign ? -1 : 1;
+  if (is_dc_coefficient) {
+    *dc_category = (sign != 0) ? -1 : 1;
   }
-  quantized_buffer[pos_index] &= 0xfffff;
-  *coefficient_level += quantized_buffer[pos_index];
+  level &= 0xfffff;
+  *coefficient_level += level;
   // Apply dequantization. Step 1 of section 7.12.3 in the spec.
   int q = q_value;
   if (quantizer_matrix != nullptr) {
@@ -1065,34 +1377,21 @@
   }
   // The intermediate multiplication can exceed 32 bits, so it has to be
   // performed by promoting one of the values to int64_t.
-  int32_t dequantized_value =
-      (static_cast<int64_t>(q) * quantized_buffer[pos_index]) & 0xffffff;
+  int32_t dequantized_value = (static_cast<int64_t>(q) * level) & 0xffffff;
   dequantized_value >>= shift;
-  if (sign) {
-    dequantized_value = -dequantized_value;
-  }
-  // Inverse transform process assumes that the quantized coefficients are
-  // stored as a virtual 2d array of size |tx_width| x |tx_height|. If
-  // transform width is 64, then this assumption is broken because the scan
-  // order used for populating the coefficients for such transforms is the
-  // same as the one used for corresponding transform with width 32 (e.g. the
-  // scan order used for 64x16 is the same as the one used for 32x16). So we
-  // have to recompute the value of pos so that it reflects the index of the
-  // 2d array of size 64 x |tx_height|.
-  if (!is_dc_coefficient && tx_width == 64) {
-    const int row_index = DivideBy32(pos);
-    const int column_index = Mod32(pos);
-    pos = MultiplyBy64(row_index) + column_index;
-  }
-  if (sequence_header_.color_config.bitdepth == 8) {
-    auto* const residual_buffer = reinterpret_cast<int16_t*>(*block.residual);
-    residual_buffer[pos] = Clip3(dequantized_value, min_value, max_value);
-#if LIBGAV1_MAX_BITDEPTH >= 10
-  } else {
-    auto* const residual_buffer = reinterpret_cast<int32_t*>(*block.residual);
-    residual_buffer[pos] = Clip3(dequantized_value, min_value, max_value);
-#endif
-  }
+  // At this point:
+  //   * |dequantized_value| is always non-negative.
+  //   * |sign| can be either 0 or 1.
+  //   * min_value = -(max_value + 1).
+  // We need to apply the following:
+  // dequantized_value = sign ? -dequantized_value : dequantized_value;
+  // dequantized_value = Clip3(dequantized_value, min_value, max_value);
+  //
+  // Note that -x == ~(x - 1).
+  //
+  // Now, The above two lines can be done with a std::min and xor as follows:
+  dequantized_value = std::min(dequantized_value - sign, max_value) ^ -sign;
+  residual_buffer[pos] = dequantized_value;
   return true;
 }
 
@@ -1109,10 +1408,11 @@
   return level;
 }
 
-int16_t Tile::ReadTransformCoefficients(const Block& block, Plane plane,
-                                        int start_x, int start_y,
-                                        TransformSize tx_size,
-                                        TransformType* const tx_type) {
+template <typename ResidualType>
+int Tile::ReadTransformCoefficients(const Block& block, Plane plane,
+                                    int start_x, int start_y,
+                                    TransformSize tx_size,
+                                    TransformType* const tx_type) {
   const int x4 = DivideBy4(start_x);
   const int y4 = DivideBy4(start_y);
   const int w4 = kTransformWidth4x4[tx_size];
@@ -1134,19 +1434,15 @@
   }
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
-  memset(*block.residual, 0, tx_width * tx_height * residual_size_);
-  const int clamped_tx_width = std::min(tx_width, 32);
+  const TransformSize adjusted_tx_size = kAdjustedTransformSize[tx_size];
+  const int adjusted_tx_width_log2 = kTransformWidthLog2[adjusted_tx_size];
+  const int tx_padding =
+      (1 << adjusted_tx_width_log2) * kResidualPaddingVertical;
+  auto* residual = reinterpret_cast<ResidualType*>(*block.residual);
+  // Clear padding to avoid bottom boundary checks when parsing quantized
+  // coefficients.
+  memset(residual, 0, (tx_width * tx_height + tx_padding) * residual_size_);
   const int clamped_tx_height = std::min(tx_height, 32);
-  const int padded_tx_width =
-      clamped_tx_width + kQuantizedCoefficientBufferPadding;
-  const int padded_tx_height =
-      clamped_tx_height + kQuantizedCoefficientBufferPadding;
-  int32_t* const quantized = block.scratch_buffer->quantized_buffer;
-  // Only the first |padded_tx_width| * |padded_tx_height| values of |quantized|
-  // will be used by this function and the functions to which it is passed into.
-  // So we simply need to zero out those values before it is being used.
-  memset(quantized, 0,
-         padded_tx_width * padded_tx_height * sizeof(quantized[0]));
   if (plane == kPlaneY) {
     ReadTransformType(block, x4, y4, tx_size);
   }
@@ -1181,9 +1477,9 @@
       cdf = symbol_decoder_context_.eob_pt_1024_cdf[plane_type];
       break;
   }
-  const int16_t eob_pt =
+  const int eob_pt =
       1 + reader_.ReadSymbol(cdf, kEobPt16SymbolCount + eob_multi_size);
-  int16_t eob = (eob_pt < 2) ? eob_pt : ((1 << (eob_pt - 2)) + 1);
+  int eob = (eob_pt < 2) ? eob_pt : ((1 << (eob_pt - 2)) + 1);
   if (eob_pt >= 3) {
     context = eob_pt - 3;
     const bool eob_extra = reader_.ReadSymbol(
@@ -1199,23 +1495,6 @@
     }
   }
   const uint16_t* scan = kScan[tx_class][tx_size];
-  const TransformSize adjusted_tx_size = kAdjustedTransformSize[tx_size];
-  const int adjusted_tx_width_log2 = kTransformWidthLog2[adjusted_tx_size];
-  // Lookup used to call the right variant of GetCoeffBaseContext*() based on
-  // the transform class.
-  static constexpr int (Tile::*kGetCoeffBaseContextFunc[])(
-      const int32_t*, TransformSize, int, uint16_t) = {
-      &Tile::GetCoeffBaseContext2D, &Tile::GetCoeffBaseContextHorizontal,
-      &Tile::GetCoeffBaseContextVertical};
-  auto get_coeff_base_context_func = kGetCoeffBaseContextFunc[tx_class];
-  // Lookup used to call the right variant of GetCoeffBaseRangeContext*() based
-  // on the transform class.
-  static constexpr int (Tile::*kGetCoeffBaseRangeContextFunc[])(
-      const int32_t*, int, int) = {&Tile::GetCoeffBaseRangeContext2D,
-                                   &Tile::GetCoeffBaseRangeContextHorizontal,
-                                   &Tile::GetCoeffBaseRangeContextVertical};
-  auto get_coeff_base_range_context_func =
-      kGetCoeffBaseRangeContextFunc[tx_class];
   const int clamped_tx_size_context = std::min(tx_size_context, 3);
   // Read the last coefficient.
   {
@@ -1227,36 +1506,37 @@
                     .coeff_base_eob_cdf[tx_size_context][plane_type][context],
                 kCoeffBaseEobSymbolCount);
     if (level > kNumQuantizerBaseLevels) {
-      level += ReadCoeffBaseRange(clamped_tx_size_context,
-                                  (this->*get_coeff_base_range_context_func)(
-                                      quantized, adjusted_tx_width_log2, pos),
-                                  plane_type);
+      level += ReadCoeffBaseRange(
+          clamped_tx_size_context,
+          GetCoeffBaseRangeContextEob(adjusted_tx_width_log2, pos, tx_class),
+          plane_type);
     }
-    quantized[PaddedIndex(pos, adjusted_tx_width_log2)] = level;
+    residual[pos] = level;
   }
-  // Read all the other coefficients.
-  for (int i = eob - 2; i >= 0; --i) {
-    const uint16_t pos = scan[i];
-    context = (this->*get_coeff_base_context_func)(quantized, tx_size,
-                                                   adjusted_tx_width_log2, pos);
-    int level = reader_.ReadSymbol<kCoeffBaseSymbolCount>(
-        symbol_decoder_context_
-            .coeff_base_cdf[tx_size_context][plane_type][context]);
-    if (level > kNumQuantizerBaseLevels) {
-      level += ReadCoeffBaseRange(clamped_tx_size_context,
-                                  (this->*get_coeff_base_range_context_func)(
-                                      quantized, adjusted_tx_width_log2, pos),
-                                  plane_type);
-    }
-    quantized[PaddedIndex(pos, adjusted_tx_width_log2)] = level;
+  if (eob > 1) {
+    // Read all the other coefficients.
+    // Lookup used to call the right variant of ReadCoeffBase*() based on the
+    // transform class.
+    static constexpr void (Tile::*kGetCoeffBaseFunc[])(
+        const uint16_t* scan, PlaneType plane_type, TransformSize tx_size,
+        int clamped_tx_size_context, int adjusted_tx_width_log2, int eob,
+        uint16_t coeff_base_cdf[kCoeffBaseContexts][kCoeffBaseSymbolCount + 1],
+        ResidualType* quantized_buffer) = {
+        &Tile::ReadCoeffBase2D<ResidualType>,
+        &Tile::ReadCoeffBaseHorizontal<ResidualType>,
+        &Tile::ReadCoeffBaseVertical<ResidualType>};
+    (this->*kGetCoeffBaseFunc[tx_class])(
+        scan, plane_type, tx_size, clamped_tx_size_context,
+        adjusted_tx_width_log2, eob,
+        symbol_decoder_context_.coeff_base_cdf[tx_size_context][plane_type],
+        residual);
   }
-  const int min_value = -(1 << (7 + sequence_header_.color_config.bitdepth));
   const int max_value = (1 << (7 + sequence_header_.color_config.bitdepth)) - 1;
   const int current_quantizer_index = GetQIndex(
       frame_header_.segmentation, bp.segment_id, current_quantizer_index_);
   const int dc_q_value = quantizer_.GetDcValue(plane, current_quantizer_index);
   const int ac_q_value = quantizer_.GetAcValue(plane, current_quantizer_index);
-  const int shift = GetQuantizationShift(tx_size);
+  const int shift = kQuantizationShift[tx_size];
   const uint8_t* const quantizer_matrix =
       (frame_header_.quantizer.use_matrix &&
        *tx_type < kTransformTypeIdentityIdentity &&
@@ -1268,24 +1548,27 @@
   int coefficient_level = 0;
   int8_t dc_category = 0;
   uint16_t* const dc_sign_cdf =
-      (quantized[0] != 0)
+      (residual[0] != 0)
           ? symbol_decoder_context_.dc_sign_cdf[plane_type][GetDcSignContext(
                 x4, y4, w4, h4, plane)]
           : nullptr;
   assert(scan[0] == 0);
-  if (!ReadSignAndApplyDequantization</*is_dc_coefficient=*/true>(
-          block, quantized, scan, 0, adjusted_tx_width_log2, tx_width,
-          dc_q_value, quantizer_matrix, shift, min_value, max_value,
-          dc_sign_cdf, &dc_category, &coefficient_level)) {
+  if (!ReadSignAndApplyDequantization<ResidualType, /*is_dc_coefficient=*/true>(
+          scan, 0, dc_q_value, quantizer_matrix, shift, max_value, dc_sign_cdf,
+          &dc_category, &coefficient_level, residual)) {
     return -1;
   }
-  for (int i = 1; i < eob; ++i) {
-    if (!ReadSignAndApplyDequantization</*is_dc_coefficient=*/false>(
-            block, quantized, scan, i, adjusted_tx_width_log2, tx_width,
-            ac_q_value, quantizer_matrix, shift, min_value, max_value, nullptr,
-            nullptr, &coefficient_level)) {
-      return -1;
-    }
+  if (eob > 1) {
+    int i = 1;
+    do {
+      if (!ReadSignAndApplyDequantization<ResidualType,
+                                          /*is_dc_coefficient=*/false>(
+              scan, i, ac_q_value, quantizer_matrix, shift, max_value, nullptr,
+              nullptr, &coefficient_level, residual)) {
+        return -1;
+      }
+    } while (++i < eob);
+    MoveCoefficientsForTxWidth64(clamped_tx_height, tx_width, residual);
   }
   SetEntropyContexts(x4, y4, w4, h4, plane, std::min(4, coefficient_level),
                      dc_category);
@@ -1295,6 +1578,25 @@
   return eob;
 }
 
+// CALL_BITDEPTH_FUNCTION is a macro that calls the appropriate template
+// |function| depending on the value of |sequence_header_.color_config.bitdepth|
+// with the variadic arguments.
+#if LIBGAV1_MAX_BITDEPTH >= 10
+#define CALL_BITDEPTH_FUNCTION(function, ...)         \
+  do {                                                \
+    if (sequence_header_.color_config.bitdepth > 8) { \
+      function<uint16_t>(__VA_ARGS__);                \
+    } else {                                          \
+      function<uint8_t>(__VA_ARGS__);                 \
+    }                                                 \
+  } while (false)
+#else
+#define CALL_BITDEPTH_FUNCTION(function, ...) \
+  do {                                        \
+    function<uint8_t>(__VA_ARGS__);           \
+  } while (false)
+#endif
+
 bool Tile::TransformBlock(const Block& block, Plane plane, int base_x,
                           int base_y, TransformSize tx_size, int x, int y,
                           ProcessingMode mode) {
@@ -1317,15 +1619,8 @@
                          mode == kProcessingModeParseAndDecode;
   if (do_decode && !bp.is_inter) {
     if (bp.palette_mode_info.size[GetPlaneType(plane)] > 0) {
-      if (sequence_header_.color_config.bitdepth == 8) {
-        PalettePrediction<uint8_t>(block, plane, start_x, start_y, x, y,
-                                   tx_size);
-#if LIBGAV1_MAX_BITDEPTH >= 10
-      } else {
-        PalettePrediction<uint16_t>(block, plane, start_x, start_y, x, y,
-                                    tx_size);
-#endif
-      }
+      CALL_BITDEPTH_FUNCTION(PalettePrediction, block, plane, start_x, start_y,
+                             x, y, tx_size);
     } else {
       const PredictionMode mode =
           (plane == kPlaneY)
@@ -1337,37 +1632,17 @@
           (sub_block_column4x4 >> subsampling_x) + step_x + 1;
       const int bl_row4x4 = (sub_block_row4x4 >> subsampling_y) + step_y + 1;
       const int bl_column4x4 = (sub_block_column4x4 >> subsampling_x);
-      const bool has_left =
-          x > 0 || (plane == kPlaneY ? block.left_available
-                                     : block.LeftAvailableChroma());
-      const bool has_top =
-          y > 0 ||
-          (plane == kPlaneY ? block.top_available : block.TopAvailableChroma());
-      if (sequence_header_.color_config.bitdepth == 8) {
-        IntraPrediction<uint8_t>(
-            block, plane, start_x, start_y, has_left, has_top,
-            block.scratch_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
-            block.scratch_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
-            mode, tx_size);
-#if LIBGAV1_MAX_BITDEPTH >= 10
-      } else {
-        IntraPrediction<uint16_t>(
-            block, plane, start_x, start_y, has_left, has_top,
-            block.scratch_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
-            block.scratch_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
-            mode, tx_size);
-#endif
-      }
+      const bool has_left = x > 0 || block.left_available[plane];
+      const bool has_top = y > 0 || block.top_available[plane];
+
+      CALL_BITDEPTH_FUNCTION(
+          IntraPrediction, block, plane, start_x, start_y, has_left, has_top,
+          block.scratch_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
+          block.scratch_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
+          mode, tx_size);
       if (plane != kPlaneY && bp.uv_mode == kPredictionModeChromaFromLuma) {
-        if (sequence_header_.color_config.bitdepth == 8) {
-          ChromaFromLumaPrediction<uint8_t>(block, plane, start_x, start_y,
-                                            tx_size);
-#if LIBGAV1_MAX_BITDEPTH >= 10
-        } else {
-          ChromaFromLumaPrediction<uint16_t>(block, plane, start_x, start_y,
-                                             tx_size);
-#endif
-        }
+        CALL_BITDEPTH_FUNCTION(ChromaFromLumaPrediction, block, plane, start_x,
+                               start_y, tx_size);
       }
     }
     if (plane == kPlaneY) {
@@ -1381,34 +1656,35 @@
   if (!bp.skip) {
     const int sb_row_index = SuperBlockRowIndex(block.row4x4);
     const int sb_column_index = SuperBlockColumnIndex(block.column4x4);
-    switch (mode) {
-      case kProcessingModeParseAndDecode: {
-        TransformType tx_type;
-        const int16_t non_zero_coeff_count = ReadTransformCoefficients(
+    if (mode == kProcessingModeDecodeOnly) {
+      TransformParameterQueue& tx_params =
+          *residual_buffer_threaded_[sb_row_index][sb_column_index]
+               ->transform_parameters();
+      ReconstructBlock(block, plane, start_x, start_y, tx_size,
+                       tx_params.Type(), tx_params.NonZeroCoeffCount());
+      tx_params.Pop();
+    } else {
+      TransformType tx_type;
+      int non_zero_coeff_count;
+#if LIBGAV1_MAX_BITDEPTH >= 10
+      if (sequence_header_.color_config.bitdepth > 8) {
+        non_zero_coeff_count = ReadTransformCoefficients<int32_t>(
             block, plane, start_x, start_y, tx_size, &tx_type);
-        if (non_zero_coeff_count < 0) return false;
+      } else  // NOLINT
+#endif
+      {
+        non_zero_coeff_count = ReadTransformCoefficients<int16_t>(
+            block, plane, start_x, start_y, tx_size, &tx_type);
+      }
+      if (non_zero_coeff_count < 0) return false;
+      if (mode == kProcessingModeParseAndDecode) {
         ReconstructBlock(block, plane, start_x, start_y, tx_size, tx_type,
                          non_zero_coeff_count);
-        break;
-      }
-      case kProcessingModeParseOnly: {
-        TransformType tx_type;
-        const int16_t non_zero_coeff_count = ReadTransformCoefficients(
-            block, plane, start_x, start_y, tx_size, &tx_type);
-        if (non_zero_coeff_count < 0) return false;
+      } else {
+        assert(mode == kProcessingModeParseOnly);
         residual_buffer_threaded_[sb_row_index][sb_column_index]
             ->transform_parameters()
             ->Push(non_zero_coeff_count, tx_type);
-        break;
-      }
-      case kProcessingModeDecodeOnly: {
-        TransformParameterQueue& tx_params =
-            *residual_buffer_threaded_[sb_row_index][sb_column_index]
-                 ->transform_parameters();
-        ReconstructBlock(block, plane, start_x, start_y, tx_size,
-                         tx_params.Type(), tx_params.NonZeroCoeffCount());
-        tx_params.Pop();
-        break;
       }
     }
   }
@@ -1417,11 +1693,8 @@
         &block.scratch_buffer
              ->block_decoded[plane][(sub_block_row4x4 >> subsampling_y) + 1]
                             [(sub_block_column4x4 >> subsampling_x) + 1];
-    for (int i = 0; i < step_y; ++i) {
-      static_assert(sizeof(bool) == 1, "");
-      memset(block_decoded, 1, step_x);
-      block_decoded += DecoderScratchBuffer::kBlockDecodedStride;
-    }
+    SetBlockValues<bool>(step_y, step_x, true, block_decoded,
+                         TileScratchBuffer::kBlockDecodedStride);
   }
   return true;
 }
@@ -1437,7 +1710,7 @@
   stack.Push(TransformTreeNode(start_x, start_y,
                                static_cast<TransformSize>(plane_size)));
 
-  while (!stack.Empty()) {
+  do {
     TransformTreeNode node = stack.Pop();
     const int row = DivideBy4(node.y);
     const int column = DivideBy4(node.x);
@@ -1479,24 +1752,18 @@
     stack.Push(TransformTreeNode(node.x, node.y + half_height, split_tx_size));
     stack.Push(TransformTreeNode(node.x + half_width, node.y, split_tx_size));
     stack.Push(TransformTreeNode(node.x, node.y, split_tx_size));
-  }
+  } while (!stack.Empty());
   return true;
 }
 
 void Tile::ReconstructBlock(const Block& block, Plane plane, int start_x,
                             int start_y, TransformSize tx_size,
-                            TransformType tx_type,
-                            int16_t non_zero_coeff_count) {
+                            TransformType tx_type, int non_zero_coeff_count) {
+  // Reconstruction process. Steps 2 and 3 of Section 7.12.3 in the spec.
   assert(non_zero_coeff_count >= 0);
   if (non_zero_coeff_count == 0) return;
-  // Reconstruction process. Steps 2 and 3 of Section 7.12.3 in the spec.
-  if (sequence_header_.color_config.bitdepth == 8) {
-    Reconstruct(dsp_, tx_type, tx_size,
-                frame_header_.segmentation.lossless[block.bp->segment_id],
-                reinterpret_cast<int16_t*>(*block.residual), start_x, start_y,
-                &buffer_[plane], non_zero_coeff_count);
 #if LIBGAV1_MAX_BITDEPTH >= 10
-  } else {
+  if (sequence_header_.color_config.bitdepth > 8) {
     Array2DView<uint16_t> buffer(
         buffer_[plane].rows(), buffer_[plane].columns() / sizeof(uint16_t),
         reinterpret_cast<uint16_t*>(&buffer_[plane][0][0]));
@@ -1504,7 +1771,13 @@
                 frame_header_.segmentation.lossless[block.bp->segment_id],
                 reinterpret_cast<int32_t*>(*block.residual), start_x, start_y,
                 &buffer, non_zero_coeff_count);
+  } else  // NOLINT
 #endif
+  {
+    Reconstruct(dsp_, tx_type, tx_size,
+                frame_header_.segmentation.lossless[block.bp->segment_id],
+                reinterpret_cast<int16_t*>(*block.residual), start_x, start_y,
+                &buffer_[plane], non_zero_coeff_count);
   }
   if (split_parse_and_decode_) {
     *block.residual +=
@@ -1513,8 +1786,8 @@
 }
 
 bool Tile::Residual(const Block& block, ProcessingMode mode) {
-  const int width_chunks = std::max(1, kBlockWidthPixels[block.size] >> 6);
-  const int height_chunks = std::max(1, kBlockHeightPixels[block.size] >> 6);
+  const int width_chunks = std::max(1, block.width >> 6);
+  const int height_chunks = std::max(1, block.height >> 6);
   const BlockSize size_chunk4x4 =
       (width_chunks > 1 || height_chunks > 1) ? kBlock64x64 : block.size;
   const BlockParameters& bp = *block.bp;
@@ -1574,7 +1847,7 @@
 bool Tile::IsMvValid(const Block& block, bool is_compound) const {
   const BlockParameters& bp = *block.bp;
   for (int i = 0; i < 1 + static_cast<int>(is_compound); ++i) {
-    for (int mv_component : bp.mv[i].mv) {
+    for (int mv_component : bp.mv.mv[i].mv) {
       if (std::abs(mv_component) >= (1 << 14)) {
         return false;
       }
@@ -1583,22 +1856,20 @@
   if (!block.bp->prediction_parameters->use_intra_block_copy) {
     return true;
   }
-  const int block_width = kBlockWidthPixels[block.size];
-  const int block_height = kBlockHeightPixels[block.size];
-  if ((bp.mv[0].mv[0] & 7) != 0 || (bp.mv[0].mv[1] & 7) != 0) {
+  if ((bp.mv.mv[0].mv32 & 0x00070007) != 0) {
     return false;
   }
-  const int delta_row = bp.mv[0].mv[0] >> 3;
-  const int delta_column = bp.mv[0].mv[1] >> 3;
+  const int delta_row = bp.mv.mv[0].mv[0] >> 3;
+  const int delta_column = bp.mv.mv[0].mv[1] >> 3;
   int src_top_edge = MultiplyBy4(block.row4x4) + delta_row;
   int src_left_edge = MultiplyBy4(block.column4x4) + delta_column;
-  const int src_bottom_edge = src_top_edge + block_height;
-  const int src_right_edge = src_left_edge + block_width;
+  const int src_bottom_edge = src_top_edge + block.height;
+  const int src_right_edge = src_left_edge + block.width;
   if (block.HasChroma()) {
-    if (block_width < 8 && subsampling_x_[kPlaneU] != 0) {
+    if (block.width < 8 && subsampling_x_[kPlaneU] != 0) {
       src_left_edge -= 4;
     }
-    if (block_height < 8 && subsampling_y_[kPlaneU] != 0) {
+    if (block.height < 8 && subsampling_y_[kPlaneU] != 0) {
       src_top_edge -= 4;
     }
   }
@@ -1636,58 +1907,102 @@
                                       wavefront_offset;
 }
 
-bool Tile::AssignMv(const Block& block, bool is_compound) {
-  MotionVector predicted_mv[2] = {};
+bool Tile::AssignInterMv(const Block& block, bool is_compound) {
+  int min[2];
+  int max[2];
+  GetClampParameters(block, min, max);
   BlockParameters& bp = *block.bp;
-  for (int i = 0; i < 1 + static_cast<int>(is_compound); ++i) {
-    const PredictionParameters& prediction_parameters =
-        *block.bp->prediction_parameters;
-    const PredictionMode mode = prediction_parameters.use_intra_block_copy
-                                    ? kPredictionModeNewMv
-                                    : GetSinglePredictionMode(i, bp.y_mode);
-    if (prediction_parameters.use_intra_block_copy) {
-      predicted_mv[0] = prediction_parameters.ref_mv_stack[0].mv[0];
-      if (predicted_mv[0].mv[0] == 0 && predicted_mv[0].mv[1] == 0) {
-        predicted_mv[0] = prediction_parameters.ref_mv_stack[1].mv[0];
-      }
-      if (predicted_mv[0].mv[0] == 0 && predicted_mv[0].mv[1] == 0) {
-        const int super_block_size4x4 = kNum4x4BlocksHigh[SuperBlockSize()];
-        if (block.row4x4 - super_block_size4x4 < row4x4_start_) {
-          predicted_mv[0].mv[1] = -MultiplyBy8(
-              MultiplyBy4(super_block_size4x4) + kIntraBlockCopyDelayPixels);
-        } else {
-          predicted_mv[0].mv[0] = -MultiplyBy32(super_block_size4x4);
+  const PredictionParameters& prediction_parameters = *bp.prediction_parameters;
+  if (is_compound) {
+    for (int i = 0; i < 2; ++i) {
+      const PredictionMode mode = GetSinglePredictionMode(i, bp.y_mode);
+      MotionVector predicted_mv;
+      if (mode == kPredictionModeGlobalMv) {
+        predicted_mv = prediction_parameters.global_mv[i];
+      } else {
+        const int ref_mv_index = (mode == kPredictionModeNearestMv ||
+                                  (mode == kPredictionModeNewMv &&
+                                   prediction_parameters.ref_mv_count <= 1))
+                                     ? 0
+                                     : prediction_parameters.ref_mv_index;
+        predicted_mv = prediction_parameters.reference_mv(ref_mv_index, i);
+        if (ref_mv_index < prediction_parameters.ref_mv_count) {
+          predicted_mv.mv[0] = Clip3(predicted_mv.mv[0], min[0], max[0]);
+          predicted_mv.mv[1] = Clip3(predicted_mv.mv[1], min[1], max[1]);
         }
       }
-    } else if (mode == kPredictionModeGlobalMv) {
-      predicted_mv[i] = prediction_parameters.global_mv[i];
+      if (mode == kPredictionModeNewMv) {
+        ReadMotionVector(block, i);
+        bp.mv.mv[i].mv[0] += predicted_mv.mv[0];
+        bp.mv.mv[i].mv[1] += predicted_mv.mv[1];
+      } else {
+        bp.mv.mv[i] = predicted_mv;
+      }
+    }
+  } else {
+    const PredictionMode mode = GetSinglePredictionMode(0, bp.y_mode);
+    MotionVector predicted_mv;
+    if (mode == kPredictionModeGlobalMv) {
+      predicted_mv = prediction_parameters.global_mv[0];
     } else {
       const int ref_mv_index = (mode == kPredictionModeNearestMv ||
                                 (mode == kPredictionModeNewMv &&
                                  prediction_parameters.ref_mv_count <= 1))
                                    ? 0
                                    : prediction_parameters.ref_mv_index;
-      predicted_mv[i] = prediction_parameters.ref_mv_stack[ref_mv_index].mv[i];
+      predicted_mv = prediction_parameters.reference_mv(ref_mv_index);
+      if (ref_mv_index < prediction_parameters.ref_mv_count) {
+        predicted_mv.mv[0] = Clip3(predicted_mv.mv[0], min[0], max[0]);
+        predicted_mv.mv[1] = Clip3(predicted_mv.mv[1], min[1], max[1]);
+      }
     }
     if (mode == kPredictionModeNewMv) {
-      ReadMotionVector(block, i);
-      bp.mv[i].mv[0] += predicted_mv[i].mv[0];
-      bp.mv[i].mv[1] += predicted_mv[i].mv[1];
+      ReadMotionVector(block, 0);
+      bp.mv.mv[0].mv[0] += predicted_mv.mv[0];
+      bp.mv.mv[0].mv[1] += predicted_mv.mv[1];
     } else {
-      bp.mv[i] = predicted_mv[i];
+      bp.mv.mv[0] = predicted_mv;
     }
   }
   return IsMvValid(block, is_compound);
 }
 
+bool Tile::AssignIntraMv(const Block& block) {
+  // TODO(linfengz): Check if the clamping process is necessary.
+  int min[2];
+  int max[2];
+  GetClampParameters(block, min, max);
+  BlockParameters& bp = *block.bp;
+  const PredictionParameters& prediction_parameters = *bp.prediction_parameters;
+  const MotionVector& ref_mv_0 = prediction_parameters.reference_mv(0);
+  ReadMotionVector(block, 0);
+  if (ref_mv_0.mv32 == 0) {
+    const MotionVector& ref_mv_1 = prediction_parameters.reference_mv(1);
+    if (ref_mv_1.mv32 == 0) {
+      const int super_block_size4x4 = kNum4x4BlocksHigh[SuperBlockSize()];
+      if (block.row4x4 - super_block_size4x4 < row4x4_start_) {
+        bp.mv.mv[0].mv[1] -= MultiplyBy32(super_block_size4x4);
+        bp.mv.mv[0].mv[1] -= MultiplyBy8(kIntraBlockCopyDelayPixels);
+      } else {
+        bp.mv.mv[0].mv[0] -= MultiplyBy32(super_block_size4x4);
+      }
+    } else {
+      bp.mv.mv[0].mv[0] += Clip3(ref_mv_1.mv[0], min[0], max[0]);
+      bp.mv.mv[0].mv[1] += Clip3(ref_mv_1.mv[1], min[0], max[0]);
+    }
+  } else {
+    bp.mv.mv[0].mv[0] += Clip3(ref_mv_0.mv[0], min[0], max[0]);
+    bp.mv.mv[0].mv[1] += Clip3(ref_mv_0.mv[1], min[1], max[1]);
+  }
+  return IsMvValid(block, /*is_compound=*/false);
+}
+
 void Tile::ResetEntropyContext(const Block& block) {
-  const int block_width4x4 = kNum4x4BlocksWide[block.size];
-  const int block_height4x4 = kNum4x4BlocksHigh[block.size];
   for (int plane = 0; plane < (block.HasChroma() ? PlaneCount() : 1); ++plane) {
     const int subsampling_x = subsampling_x_[plane];
     const int start_x = block.column4x4 >> subsampling_x;
     const int end_x =
-        std::min((block.column4x4 + block_width4x4) >> subsampling_x,
+        std::min((block.column4x4 + block.width4x4) >> subsampling_x,
                  frame_header_.columns4x4);
     memset(&coefficient_levels_[kEntropyContextTop][plane][start_x], 0,
            end_x - start_x);
@@ -1696,7 +2011,7 @@
     const int subsampling_y = subsampling_y_[plane];
     const int start_y = block.row4x4 >> subsampling_y;
     const int end_y =
-        std::min((block.row4x4 + block_height4x4) >> subsampling_y,
+        std::min((block.row4x4 + block.height4x4) >> subsampling_y,
                  frame_header_.rows4x4);
     memset(&coefficient_levels_[kEntropyContextLeft][plane][start_y], 0,
            end_y - start_y);
@@ -1705,12 +2020,15 @@
   }
 }
 
-void Tile::ComputePrediction(const Block& block) {
+bool Tile::ComputePrediction(const Block& block) {
+  const BlockParameters& bp = *block.bp;
+  if (!bp.is_inter) return true;
   const int mask =
       (1 << (4 + static_cast<int>(sequence_header_.use_128x128_superblock))) -
       1;
   const int sub_block_row4x4 = block.row4x4 & mask;
   const int sub_block_column4x4 = block.column4x4 & mask;
+  const int plane_count = block.HasChroma() ? PlaneCount() : 1;
   // Returns true if this block applies local warping. The state is determined
   // in the Y plane and carried for use in the U/V planes.
   // But the U/V planes will not apply warping when the block size is smaller
@@ -1718,20 +2036,19 @@
   bool is_local_valid = false;
   // Local warping parameters, similar usage as is_local_valid.
   GlobalMotion local_warp_params;
-  for (int plane = 0; plane < (block.HasChroma() ? PlaneCount() : 1); ++plane) {
+  int plane = 0;
+  do {
     const int8_t subsampling_x = subsampling_x_[plane];
     const int8_t subsampling_y = subsampling_y_[plane];
-    const BlockSize plane_size =
-        block.residual_size[GetPlaneType(static_cast<Plane>(plane))];
+    const BlockSize plane_size = block.residual_size[plane];
     const int block_width4x4 = kNum4x4BlocksWide[plane_size];
     const int block_height4x4 = kNum4x4BlocksHigh[plane_size];
     const int block_width = MultiplyBy4(block_width4x4);
     const int block_height = MultiplyBy4(block_height4x4);
     const int base_x = MultiplyBy4(block.column4x4 >> subsampling_x);
     const int base_y = MultiplyBy4(block.row4x4 >> subsampling_y);
-    const BlockParameters& bp = *block.bp;
-    if (bp.is_inter && bp.reference_frame[1] == kReferenceFrameIntra) {
-      const int tr_row4x4 = (sub_block_row4x4 >> subsampling_y);
+    if (bp.reference_frame[1] == kReferenceFrameIntra) {
+      const int tr_row4x4 = sub_block_row4x4 >> subsampling_y;
       const int tr_column4x4 =
           (sub_block_column4x4 >> subsampling_x) + block_width4x4 + 1;
       const int bl_row4x4 =
@@ -1740,88 +2057,98 @@
       const TransformSize tx_size =
           k4x4SizeToTransformSize[k4x4WidthLog2[plane_size]]
                                  [k4x4HeightLog2[plane_size]];
-      const bool has_left =
-          plane == kPlaneY ? block.left_available : block.LeftAvailableChroma();
-      const bool has_top =
-          plane == kPlaneY ? block.top_available : block.TopAvailableChroma();
-      if (sequence_header_.color_config.bitdepth == 8) {
-        IntraPrediction<uint8_t>(
-            block, static_cast<Plane>(plane), base_x, base_y, has_left, has_top,
-            block.scratch_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
-            block.scratch_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
-            kInterIntraToIntraMode[block.bp->prediction_parameters
-                                       ->inter_intra_mode],
-            tx_size);
-#if LIBGAV1_MAX_BITDEPTH >= 10
-      } else {
-        IntraPrediction<uint16_t>(
-            block, static_cast<Plane>(plane), base_x, base_y, has_left, has_top,
-            block.scratch_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
-            block.scratch_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
-            kInterIntraToIntraMode[block.bp->prediction_parameters
-                                       ->inter_intra_mode],
-            tx_size);
-#endif
-      }
+      const bool has_left = block.left_available[plane];
+      const bool has_top = block.top_available[plane];
+      CALL_BITDEPTH_FUNCTION(
+          IntraPrediction, block, static_cast<Plane>(plane), base_x, base_y,
+          has_left, has_top,
+          block.scratch_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
+          block.scratch_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
+          kInterIntraToIntraMode[block.bp->prediction_parameters
+                                     ->inter_intra_mode],
+          tx_size);
     }
-    if (bp.is_inter) {
-      int candidate_row = (block.row4x4 >> subsampling_y) << subsampling_y;
-      int candidate_column = (block.column4x4 >> subsampling_x)
-                             << subsampling_x;
-      bool some_use_intra = false;
-      for (int r = 0; r < (block_height4x4 << subsampling_y); ++r) {
-        for (int c = 0; c < (block_width4x4 << subsampling_x); ++c) {
-          auto* const bp = block_parameters_holder_.Find(candidate_row + r,
-                                                         candidate_column + c);
-          if (bp != nullptr && bp->reference_frame[0] == kReferenceFrameIntra) {
-            some_use_intra = true;
-            break;
-          }
-        }
-        if (some_use_intra) break;
-      }
-      int prediction_width;
-      int prediction_height;
-      if (some_use_intra) {
-        candidate_row = block.row4x4;
-        candidate_column = block.column4x4;
-        prediction_width = block_width;
-        prediction_height = block_height;
-      } else {
-        prediction_width = kBlockWidthPixels[block.size] >> subsampling_x;
-        prediction_height = kBlockHeightPixels[block.size] >> subsampling_y;
-      }
-      for (int r = 0, y = 0; y < block_height; y += prediction_height, ++r) {
-        for (int c = 0, x = 0; x < block_width; x += prediction_width, ++c) {
-          InterPrediction(block, static_cast<Plane>(plane), base_x + x,
-                          base_y + y, prediction_width, prediction_height,
-                          candidate_row + r, candidate_column + c,
-                          &is_local_valid, &local_warp_params);
+    int candidate_row = block.row4x4;
+    int candidate_column = block.column4x4;
+    bool some_use_intra = bp.reference_frame[0] == kReferenceFrameIntra;
+    if (!some_use_intra && plane != 0) {
+      candidate_row = (candidate_row >> subsampling_y) << subsampling_y;
+      candidate_column = (candidate_column >> subsampling_x) << subsampling_x;
+      if (candidate_row != block.row4x4) {
+        // Top block.
+        const BlockParameters& bp_top =
+            *block_parameters_holder_.Find(candidate_row, block.column4x4);
+        some_use_intra = bp_top.reference_frame[0] == kReferenceFrameIntra;
+        if (!some_use_intra && candidate_column != block.column4x4) {
+          // Top-left block.
+          const BlockParameters& bp_top_left =
+              *block_parameters_holder_.Find(candidate_row, candidate_column);
+          some_use_intra =
+              bp_top_left.reference_frame[0] == kReferenceFrameIntra;
         }
       }
+      if (!some_use_intra && candidate_column != block.column4x4) {
+        // Left block.
+        const BlockParameters& bp_left =
+            *block_parameters_holder_.Find(block.row4x4, candidate_column);
+        some_use_intra = bp_left.reference_frame[0] == kReferenceFrameIntra;
+      }
     }
-  }
+    int prediction_width;
+    int prediction_height;
+    if (some_use_intra) {
+      candidate_row = block.row4x4;
+      candidate_column = block.column4x4;
+      prediction_width = block_width;
+      prediction_height = block_height;
+    } else {
+      prediction_width = block.width >> subsampling_x;
+      prediction_height = block.height >> subsampling_y;
+    }
+    int r = 0;
+    int y = 0;
+    do {
+      int c = 0;
+      int x = 0;
+      do {
+        if (!InterPrediction(block, static_cast<Plane>(plane), base_x + x,
+                             base_y + y, prediction_width, prediction_height,
+                             candidate_row + r, candidate_column + c,
+                             &is_local_valid, &local_warp_params)) {
+          return false;
+        }
+        ++c;
+        x += prediction_width;
+      } while (x < block_width);
+      ++r;
+      y += prediction_height;
+    } while (y < block_height);
+  } while (++plane < plane_count);
+  return true;
 }
 
+#undef CALL_BITDEPTH_FUNCTION
+
 void Tile::PopulateDeblockFilterLevel(const Block& block) {
   if (!post_filter_.DoDeblock()) return;
   BlockParameters& bp = *block.bp;
+  const int mode_id =
+      static_cast<int>(kPredictionModeDeltasMask.Contains(bp.y_mode));
   for (int i = 0; i < kFrameLfCount; ++i) {
     if (delta_lf_all_zero_) {
       bp.deblock_filter_level[i] = post_filter_.GetZeroDeltaDeblockFilterLevel(
-          bp.segment_id, i, bp.reference_frame[0],
-          LoopFilterMask::GetModeId(bp.y_mode));
+          bp.segment_id, i, bp.reference_frame[0], mode_id);
     } else {
       bp.deblock_filter_level[i] =
           deblock_filter_levels_[bp.segment_id][i][bp.reference_frame[0]]
-                                [LoopFilterMask::GetModeId(bp.y_mode)];
+                                [mode_id];
     }
   }
 }
 
 bool Tile::ProcessBlock(int row4x4, int column4x4, BlockSize block_size,
                         ParameterTree* const tree,
-                        DecoderScratchBuffer* const scratch_buffer,
+                        TileScratchBuffer* const scratch_buffer,
                         ResidualPtr* residual) {
   // Do not process the block if the starting point is beyond the visible frame.
   // This is equivalent to the has_row/has_column check in the
@@ -1831,34 +2158,34 @@
       column4x4 >= frame_header_.columns4x4) {
     return true;
   }
-  Block block(*this, row4x4, column4x4, block_size, scratch_buffer, residual,
-              tree->parameters());
-  block.bp->size = block_size;
-  block_parameters_holder_.FillCache(row4x4, column4x4, block_size,
-                                     tree->parameters());
-  block.bp->prediction_parameters =
+  BlockParameters& bp = *tree->parameters();
+  block_parameters_holder_.FillCache(row4x4, column4x4, block_size, &bp);
+  Block block(*this, block_size, row4x4, column4x4, scratch_buffer, residual);
+  bp.size = block_size;
+  bp.prediction_parameters =
       split_parse_and_decode_ ? std::unique_ptr<PredictionParameters>(
                                     new (std::nothrow) PredictionParameters())
                               : std::move(prediction_parameters_);
-  if (block.bp->prediction_parameters == nullptr) return false;
+  if (bp.prediction_parameters == nullptr) return false;
   if (!DecodeModeInfo(block)) return false;
+  bp.is_global_mv_block = (bp.y_mode == kPredictionModeGlobalMv ||
+                           bp.y_mode == kPredictionModeGlobalGlobalMv) &&
+                          !IsBlockDimension4(bp.size);
   PopulateDeblockFilterLevel(block);
   if (!ReadPaletteTokens(block)) return false;
   DecodeTransformSize(block);
-  BlockParameters& bp = *block.bp;
   // Part of Section 5.11.37 in the spec (implemented as a simple lookup).
-  bp.uv_transform_size =
-      frame_header_.segmentation.lossless[bp.segment_id]
-          ? kTransformSize4x4
-          : kUVTransformSize[block.residual_size[kPlaneTypeUV]];
+  bp.uv_transform_size = frame_header_.segmentation.lossless[bp.segment_id]
+                             ? kTransformSize4x4
+                             : kUVTransformSize[block.residual_size[kPlaneU]];
   if (bp.skip) ResetEntropyContext(block);
-  const int block_width4x4 = kNum4x4BlocksWide[block_size];
-  const int block_height4x4 = kNum4x4BlocksHigh[block_size];
   if (split_parse_and_decode_) {
     if (!Residual(block, kProcessingModeParseOnly)) return false;
   } else {
-    ComputePrediction(block);
-    if (!Residual(block, kProcessingModeParseAndDecode)) return false;
+    if (!ComputePrediction(block) ||
+        !Residual(block, kProcessingModeParseAndDecode)) {
+      return false;
+    }
   }
   // If frame_header_.segmentation.enabled is false, bp.segment_id is 0 for all
   // blocks. We don't need to call save bp.segment_id in the current frame
@@ -1870,25 +2197,22 @@
   // save bp.segment_id in the current frame.
   if (frame_header_.segmentation.enabled &&
       frame_header_.segmentation.update_map) {
-    const int x_limit =
-        std::min(frame_header_.columns4x4 - column4x4, block_width4x4);
-    const int y_limit =
-        std::min(frame_header_.rows4x4 - row4x4, block_height4x4);
+    const int x_limit = std::min(frame_header_.columns4x4 - column4x4,
+                                 static_cast<int>(block.width4x4));
+    const int y_limit = std::min(frame_header_.rows4x4 - row4x4,
+                                 static_cast<int>(block.height4x4));
     current_frame_.segmentation_map()->FillBlock(row4x4, column4x4, x_limit,
                                                  y_limit, bp.segment_id);
   }
-  if (build_bit_mask_when_parsing_ || !split_parse_and_decode_) {
-    BuildBitMask(row4x4, column4x4, block_size);
-  }
+  StoreMotionFieldMvsIntoCurrentFrame(block);
   if (!split_parse_and_decode_) {
-    StoreMotionFieldMvsIntoCurrentFrame(block);
-    prediction_parameters_ = std::move(block.bp->prediction_parameters);
+    prediction_parameters_ = std::move(bp.prediction_parameters);
   }
   return true;
 }
 
 bool Tile::DecodeBlock(ParameterTree* const tree,
-                       DecoderScratchBuffer* const scratch_buffer,
+                       TileScratchBuffer* const scratch_buffer,
                        ResidualPtr* residual) {
   const int row4x4 = tree->row4x4();
   const int column4x4 = tree->column4x4();
@@ -1897,21 +2221,18 @@
     return true;
   }
   const BlockSize block_size = tree->block_size();
-  Block block(*this, row4x4, column4x4, block_size, scratch_buffer, residual,
-              tree->parameters());
-  ComputePrediction(block);
-  if (!Residual(block, kProcessingModeDecodeOnly)) return false;
-  if (!build_bit_mask_when_parsing_) {
-    BuildBitMask(row4x4, column4x4, block_size);
+  Block block(*this, block_size, row4x4, column4x4, scratch_buffer, residual);
+  if (!ComputePrediction(block) ||
+      !Residual(block, kProcessingModeDecodeOnly)) {
+    return false;
   }
-  StoreMotionFieldMvsIntoCurrentFrame(block);
   block.bp->prediction_parameters.reset(nullptr);
   return true;
 }
 
 bool Tile::ProcessPartition(int row4x4_start, int column4x4_start,
                             ParameterTree* const root,
-                            DecoderScratchBuffer* const scratch_buffer,
+                            TileScratchBuffer* const scratch_buffer,
                             ResidualPtr* residual) {
   Stack<ParameterTree*, kDfsStackSize> stack;
 
@@ -2025,7 +2346,7 @@
 }
 
 void Tile::ResetCdef(const int row4x4, const int column4x4) {
-  if (cdef_index_[0] == nullptr) return;
+  if (!sequence_header_.enable_cdef) return;
   const int row = DivideBy16(row4x4);
   const int column = DivideBy16(column4x4);
   cdef_index_[row][column] = -1;
@@ -2039,7 +2360,7 @@
   }
 }
 
-void Tile::ClearBlockDecoded(DecoderScratchBuffer* const scratch_buffer,
+void Tile::ClearBlockDecoded(TileScratchBuffer* const scratch_buffer,
                              int row4x4, int column4x4) {
   // Set everything to false.
   memset(scratch_buffer->block_decoded, 0,
@@ -2075,7 +2396,7 @@
 }
 
 bool Tile::ProcessSuperBlock(int row4x4, int column4x4, int block_width4x4,
-                             DecoderScratchBuffer* const scratch_buffer,
+                             TileScratchBuffer* const scratch_buffer,
                              ProcessingMode mode) {
   const bool parsing =
       mode == kProcessingModeParseOnly || mode == kProcessingModeParseAndDecode;
@@ -2139,11 +2460,11 @@
 }
 
 bool Tile::DecodeSuperBlock(ParameterTree* const tree,
-                            DecoderScratchBuffer* const scratch_buffer,
+                            TileScratchBuffer* const scratch_buffer,
                             ResidualPtr* residual) {
   Stack<ParameterTree*, kDfsStackSize> stack;
   stack.Push(tree);
-  while (!stack.Empty()) {
+  do {
     ParameterTree* const node = stack.Pop();
     if (node->partition() != kPartitionNone) {
       for (int i = 3; i >= 0; --i) {
@@ -2157,7 +2478,7 @@
                    node->row4x4(), node->column4x4());
       return false;
     }
-  }
+  } while (!stack.Empty());
   return true;
 }
 
@@ -2189,222 +2510,87 @@
   }
 }
 
-void Tile::BuildBitMask(int row4x4, int column4x4, BlockSize block_size) {
-  if (!post_filter_.DoDeblock()) return;
-  if (block_size <= kBlock64x64) {
-    BuildBitMaskHelper(row4x4, column4x4, block_size, true, true);
-  } else {
-    const int block_width4x4 = kNum4x4BlocksWide[block_size];
-    const int block_height4x4 = kNum4x4BlocksHigh[block_size];
-    for (int y = 0; y < block_height4x4; y += 16) {
-      for (int x = 0; x < block_width4x4; x += 16) {
-        BuildBitMaskHelper(row4x4 + y, column4x4 + x, kBlock64x64, x == 0,
-                           y == 0);
-      }
-    }
-  }
-}
-
-void Tile::BuildBitMaskHelper(int row4x4, int column4x4, BlockSize block_size,
-                              const bool is_vertical_block_border,
-                              const bool is_horizontal_block_border) {
-  const int block_width4x4 = kNum4x4BlocksWide[block_size];
-  const int block_height4x4 = kNum4x4BlocksHigh[block_size];
-  BlockParameters& bp = *block_parameters_holder_.Find(row4x4, column4x4);
-  const bool skip = bp.skip && bp.is_inter;
-  LoopFilterMask* const masks = post_filter_.masks();
-  const int unit_id = DivideBy16(row4x4) * masks->num_64x64_blocks_per_row() +
-                      DivideBy16(column4x4);
-
-  for (int plane = kPlaneY; plane < PlaneCount(); ++plane) {
-    // For U and V planes, do not build bit masks if level == 0.
-    if (plane > kPlaneY && frame_header_.loop_filter.level[plane + 1] == 0) {
-      continue;
-    }
-    // Build bit mask for vertical edges.
-    const int subsampling_x = subsampling_x_[plane];
-    const int subsampling_y = subsampling_y_[plane];
-    const int plane_width =
-        RightShiftWithRounding(frame_header_.width, subsampling_x);
-    const int column_limit =
-        std::min({column4x4 + block_width4x4, frame_header_.columns4x4,
-                  DivideBy4(plane_width + 3) << subsampling_x});
-    const int plane_height =
-        RightShiftWithRounding(frame_header_.height, subsampling_y);
-    const int row_limit =
-        std::min({row4x4 + block_height4x4, frame_header_.rows4x4,
-                  DivideBy4(plane_height + 3) << subsampling_y});
-    const int row_start = GetDeblockPosition(row4x4, subsampling_y);
-    const int column_start = GetDeblockPosition(column4x4, subsampling_x);
-    if (row_start >= row_limit || column_start >= column_limit) {
-      continue;
-    }
-    const int vertical_step = 1 << subsampling_y;
-    const int horizontal_step = 1 << subsampling_x;
-    const BlockParameters& bp =
-        *block_parameters_holder_.Find(row_start, column_start);
-    const int horizontal_level_index =
-        kDeblockFilterLevelIndex[plane][kLoopFilterTypeHorizontal];
-    const int vertical_level_index =
-        kDeblockFilterLevelIndex[plane][kLoopFilterTypeVertical];
-    const uint8_t vertical_level =
-        bp.deblock_filter_level[vertical_level_index];
-
-    for (int row = row_start; row < row_limit; row += vertical_step) {
-      for (int column = column_start; column < column_limit;) {
-        const TransformSize tx_size = (plane == kPlaneY)
-                                          ? inter_transform_sizes_[row][column]
-                                          : bp.uv_transform_size;
-        // (1). Don't filter frame boundary.
-        // (2). For tile boundary, we don't know whether the previous tile is
-        // available or not, thus we handle it after all tiles are decoded.
-        const bool is_vertical_border =
-            (column == column_start) && is_vertical_block_border;
-        if (column == GetDeblockPosition(column4x4_start_, subsampling_x) ||
-            (skip && !is_vertical_border)) {
-          column += kNum4x4BlocksWide[tx_size] << subsampling_x;
-          continue;
-        }
-
-        // bp_left is the parameter of the left prediction block which
-        // is guaranteed to be inside the tile.
-        const BlockParameters& bp_left =
-            *block_parameters_holder_.Find(row, column - horizontal_step);
-        const uint8_t left_level =
-            is_vertical_border
-                ? bp_left.deblock_filter_level[vertical_level_index]
-                : vertical_level;
-        // We don't have to check if the left block is skipped or not,
-        // because if the current transform block is on the edge of the coding
-        // block, is_vertical_border is true; if it's not on the edge,
-        // left skip is equal to skip.
-        if (vertical_level != 0 || left_level != 0) {
-          const TransformSize left_tx_size =
-              (plane == kPlaneY)
-                  ? inter_transform_sizes_[row][column - horizontal_step]
-                  : bp_left.uv_transform_size;
-          const LoopFilterTransformSizeId transform_size_id =
-              GetTransformSizeIdWidth(tx_size, left_tx_size);
-          const int r = row & (kNum4x4InLoopFilterMaskUnit - 1);
-          const int c = column & (kNum4x4InLoopFilterMaskUnit - 1);
-          const int shift = LoopFilterMask::GetShift(r, c);
-          const int index = LoopFilterMask::GetIndex(r);
-          const auto mask = static_cast<uint64_t>(1) << shift;
-          masks->SetLeft(mask, unit_id, plane, transform_size_id, index);
-          const uint8_t current_level =
-              (vertical_level == 0) ? left_level : vertical_level;
-          masks->SetLevel(current_level, unit_id, plane,
-                          kLoopFilterTypeVertical,
-                          LoopFilterMask::GetLevelOffset(r, c));
-        }
-        column += kNum4x4BlocksWide[tx_size] << subsampling_x;
-      }
-    }
-
-    // Build bit mask for horizontal edges.
-    const uint8_t horizontal_level =
-        bp.deblock_filter_level[horizontal_level_index];
-    for (int column = column_start; column < column_limit;
-         column += horizontal_step) {
-      for (int row = row_start; row < row_limit;) {
-        const TransformSize tx_size = (plane == kPlaneY)
-                                          ? inter_transform_sizes_[row][column]
-                                          : bp.uv_transform_size;
-
-        // (1). Don't filter frame boundary.
-        // (2). For tile boundary, we don't know whether the previous tile is
-        // available or not, thus we handle it after all tiles are decoded.
-        const bool is_horizontal_border =
-            (row == row_start) && is_horizontal_block_border;
-        if (row == GetDeblockPosition(row4x4_start_, subsampling_y) ||
-            (skip && !is_horizontal_border)) {
-          row += kNum4x4BlocksHigh[tx_size] << subsampling_y;
-          continue;
-        }
-
-        // bp_top is the parameter of the top prediction block which is
-        // guaranteed to be inside the tile.
-        const BlockParameters& bp_top =
-            *block_parameters_holder_.Find(row - vertical_step, column);
-        const uint8_t top_level =
-            is_horizontal_border
-                ? bp_top.deblock_filter_level[horizontal_level_index]
-                : horizontal_level;
-        // We don't have to check it the top block is skippped or not,
-        // because if the current transform block is on the edge of the coding
-        // block, is_horizontal_border is true; if it's not on the edge,
-        // top skip is equal to skip.
-        if (horizontal_level != 0 || top_level != 0) {
-          const TransformSize top_tx_size =
-              (plane == kPlaneY)
-                  ? inter_transform_sizes_[row - vertical_step][column]
-                  : bp_top.uv_transform_size;
-          const LoopFilterTransformSizeId transform_size_id =
-              static_cast<LoopFilterTransformSizeId>(
-                  std::min({kTransformHeightLog2[tx_size] - 2,
-                            kTransformHeightLog2[top_tx_size] - 2, 2}));
-          const int r = row & (kNum4x4InLoopFilterMaskUnit - 1);
-          const int c = column & (kNum4x4InLoopFilterMaskUnit - 1);
-          const int shift = LoopFilterMask::GetShift(r, c);
-          const int index = LoopFilterMask::GetIndex(r);
-          const auto mask = static_cast<uint64_t>(1) << shift;
-          masks->SetTop(mask, unit_id, plane, transform_size_id, index);
-          const uint8_t current_level =
-              (horizontal_level == 0) ? top_level : horizontal_level;
-          masks->SetLevel(current_level, unit_id, plane,
-                          kLoopFilterTypeHorizontal,
-                          LoopFilterMask::GetLevelOffset(r, c));
-        }
-        row += kNum4x4BlocksHigh[tx_size] << subsampling_y;
-      }
-    }
-  }
-}
-
 void Tile::StoreMotionFieldMvsIntoCurrentFrame(const Block& block) {
-  // The largest reference MV component that can be saved.
-  constexpr int kRefMvsLimit = (1 << 12) - 1;
-  const BlockParameters& bp = *block.bp;
-  ReferenceFrameType reference_frame_to_store = kReferenceFrameNone;
-  MotionVector mv_to_store = {};
-  for (int i = 1; i >= 0; --i) {
-    if (bp.reference_frame[i] > kReferenceFrameIntra &&
-        std::abs(bp.mv[i].mv[MotionVector::kRow]) <= kRefMvsLimit &&
-        std::abs(bp.mv[i].mv[MotionVector::kColumn]) <= kRefMvsLimit &&
-        GetRelativeDistance(
-            reference_order_hint_
-                [frame_header_.reference_frame_index[bp.reference_frame[i] -
-                                                     kReferenceFrameLast]],
-            frame_header_.order_hint, sequence_header_.enable_order_hint,
-            sequence_header_.order_hint_bits) < 0) {
-      reference_frame_to_store = bp.reference_frame[i];
-      mv_to_store = bp.mv[i];
-      break;
-    }
+  if (frame_header_.refresh_frame_flags == 0 ||
+      IsIntraFrame(frame_header_.frame_type)) {
+    return;
   }
   // Iterate over odd rows/columns beginning at the first odd row/column for the
   // block. It is done this way because motion field mvs are only needed at a
   // 8x8 granularity.
-  const int row_start = block.row4x4 | 1;
-  const int row_limit = std::min(block.row4x4 + kNum4x4BlocksHigh[block.size],
-                                 frame_header_.rows4x4);
-  const int column_start = block.column4x4 | 1;
-  const int column_limit =
-      std::min(block.column4x4 + kNum4x4BlocksWide[block.size],
-               frame_header_.columns4x4);
-  for (int row = row_start; row < row_limit; row += 2) {
-    const int row_index = DivideBy2(row);
-    ReferenceFrameType* const reference_frame_row_start =
-        current_frame_.motion_field_reference_frame(row_index,
-                                                    DivideBy2(column_start));
-    static_assert(sizeof(reference_frame_to_store) == sizeof(int8_t), "");
-    memset(reference_frame_row_start, reference_frame_to_store,
-           DivideBy2(column_limit - column_start + 1));
-    if (reference_frame_to_store <= kReferenceFrameIntra) continue;
-    for (int column = column_start; column < column_limit; column += 2) {
+  const int row_start4x4 = block.row4x4 | 1;
+  const int row_limit4x4 =
+      std::min(block.row4x4 + block.height4x4, frame_header_.rows4x4);
+  if (row_start4x4 >= row_limit4x4) return;
+  const int column_start4x4 = block.column4x4 | 1;
+  const int column_limit4x4 =
+      std::min(block.column4x4 + block.width4x4, frame_header_.columns4x4);
+  if (column_start4x4 >= column_limit4x4) return;
+
+  // The largest reference MV component that can be saved.
+  constexpr int kRefMvsLimit = (1 << 12) - 1;
+  const BlockParameters& bp = *block.bp;
+  ReferenceInfo* reference_info = current_frame_.reference_info();
+  for (int i = 1; i >= 0; --i) {
+    const ReferenceFrameType reference_frame_to_store = bp.reference_frame[i];
+    // Must make a local copy so that StoreMotionFieldMvs() knows there is no
+    // overlap between load and store.
+    const MotionVector mv_to_store = bp.mv.mv[i];
+    const int mv_row = std::abs(mv_to_store.mv[MotionVector::kRow]);
+    const int mv_column = std::abs(mv_to_store.mv[MotionVector::kColumn]);
+    if (reference_frame_to_store > kReferenceFrameIntra &&
+        // kRefMvsLimit equals 0x07FF, so we can first bitwise OR the two
+        // absolute values and then compare with kRefMvsLimit to save a branch.
+        // The next line is equivalent to:
+        // mv_row <= kRefMvsLimit && mv_column <= kRefMvsLimit
+        (mv_row | mv_column) <= kRefMvsLimit &&
+        reference_info->relative_distance_from[reference_frame_to_store] < 0) {
+      const int row_start8x8 = DivideBy2(row_start4x4);
+      const int row_limit8x8 = DivideBy2(row_limit4x4);
+      const int column_start8x8 = DivideBy2(column_start4x4);
+      const int column_limit8x8 = DivideBy2(column_limit4x4);
+      const int rows = row_limit8x8 - row_start8x8;
+      const int columns = column_limit8x8 - column_start8x8;
+      const ptrdiff_t stride = DivideBy2(current_frame_.columns4x4());
+      ReferenceFrameType* const reference_frame_row_start =
+          &reference_info
+               ->motion_field_reference_frame[row_start8x8][column_start8x8];
       MotionVector* const mv =
-          current_frame_.motion_field_mv(row_index, DivideBy2(column));
-      *mv = mv_to_store;
+          &reference_info->motion_field_mv[row_start8x8][column_start8x8];
+
+      // Specialize columns cases 1, 2, 4, 8 and 16. This makes memset() inlined
+      // and simplifies std::fill() for these cases.
+      if (columns <= 1) {
+        // Don't change the above condition to (columns == 1).
+        // Condition (columns <= 1) may help the compiler simplify the inlining
+        // of the general case of StoreMotionFieldMvs() by eliminating the
+        // (columns == 0) case.
+        assert(columns == 1);
+        StoreMotionFieldMvs(reference_frame_to_store, mv_to_store, stride, rows,
+                            1, reference_frame_row_start, mv);
+      } else if (columns == 2) {
+        StoreMotionFieldMvs(reference_frame_to_store, mv_to_store, stride, rows,
+                            2, reference_frame_row_start, mv);
+      } else if (columns == 4) {
+        StoreMotionFieldMvs(reference_frame_to_store, mv_to_store, stride, rows,
+                            4, reference_frame_row_start, mv);
+      } else if (columns == 8) {
+        StoreMotionFieldMvs(reference_frame_to_store, mv_to_store, stride, rows,
+                            8, reference_frame_row_start, mv);
+      } else if (columns == 16) {
+        StoreMotionFieldMvs(reference_frame_to_store, mv_to_store, stride, rows,
+                            16, reference_frame_row_start, mv);
+      } else if (columns < 16) {
+        // This always true condition (columns < 16) may help the compiler
+        // simplify the inlining of the following function.
+        // This general case is rare and usually only happens to the blocks
+        // which contain the right boundary of the frame.
+        StoreMotionFieldMvs(reference_frame_to_store, mv_to_store, stride, rows,
+                            columns, reference_frame_row_start, mv);
+      } else {
+        assert(false);
+      }
+      return;
     }
   }
 }
diff --git a/libgav1/src/decoder_scratch_buffer.cc b/libgav1/src/tile_scratch_buffer.cc
similarity index 80%
rename from libgav1/src/decoder_scratch_buffer.cc
rename to libgav1/src/tile_scratch_buffer.cc
index bb9b5f2..0b5ac96 100644
--- a/libgav1/src/decoder_scratch_buffer.cc
+++ b/libgav1/src/tile_scratch_buffer.cc
@@ -12,12 +12,15 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/decoder_scratch_buffer.h"
+#include "src/tile_scratch_buffer.h"
+
+#include "src/utils/compiler_attributes.h"
 
 namespace libgav1 {
 
+#if !LIBGAV1_CXX17
 // static
-constexpr int DecoderScratchBuffer::kBlockDecodedStride;
-constexpr int DecoderScratchBuffer::kPixelSize;
+constexpr int TileScratchBuffer::kBlockDecodedStride;
+#endif
 
 }  // namespace libgav1
diff --git a/libgav1/src/tile_scratch_buffer.h b/libgav1/src/tile_scratch_buffer.h
new file mode 100644
index 0000000..3eaf8b8
--- /dev/null
+++ b/libgav1/src/tile_scratch_buffer.h
@@ -0,0 +1,160 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_TILE_SCRATCH_BUFFER_H_
+#define LIBGAV1_SRC_TILE_SCRATCH_BUFFER_H_
+
+#include <cstdint>
+#include <mutex>  // NOLINT (unapproved c++11 header)
+
+#include "src/dsp/constants.h"
+#include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
+#include "src/utils/constants.h"
+#include "src/utils/memory.h"
+#include "src/utils/stack.h"
+
+namespace libgav1 {
+
+// Buffer to facilitate decoding a superblock.
+struct TileScratchBuffer : public MaxAlignedAllocable {
+  static constexpr int kBlockDecodedStride = 34;
+
+  LIBGAV1_MUST_USE_RESULT bool Init(int bitdepth) {
+#if LIBGAV1_MAX_BITDEPTH >= 10
+    const int pixel_size = (bitdepth == 8) ? 1 : 2;
+#else
+    assert(bitdepth == 8);
+    static_cast<void>(bitdepth);
+    const int pixel_size = 1;
+#endif
+
+    constexpr int unaligned_convolve_buffer_stride =
+        kMaxScaledSuperBlockSizeInPixels + kConvolveBorderLeftTop +
+        kConvolveBorderRight;
+    convolve_block_buffer_stride = Align<ptrdiff_t>(
+        unaligned_convolve_buffer_stride * pixel_size, kMaxAlignment);
+    constexpr int convolve_buffer_height = kMaxScaledSuperBlockSizeInPixels +
+                                           kConvolveBorderLeftTop +
+                                           kConvolveBorderBottom;
+
+    convolve_block_buffer = MakeAlignedUniquePtr<uint8_t>(
+        kMaxAlignment, convolve_buffer_height * convolve_block_buffer_stride);
+    return convolve_block_buffer != nullptr;
+  }
+
+  // kCompoundPredictionTypeDiffWeighted prediction mode needs a mask of the
+  // prediction block size. This buffer is used to store that mask. The masks
+  // will be created for the Y plane and will be re-used for the U & V planes.
+  alignas(kMaxAlignment) uint8_t weight_mask[kMaxSuperBlockSizeSquareInPixels];
+
+  // For each instance of the TileScratchBuffer, only one of the following
+  // buffers will be used at any given time, so it is ok to share them in a
+  // union.
+  union {
+    // Buffers used for prediction process.
+    // Compound prediction calculations always output 16-bit values. Depending
+    // on the bitdepth the values may be treated as int16_t or uint16_t. See
+    // src/dsp/convolve.cc and src/dsp/warp.cc for explanations.
+    // Inter/intra calculations output Pixel values.
+    // These buffers always use width as the stride. This enables packing the
+    // values in and simplifies loads/stores for small values.
+
+    // 10/12 bit compound prediction and 10/12 bit inter/intra prediction.
+    alignas(kMaxAlignment) uint16_t
+        prediction_buffer[2][kMaxSuperBlockSizeSquareInPixels];
+    // 8 bit compound prediction buffer.
+    alignas(kMaxAlignment) int16_t
+        compound_prediction_buffer_8bpp[2][kMaxSuperBlockSizeSquareInPixels];
+
+    // Union usage note: This is used only by functions in the "intra"
+    // prediction path.
+    //
+    // Buffer used for storing subsampled luma samples needed for CFL
+    // prediction. This buffer is used to avoid repetition of the subsampling
+    // for the V plane when it is already done for the U plane.
+    int16_t cfl_luma_buffer[kCflLumaBufferStride][kCflLumaBufferStride];
+  };
+
+  // Buffer used for convolve. The maximum size required for this buffer is:
+  //  maximum block height (with scaling and border) = 2 * 128 + 3 + 4 = 263.
+  //  maximum block stride (with scaling and border aligned to 16) =
+  //     (2 * 128 + 3 + 8 + 5) * pixel_size = 272 * pixel_size.
+  //  Where pixel_size is (bitdepth == 8) ? 1 : 2.
+  // Has an alignment of kMaxAlignment when allocated.
+  AlignedUniquePtr<uint8_t> convolve_block_buffer;
+  ptrdiff_t convolve_block_buffer_stride;
+
+  // Flag indicating whether the data in |cfl_luma_buffer| is valid.
+  bool cfl_luma_buffer_valid;
+
+  // Equivalent to BlockDecoded array in the spec. This stores the decoded
+  // state of every 4x4 block in a superblock. It has 1 row/column border on
+  // all 4 sides (hence the 34x34 dimension instead of 32x32). Note that the
+  // spec uses "-1" as an index to access the left and top borders. In the
+  // code, we treat the index (1, 1) as equivalent to the spec's (0, 0). So
+  // all accesses into this array will be offset by +1 when compared with the
+  // spec.
+  bool block_decoded[kMaxPlanes][kBlockDecodedStride][kBlockDecodedStride];
+};
+
+class TileScratchBufferPool {
+ public:
+  void Reset(int bitdepth) {
+    if (bitdepth_ == bitdepth) return;
+#if LIBGAV1_MAX_BITDEPTH >= 10
+    if (bitdepth_ == 8 && bitdepth != 8) {
+      // We are going from a pixel size of 1 to a pixel size of 2. So invalidate
+      // the stack.
+      std::lock_guard<std::mutex> lock(mutex_);
+      while (!buffers_.Empty()) {
+        buffers_.Pop();
+      }
+    }
+#endif
+    bitdepth_ = bitdepth;
+  }
+
+  std::unique_ptr<TileScratchBuffer> Get() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    if (buffers_.Empty()) {
+      std::unique_ptr<TileScratchBuffer> scratch_buffer(new (std::nothrow)
+                                                            TileScratchBuffer);
+      if (scratch_buffer == nullptr || !scratch_buffer->Init(bitdepth_)) {
+        return nullptr;
+      }
+      return scratch_buffer;
+    }
+    return buffers_.Pop();
+  }
+
+  void Release(std::unique_ptr<TileScratchBuffer> scratch_buffer) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    buffers_.Push(std::move(scratch_buffer));
+  }
+
+ private:
+  std::mutex mutex_;
+  // We will never need more than kMaxThreads scratch buffers since that is the
+  // maximum amount of work that will be done at any given time.
+  Stack<std::unique_ptr<TileScratchBuffer>, kMaxThreads> buffers_
+      LIBGAV1_GUARDED_BY(mutex_);
+  int bitdepth_ = 0;
+};
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_TILE_SCRATCH_BUFFER_H_
diff --git a/libgav1/src/utils/array_2d.h b/libgav1/src/utils/array_2d.h
index 4eaaf32..2df6241 100644
--- a/libgav1/src/utils/array_2d.h
+++ b/libgav1/src/utils/array_2d.h
@@ -17,6 +17,7 @@
 #ifndef LIBGAV1_SRC_UTILS_ARRAY_2D_H_
 #define LIBGAV1_SRC_UTILS_ARRAY_2D_H_
 
+#include <cassert>
 #include <cstddef>
 #include <cstring>
 #include <memory>
@@ -55,8 +56,9 @@
 
  private:
   const T* GetRow(int row) const {
+    assert(row < rows_);
     const ptrdiff_t offset = static_cast<ptrdiff_t>(row) * columns_;
-    return (row < rows_) ? data_ + offset : nullptr;
+    return data_ + offset;
   }
 
   int rows_ = 0;
@@ -77,10 +79,10 @@
 
   LIBGAV1_MUST_USE_RESULT bool Reset(int rows, int columns,
                                      bool zero_initialize = true) {
-    const size_t size = rows * columns;
+    size_ = rows * columns;
     // If T is not a trivial type, we should always reallocate the data_
     // buffer, so that the destructors of any existing objects are invoked.
-    if (!std::is_trivial<T>::value || size_ < size) {
+    if (!std::is_trivial<T>::value || allocated_size_ < size_) {
       // Note: This invokes the global operator new if T is a non-class type,
       // such as integer or enum types, or a class type that is not derived
       // from libgav1::Allocable, such as std::unique_ptr. If we enforce a
@@ -88,20 +90,20 @@
       // consumption, we will need to handle the allocations here that use the
       // global operator new.
       if (zero_initialize) {
-        data_.reset(new (std::nothrow) T[size]());
+        data_.reset(new (std::nothrow) T[size_]());
       } else {
-        data_.reset(new (std::nothrow) T[size]);
+        data_.reset(new (std::nothrow) T[size_]);
       }
       if (data_ == nullptr) {
-        size_ = 0;
+        allocated_size_ = 0;
         return false;
       }
-      size_ = size;
+      allocated_size_ = size_;
     } else if (zero_initialize) {
       // Cast the data_ pointer to void* to avoid the GCC -Wclass-memaccess
       // warning. The memset is safe because T is a trivial type.
       void* dest = data_.get();
-      memset(dest, 0, sizeof(T) * size);
+      memset(dest, 0, sizeof(T) * size_);
     }
     data_view_.Reset(rows, columns, data_.get());
     return true;
@@ -109,7 +111,9 @@
 
   int rows() const { return data_view_.rows(); }
   int columns() const { return data_view_.columns(); }
+  size_t size() const { return size_; }
   T* data() { return data_.get(); }
+  const T* data() const { return data_.get(); }
 
   T* operator[](int row) { return data_view_[row]; }
 
@@ -117,6 +121,7 @@
 
  private:
   std::unique_ptr<T[]> data_ = nullptr;
+  size_t allocated_size_ = 0;
   size_t size_ = 0;
   Array2DView<T> data_view_;
 };
diff --git a/libgav1/src/utils/block_parameters_holder.cc b/libgav1/src/utils/block_parameters_holder.cc
index 316510e..79bb2b8 100644
--- a/libgav1/src/utils/block_parameters_holder.cc
+++ b/libgav1/src/utils/block_parameters_holder.cc
@@ -27,7 +27,7 @@
 namespace {
 
 // Returns the number of super block rows/columns for |value4x4| where value4x4
-// is either rows4x4 or column4x4.
+// is either rows4x4 or columns4x4.
 int RowsOrColumns4x4ToSuperBlocks(int value4x4, bool use_128x128_superblock) {
   return use_128x128_superblock ? DivideBy128(MultiplyBy4(value4x4) + 127)
                                 : DivideBy64(MultiplyBy4(value4x4) + 63);
@@ -35,13 +35,11 @@
 
 }  // namespace
 
-BlockParametersHolder::BlockParametersHolder(int rows4x4, int columns4x4,
-                                             bool use_128x128_superblock)
-    : rows4x4_(rows4x4),
-      columns4x4_(columns4x4),
-      use_128x128_superblock_(use_128x128_superblock) {}
-
-bool BlockParametersHolder::Init() {
+bool BlockParametersHolder::Reset(int rows4x4, int columns4x4,
+                                  bool use_128x128_superblock) {
+  rows4x4_ = rows4x4;
+  columns4x4_ = columns4x4;
+  use_128x128_superblock_ = use_128x128_superblock;
   if (!block_parameters_cache_.Reset(rows4x4_, columns4x4_)) {
     LIBGAV1_DLOG(ERROR, "block_parameters_cache_.Reset() failed.");
     return false;
@@ -73,14 +71,36 @@
 void BlockParametersHolder::FillCache(int row4x4, int column4x4,
                                       BlockSize block_size,
                                       BlockParameters* const bp) {
-  const int row4x4_end =
-      std::min(row4x4 + kNum4x4BlocksHigh[block_size], rows4x4_);
-  const int column4x4_end =
-      std::min(column4x4 + kNum4x4BlocksWide[block_size], columns4x4_);
-  for (int y = row4x4; y < row4x4_end; ++y) {
-    for (int x = column4x4; x < column4x4_end; ++x) {
-      block_parameters_cache_[y][x] = bp;
-    }
+  int rows = std::min(static_cast<int>(kNum4x4BlocksHigh[block_size]),
+                      rows4x4_ - row4x4);
+  const int columns = std::min(static_cast<int>(kNum4x4BlocksWide[block_size]),
+                               columns4x4_ - column4x4);
+  auto* bp_dst = &block_parameters_cache_[row4x4][column4x4];
+  // Specialize columns cases (values in kNum4x4BlocksWide[]) for better
+  // performance.
+  if (columns == 1) {
+    SetBlock<BlockParameters*>(rows, 1, bp, bp_dst, columns4x4_);
+  } else if (columns == 2) {
+    SetBlock<BlockParameters*>(rows, 2, bp, bp_dst, columns4x4_);
+  } else if (columns == 4) {
+    SetBlock<BlockParameters*>(rows, 4, bp, bp_dst, columns4x4_);
+  } else if (columns == 8) {
+    SetBlock<BlockParameters*>(rows, 8, bp, bp_dst, columns4x4_);
+  } else if (columns == 16) {
+    SetBlock<BlockParameters*>(rows, 16, bp, bp_dst, columns4x4_);
+  } else if (columns == 32) {
+    SetBlock<BlockParameters*>(rows, 32, bp, bp_dst, columns4x4_);
+  } else {
+    do {
+      // The following loop has better performance than using std::fill().
+      // std::fill() has some overhead in checking zero loop count.
+      int x = columns;
+      auto* d = bp_dst;
+      do {
+        *d++ = bp;
+      } while (--x != 0);
+      bp_dst += columns4x4_;
+    } while (--rows != 0);
   }
 }
 
diff --git a/libgav1/src/utils/block_parameters_holder.h b/libgav1/src/utils/block_parameters_holder.h
index 9cf5478..35543c3 100644
--- a/libgav1/src/utils/block_parameters_holder.h
+++ b/libgav1/src/utils/block_parameters_holder.h
@@ -31,17 +31,16 @@
 // corresponding to a superblock.
 class BlockParametersHolder {
  public:
-  // If |use_128x128_superblock| is true, 128x128 superblocks will be used,
-  // otherwise 64x64 superblocks will be used.
-  BlockParametersHolder(int rows4x4, int columns4x4,
-                        bool use_128x128_superblock);
+  BlockParametersHolder() = default;
 
   // Not copyable or movable.
   BlockParametersHolder(const BlockParametersHolder&) = delete;
   BlockParametersHolder& operator=(const BlockParametersHolder&) = delete;
 
-  // Must be called first.
-  LIBGAV1_MUST_USE_RESULT bool Init();
+  // If |use_128x128_superblock| is true, 128x128 superblocks will be used,
+  // otherwise 64x64 superblocks will be used.
+  LIBGAV1_MUST_USE_RESULT bool Reset(int rows4x4, int columns4x4,
+                                     bool use_128x128_superblock);
 
   // Finds the BlockParameters corresponding to |row4x4| and |column4x4|. This
   // is done as a simple look up of the |block_parameters_cache_| matrix.
@@ -50,6 +49,16 @@
     return block_parameters_cache_[row4x4][column4x4];
   }
 
+  BlockParameters** Address(int row4x4, int column4x4) {
+    return block_parameters_cache_.data() + row4x4 * columns4x4_ + column4x4;
+  }
+
+  BlockParameters* const* Address(int row4x4, int column4x4) const {
+    return block_parameters_cache_.data() + row4x4 * columns4x4_ + column4x4;
+  }
+
+  int columns4x4() const { return columns4x4_; }
+
   // Returns the ParameterTree corresponding to superblock starting at (|row|,
   // |column|).
   ParameterTree* Tree(int row, int column) { return trees_[row][column].get(); }
@@ -60,12 +69,12 @@
                  BlockParameters* bp);
 
  private:
-  const int rows4x4_;
-  const int columns4x4_;
-  const bool use_128x128_superblock_;
+  int rows4x4_ = 0;
+  int columns4x4_ = 0;
+  bool use_128x128_superblock_ = false;
   Array2D<std::unique_ptr<ParameterTree>> trees_;
 
-  // This is a 2d array of size |rows4x4_| * |columns4x4_|.This is filled in by
+  // This is a 2d array of size |rows4x4_| * |columns4x4_|. This is filled in by
   // FillCache() and used by Find() to perform look ups using exactly one look
   // up (instead of traversing the entire tree).
   Array2D<BlockParameters*> block_parameters_cache_;
diff --git a/libgav1/src/utils/common.h b/libgav1/src/utils/common.h
index 661483d..8caad2e 100644
--- a/libgav1/src/utils/common.h
+++ b/libgav1/src/utils/common.h
@@ -19,6 +19,7 @@
 
 #if defined(_MSC_VER)
 #include <intrin.h>
+#pragma intrinsic(_BitScanForward)
 #pragma intrinsic(_BitScanReverse)
 #if defined(_M_X64) || defined(_M_ARM) || defined(_M_ARM64)
 #pragma intrinsic(_BitScanReverse64)
@@ -29,9 +30,13 @@
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
+#include <cstdlib>
+#include <cstring>
+#include <type_traits>
 
 #include "src/utils/bit_mask_set.h"
 #include "src/utils/constants.h"
+#include "src/utils/types.h"
 
 namespace libgav1 {
 
@@ -43,10 +48,43 @@
   return (value + alignment_mask) & ~alignment_mask;
 }
 
+// Aligns |addr| to the desired |alignment|. |alignment| must be a power of 2.
+inline uint8_t* AlignAddr(uint8_t* const addr, const uintptr_t alignment) {
+  const auto value = reinterpret_cast<uintptr_t>(addr);
+  return reinterpret_cast<uint8_t*>(Align(value, alignment));
+}
+
 inline int32_t Clip3(int32_t value, int32_t low, int32_t high) {
   return value < low ? low : (value > high ? high : value);
 }
 
+// The following 2 templates set a block of data with uncontiguous memory to
+// |value|. The compilers usually generate several branches to handle different
+// cases of |columns| when inlining memset() and std::fill(), and these branches
+// are unfortunately within the loop of |rows|. So calling these templates
+// directly could be inefficient. It is recommended to specialize common cases
+// of |columns|, such as 1, 2, 4, 8, 16 and 32, etc. in advance before
+// processing the generic case of |columns|. The code size may be larger, but
+// there would be big speed gains.
+// Call template MemSetBlock<> when sizeof(|T|) is 1.
+// Call template SetBlock<> when sizeof(|T|) is larger than 1.
+template <typename T>
+void MemSetBlock(int rows, int columns, T value, T* dst, ptrdiff_t stride) {
+  static_assert(sizeof(T) == 1, "");
+  do {
+    memset(dst, value, columns);
+    dst += stride;
+  } while (--rows != 0);
+}
+
+template <typename T>
+void SetBlock(int rows, int columns, T value, T* dst, ptrdiff_t stride) {
+  do {
+    std::fill(dst, dst + columns, value);
+    dst += stride;
+  } while (--rows != 0);
+}
+
 #if defined(__GNUC__)
 
 inline int CountLeadingZeros(uint32_t n) {
@@ -59,25 +97,34 @@
   return __builtin_clzll(n);
 }
 
+inline int CountTrailingZeros(uint32_t n) {
+  assert(n != 0);
+  return __builtin_ctz(n);
+}
+
 #elif defined(_MSC_VER)
 
 inline int CountLeadingZeros(uint32_t n) {
+  assert(n != 0);
   unsigned long first_set_bit;  // NOLINT(runtime/int)
-  const unsigned char bit_set = _BitScanReverse(
-      &first_set_bit, static_cast<unsigned long>(n));  // NOLINT(runtime/int)
+  const unsigned char bit_set = _BitScanReverse(&first_set_bit, n);
   assert(bit_set != 0);
   static_cast<void>(bit_set);
   return 31 - static_cast<int>(first_set_bit);
 }
 
 inline int CountLeadingZeros(uint64_t n) {
+  assert(n != 0);
   unsigned long first_set_bit;  // NOLINT(runtime/int)
 #if defined(HAVE_BITSCANREVERSE64)
   const unsigned char bit_set =
       _BitScanReverse64(&first_set_bit, static_cast<unsigned __int64>(n));
-#else   // !defined(HAVE_BITSCANREVERSE64)
+#else  // !defined(HAVE_BITSCANREVERSE64)
   const auto n_hi = static_cast<unsigned long>(n >> 32);  // NOLINT(runtime/int)
-  if (n_hi != 0 && _BitScanReverse(&first_set_bit, n_hi)) {
+  if (n_hi != 0) {
+    const unsigned char bit_set = _BitScanReverse(&first_set_bit, n_hi);
+    assert(bit_set != 0);
+    static_cast<void>(bit_set);
     return 31 - static_cast<int>(first_set_bit);
   }
   const unsigned char bit_set = _BitScanReverse(
@@ -90,6 +137,15 @@
 
 #undef HAVE_BITSCANREVERSE64
 
+inline int CountTrailingZeros(uint32_t n) {
+  assert(n != 0);
+  unsigned long first_set_bit;  // NOLINT(runtime/int)
+  const unsigned char bit_set = _BitScanForward(&first_set_bit, n);
+  assert(bit_set != 0);
+  static_cast<void>(bit_set);
+  return static_cast<int>(first_set_bit);
+}
+
 #else  // !defined(__GNUC__) && !defined(_MSC_VER)
 
 template <const int kMSB, typename T>
@@ -108,6 +164,23 @@
 
 inline int CountLeadingZeros(uint64_t n) { return CountLeadingZeros<63>(n); }
 
+// This is the algorithm on the left in Figure 5-23, Hacker's Delight, Second
+// Edition, page 109. The book says:
+//   If the number of trailing 0's is expected to be small or large, then the
+//   simple loops shown in Figure 5-23 are quite fast.
+inline int CountTrailingZeros(uint32_t n) {
+  assert(n != 0);
+  // Create a word with 1's at the positions of the trailing 0's in |n|, and
+  // 0's elsewhere (e.g., 01011000 => 00000111).
+  n = ~n & (n - 1);
+  int count = 0;
+  while (n != 0) {
+    ++count;
+    n >>= 1;
+  }
+  return count;
+}
+
 #endif  // defined(__GNUC__)
 
 inline int FloorLog2(int32_t n) {
@@ -138,6 +211,10 @@
   return (n < 2) ? 0 : FloorLog2(n - 1) + 1;
 }
 
+constexpr int Ceil(int dividend, int divisor) {
+  return dividend / divisor + static_cast<int>(dividend % divisor != 0);
+}
+
 inline int32_t RightShiftWithRounding(int32_t value, int bits) {
   assert(bits >= 0);
   return (value + ((1 << bits) >> 1)) >> bits;
@@ -156,15 +233,21 @@
 }
 
 inline int32_t RightShiftWithRoundingSigned(int32_t value, int bits) {
-  return (value >= 0) ? RightShiftWithRounding(value, bits)
-                      : -RightShiftWithRounding(-value, bits);
+  assert(bits > 0);
+  // The next line is equivalent to:
+  // return (value >= 0) ? RightShiftWithRounding(value, bits)
+  //                     : -RightShiftWithRounding(-value, bits);
+  return RightShiftWithRounding(value + (value >> 31), bits);
 }
 
 // This variant is used when |value| can exceed 32 bits. Although the final
 // result must always fit into int32_t.
 inline int32_t RightShiftWithRoundingSigned(int64_t value, int bits) {
-  return (value >= 0) ? RightShiftWithRounding(value, bits)
-                      : -RightShiftWithRounding(-value, bits);
+  assert(bits > 0);
+  // The next line is equivalent to:
+  // return (value >= 0) ? RightShiftWithRounding(value, bits)
+  //                     : -RightShiftWithRounding(-value, bits);
+  return RightShiftWithRounding(value + (value >> 63), bits);
 }
 
 constexpr int DivideBy2(int n) { return n >> 1; }
@@ -232,55 +315,123 @@
 
 // 5.9.3.
 //
-// |a| and |b| are order hints, treated as unsigned |order_hint_bits|-bit
-// integers.
+// |a| and |b| are order hints, treated as unsigned order_hint_bits-bit
+// integers. |order_hint_shift_bits| equals (32 - order_hint_bits) % 32.
+// order_hint_bits is at most 8, so |order_hint_shift_bits| is zero or a
+// value between 24 and 31 (inclusive).
 //
-// If enable_order_hint is false, returns 0. If enable_order_hint is true,
-// returns the signed difference a - b using "modular arithmetic". More
-// precisely, the signed difference a - b is treated as a signed
-// |order_hint_bits|-bit integer and cast to an int. The returned difference is
-// between -(1 << (order_hint_bits - 1)) and (1 << (order_hint_bits - 1)) - 1
+// If |order_hint_shift_bits| is zero, |a| and |b| are both zeros, and the
+// result is zero. If |order_hint_shift_bits| is not zero, returns the
+// signed difference |a| - |b| using "modular arithmetic". More precisely, the
+// signed difference |a| - |b| is treated as a signed order_hint_bits-bit
+// integer and cast to an int. The returned difference is between
+// -(1 << (order_hint_bits - 1)) and (1 << (order_hint_bits - 1)) - 1
 // (inclusive).
 //
-// NOTE: |a| and |b| are the |order_hint_bits| least significant bits of the
+// NOTE: |a| and |b| are the order_hint_bits least significant bits of the
 // actual values. This function returns the signed difference between the
 // actual values. The returned difference is correct as long as the actual
 // values are not more than 1 << (order_hint_bits - 1) - 1 apart.
 //
-// Example: Suppose |order_hint_bits| is 4. Then |a| and |b| are in the range
-// [0, 15], and the actual values for |a| and |b| must not be more than 7
-// apart. (If the actual values for |a| and |b| are exactly 8 apart, this
-// function cannot tell whether the actual value for |a| is before or after the
-// actual value for |b|.)
+// Example: Suppose order_hint_bits is 4 and |order_hint_shift_bits|
+// is 28. Then |a| and |b| are in the range [0, 15], and the actual values for
+// |a| and |b| must not be more than 7 apart. (If the actual values for |a| and
+// |b| are exactly 8 apart, this function cannot tell whether the actual value
+// for |a| is before or after the actual value for |b|.)
 //
 // First, consider the order hints 2 and 6. For this simple case, we have
-//   GetRelativeDistance(2, 6, true, 4) = 2 - 6 = -4, and
-//   GetRelativeDistance(6, 2, true, 4) = 6 - 2 = 4.
+//   GetRelativeDistance(2, 6, 28) = 2 - 6 = -4, and
+//   GetRelativeDistance(6, 2, 28) = 6 - 2 = 4.
 //
 // On the other hand, consider the order hints 2 and 14. The order hints are
 // 12 (> 7) apart, so we need to use the actual values instead. The actual
 // values may be 34 (= 2 mod 16) and 30 (= 14 mod 16), respectively. Therefore
 // we have
-//   GetRelativeDistance(2, 14, true, 4) = 34 - 30 = 4, and
-//   GetRelativeDistance(14, 2, true, 4) = 30 - 34 = -4.
-inline int GetRelativeDistance(int a, int b, bool enable_order_hint,
-                               int order_hint_bits) {
-  if (!enable_order_hint) {
-    assert(order_hint_bits == 0);
-    return 0;
-  }
-  assert(order_hint_bits > 0);
-  assert(a >= 0 && a < (1 << order_hint_bits));
-  assert(b >= 0 && b < (1 << order_hint_bits));
+//   GetRelativeDistance(2, 14, 28) = 34 - 30 = 4, and
+//   GetRelativeDistance(14, 2, 28) = 30 - 34 = -4.
+//
+// The following comments apply only to specific CPUs' SIMD implementations,
+// such as intrinsics code.
+// For the 2 shift operations in this function, if the SIMD packed data is
+// 16-bit wide, try to use |order_hint_shift_bits| - 16 as the number of bits to
+// shift; If the SIMD packed data is 8-bit wide, try to use
+// |order_hint_shift_bits| - 24 as as the number of bits to shift.
+// |order_hint_shift_bits| - 16 and |order_hint_shift_bits| - 24 could be -16 or
+// -24. In these cases diff is 0, and the behavior of left or right shifting -16
+// or -24 bits is defined for x86 SIMD instructions and ARM NEON instructions,
+// and the result of shifting 0 is still 0. There is no guarantee that this
+// behavior and result apply to other CPUs' SIMD instructions.
+inline int GetRelativeDistance(const unsigned int a, const unsigned int b,
+                               const unsigned int order_hint_shift_bits) {
   const int diff = a - b;
-  const int m = 1 << (order_hint_bits - 1);
-  return (diff & (m - 1)) - (diff & m);
+  assert(order_hint_shift_bits <= 31);
+  if (order_hint_shift_bits == 0) {
+    assert(a == 0);
+    assert(b == 0);
+  } else {
+    assert(order_hint_shift_bits >= 24);  // i.e., order_hint_bits <= 8
+    assert(a < (1u << (32 - order_hint_shift_bits)));
+    assert(b < (1u << (32 - order_hint_shift_bits)));
+    assert(diff < (1 << (32 - order_hint_shift_bits)));
+    assert(diff >= -(1 << (32 - order_hint_shift_bits)));
+  }
+  // Sign extend the result of subtracting the values.
+  // Cast to unsigned int and then left shift to avoid undefined behavior with
+  // negative values. Cast to int to do the sign extension through right shift.
+  // This requires the right shift of a signed integer be an arithmetic shift,
+  // which is true for clang, gcc, and Visual C++.
+  // These two casts do not generate extra instructions.
+  // Don't use LeftShift(diff) since a valid diff may fail its assertions.
+  // For example, GetRelativeDistance(2, 14, 28), diff equals -12 and is less
+  // than the minimum allowed value of LeftShift() which is -8.
+  // The next 3 lines are equivalent to:
+  // const int order_hint_bits = Mod32(32 - order_hint_shift_bits);
+  // const int m = (1 << order_hint_bits) >> 1;
+  // return (diff & (m - 1)) - (diff & m);
+  return static_cast<int>(static_cast<unsigned int>(diff)
+                          << order_hint_shift_bits) >>
+         order_hint_shift_bits;
+}
+
+// Applies |sign| (must be 0 or -1) to |value|, i.e.,
+//   return (sign == 0) ? value : -value;
+// and does so without a branch.
+constexpr int ApplySign(int value, int sign) { return (value ^ sign) - sign; }
+
+// 7.9.3. (without the clamp for numerator and denominator).
+inline void GetMvProjection(const MotionVector& mv, int numerator,
+                            int division_multiplier,
+                            MotionVector* projection_mv) {
+  // Allow numerator and to be 0 so that this function can be called
+  // unconditionally. When numerator is 0, |projection_mv| will be 0, and this
+  // is what we want.
+  assert(std::abs(numerator) <= kMaxFrameDistance);
+  for (int i = 0; i < 2; ++i) {
+    projection_mv->mv[i] =
+        Clip3(RightShiftWithRoundingSigned(
+                  mv.mv[i] * numerator * division_multiplier, 14),
+              -kProjectionMvClamp, kProjectionMvClamp);
+  }
+}
+
+// 7.9.4.
+constexpr int Project(int value, int delta, int dst_sign) {
+  return value + ApplySign(delta / 64, dst_sign);
 }
 
 inline bool IsBlockSmallerThan8x8(BlockSize size) {
   return size < kBlock8x8 && size != kBlock4x16;
 }
 
+// Returns true if the either the width or the height of the block is equal to
+// four.
+inline bool IsBlockDimension4(BlockSize size) {
+  return size < kBlock8x8 || size == kBlock16x4;
+}
+
+// Converts bitdepth 8, 10, and 12 to array index 0, 1, and 2, respectively.
+constexpr int BitdepthToArrayIndex(int bitdepth) { return (bitdepth - 8) >> 1; }
+
 // Maps a square transform to an index between [0, 4]. kTransformSize4x4 maps
 // to 0, kTransformSize8x8 maps to 1 and so on.
 inline int TransformSizeToSquareTransformIndex(TransformSize tx_size) {
@@ -288,7 +439,9 @@
 
   // The values of the square transform sizes happen to be in the right
   // ranges, so we can just divide them by 4 to get the indexes.
-  static_assert(0 <= kTransformSize4x4 && kTransformSize4x4 < 4, "");
+  static_assert(
+      std::is_unsigned<std::underlying_type<TransformSize>::type>::value, "");
+  static_assert(kTransformSize4x4 < 4, "");
   static_assert(4 <= kTransformSize8x8 && kTransformSize8x8 < 8, "");
   static_assert(8 <= kTransformSize16x16 && kTransformSize16x16 < 12, "");
   static_assert(12 <= kTransformSize32x32 && kTransformSize32x32 < 16, "");
@@ -319,9 +472,14 @@
   //   Both x and y are not subsampled: 3 / 1 (which is equivalent to 6 / 2).
   // So we compute the final subsampling multiplier as follows:
   //   multiplier = (2 + (4 >> subsampling_x >> subsampling_y)) / 2.
+  // Add 32 * |kResidualPaddingVertical| padding to avoid bottom boundary checks
+  // when parsing quantized coefficients.
   const int subsampling_multiplier_num =
       2 + (4 >> subsampling_x >> subsampling_y);
-  return (residual_size * rows * columns * subsampling_multiplier_num) >> 1;
+  const int number_elements =
+      (rows * columns * subsampling_multiplier_num) >> 1;
+  const int tx_padding = 32 * kResidualPaddingVertical;
+  return residual_size * (number_elements + tx_padding);
 }
 
 // This function is equivalent to:
@@ -337,8 +495,23 @@
                        left_tx_size > kTransformSize8x32));
 }
 
+// This is used for 7.11.3.4 Block Inter Prediction Process, to select convolve
+// filters.
+inline int GetFilterIndex(const int filter_index, const int length) {
+  if (length <= 4) {
+    if (filter_index == kInterpolationFilterEightTap ||
+        filter_index == kInterpolationFilterEightTapSharp) {
+      return 4;
+    }
+    if (filter_index == kInterpolationFilterEightTapSmooth) {
+      return 5;
+    }
+  }
+  return filter_index;
+}
+
 constexpr int SubsampledValue(int value, int subsampling) {
-  return (subsampling == 0) ? value : DivideBy2(value + 1);
+  return (value + subsampling) >> subsampling;
 }
 
 }  // namespace libgav1
diff --git a/libgav1/src/utils/compiler_attributes.h b/libgav1/src/utils/compiler_attributes.h
index f2bc750..e122426 100644
--- a/libgav1/src/utils/compiler_attributes.h
+++ b/libgav1/src/utils/compiler_attributes.h
@@ -20,6 +20,19 @@
 // A collection of compiler attribute checks and defines to control for
 // compatibility across toolchains.
 
+//------------------------------------------------------------------------------
+// Language version, attribute and feature helpers.
+
+// Detect c++17 support. Visual Studio sets __cplusplus to 199711L by default
+// unless compiled with /Zc:__cplusplus, use the value controlled by /std
+// instead.
+// https://docs.microsoft.com/en-us/cpp/build/reference/zc-cplusplus
+#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
+#define LIBGAV1_CXX17 1
+#else
+#define LIBGAV1_CXX17 0
+#endif
+
 #if defined(__has_attribute)
 #define LIBGAV1_HAS_ATTRIBUTE __has_attribute
 #else
@@ -35,6 +48,12 @@
 //------------------------------------------------------------------------------
 // Sanitizer attributes.
 
+#if LIBGAV1_HAS_FEATURE(address_sanitizer) || defined(__SANITIZE_ADDRESS__)
+#define LIBGAV1_ASAN 1
+#else
+#define LIBGAV1_ASAN 0
+#endif
+
 #if LIBGAV1_HAS_FEATURE(memory_sanitizer)
 #define LIBGAV1_MSAN 1
 #else
@@ -48,6 +67,20 @@
 #endif
 
 //------------------------------------------------------------------------------
+// AddressSanitizer support.
+
+// Define the macros for AddressSanitizer manual memory poisoning. See
+// https://github.com/google/sanitizers/wiki/AddressSanitizerManualPoisoning.
+#if LIBGAV1_ASAN
+#include <sanitizer/asan_interface.h>
+#else
+#define ASAN_POISON_MEMORY_REGION(addr, size) \
+  (static_cast<void>(addr), static_cast<void>(size))
+#define ASAN_UNPOISON_MEMORY_REGION(addr, size) \
+  (static_cast<void>(addr), static_cast<void>(size))
+#endif
+
+//------------------------------------------------------------------------------
 // Function attributes.
 // GCC: https://gcc.gnu.org/onlinedocs/gcc/Function-Attributes.html
 // Clang: https://clang.llvm.org/docs/AttributeReference.html
diff --git a/libgav1/src/utils/constants.cc b/libgav1/src/utils/constants.cc
index 98e2fbf..97959fa 100644
--- a/libgav1/src/utils/constants.cc
+++ b/libgav1/src/utils/constants.cc
@@ -125,6 +125,11 @@
     {{kBlock128x64, kBlockInvalid}, {kBlock64x64, kBlock64x32}},
     {{kBlock128x128, kBlock128x64}, {kBlock64x128, kBlock64x64}}};
 
+const int16_t kProjectionMvDivisionLookup[kMaxFrameDistance + 1] = {
+    0,    16384, 8192, 5461, 4096, 3276, 2730, 2340, 2048, 1820, 1638,
+    1489, 1365,  1260, 1170, 1092, 1024, 963,  910,  862,  819,  780,
+    744,  712,   682,  655,  630,  606,  585,  564,  546,  528};
+
 const uint8_t kTransformWidth[kNumTransformSizes] = {
     4, 4, 4, 8, 8, 8, 8, 16, 16, 16, 16, 16, 32, 32, 32, 32, 64, 64, 64};
 
@@ -189,336 +194,644 @@
 
 const int8_t kWienerTapsMax[3] = {10, 8, 46};
 
-const int16_t kUpscaleFilter[kSuperResFilterShifts][kSuperResFilterTaps] = {
-    {0, 0, 0, 128, 0, 0, 0, 0},        {0, 0, -1, 128, 2, -1, 0, 0},
-    {0, 1, -3, 127, 4, -2, 1, 0},      {0, 1, -4, 127, 6, -3, 1, 0},
-    {0, 2, -6, 126, 8, -3, 1, 0},      {0, 2, -7, 125, 11, -4, 1, 0},
-    {-1, 2, -8, 125, 13, -5, 2, 0},    {-1, 3, -9, 124, 15, -6, 2, 0},
-    {-1, 3, -10, 123, 18, -6, 2, -1},  {-1, 3, -11, 122, 20, -7, 3, -1},
-    {-1, 4, -12, 121, 22, -8, 3, -1},  {-1, 4, -13, 120, 25, -9, 3, -1},
-    {-1, 4, -14, 118, 28, -9, 3, -1},  {-1, 4, -15, 117, 30, -10, 4, -1},
-    {-1, 5, -16, 116, 32, -11, 4, -1}, {-1, 5, -16, 114, 35, -12, 4, -1},
-    {-1, 5, -17, 112, 38, -12, 4, -1}, {-1, 5, -18, 111, 40, -13, 5, -1},
-    {-1, 5, -18, 109, 43, -14, 5, -1}, {-1, 6, -19, 107, 45, -14, 5, -1},
-    {-1, 6, -19, 105, 48, -15, 5, -1}, {-1, 6, -19, 103, 51, -16, 5, -1},
-    {-1, 6, -20, 101, 53, -16, 6, -1}, {-1, 6, -20, 99, 56, -17, 6, -1},
-    {-1, 6, -20, 97, 58, -17, 6, -1},  {-1, 6, -20, 95, 61, -18, 6, -1},
-    {-2, 7, -20, 93, 64, -18, 6, -2},  {-2, 7, -20, 91, 66, -19, 6, -1},
-    {-2, 7, -20, 88, 69, -19, 6, -1},  {-2, 7, -20, 86, 71, -19, 6, -1},
-    {-2, 7, -20, 84, 74, -20, 7, -2},  {-2, 7, -20, 81, 76, -20, 7, -1},
-    {-2, 7, -20, 79, 79, -20, 7, -2},  {-1, 7, -20, 76, 81, -20, 7, -2},
-    {-2, 7, -20, 74, 84, -20, 7, -2},  {-1, 6, -19, 71, 86, -20, 7, -2},
-    {-1, 6, -19, 69, 88, -20, 7, -2},  {-1, 6, -19, 66, 91, -20, 7, -2},
-    {-2, 6, -18, 64, 93, -20, 7, -2},  {-1, 6, -18, 61, 95, -20, 6, -1},
-    {-1, 6, -17, 58, 97, -20, 6, -1},  {-1, 6, -17, 56, 99, -20, 6, -1},
-    {-1, 6, -16, 53, 101, -20, 6, -1}, {-1, 5, -16, 51, 103, -19, 6, -1},
-    {-1, 5, -15, 48, 105, -19, 6, -1}, {-1, 5, -14, 45, 107, -19, 6, -1},
-    {-1, 5, -14, 43, 109, -18, 5, -1}, {-1, 5, -13, 40, 111, -18, 5, -1},
-    {-1, 4, -12, 38, 112, -17, 5, -1}, {-1, 4, -12, 35, 114, -16, 5, -1},
-    {-1, 4, -11, 32, 116, -16, 5, -1}, {-1, 4, -10, 30, 117, -15, 4, -1},
-    {-1, 3, -9, 28, 118, -14, 4, -1},  {-1, 3, -9, 25, 120, -13, 4, -1},
-    {-1, 3, -8, 22, 121, -12, 4, -1},  {-1, 3, -7, 20, 122, -11, 3, -1},
-    {-1, 2, -6, 18, 123, -10, 3, -1},  {0, 2, -6, 15, 124, -9, 3, -1},
-    {0, 2, -5, 13, 125, -8, 2, -1},    {0, 1, -4, 11, 125, -7, 2, 0},
-    {0, 1, -3, 8, 126, -6, 2, 0},      {0, 1, -3, 6, 127, -4, 1, 0},
-    {0, 1, -2, 4, 127, -3, 1, 0},      {0, 0, -1, 2, 128, -1, 0, 0},
+// This was modified from Upscale_Filter as defined in AV1 Section 7.16, in
+// order to support 16-bit packed NEON operations.
+// The sign of each tap is: - + - + + - + -
+alignas(16) const uint8_t
+    kUpscaleFilterUnsigned[kSuperResFilterShifts][kSuperResFilterTaps] = {
+        {0, 0, 0, 128, 0, 0, 0, 0},    {0, 0, 1, 128, 2, 1, 0, 0},
+        {0, 1, 3, 127, 4, 2, 1, 0},    {0, 1, 4, 127, 6, 3, 1, 0},
+        {0, 2, 6, 126, 8, 3, 1, 0},    {0, 2, 7, 125, 11, 4, 1, 0},
+        {1, 2, 8, 125, 13, 5, 2, 0},   {1, 3, 9, 124, 15, 6, 2, 0},
+        {1, 3, 10, 123, 18, 6, 2, 1},  {1, 3, 11, 122, 20, 7, 3, 1},
+        {1, 4, 12, 121, 22, 8, 3, 1},  {1, 4, 13, 120, 25, 9, 3, 1},
+        {1, 4, 14, 118, 28, 9, 3, 1},  {1, 4, 15, 117, 30, 10, 4, 1},
+        {1, 5, 16, 116, 32, 11, 4, 1}, {1, 5, 16, 114, 35, 12, 4, 1},
+        {1, 5, 17, 112, 38, 12, 4, 1}, {1, 5, 18, 111, 40, 13, 5, 1},
+        {1, 5, 18, 109, 43, 14, 5, 1}, {1, 6, 19, 107, 45, 14, 5, 1},
+        {1, 6, 19, 105, 48, 15, 5, 1}, {1, 6, 19, 103, 51, 16, 5, 1},
+        {1, 6, 20, 101, 53, 16, 6, 1}, {1, 6, 20, 99, 56, 17, 6, 1},
+        {1, 6, 20, 97, 58, 17, 6, 1},  {1, 6, 20, 95, 61, 18, 6, 1},
+        {2, 7, 20, 93, 64, 18, 6, 2},  {2, 7, 20, 91, 66, 19, 6, 1},
+        {2, 7, 20, 88, 69, 19, 6, 1},  {2, 7, 20, 86, 71, 19, 6, 1},
+        {2, 7, 20, 84, 74, 20, 7, 2},  {2, 7, 20, 81, 76, 20, 7, 1},
+        {2, 7, 20, 79, 79, 20, 7, 2},  {1, 7, 20, 76, 81, 20, 7, 2},
+        {2, 7, 20, 74, 84, 20, 7, 2},  {1, 6, 19, 71, 86, 20, 7, 2},
+        {1, 6, 19, 69, 88, 20, 7, 2},  {1, 6, 19, 66, 91, 20, 7, 2},
+        {2, 6, 18, 64, 93, 20, 7, 2},  {1, 6, 18, 61, 95, 20, 6, 1},
+        {1, 6, 17, 58, 97, 20, 6, 1},  {1, 6, 17, 56, 99, 20, 6, 1},
+        {1, 6, 16, 53, 101, 20, 6, 1}, {1, 5, 16, 51, 103, 19, 6, 1},
+        {1, 5, 15, 48, 105, 19, 6, 1}, {1, 5, 14, 45, 107, 19, 6, 1},
+        {1, 5, 14, 43, 109, 18, 5, 1}, {1, 5, 13, 40, 111, 18, 5, 1},
+        {1, 4, 12, 38, 112, 17, 5, 1}, {1, 4, 12, 35, 114, 16, 5, 1},
+        {1, 4, 11, 32, 116, 16, 5, 1}, {1, 4, 10, 30, 117, 15, 4, 1},
+        {1, 3, 9, 28, 118, 14, 4, 1},  {1, 3, 9, 25, 120, 13, 4, 1},
+        {1, 3, 8, 22, 121, 12, 4, 1},  {1, 3, 7, 20, 122, 11, 3, 1},
+        {1, 2, 6, 18, 123, 10, 3, 1},  {0, 2, 6, 15, 124, 9, 3, 1},
+        {0, 2, 5, 13, 125, 8, 2, 1},   {0, 1, 4, 11, 125, 7, 2, 0},
+        {0, 1, 3, 8, 126, 6, 2, 0},    {0, 1, 3, 6, 127, 4, 1, 0},
+        {0, 1, 2, 4, 127, 3, 1, 0},    {0, 0, 1, 2, 128, 1, 0, 0},
 };
 
-const int16_t kWarpedFilters[3 * kWarpedPixelPrecisionShifts + 1][8] = {
-    // [-1, 0).
-    {0, 0, 127, 1, 0, 0, 0, 0},
-    {0, -1, 127, 2, 0, 0, 0, 0},
-    {1, -3, 127, 4, -1, 0, 0, 0},
-    {1, -4, 126, 6, -2, 1, 0, 0},
-    {1, -5, 126, 8, -3, 1, 0, 0},
-    {1, -6, 125, 11, -4, 1, 0, 0},
-    {1, -7, 124, 13, -4, 1, 0, 0},
-    {2, -8, 123, 15, -5, 1, 0, 0},
-    {2, -9, 122, 18, -6, 1, 0, 0},
-    {2, -10, 121, 20, -6, 1, 0, 0},
-    {2, -11, 120, 22, -7, 2, 0, 0},
-    {2, -12, 119, 25, -8, 2, 0, 0},
-    {3, -13, 117, 27, -8, 2, 0, 0},
-    {3, -13, 116, 29, -9, 2, 0, 0},
-    {3, -14, 114, 32, -10, 3, 0, 0},
-    {3, -15, 113, 35, -10, 2, 0, 0},
-    {3, -15, 111, 37, -11, 3, 0, 0},
-    {3, -16, 109, 40, -11, 3, 0, 0},
-    {3, -16, 108, 42, -12, 3, 0, 0},
-    {4, -17, 106, 45, -13, 3, 0, 0},
-    {4, -17, 104, 47, -13, 3, 0, 0},
-    {4, -17, 102, 50, -14, 3, 0, 0},
-    {4, -17, 100, 52, -14, 3, 0, 0},
-    {4, -18, 98, 55, -15, 4, 0, 0},
-    {4, -18, 96, 58, -15, 3, 0, 0},
-    {4, -18, 94, 60, -16, 4, 0, 0},
-    {4, -18, 91, 63, -16, 4, 0, 0},
-    {4, -18, 89, 65, -16, 4, 0, 0},
-    {4, -18, 87, 68, -17, 4, 0, 0},
-    {4, -18, 85, 70, -17, 4, 0, 0},
-    {4, -18, 82, 73, -17, 4, 0, 0},
-    {4, -18, 80, 75, -17, 4, 0, 0},
-    {4, -18, 78, 78, -18, 4, 0, 0},
-    {4, -17, 75, 80, -18, 4, 0, 0},
-    {4, -17, 73, 82, -18, 4, 0, 0},
-    {4, -17, 70, 85, -18, 4, 0, 0},
-    {4, -17, 68, 87, -18, 4, 0, 0},
-    {4, -16, 65, 89, -18, 4, 0, 0},
-    {4, -16, 63, 91, -18, 4, 0, 0},
-    {4, -16, 60, 94, -18, 4, 0, 0},
-    {3, -15, 58, 96, -18, 4, 0, 0},
-    {4, -15, 55, 98, -18, 4, 0, 0},
-    {3, -14, 52, 100, -17, 4, 0, 0},
-    {3, -14, 50, 102, -17, 4, 0, 0},
-    {3, -13, 47, 104, -17, 4, 0, 0},
-    {3, -13, 45, 106, -17, 4, 0, 0},
-    {3, -12, 42, 108, -16, 3, 0, 0},
-    {3, -11, 40, 109, -16, 3, 0, 0},
-    {3, -11, 37, 111, -15, 3, 0, 0},
-    {2, -10, 35, 113, -15, 3, 0, 0},
-    {3, -10, 32, 114, -14, 3, 0, 0},
-    {2, -9, 29, 116, -13, 3, 0, 0},
-    {2, -8, 27, 117, -13, 3, 0, 0},
-    {2, -8, 25, 119, -12, 2, 0, 0},
-    {2, -7, 22, 120, -11, 2, 0, 0},
-    {1, -6, 20, 121, -10, 2, 0, 0},
-    {1, -6, 18, 122, -9, 2, 0, 0},
-    {1, -5, 15, 123, -8, 2, 0, 0},
-    {1, -4, 13, 124, -7, 1, 0, 0},
-    {1, -4, 11, 125, -6, 1, 0, 0},
-    {1, -3, 8, 126, -5, 1, 0, 0},
-    {1, -2, 6, 126, -4, 1, 0, 0},
-    {0, -1, 4, 127, -3, 1, 0, 0},
-    {0, 0, 2, 127, -1, 0, 0, 0},
-    // [0, 1).
-    {0, 0, 0, 127, 1, 0, 0, 0},
-    {0, 0, -1, 127, 2, 0, 0, 0},
-    {0, 1, -3, 127, 4, -2, 1, 0},
-    {0, 1, -5, 127, 6, -2, 1, 0},
-    {0, 2, -6, 126, 8, -3, 1, 0},
-    {-1, 2, -7, 126, 11, -4, 2, -1},
-    {-1, 3, -8, 125, 13, -5, 2, -1},
-    {-1, 3, -10, 124, 16, -6, 3, -1},
-    {-1, 4, -11, 123, 18, -7, 3, -1},
-    {-1, 4, -12, 122, 20, -7, 3, -1},
-    {-1, 4, -13, 121, 23, -8, 3, -1},
-    {-2, 5, -14, 120, 25, -9, 4, -1},
-    {-1, 5, -15, 119, 27, -10, 4, -1},
-    {-1, 5, -16, 118, 30, -11, 4, -1},
-    {-2, 6, -17, 116, 33, -12, 5, -1},
-    {-2, 6, -17, 114, 35, -12, 5, -1},
-    {-2, 6, -18, 113, 38, -13, 5, -1},
-    {-2, 7, -19, 111, 41, -14, 6, -2},
-    {-2, 7, -19, 110, 43, -15, 6, -2},
-    {-2, 7, -20, 108, 46, -15, 6, -2},
-    {-2, 7, -20, 106, 49, -16, 6, -2},
-    {-2, 7, -21, 104, 51, -16, 7, -2},
-    {-2, 7, -21, 102, 54, -17, 7, -2},
-    {-2, 8, -21, 100, 56, -18, 7, -2},
-    {-2, 8, -22, 98, 59, -18, 7, -2},
-    {-2, 8, -22, 96, 62, -19, 7, -2},
-    {-2, 8, -22, 94, 64, -19, 7, -2},
-    {-2, 8, -22, 91, 67, -20, 8, -2},
-    {-2, 8, -22, 89, 69, -20, 8, -2},
-    {-2, 8, -22, 87, 72, -21, 8, -2},
-    {-2, 8, -21, 84, 74, -21, 8, -2},
-    {-2, 8, -22, 82, 77, -21, 8, -2},
-    {-2, 8, -21, 79, 79, -21, 8, -2},
-    {-2, 8, -21, 77, 82, -22, 8, -2},
-    {-2, 8, -21, 74, 84, -21, 8, -2},
-    {-2, 8, -21, 72, 87, -22, 8, -2},
-    {-2, 8, -20, 69, 89, -22, 8, -2},
-    {-2, 8, -20, 67, 91, -22, 8, -2},
-    {-2, 7, -19, 64, 94, -22, 8, -2},
-    {-2, 7, -19, 62, 96, -22, 8, -2},
-    {-2, 7, -18, 59, 98, -22, 8, -2},
-    {-2, 7, -18, 56, 100, -21, 8, -2},
-    {-2, 7, -17, 54, 102, -21, 7, -2},
-    {-2, 7, -16, 51, 104, -21, 7, -2},
-    {-2, 6, -16, 49, 106, -20, 7, -2},
-    {-2, 6, -15, 46, 108, -20, 7, -2},
-    {-2, 6, -15, 43, 110, -19, 7, -2},
-    {-2, 6, -14, 41, 111, -19, 7, -2},
-    {-1, 5, -13, 38, 113, -18, 6, -2},
-    {-1, 5, -12, 35, 114, -17, 6, -2},
-    {-1, 5, -12, 33, 116, -17, 6, -2},
-    {-1, 4, -11, 30, 118, -16, 5, -1},
-    {-1, 4, -10, 27, 119, -15, 5, -1},
-    {-1, 4, -9, 25, 120, -14, 5, -2},
-    {-1, 3, -8, 23, 121, -13, 4, -1},
-    {-1, 3, -7, 20, 122, -12, 4, -1},
-    {-1, 3, -7, 18, 123, -11, 4, -1},
-    {-1, 3, -6, 16, 124, -10, 3, -1},
-    {-1, 2, -5, 13, 125, -8, 3, -1},
-    {-1, 2, -4, 11, 126, -7, 2, -1},
-    {0, 1, -3, 8, 126, -6, 2, 0},
-    {0, 1, -2, 6, 127, -5, 1, 0},
-    {0, 1, -2, 4, 127, -3, 1, 0},
-    {0, 0, 0, 2, 127, -1, 0, 0},
-    // [1, 2).
-    {0, 0, 0, 1, 127, 0, 0, 0},
-    {0, 0, 0, -1, 127, 2, 0, 0},
-    {0, 0, 1, -3, 127, 4, -1, 0},
-    {0, 0, 1, -4, 126, 6, -2, 1},
-    {0, 0, 1, -5, 126, 8, -3, 1},
-    {0, 0, 1, -6, 125, 11, -4, 1},
-    {0, 0, 1, -7, 124, 13, -4, 1},
-    {0, 0, 2, -8, 123, 15, -5, 1},
-    {0, 0, 2, -9, 122, 18, -6, 1},
-    {0, 0, 2, -10, 121, 20, -6, 1},
-    {0, 0, 2, -11, 120, 22, -7, 2},
-    {0, 0, 2, -12, 119, 25, -8, 2},
-    {0, 0, 3, -13, 117, 27, -8, 2},
-    {0, 0, 3, -13, 116, 29, -9, 2},
-    {0, 0, 3, -14, 114, 32, -10, 3},
-    {0, 0, 3, -15, 113, 35, -10, 2},
-    {0, 0, 3, -15, 111, 37, -11, 3},
-    {0, 0, 3, -16, 109, 40, -11, 3},
-    {0, 0, 3, -16, 108, 42, -12, 3},
-    {0, 0, 4, -17, 106, 45, -13, 3},
-    {0, 0, 4, -17, 104, 47, -13, 3},
-    {0, 0, 4, -17, 102, 50, -14, 3},
-    {0, 0, 4, -17, 100, 52, -14, 3},
-    {0, 0, 4, -18, 98, 55, -15, 4},
-    {0, 0, 4, -18, 96, 58, -15, 3},
-    {0, 0, 4, -18, 94, 60, -16, 4},
-    {0, 0, 4, -18, 91, 63, -16, 4},
-    {0, 0, 4, -18, 89, 65, -16, 4},
-    {0, 0, 4, -18, 87, 68, -17, 4},
-    {0, 0, 4, -18, 85, 70, -17, 4},
-    {0, 0, 4, -18, 82, 73, -17, 4},
-    {0, 0, 4, -18, 80, 75, -17, 4},
-    {0, 0, 4, -18, 78, 78, -18, 4},
-    {0, 0, 4, -17, 75, 80, -18, 4},
-    {0, 0, 4, -17, 73, 82, -18, 4},
-    {0, 0, 4, -17, 70, 85, -18, 4},
-    {0, 0, 4, -17, 68, 87, -18, 4},
-    {0, 0, 4, -16, 65, 89, -18, 4},
-    {0, 0, 4, -16, 63, 91, -18, 4},
-    {0, 0, 4, -16, 60, 94, -18, 4},
-    {0, 0, 3, -15, 58, 96, -18, 4},
-    {0, 0, 4, -15, 55, 98, -18, 4},
-    {0, 0, 3, -14, 52, 100, -17, 4},
-    {0, 0, 3, -14, 50, 102, -17, 4},
-    {0, 0, 3, -13, 47, 104, -17, 4},
-    {0, 0, 3, -13, 45, 106, -17, 4},
-    {0, 0, 3, -12, 42, 108, -16, 3},
-    {0, 0, 3, -11, 40, 109, -16, 3},
-    {0, 0, 3, -11, 37, 111, -15, 3},
-    {0, 0, 2, -10, 35, 113, -15, 3},
-    {0, 0, 3, -10, 32, 114, -14, 3},
-    {0, 0, 2, -9, 29, 116, -13, 3},
-    {0, 0, 2, -8, 27, 117, -13, 3},
-    {0, 0, 2, -8, 25, 119, -12, 2},
-    {0, 0, 2, -7, 22, 120, -11, 2},
-    {0, 0, 1, -6, 20, 121, -10, 2},
-    {0, 0, 1, -6, 18, 122, -9, 2},
-    {0, 0, 1, -5, 15, 123, -8, 2},
-    {0, 0, 1, -4, 13, 124, -7, 1},
-    {0, 0, 1, -4, 11, 125, -6, 1},
-    {0, 0, 1, -3, 8, 126, -5, 1},
-    {0, 0, 1, -2, 6, 126, -4, 1},
-    {0, 0, 0, -1, 4, 127, -3, 1},
-    {0, 0, 0, 0, 2, 127, -1, 0},
-    // dummy, replicate row index 191.
-    {0, 0, 0, 0, 2, 127, -1, 0}};
+alignas(8) const int8_t
+    kWarpedFilters8[3 * kWarpedPixelPrecisionShifts + 1][8] = {
+        // [-1, 0).
+        {0, 0, 127, 1, 0, 0, 0, 0},
+        {0, -1, 127, 2, 0, 0, 0, 0},
+        {1, -3, 127, 4, -1, 0, 0, 0},
+        {1, -4, 126, 6, -2, 1, 0, 0},
+        {1, -5, 126, 8, -3, 1, 0, 0},
+        {1, -6, 125, 11, -4, 1, 0, 0},
+        {1, -7, 124, 13, -4, 1, 0, 0},
+        {2, -8, 123, 15, -5, 1, 0, 0},
+        {2, -9, 122, 18, -6, 1, 0, 0},
+        {2, -10, 121, 20, -6, 1, 0, 0},
+        {2, -11, 120, 22, -7, 2, 0, 0},
+        {2, -12, 119, 25, -8, 2, 0, 0},
+        {3, -13, 117, 27, -8, 2, 0, 0},
+        {3, -13, 116, 29, -9, 2, 0, 0},
+        {3, -14, 114, 32, -10, 3, 0, 0},
+        {3, -15, 113, 35, -10, 2, 0, 0},
+        {3, -15, 111, 37, -11, 3, 0, 0},
+        {3, -16, 109, 40, -11, 3, 0, 0},
+        {3, -16, 108, 42, -12, 3, 0, 0},
+        {4, -17, 106, 45, -13, 3, 0, 0},
+        {4, -17, 104, 47, -13, 3, 0, 0},
+        {4, -17, 102, 50, -14, 3, 0, 0},
+        {4, -17, 100, 52, -14, 3, 0, 0},
+        {4, -18, 98, 55, -15, 4, 0, 0},
+        {4, -18, 96, 58, -15, 3, 0, 0},
+        {4, -18, 94, 60, -16, 4, 0, 0},
+        {4, -18, 91, 63, -16, 4, 0, 0},
+        {4, -18, 89, 65, -16, 4, 0, 0},
+        {4, -18, 87, 68, -17, 4, 0, 0},
+        {4, -18, 85, 70, -17, 4, 0, 0},
+        {4, -18, 82, 73, -17, 4, 0, 0},
+        {4, -18, 80, 75, -17, 4, 0, 0},
+        {4, -18, 78, 78, -18, 4, 0, 0},
+        {4, -17, 75, 80, -18, 4, 0, 0},
+        {4, -17, 73, 82, -18, 4, 0, 0},
+        {4, -17, 70, 85, -18, 4, 0, 0},
+        {4, -17, 68, 87, -18, 4, 0, 0},
+        {4, -16, 65, 89, -18, 4, 0, 0},
+        {4, -16, 63, 91, -18, 4, 0, 0},
+        {4, -16, 60, 94, -18, 4, 0, 0},
+        {3, -15, 58, 96, -18, 4, 0, 0},
+        {4, -15, 55, 98, -18, 4, 0, 0},
+        {3, -14, 52, 100, -17, 4, 0, 0},
+        {3, -14, 50, 102, -17, 4, 0, 0},
+        {3, -13, 47, 104, -17, 4, 0, 0},
+        {3, -13, 45, 106, -17, 4, 0, 0},
+        {3, -12, 42, 108, -16, 3, 0, 0},
+        {3, -11, 40, 109, -16, 3, 0, 0},
+        {3, -11, 37, 111, -15, 3, 0, 0},
+        {2, -10, 35, 113, -15, 3, 0, 0},
+        {3, -10, 32, 114, -14, 3, 0, 0},
+        {2, -9, 29, 116, -13, 3, 0, 0},
+        {2, -8, 27, 117, -13, 3, 0, 0},
+        {2, -8, 25, 119, -12, 2, 0, 0},
+        {2, -7, 22, 120, -11, 2, 0, 0},
+        {1, -6, 20, 121, -10, 2, 0, 0},
+        {1, -6, 18, 122, -9, 2, 0, 0},
+        {1, -5, 15, 123, -8, 2, 0, 0},
+        {1, -4, 13, 124, -7, 1, 0, 0},
+        {1, -4, 11, 125, -6, 1, 0, 0},
+        {1, -3, 8, 126, -5, 1, 0, 0},
+        {1, -2, 6, 126, -4, 1, 0, 0},
+        {0, -1, 4, 127, -3, 1, 0, 0},
+        {0, 0, 2, 127, -1, 0, 0, 0},
+        // [0, 1).
+        {0, 0, 0, 127, 1, 0, 0, 0},
+        {0, 0, -1, 127, 2, 0, 0, 0},
+        {0, 1, -3, 127, 4, -2, 1, 0},
+        {0, 1, -5, 127, 6, -2, 1, 0},
+        {0, 2, -6, 126, 8, -3, 1, 0},
+        {-1, 2, -7, 126, 11, -4, 2, -1},
+        {-1, 3, -8, 125, 13, -5, 2, -1},
+        {-1, 3, -10, 124, 16, -6, 3, -1},
+        {-1, 4, -11, 123, 18, -7, 3, -1},
+        {-1, 4, -12, 122, 20, -7, 3, -1},
+        {-1, 4, -13, 121, 23, -8, 3, -1},
+        {-2, 5, -14, 120, 25, -9, 4, -1},
+        {-1, 5, -15, 119, 27, -10, 4, -1},
+        {-1, 5, -16, 118, 30, -11, 4, -1},
+        {-2, 6, -17, 116, 33, -12, 5, -1},
+        {-2, 6, -17, 114, 35, -12, 5, -1},
+        {-2, 6, -18, 113, 38, -13, 5, -1},
+        {-2, 7, -19, 111, 41, -14, 6, -2},
+        {-2, 7, -19, 110, 43, -15, 6, -2},
+        {-2, 7, -20, 108, 46, -15, 6, -2},
+        {-2, 7, -20, 106, 49, -16, 6, -2},
+        {-2, 7, -21, 104, 51, -16, 7, -2},
+        {-2, 7, -21, 102, 54, -17, 7, -2},
+        {-2, 8, -21, 100, 56, -18, 7, -2},
+        {-2, 8, -22, 98, 59, -18, 7, -2},
+        {-2, 8, -22, 96, 62, -19, 7, -2},
+        {-2, 8, -22, 94, 64, -19, 7, -2},
+        {-2, 8, -22, 91, 67, -20, 8, -2},
+        {-2, 8, -22, 89, 69, -20, 8, -2},
+        {-2, 8, -22, 87, 72, -21, 8, -2},
+        {-2, 8, -21, 84, 74, -21, 8, -2},
+        {-2, 8, -22, 82, 77, -21, 8, -2},
+        {-2, 8, -21, 79, 79, -21, 8, -2},
+        {-2, 8, -21, 77, 82, -22, 8, -2},
+        {-2, 8, -21, 74, 84, -21, 8, -2},
+        {-2, 8, -21, 72, 87, -22, 8, -2},
+        {-2, 8, -20, 69, 89, -22, 8, -2},
+        {-2, 8, -20, 67, 91, -22, 8, -2},
+        {-2, 7, -19, 64, 94, -22, 8, -2},
+        {-2, 7, -19, 62, 96, -22, 8, -2},
+        {-2, 7, -18, 59, 98, -22, 8, -2},
+        {-2, 7, -18, 56, 100, -21, 8, -2},
+        {-2, 7, -17, 54, 102, -21, 7, -2},
+        {-2, 7, -16, 51, 104, -21, 7, -2},
+        {-2, 6, -16, 49, 106, -20, 7, -2},
+        {-2, 6, -15, 46, 108, -20, 7, -2},
+        {-2, 6, -15, 43, 110, -19, 7, -2},
+        {-2, 6, -14, 41, 111, -19, 7, -2},
+        {-1, 5, -13, 38, 113, -18, 6, -2},
+        {-1, 5, -12, 35, 114, -17, 6, -2},
+        {-1, 5, -12, 33, 116, -17, 6, -2},
+        {-1, 4, -11, 30, 118, -16, 5, -1},
+        {-1, 4, -10, 27, 119, -15, 5, -1},
+        {-1, 4, -9, 25, 120, -14, 5, -2},
+        {-1, 3, -8, 23, 121, -13, 4, -1},
+        {-1, 3, -7, 20, 122, -12, 4, -1},
+        {-1, 3, -7, 18, 123, -11, 4, -1},
+        {-1, 3, -6, 16, 124, -10, 3, -1},
+        {-1, 2, -5, 13, 125, -8, 3, -1},
+        {-1, 2, -4, 11, 126, -7, 2, -1},
+        {0, 1, -3, 8, 126, -6, 2, 0},
+        {0, 1, -2, 6, 127, -5, 1, 0},
+        {0, 1, -2, 4, 127, -3, 1, 0},
+        {0, 0, 0, 2, 127, -1, 0, 0},
+        // [1, 2).
+        {0, 0, 0, 1, 127, 0, 0, 0},
+        {0, 0, 0, -1, 127, 2, 0, 0},
+        {0, 0, 1, -3, 127, 4, -1, 0},
+        {0, 0, 1, -4, 126, 6, -2, 1},
+        {0, 0, 1, -5, 126, 8, -3, 1},
+        {0, 0, 1, -6, 125, 11, -4, 1},
+        {0, 0, 1, -7, 124, 13, -4, 1},
+        {0, 0, 2, -8, 123, 15, -5, 1},
+        {0, 0, 2, -9, 122, 18, -6, 1},
+        {0, 0, 2, -10, 121, 20, -6, 1},
+        {0, 0, 2, -11, 120, 22, -7, 2},
+        {0, 0, 2, -12, 119, 25, -8, 2},
+        {0, 0, 3, -13, 117, 27, -8, 2},
+        {0, 0, 3, -13, 116, 29, -9, 2},
+        {0, 0, 3, -14, 114, 32, -10, 3},
+        {0, 0, 3, -15, 113, 35, -10, 2},
+        {0, 0, 3, -15, 111, 37, -11, 3},
+        {0, 0, 3, -16, 109, 40, -11, 3},
+        {0, 0, 3, -16, 108, 42, -12, 3},
+        {0, 0, 4, -17, 106, 45, -13, 3},
+        {0, 0, 4, -17, 104, 47, -13, 3},
+        {0, 0, 4, -17, 102, 50, -14, 3},
+        {0, 0, 4, -17, 100, 52, -14, 3},
+        {0, 0, 4, -18, 98, 55, -15, 4},
+        {0, 0, 4, -18, 96, 58, -15, 3},
+        {0, 0, 4, -18, 94, 60, -16, 4},
+        {0, 0, 4, -18, 91, 63, -16, 4},
+        {0, 0, 4, -18, 89, 65, -16, 4},
+        {0, 0, 4, -18, 87, 68, -17, 4},
+        {0, 0, 4, -18, 85, 70, -17, 4},
+        {0, 0, 4, -18, 82, 73, -17, 4},
+        {0, 0, 4, -18, 80, 75, -17, 4},
+        {0, 0, 4, -18, 78, 78, -18, 4},
+        {0, 0, 4, -17, 75, 80, -18, 4},
+        {0, 0, 4, -17, 73, 82, -18, 4},
+        {0, 0, 4, -17, 70, 85, -18, 4},
+        {0, 0, 4, -17, 68, 87, -18, 4},
+        {0, 0, 4, -16, 65, 89, -18, 4},
+        {0, 0, 4, -16, 63, 91, -18, 4},
+        {0, 0, 4, -16, 60, 94, -18, 4},
+        {0, 0, 3, -15, 58, 96, -18, 4},
+        {0, 0, 4, -15, 55, 98, -18, 4},
+        {0, 0, 3, -14, 52, 100, -17, 4},
+        {0, 0, 3, -14, 50, 102, -17, 4},
+        {0, 0, 3, -13, 47, 104, -17, 4},
+        {0, 0, 3, -13, 45, 106, -17, 4},
+        {0, 0, 3, -12, 42, 108, -16, 3},
+        {0, 0, 3, -11, 40, 109, -16, 3},
+        {0, 0, 3, -11, 37, 111, -15, 3},
+        {0, 0, 2, -10, 35, 113, -15, 3},
+        {0, 0, 3, -10, 32, 114, -14, 3},
+        {0, 0, 2, -9, 29, 116, -13, 3},
+        {0, 0, 2, -8, 27, 117, -13, 3},
+        {0, 0, 2, -8, 25, 119, -12, 2},
+        {0, 0, 2, -7, 22, 120, -11, 2},
+        {0, 0, 1, -6, 20, 121, -10, 2},
+        {0, 0, 1, -6, 18, 122, -9, 2},
+        {0, 0, 1, -5, 15, 123, -8, 2},
+        {0, 0, 1, -4, 13, 124, -7, 1},
+        {0, 0, 1, -4, 11, 125, -6, 1},
+        {0, 0, 1, -3, 8, 126, -5, 1},
+        {0, 0, 1, -2, 6, 126, -4, 1},
+        {0, 0, 0, -1, 4, 127, -3, 1},
+        {0, 0, 0, 0, 2, 127, -1, 0},
+        // dummy, replicate row index 191.
+        {0, 0, 0, 0, 2, 127, -1, 0}};
 
-const int16_t kSubPixelFilters[6][16][8] = {{{0, 0, 0, 128, 0, 0, 0, 0},
-                                             {0, 2, -6, 126, 8, -2, 0, 0},
-                                             {0, 2, -10, 122, 18, -4, 0, 0},
-                                             {0, 2, -12, 116, 28, -8, 2, 0},
-                                             {0, 2, -14, 110, 38, -10, 2, 0},
-                                             {0, 2, -14, 102, 48, -12, 2, 0},
-                                             {0, 2, -16, 94, 58, -12, 2, 0},
-                                             {0, 2, -14, 84, 66, -12, 2, 0},
-                                             {0, 2, -14, 76, 76, -14, 2, 0},
-                                             {0, 2, -12, 66, 84, -14, 2, 0},
-                                             {0, 2, -12, 58, 94, -16, 2, 0},
-                                             {0, 2, -12, 48, 102, -14, 2, 0},
-                                             {0, 2, -10, 38, 110, -14, 2, 0},
-                                             {0, 2, -8, 28, 116, -12, 2, 0},
-                                             {0, 0, -4, 18, 122, -10, 2, 0},
-                                             {0, 0, -2, 8, 126, -6, 2, 0}},
-                                            {{0, 0, 0, 128, 0, 0, 0, 0},
-                                             {0, 2, 28, 62, 34, 2, 0, 0},
-                                             {0, 0, 26, 62, 36, 4, 0, 0},
-                                             {0, 0, 22, 62, 40, 4, 0, 0},
-                                             {0, 0, 20, 60, 42, 6, 0, 0},
-                                             {0, 0, 18, 58, 44, 8, 0, 0},
-                                             {0, 0, 16, 56, 46, 10, 0, 0},
-                                             {0, -2, 16, 54, 48, 12, 0, 0},
-                                             {0, -2, 14, 52, 52, 14, -2, 0},
-                                             {0, 0, 12, 48, 54, 16, -2, 0},
-                                             {0, 0, 10, 46, 56, 16, 0, 0},
-                                             {0, 0, 8, 44, 58, 18, 0, 0},
-                                             {0, 0, 6, 42, 60, 20, 0, 0},
-                                             {0, 0, 4, 40, 62, 22, 0, 0},
-                                             {0, 0, 4, 36, 62, 26, 0, 0},
-                                             {0, 0, 2, 34, 62, 28, 2, 0}},
-                                            {{0, 0, 0, 128, 0, 0, 0, 0},
-                                             {-2, 2, -6, 126, 8, -2, 2, 0},
-                                             {-2, 6, -12, 124, 16, -6, 4, -2},
-                                             {-2, 8, -18, 120, 26, -10, 6, -2},
-                                             {-4, 10, -22, 116, 38, -14, 6, -2},
-                                             {-4, 10, -22, 108, 48, -18, 8, -2},
-                                             {-4, 10, -24, 100, 60, -20, 8, -2},
-                                             {-4, 10, -24, 90, 70, -22, 10, -2},
-                                             {-4, 12, -24, 80, 80, -24, 12, -4},
-                                             {-2, 10, -22, 70, 90, -24, 10, -4},
-                                             {-2, 8, -20, 60, 100, -24, 10, -4},
-                                             {-2, 8, -18, 48, 108, -22, 10, -4},
-                                             {-2, 6, -14, 38, 116, -22, 10, -4},
-                                             {-2, 6, -10, 26, 120, -18, 8, -2},
-                                             {-2, 4, -6, 16, 124, -12, 6, -2},
-                                             {0, 2, -2, 8, 126, -6, 2, -2}},
-                                            {{0, 0, 0, 128, 0, 0, 0, 0},
-                                             {0, 0, 0, 120, 8, 0, 0, 0},
-                                             {0, 0, 0, 112, 16, 0, 0, 0},
-                                             {0, 0, 0, 104, 24, 0, 0, 0},
-                                             {0, 0, 0, 96, 32, 0, 0, 0},
-                                             {0, 0, 0, 88, 40, 0, 0, 0},
-                                             {0, 0, 0, 80, 48, 0, 0, 0},
-                                             {0, 0, 0, 72, 56, 0, 0, 0},
-                                             {0, 0, 0, 64, 64, 0, 0, 0},
-                                             {0, 0, 0, 56, 72, 0, 0, 0},
-                                             {0, 0, 0, 48, 80, 0, 0, 0},
-                                             {0, 0, 0, 40, 88, 0, 0, 0},
-                                             {0, 0, 0, 32, 96, 0, 0, 0},
-                                             {0, 0, 0, 24, 104, 0, 0, 0},
-                                             {0, 0, 0, 16, 112, 0, 0, 0},
-                                             {0, 0, 0, 8, 120, 0, 0, 0}},
-                                            {{0, 0, 0, 128, 0, 0, 0, 0},
-                                             {0, 0, -4, 126, 8, -2, 0, 0},
-                                             {0, 0, -8, 122, 18, -4, 0, 0},
-                                             {0, 0, -10, 116, 28, -6, 0, 0},
-                                             {0, 0, -12, 110, 38, -8, 0, 0},
-                                             {0, 0, -12, 102, 48, -10, 0, 0},
-                                             {0, 0, -14, 94, 58, -10, 0, 0},
-                                             {0, 0, -12, 84, 66, -10, 0, 0},
-                                             {0, 0, -12, 76, 76, -12, 0, 0},
-                                             {0, 0, -10, 66, 84, -12, 0, 0},
-                                             {0, 0, -10, 58, 94, -14, 0, 0},
-                                             {0, 0, -10, 48, 102, -12, 0, 0},
-                                             {0, 0, -8, 38, 110, -12, 0, 0},
-                                             {0, 0, -6, 28, 116, -10, 0, 0},
-                                             {0, 0, -4, 18, 122, -8, 0, 0},
-                                             {0, 0, -2, 8, 126, -4, 0, 0}},
-                                            {{0, 0, 0, 128, 0, 0, 0, 0},
-                                             {0, 0, 30, 62, 34, 2, 0, 0},
-                                             {0, 0, 26, 62, 36, 4, 0, 0},
-                                             {0, 0, 22, 62, 40, 4, 0, 0},
-                                             {0, 0, 20, 60, 42, 6, 0, 0},
-                                             {0, 0, 18, 58, 44, 8, 0, 0},
-                                             {0, 0, 16, 56, 46, 10, 0, 0},
-                                             {0, 0, 14, 54, 48, 12, 0, 0},
-                                             {0, 0, 12, 52, 52, 12, 0, 0},
-                                             {0, 0, 12, 48, 54, 14, 0, 0},
-                                             {0, 0, 10, 46, 56, 16, 0, 0},
-                                             {0, 0, 8, 44, 58, 18, 0, 0},
-                                             {0, 0, 6, 42, 60, 20, 0, 0},
-                                             {0, 0, 4, 40, 62, 22, 0, 0},
-                                             {0, 0, 4, 36, 62, 26, 0, 0},
-                                             {0, 0, 2, 34, 62, 30, 0, 0}}};
+alignas(16) const int16_t
+    kWarpedFilters[3 * kWarpedPixelPrecisionShifts + 1][8] = {
+        // [-1, 0).
+        {0, 0, 127, 1, 0, 0, 0, 0},
+        {0, -1, 127, 2, 0, 0, 0, 0},
+        {1, -3, 127, 4, -1, 0, 0, 0},
+        {1, -4, 126, 6, -2, 1, 0, 0},
+        {1, -5, 126, 8, -3, 1, 0, 0},
+        {1, -6, 125, 11, -4, 1, 0, 0},
+        {1, -7, 124, 13, -4, 1, 0, 0},
+        {2, -8, 123, 15, -5, 1, 0, 0},
+        {2, -9, 122, 18, -6, 1, 0, 0},
+        {2, -10, 121, 20, -6, 1, 0, 0},
+        {2, -11, 120, 22, -7, 2, 0, 0},
+        {2, -12, 119, 25, -8, 2, 0, 0},
+        {3, -13, 117, 27, -8, 2, 0, 0},
+        {3, -13, 116, 29, -9, 2, 0, 0},
+        {3, -14, 114, 32, -10, 3, 0, 0},
+        {3, -15, 113, 35, -10, 2, 0, 0},
+        {3, -15, 111, 37, -11, 3, 0, 0},
+        {3, -16, 109, 40, -11, 3, 0, 0},
+        {3, -16, 108, 42, -12, 3, 0, 0},
+        {4, -17, 106, 45, -13, 3, 0, 0},
+        {4, -17, 104, 47, -13, 3, 0, 0},
+        {4, -17, 102, 50, -14, 3, 0, 0},
+        {4, -17, 100, 52, -14, 3, 0, 0},
+        {4, -18, 98, 55, -15, 4, 0, 0},
+        {4, -18, 96, 58, -15, 3, 0, 0},
+        {4, -18, 94, 60, -16, 4, 0, 0},
+        {4, -18, 91, 63, -16, 4, 0, 0},
+        {4, -18, 89, 65, -16, 4, 0, 0},
+        {4, -18, 87, 68, -17, 4, 0, 0},
+        {4, -18, 85, 70, -17, 4, 0, 0},
+        {4, -18, 82, 73, -17, 4, 0, 0},
+        {4, -18, 80, 75, -17, 4, 0, 0},
+        {4, -18, 78, 78, -18, 4, 0, 0},
+        {4, -17, 75, 80, -18, 4, 0, 0},
+        {4, -17, 73, 82, -18, 4, 0, 0},
+        {4, -17, 70, 85, -18, 4, 0, 0},
+        {4, -17, 68, 87, -18, 4, 0, 0},
+        {4, -16, 65, 89, -18, 4, 0, 0},
+        {4, -16, 63, 91, -18, 4, 0, 0},
+        {4, -16, 60, 94, -18, 4, 0, 0},
+        {3, -15, 58, 96, -18, 4, 0, 0},
+        {4, -15, 55, 98, -18, 4, 0, 0},
+        {3, -14, 52, 100, -17, 4, 0, 0},
+        {3, -14, 50, 102, -17, 4, 0, 0},
+        {3, -13, 47, 104, -17, 4, 0, 0},
+        {3, -13, 45, 106, -17, 4, 0, 0},
+        {3, -12, 42, 108, -16, 3, 0, 0},
+        {3, -11, 40, 109, -16, 3, 0, 0},
+        {3, -11, 37, 111, -15, 3, 0, 0},
+        {2, -10, 35, 113, -15, 3, 0, 0},
+        {3, -10, 32, 114, -14, 3, 0, 0},
+        {2, -9, 29, 116, -13, 3, 0, 0},
+        {2, -8, 27, 117, -13, 3, 0, 0},
+        {2, -8, 25, 119, -12, 2, 0, 0},
+        {2, -7, 22, 120, -11, 2, 0, 0},
+        {1, -6, 20, 121, -10, 2, 0, 0},
+        {1, -6, 18, 122, -9, 2, 0, 0},
+        {1, -5, 15, 123, -8, 2, 0, 0},
+        {1, -4, 13, 124, -7, 1, 0, 0},
+        {1, -4, 11, 125, -6, 1, 0, 0},
+        {1, -3, 8, 126, -5, 1, 0, 0},
+        {1, -2, 6, 126, -4, 1, 0, 0},
+        {0, -1, 4, 127, -3, 1, 0, 0},
+        {0, 0, 2, 127, -1, 0, 0, 0},
+        // [0, 1).
+        {0, 0, 0, 127, 1, 0, 0, 0},
+        {0, 0, -1, 127, 2, 0, 0, 0},
+        {0, 1, -3, 127, 4, -2, 1, 0},
+        {0, 1, -5, 127, 6, -2, 1, 0},
+        {0, 2, -6, 126, 8, -3, 1, 0},
+        {-1, 2, -7, 126, 11, -4, 2, -1},
+        {-1, 3, -8, 125, 13, -5, 2, -1},
+        {-1, 3, -10, 124, 16, -6, 3, -1},
+        {-1, 4, -11, 123, 18, -7, 3, -1},
+        {-1, 4, -12, 122, 20, -7, 3, -1},
+        {-1, 4, -13, 121, 23, -8, 3, -1},
+        {-2, 5, -14, 120, 25, -9, 4, -1},
+        {-1, 5, -15, 119, 27, -10, 4, -1},
+        {-1, 5, -16, 118, 30, -11, 4, -1},
+        {-2, 6, -17, 116, 33, -12, 5, -1},
+        {-2, 6, -17, 114, 35, -12, 5, -1},
+        {-2, 6, -18, 113, 38, -13, 5, -1},
+        {-2, 7, -19, 111, 41, -14, 6, -2},
+        {-2, 7, -19, 110, 43, -15, 6, -2},
+        {-2, 7, -20, 108, 46, -15, 6, -2},
+        {-2, 7, -20, 106, 49, -16, 6, -2},
+        {-2, 7, -21, 104, 51, -16, 7, -2},
+        {-2, 7, -21, 102, 54, -17, 7, -2},
+        {-2, 8, -21, 100, 56, -18, 7, -2},
+        {-2, 8, -22, 98, 59, -18, 7, -2},
+        {-2, 8, -22, 96, 62, -19, 7, -2},
+        {-2, 8, -22, 94, 64, -19, 7, -2},
+        {-2, 8, -22, 91, 67, -20, 8, -2},
+        {-2, 8, -22, 89, 69, -20, 8, -2},
+        {-2, 8, -22, 87, 72, -21, 8, -2},
+        {-2, 8, -21, 84, 74, -21, 8, -2},
+        {-2, 8, -22, 82, 77, -21, 8, -2},
+        {-2, 8, -21, 79, 79, -21, 8, -2},
+        {-2, 8, -21, 77, 82, -22, 8, -2},
+        {-2, 8, -21, 74, 84, -21, 8, -2},
+        {-2, 8, -21, 72, 87, -22, 8, -2},
+        {-2, 8, -20, 69, 89, -22, 8, -2},
+        {-2, 8, -20, 67, 91, -22, 8, -2},
+        {-2, 7, -19, 64, 94, -22, 8, -2},
+        {-2, 7, -19, 62, 96, -22, 8, -2},
+        {-2, 7, -18, 59, 98, -22, 8, -2},
+        {-2, 7, -18, 56, 100, -21, 8, -2},
+        {-2, 7, -17, 54, 102, -21, 7, -2},
+        {-2, 7, -16, 51, 104, -21, 7, -2},
+        {-2, 6, -16, 49, 106, -20, 7, -2},
+        {-2, 6, -15, 46, 108, -20, 7, -2},
+        {-2, 6, -15, 43, 110, -19, 7, -2},
+        {-2, 6, -14, 41, 111, -19, 7, -2},
+        {-1, 5, -13, 38, 113, -18, 6, -2},
+        {-1, 5, -12, 35, 114, -17, 6, -2},
+        {-1, 5, -12, 33, 116, -17, 6, -2},
+        {-1, 4, -11, 30, 118, -16, 5, -1},
+        {-1, 4, -10, 27, 119, -15, 5, -1},
+        {-1, 4, -9, 25, 120, -14, 5, -2},
+        {-1, 3, -8, 23, 121, -13, 4, -1},
+        {-1, 3, -7, 20, 122, -12, 4, -1},
+        {-1, 3, -7, 18, 123, -11, 4, -1},
+        {-1, 3, -6, 16, 124, -10, 3, -1},
+        {-1, 2, -5, 13, 125, -8, 3, -1},
+        {-1, 2, -4, 11, 126, -7, 2, -1},
+        {0, 1, -3, 8, 126, -6, 2, 0},
+        {0, 1, -2, 6, 127, -5, 1, 0},
+        {0, 1, -2, 4, 127, -3, 1, 0},
+        {0, 0, 0, 2, 127, -1, 0, 0},
+        // [1, 2).
+        {0, 0, 0, 1, 127, 0, 0, 0},
+        {0, 0, 0, -1, 127, 2, 0, 0},
+        {0, 0, 1, -3, 127, 4, -1, 0},
+        {0, 0, 1, -4, 126, 6, -2, 1},
+        {0, 0, 1, -5, 126, 8, -3, 1},
+        {0, 0, 1, -6, 125, 11, -4, 1},
+        {0, 0, 1, -7, 124, 13, -4, 1},
+        {0, 0, 2, -8, 123, 15, -5, 1},
+        {0, 0, 2, -9, 122, 18, -6, 1},
+        {0, 0, 2, -10, 121, 20, -6, 1},
+        {0, 0, 2, -11, 120, 22, -7, 2},
+        {0, 0, 2, -12, 119, 25, -8, 2},
+        {0, 0, 3, -13, 117, 27, -8, 2},
+        {0, 0, 3, -13, 116, 29, -9, 2},
+        {0, 0, 3, -14, 114, 32, -10, 3},
+        {0, 0, 3, -15, 113, 35, -10, 2},
+        {0, 0, 3, -15, 111, 37, -11, 3},
+        {0, 0, 3, -16, 109, 40, -11, 3},
+        {0, 0, 3, -16, 108, 42, -12, 3},
+        {0, 0, 4, -17, 106, 45, -13, 3},
+        {0, 0, 4, -17, 104, 47, -13, 3},
+        {0, 0, 4, -17, 102, 50, -14, 3},
+        {0, 0, 4, -17, 100, 52, -14, 3},
+        {0, 0, 4, -18, 98, 55, -15, 4},
+        {0, 0, 4, -18, 96, 58, -15, 3},
+        {0, 0, 4, -18, 94, 60, -16, 4},
+        {0, 0, 4, -18, 91, 63, -16, 4},
+        {0, 0, 4, -18, 89, 65, -16, 4},
+        {0, 0, 4, -18, 87, 68, -17, 4},
+        {0, 0, 4, -18, 85, 70, -17, 4},
+        {0, 0, 4, -18, 82, 73, -17, 4},
+        {0, 0, 4, -18, 80, 75, -17, 4},
+        {0, 0, 4, -18, 78, 78, -18, 4},
+        {0, 0, 4, -17, 75, 80, -18, 4},
+        {0, 0, 4, -17, 73, 82, -18, 4},
+        {0, 0, 4, -17, 70, 85, -18, 4},
+        {0, 0, 4, -17, 68, 87, -18, 4},
+        {0, 0, 4, -16, 65, 89, -18, 4},
+        {0, 0, 4, -16, 63, 91, -18, 4},
+        {0, 0, 4, -16, 60, 94, -18, 4},
+        {0, 0, 3, -15, 58, 96, -18, 4},
+        {0, 0, 4, -15, 55, 98, -18, 4},
+        {0, 0, 3, -14, 52, 100, -17, 4},
+        {0, 0, 3, -14, 50, 102, -17, 4},
+        {0, 0, 3, -13, 47, 104, -17, 4},
+        {0, 0, 3, -13, 45, 106, -17, 4},
+        {0, 0, 3, -12, 42, 108, -16, 3},
+        {0, 0, 3, -11, 40, 109, -16, 3},
+        {0, 0, 3, -11, 37, 111, -15, 3},
+        {0, 0, 2, -10, 35, 113, -15, 3},
+        {0, 0, 3, -10, 32, 114, -14, 3},
+        {0, 0, 2, -9, 29, 116, -13, 3},
+        {0, 0, 2, -8, 27, 117, -13, 3},
+        {0, 0, 2, -8, 25, 119, -12, 2},
+        {0, 0, 2, -7, 22, 120, -11, 2},
+        {0, 0, 1, -6, 20, 121, -10, 2},
+        {0, 0, 1, -6, 18, 122, -9, 2},
+        {0, 0, 1, -5, 15, 123, -8, 2},
+        {0, 0, 1, -4, 13, 124, -7, 1},
+        {0, 0, 1, -4, 11, 125, -6, 1},
+        {0, 0, 1, -3, 8, 126, -5, 1},
+        {0, 0, 1, -2, 6, 126, -4, 1},
+        {0, 0, 0, -1, 4, 127, -3, 1},
+        {0, 0, 0, 0, 2, 127, -1, 0},
+        // dummy, replicate row index 191.
+        {0, 0, 0, 0, 2, 127, -1, 0}};
+
+// Every value in |kSubPixelFilters| is even. Divide by 2 to simplify
+// calculations by reducing the range by 1 bit.
+alignas(8) const int8_t kHalfSubPixelFilters[6][16][8] = {
+    {{0, 0, 0, 64, 0, 0, 0, 0},
+     {0, 1, -3, 63, 4, -1, 0, 0},
+     {0, 1, -5, 61, 9, -2, 0, 0},
+     {0, 1, -6, 58, 14, -4, 1, 0},
+     {0, 1, -7, 55, 19, -5, 1, 0},
+     {0, 1, -7, 51, 24, -6, 1, 0},
+     {0, 1, -8, 47, 29, -6, 1, 0},
+     {0, 1, -7, 42, 33, -6, 1, 0},
+     {0, 1, -7, 38, 38, -7, 1, 0},
+     {0, 1, -6, 33, 42, -7, 1, 0},
+     {0, 1, -6, 29, 47, -8, 1, 0},
+     {0, 1, -6, 24, 51, -7, 1, 0},
+     {0, 1, -5, 19, 55, -7, 1, 0},
+     {0, 1, -4, 14, 58, -6, 1, 0},
+     {0, 0, -2, 9, 61, -5, 1, 0},
+     {0, 0, -1, 4, 63, -3, 1, 0}},
+    {{0, 0, 0, 64, 0, 0, 0, 0},
+     {0, 1, 14, 31, 17, 1, 0, 0},
+     {0, 0, 13, 31, 18, 2, 0, 0},
+     {0, 0, 11, 31, 20, 2, 0, 0},
+     {0, 0, 10, 30, 21, 3, 0, 0},
+     {0, 0, 9, 29, 22, 4, 0, 0},
+     {0, 0, 8, 28, 23, 5, 0, 0},
+     {0, -1, 8, 27, 24, 6, 0, 0},
+     {0, -1, 7, 26, 26, 7, -1, 0},
+     {0, 0, 6, 24, 27, 8, -1, 0},
+     {0, 0, 5, 23, 28, 8, 0, 0},
+     {0, 0, 4, 22, 29, 9, 0, 0},
+     {0, 0, 3, 21, 30, 10, 0, 0},
+     {0, 0, 2, 20, 31, 11, 0, 0},
+     {0, 0, 2, 18, 31, 13, 0, 0},
+     {0, 0, 1, 17, 31, 14, 1, 0}},
+    {{0, 0, 0, 64, 0, 0, 0, 0},
+     {-1, 1, -3, 63, 4, -1, 1, 0},
+     {-1, 3, -6, 62, 8, -3, 2, -1},
+     {-1, 4, -9, 60, 13, -5, 3, -1},
+     {-2, 5, -11, 58, 19, -7, 3, -1},
+     {-2, 5, -11, 54, 24, -9, 4, -1},
+     {-2, 5, -12, 50, 30, -10, 4, -1},
+     {-2, 5, -12, 45, 35, -11, 5, -1},
+     {-2, 6, -12, 40, 40, -12, 6, -2},
+     {-1, 5, -11, 35, 45, -12, 5, -2},
+     {-1, 4, -10, 30, 50, -12, 5, -2},
+     {-1, 4, -9, 24, 54, -11, 5, -2},
+     {-1, 3, -7, 19, 58, -11, 5, -2},
+     {-1, 3, -5, 13, 60, -9, 4, -1},
+     {-1, 2, -3, 8, 62, -6, 3, -1},
+     {0, 1, -1, 4, 63, -3, 1, -1}},
+    {{0, 0, 0, 64, 0, 0, 0, 0},
+     {0, 0, 0, 60, 4, 0, 0, 0},
+     {0, 0, 0, 56, 8, 0, 0, 0},
+     {0, 0, 0, 52, 12, 0, 0, 0},
+     {0, 0, 0, 48, 16, 0, 0, 0},
+     {0, 0, 0, 44, 20, 0, 0, 0},
+     {0, 0, 0, 40, 24, 0, 0, 0},
+     {0, 0, 0, 36, 28, 0, 0, 0},
+     {0, 0, 0, 32, 32, 0, 0, 0},
+     {0, 0, 0, 28, 36, 0, 0, 0},
+     {0, 0, 0, 24, 40, 0, 0, 0},
+     {0, 0, 0, 20, 44, 0, 0, 0},
+     {0, 0, 0, 16, 48, 0, 0, 0},
+     {0, 0, 0, 12, 52, 0, 0, 0},
+     {0, 0, 0, 8, 56, 0, 0, 0},
+     {0, 0, 0, 4, 60, 0, 0, 0}},
+    {{0, 0, 0, 64, 0, 0, 0, 0},
+     {0, 0, -2, 63, 4, -1, 0, 0},
+     {0, 0, -4, 61, 9, -2, 0, 0},
+     {0, 0, -5, 58, 14, -3, 0, 0},
+     {0, 0, -6, 55, 19, -4, 0, 0},
+     {0, 0, -6, 51, 24, -5, 0, 0},
+     {0, 0, -7, 47, 29, -5, 0, 0},
+     {0, 0, -6, 42, 33, -5, 0, 0},
+     {0, 0, -6, 38, 38, -6, 0, 0},
+     {0, 0, -5, 33, 42, -6, 0, 0},
+     {0, 0, -5, 29, 47, -7, 0, 0},
+     {0, 0, -5, 24, 51, -6, 0, 0},
+     {0, 0, -4, 19, 55, -6, 0, 0},
+     {0, 0, -3, 14, 58, -5, 0, 0},
+     {0, 0, -2, 9, 61, -4, 0, 0},
+     {0, 0, -1, 4, 63, -2, 0, 0}},
+    {{0, 0, 0, 64, 0, 0, 0, 0},
+     {0, 0, 15, 31, 17, 1, 0, 0},
+     {0, 0, 13, 31, 18, 2, 0, 0},
+     {0, 0, 11, 31, 20, 2, 0, 0},
+     {0, 0, 10, 30, 21, 3, 0, 0},
+     {0, 0, 9, 29, 22, 4, 0, 0},
+     {0, 0, 8, 28, 23, 5, 0, 0},
+     {0, 0, 7, 27, 24, 6, 0, 0},
+     {0, 0, 6, 26, 26, 6, 0, 0},
+     {0, 0, 6, 24, 27, 7, 0, 0},
+     {0, 0, 5, 23, 28, 8, 0, 0},
+     {0, 0, 4, 22, 29, 9, 0, 0},
+     {0, 0, 3, 21, 30, 10, 0, 0},
+     {0, 0, 2, 20, 31, 11, 0, 0},
+     {0, 0, 2, 18, 31, 13, 0, 0},
+     {0, 0, 1, 17, 31, 15, 0, 0}}};
+
+// Absolute values of |kHalfSubPixelFilters|. Used in situations where we know
+// the pattern of the signs and account for it in other ways.
+const uint8_t kAbsHalfSubPixelFilters[6][16][8] = {
+    {{0, 0, 0, 64, 0, 0, 0, 0},
+     {0, 1, 3, 63, 4, 1, 0, 0},
+     {0, 1, 5, 61, 9, 2, 0, 0},
+     {0, 1, 6, 58, 14, 4, 1, 0},
+     {0, 1, 7, 55, 19, 5, 1, 0},
+     {0, 1, 7, 51, 24, 6, 1, 0},
+     {0, 1, 8, 47, 29, 6, 1, 0},
+     {0, 1, 7, 42, 33, 6, 1, 0},
+     {0, 1, 7, 38, 38, 7, 1, 0},
+     {0, 1, 6, 33, 42, 7, 1, 0},
+     {0, 1, 6, 29, 47, 8, 1, 0},
+     {0, 1, 6, 24, 51, 7, 1, 0},
+     {0, 1, 5, 19, 55, 7, 1, 0},
+     {0, 1, 4, 14, 58, 6, 1, 0},
+     {0, 0, 2, 9, 61, 5, 1, 0},
+     {0, 0, 1, 4, 63, 3, 1, 0}},
+    {{0, 0, 0, 64, 0, 0, 0, 0},
+     {0, 1, 14, 31, 17, 1, 0, 0},
+     {0, 0, 13, 31, 18, 2, 0, 0},
+     {0, 0, 11, 31, 20, 2, 0, 0},
+     {0, 0, 10, 30, 21, 3, 0, 0},
+     {0, 0, 9, 29, 22, 4, 0, 0},
+     {0, 0, 8, 28, 23, 5, 0, 0},
+     {0, 1, 8, 27, 24, 6, 0, 0},
+     {0, 1, 7, 26, 26, 7, 1, 0},
+     {0, 0, 6, 24, 27, 8, 1, 0},
+     {0, 0, 5, 23, 28, 8, 0, 0},
+     {0, 0, 4, 22, 29, 9, 0, 0},
+     {0, 0, 3, 21, 30, 10, 0, 0},
+     {0, 0, 2, 20, 31, 11, 0, 0},
+     {0, 0, 2, 18, 31, 13, 0, 0},
+     {0, 0, 1, 17, 31, 14, 1, 0}},
+    {{0, 0, 0, 64, 0, 0, 0, 0},
+     {1, 1, 3, 63, 4, 1, 1, 0},
+     {1, 3, 6, 62, 8, 3, 2, 1},
+     {1, 4, 9, 60, 13, 5, 3, 1},
+     {2, 5, 11, 58, 19, 7, 3, 1},
+     {2, 5, 11, 54, 24, 9, 4, 1},
+     {2, 5, 12, 50, 30, 10, 4, 1},
+     {2, 5, 12, 45, 35, 11, 5, 1},
+     {2, 6, 12, 40, 40, 12, 6, 2},
+     {1, 5, 11, 35, 45, 12, 5, 2},
+     {1, 4, 10, 30, 50, 12, 5, 2},
+     {1, 4, 9, 24, 54, 11, 5, 2},
+     {1, 3, 7, 19, 58, 11, 5, 2},
+     {1, 3, 5, 13, 60, 9, 4, 1},
+     {1, 2, 3, 8, 62, 6, 3, 1},
+     {0, 1, 1, 4, 63, 3, 1, 1}},
+    {{0, 0, 0, 64, 0, 0, 0, 0},
+     {0, 0, 0, 60, 4, 0, 0, 0},
+     {0, 0, 0, 56, 8, 0, 0, 0},
+     {0, 0, 0, 52, 12, 0, 0, 0},
+     {0, 0, 0, 48, 16, 0, 0, 0},
+     {0, 0, 0, 44, 20, 0, 0, 0},
+     {0, 0, 0, 40, 24, 0, 0, 0},
+     {0, 0, 0, 36, 28, 0, 0, 0},
+     {0, 0, 0, 32, 32, 0, 0, 0},
+     {0, 0, 0, 28, 36, 0, 0, 0},
+     {0, 0, 0, 24, 40, 0, 0, 0},
+     {0, 0, 0, 20, 44, 0, 0, 0},
+     {0, 0, 0, 16, 48, 0, 0, 0},
+     {0, 0, 0, 12, 52, 0, 0, 0},
+     {0, 0, 0, 8, 56, 0, 0, 0},
+     {0, 0, 0, 4, 60, 0, 0, 0}},
+    {{0, 0, 0, 64, 0, 0, 0, 0},
+     {0, 0, 2, 63, 4, 1, 0, 0},
+     {0, 0, 4, 61, 9, 2, 0, 0},
+     {0, 0, 5, 58, 14, 3, 0, 0},
+     {0, 0, 6, 55, 19, 4, 0, 0},
+     {0, 0, 6, 51, 24, 5, 0, 0},
+     {0, 0, 7, 47, 29, 5, 0, 0},
+     {0, 0, 6, 42, 33, 5, 0, 0},
+     {0, 0, 6, 38, 38, 6, 0, 0},
+     {0, 0, 5, 33, 42, 6, 0, 0},
+     {0, 0, 5, 29, 47, 7, 0, 0},
+     {0, 0, 5, 24, 51, 6, 0, 0},
+     {0, 0, 4, 19, 55, 6, 0, 0},
+     {0, 0, 3, 14, 58, 5, 0, 0},
+     {0, 0, 2, 9, 61, 4, 0, 0},
+     {0, 0, 1, 4, 63, 2, 0, 0}},
+    {{0, 0, 0, 64, 0, 0, 0, 0},
+     {0, 0, 15, 31, 17, 1, 0, 0},
+     {0, 0, 13, 31, 18, 2, 0, 0},
+     {0, 0, 11, 31, 20, 2, 0, 0},
+     {0, 0, 10, 30, 21, 3, 0, 0},
+     {0, 0, 9, 29, 22, 4, 0, 0},
+     {0, 0, 8, 28, 23, 5, 0, 0},
+     {0, 0, 7, 27, 24, 6, 0, 0},
+     {0, 0, 6, 26, 26, 6, 0, 0},
+     {0, 0, 6, 24, 27, 7, 0, 0},
+     {0, 0, 5, 23, 28, 8, 0, 0},
+     {0, 0, 4, 22, 29, 9, 0, 0},
+     {0, 0, 3, 21, 30, 10, 0, 0},
+     {0, 0, 2, 20, 31, 11, 0, 0},
+     {0, 0, 2, 18, 31, 13, 0, 0},
+     {0, 0, 1, 17, 31, 15, 0, 0}}};
 
 // 9.3 -- Dr_Intra_Derivative[]
 // This is a more compact version of the table from the spec. angle / 2 - 1 is
diff --git a/libgav1/src/utils/constants.h b/libgav1/src/utils/constants.h
index c08b253..ce987b4 100644
--- a/libgav1/src/utils/constants.h
+++ b/libgav1/src/utils/constants.h
@@ -28,6 +28,16 @@
 constexpr int EnumRangeLength(int begin, int end) { return end - begin + 1; }
 
 enum {
+// Maximum number of threads that the library will ever create.
+#if defined(LIBGAV1_MAX_THREADS) && LIBGAV1_MAX_THREADS > 0
+  kMaxThreads = LIBGAV1_MAX_THREADS
+#else
+  kMaxThreads = 128
+#endif
+};  // anonymous enum
+
+enum {
+  kInvalidMvValue = -32768,
   kCdfMaxProbability = 32768,
   kBlockWidthCount = 5,
   kMaxSegments = 8,
@@ -37,7 +47,6 @@
   kFrameLfCount = 4,
   kMaxLoopFilterValue = 63,
   kNum4x4In64x64 = 256,
-  kNumLoopFilterMasks = 4,
   kMaxAngleDelta = 3,
   kDirectionalIntraModes = 8,
   kMaxSuperBlockSizeLog2 = 7,
@@ -48,56 +57,80 @@
   kRestorationTypeSymbolCount = 3,
   kSgrProjParamsBits = 4,
   kSgrProjPrecisionBits = 7,
-  kRestorationBorder = 3,      // Padding on each side of a restoration block.
+  // Padding on left and right side of a restoration block.
+  // 3 is enough, but padding to 4 is more efficient, and makes the temporary
+  // source buffer 8-pixel aligned.
+  kRestorationHorizontalBorder = 4,
+  // Padding on top and bottom side of a restoration block.
+  kRestorationVerticalBorder = 2,
   kCdefBorder = 2,             // Padding on each side of a cdef block.
   kConvolveBorderLeftTop = 3,  // Left/top padding of a convolve block.
-  kConvolveBorderRightBottom = 4,  // Right/bottom padding of a convolve block.
+  // Right/bottom padding of a convolve block. This needs to be 4 at minimum,
+  // but was increased to simplify the SIMD loads in
+  // ConvolveCompoundScale2D_NEON() and ConvolveScale2D_NEON().
+  kConvolveBorderRight = 8,
+  kConvolveBorderBottom = 4,
   kSubPixelTaps = 8,
   kWienerFilterBits = 7,
+  kWienerFilterTaps = 7,
   kMaxPaletteSize = 8,
   kMinPaletteSize = 2,
   kMaxPaletteSquare = 64,
   kBorderPixels = 64,
-  // Although the left and right borders of a frame start with kBorderPixels,
-  // they may change if YuvBuffer::ShiftBuffer() is called. These constants
-  // are the minimum left and right border sizes in pixels as an extension of
-  // the frame boundary. The minimum border sizes are derived from the
-  // following requirements:
+  // The final blending process for film grain needs room to overwrite and read
+  // with SIMD instructions. The maximum overwrite is 7 pixels, but the border
+  // is required to be a multiple of 32 by YuvBuffer::Realloc, so that
+  // subsampled chroma borders are 16-aligned.
+  kBorderPixelsFilmGrain = 32,
+  // These constants are the minimum left, right, top, and bottom border sizes
+  // in pixels as an extension of the frame boundary. The minimum border sizes
+  // are derived from the following requirements:
   // - Warp_C() may read up to 13 pixels before or after a row.
   // - Warp_NEON() may read up to 13 pixels before a row. It may read up to 14
   //   pixels after a row, but the value of the last read pixel is not used.
+  // - Warp_C() and Warp_NEON() may read up to 13 pixels above the top row and
+  //   13 pixels below the bottom row.
   kMinLeftBorderPixels = 13,
   kMinRightBorderPixels = 13,
+  kMinTopBorderPixels = 13,
+  kMinBottomBorderPixels = 13,
   kWarpedModelPrecisionBits = 16,
   kMaxRefMvStackSize = 8,
-  kExtraWeightForNearestMvs = 640,
   kMaxLeastSquaresSamples = 8,
+  kMaxTemporalMvCandidates = 19,
+  // The SIMD implementations of motion vection projection functions always
+  // process 2 or 4 elements together, so we pad the corresponding buffers to
+  // size 20.
+  kMaxTemporalMvCandidatesWithPadding = 20,
   kMaxSuperBlockSizeInPixels = 128,
+  kMaxScaledSuperBlockSizeInPixels = 128 * 2,
   kMaxSuperBlockSizeSquareInPixels = 128 * 128,
-  kNum4x4InLoopFilterMaskUnit = 16,
+  kNum4x4InLoopFilterUnit = 16,
+  kProjectionMvClamp = (1 << 14) - 1,  // == 16383
+  kProjectionMvMaxHorizontalOffset = 8,
+  kCdefUnitSize = 64,
+  kCdefUnitSizeWithBorders = kCdefUnitSize + 2 * kCdefBorder,
   kRestorationUnitOffset = 8,
-  // 2 pixel padding for 5x5 box sum on each side.
-  kRestorationPadding = 4,
   // Loop restoration's processing unit size is fixed as 64x64.
-  kRestorationProcessingUnitSize = 64,
-  kRestorationProcessingUnitSizeWithBorders =
-      kRestorationProcessingUnitSize + 2 * kRestorationBorder,
-  // The max size of a box filter process output buffer.
-  kMaxBoxFilterProcessOutputPixels = kRestorationProcessingUnitSize *
-                                     kRestorationProcessingUnitSize,  // == 4096
-  // The max size of a box filter process intermediate buffer.
-  kBoxFilterProcessIntermediatePixels =
-      (kRestorationProcessingUnitSizeWithBorders + kRestorationPadding) *
-      (kRestorationProcessingUnitSizeWithBorders +
-       kRestorationPadding),  // == 5476
+  kRestorationUnitHeight = 64,
+  kRestorationUnitWidth = 256,
+  kRestorationUnitHeightWithBorders =
+      kRestorationUnitHeight + 2 * kRestorationVerticalBorder,
+  kRestorationUnitWidthWithBorders =
+      kRestorationUnitWidth + 2 * kRestorationHorizontalBorder,
   kSuperResFilterBits = 6,
   kSuperResFilterShifts = 1 << kSuperResFilterBits,
   kSuperResFilterTaps = 8,
   kSuperResScaleBits = 14,
   kSuperResExtraBits = kSuperResScaleBits - kSuperResFilterBits,
   kSuperResScaleMask = (1 << 14) - 1,
+  kSuperResHorizontalBorder = 8,
+  kSuperResVerticalBorder = 1,
+  // The SIMD implementations of superres calculate up to 4 extra upscaled
+  // pixels which will over-read 2 downscaled pixels in the end of each row.
+  kSuperResHorizontalPadding = 2,
   // TODO(chengchen): consider merging these constants:
-  // kFilterbits, kWienerFilterBits, and kSgrProjPrecisionBits, which are all 7,
+  // kFilterBits, kWienerFilterBits, and kSgrProjPrecisionBits, which are all 7,
   // They are designed to match AV1 convolution, which increases coeff
   // values up to 7 bits. We could consider to combine them and use kFilterBits
   // only.
@@ -107,6 +140,7 @@
   // integer pixel. Sub pixel values are interpolated using adjacent integer
   // pixel values. The interpolation is a filtering process.
   kSubPixelBits = 4,
+  kSubPixelMask = (1 << kSubPixelBits) - 1,
   // Precision bits when computing inter prediction locations.
   kScaleSubPixelBits = 10,
   kWarpParamRoundingBits = 6,
@@ -116,22 +150,28 @@
   kDivisorLookupPrecisionBits = 14,
   // Number of phases used in warped filtering.
   kWarpedPixelPrecisionShifts = 1 << 6,
-  kQuantizedCoefficientBufferPadding = 4,
-  // Maximum number of quantized coefficients that can be read from the
-  // bitstream. This comes from the definition of segEob in section 5.11.39.
-  // Size of the quantized coefficients buffer. This comes from the definition
-  // of segEob in section 5.11.39 (with 4 bytes padded to each row and 4 rows
-  // padded in the end to avoid boundary checks).
-  kQuantizedCoefficientBufferSize = (32 + kQuantizedCoefficientBufferPadding) *
-                                    (32 + kQuantizedCoefficientBufferPadding),
+  kResidualPaddingVertical = 4,
   kWedgeMaskMasterSize = 64,
-  kMaxMaskBlockSize = kWedgeMaskMasterSize * kWedgeMaskMasterSize,
-  kWedgeMaskSize = 9 * 2 * 16 * kWedgeMaskMasterSize * kWedgeMaskMasterSize,
   kMaxFrameDistance = 31,
   kReferenceFrameScalePrecision = 14,
   kNumWienerCoefficients = 3,
-  // Maximum number of threads that the library will ever create.
-  kMaxThreads = 32,
+  kLoopFilterMaxModeDeltas = 2,
+  kMaxCdefStrengths = 8,
+  kCdefLargeValue = 0x4000,  // Used to indicate where CDEF is not available.
+  kMaxTileColumns = 64,
+  kMaxTileRows = 64,
+  kMaxOperatingPoints = 32,
+  // There can be a maximum of 4 spatial layers and 8 temporal layers.
+  kMaxLayers = 32,
+  // The cache line size should ideally be queried at run time. 64 is a common
+  // cache line size of x86 CPUs. Web searches showed the cache line size of ARM
+  // CPUs is 32 or 64 bytes. So aligning to 64-byte boundary will work for all
+  // CPUs that we care about, even though it is excessive for some ARM
+  // CPUs.
+  //
+  // On Linux, the cache line size can be looked up with the command:
+  //   getconf LEVEL1_DCACHE_LINESIZE
+  kCacheLineSize = 64,
 };  // anonymous enum
 
 enum FrameType : uint8_t {
@@ -385,6 +425,7 @@
   kTransformClass2D,
   kTransformClassHorizontal,
   kTransformClassVertical,
+  kNumTransformClasses
 };
 
 enum FilterIntraPredictor : uint8_t {
@@ -643,6 +684,8 @@
 
 extern const BlockSize kPlaneResidualSize[kMaxBlockSizes][2][2];
 
+extern const int16_t kProjectionMvDivisionLookup[kMaxFrameDistance + 1];
+
 extern const uint8_t kTransformWidth[kNumTransformSizes];
 
 extern const uint8_t kTransformHeight[kNumTransformSizes];
@@ -675,11 +718,18 @@
 
 extern const int8_t kWienerTapsMax[3];
 
-extern const int16_t kUpscaleFilter[kSuperResFilterShifts][kSuperResFilterTaps];
+extern const uint8_t kUpscaleFilterUnsigned[kSuperResFilterShifts]
+                                           [kSuperResFilterTaps];
+
+// An int8_t version of the kWarpedFilters array.
+// Note: The array could be removed with a performance penalty.
+extern const int8_t kWarpedFilters8[3 * kWarpedPixelPrecisionShifts + 1][8];
 
 extern const int16_t kWarpedFilters[3 * kWarpedPixelPrecisionShifts + 1][8];
 
-extern const int16_t kSubPixelFilters[6][16][8];
+extern const int8_t kHalfSubPixelFilters[6][16][8];
+
+extern const uint8_t kAbsHalfSubPixelFilters[6][16][8];
 
 extern const int16_t kDirectionalIntraPredictorDerivative[44];
 
diff --git a/libgav1/src/dsp/cpu.cc b/libgav1/src/utils/cpu.cc
similarity index 95%
rename from libgav1/src/dsp/cpu.cc
rename to libgav1/src/utils/cpu.cc
index 38c4a0b..a6b7057 100644
--- a/libgav1/src/dsp/cpu.cc
+++ b/libgav1/src/utils/cpu.cc
@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "src/dsp/cpu.h"
+#include "src/utils/cpu.h"
 
 #if defined(__GNUC__) && (defined(__i386__) || defined(__x86_64__))
 #include <cpuid.h>
@@ -22,7 +22,6 @@
 #endif
 
 namespace libgav1 {
-namespace dsp {
 
 #if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || \
     defined(_M_X64)
@@ -35,11 +34,12 @@
 
 uint64_t Xgetbv() {
   const uint32_t ecx = 0;  // ecx specifies the extended control register
-  uint32_t eax, edx;
+  uint32_t eax;
+  uint32_t edx;
   __asm__ volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(ecx));
   return (static_cast<uint64_t>(edx) << 32) | eax;
 }
-#else   // _MSC_VER
+#else  // _MSC_VER
 void CpuId(int leaf, uint32_t info[4]) {
   __cpuidex(reinterpret_cast<int*>(info), leaf, 0 /*ecx=subleaf*/);
 }
@@ -81,5 +81,4 @@
 uint32_t GetCpuInfo() { return 0; }
 #endif  // x86 || x86_64
 
-}  // namespace dsp
 }  // namespace libgav1
diff --git a/libgav1/src/utils/cpu.h b/libgav1/src/utils/cpu.h
new file mode 100644
index 0000000..d098f1d
--- /dev/null
+++ b/libgav1/src/utils/cpu.h
@@ -0,0 +1,68 @@
+/*
+ * Copyright 2019 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_UTILS_CPU_H_
+#define LIBGAV1_SRC_UTILS_CPU_H_
+
+#include <cstdint>
+
+namespace libgav1 {
+
+#if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))
+#define LIBGAV1_X86_MSVC
+#endif
+
+#if !defined(LIBGAV1_ENABLE_SSE4_1)
+#if defined(__SSE4_1__) || defined(LIBGAV1_X86_MSVC)
+#define LIBGAV1_ENABLE_SSE4_1 1
+#else
+#define LIBGAV1_ENABLE_SSE4_1 0
+#endif
+#endif  // !defined(LIBGAV1_ENABLE_SSE4_1)
+
+#undef LIBGAV1_X86_MSVC
+
+#if !defined(LIBGAV1_ENABLE_NEON)
+// TODO(jzern): add support for _M_ARM64.
+#if defined(__ARM_NEON__) || defined(__aarch64__) || \
+    (defined(_MSC_VER) && defined(_M_ARM))
+#define LIBGAV1_ENABLE_NEON 1
+#else
+#define LIBGAV1_ENABLE_NEON 0
+#endif
+#endif  // !defined(LIBGAV1_ENABLE_NEON)
+
+enum CpuFeatures : uint8_t {
+  kSSE2 = 1 << 0,
+#define LIBGAV1_CPU_SSE2 (1 << 0)
+  kSSSE3 = 1 << 1,
+#define LIBGAV1_CPU_SSSE3 (1 << 1)
+  kSSE4_1 = 1 << 2,
+#define LIBGAV1_CPU_SSE4_1 (1 << 2)
+  kAVX = 1 << 3,
+#define LIBGAV1_CPU_AVX (1 << 3)
+  kAVX2 = 1 << 4,
+#define LIBGAV1_CPU_AVX2 (1 << 4)
+  kNEON = 1 << 5,
+#define LIBGAV1_CPU_NEON (1 << 5)
+};
+
+// Returns a bit-wise OR of CpuFeatures supported by this platform.
+uint32_t GetCpuInfo();
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_UTILS_CPU_H_
diff --git a/libgav1/src/utils/dynamic_buffer.h b/libgav1/src/utils/dynamic_buffer.h
new file mode 100644
index 0000000..5e2f644
--- /dev/null
+++ b/libgav1/src/utils/dynamic_buffer.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_UTILS_DYNAMIC_BUFFER_H_
+#define LIBGAV1_SRC_UTILS_DYNAMIC_BUFFER_H_
+
+#include <memory>
+#include <new>
+
+#include "src/utils/memory.h"
+
+namespace libgav1 {
+
+template <typename T>
+class DynamicBuffer {
+ public:
+  T* get() { return buffer_.get(); }
+
+  // Resizes the buffer so that it can hold at least |size| elements. Existing
+  // contents will be destroyed when resizing to a larger size.
+  //
+  // Returns true on success. If Resize() returns false, then subsequent calls
+  // to get() will return nullptr.
+  bool Resize(size_t size) {
+    if (size <= size_) return true;
+    buffer_.reset(new (std::nothrow) T[size]);
+    if (buffer_ == nullptr) {
+      size_ = 0;
+      return false;
+    }
+    size_ = size;
+    return true;
+  }
+
+ private:
+  std::unique_ptr<T[]> buffer_;
+  size_t size_ = 0;
+};
+
+template <typename T, int alignment>
+class AlignedDynamicBuffer {
+ public:
+  T* get() { return buffer_.get(); }
+
+  // Resizes the buffer so that it can hold at least |size| elements. Existing
+  // contents will be destroyed when resizing to a larger size.
+  //
+  // Returns true on success. If Resize() returns false, then subsequent calls
+  // to get() will return nullptr.
+  bool Resize(size_t size) {
+    if (size <= size_) return true;
+    buffer_ = MakeAlignedUniquePtr<T>(alignment, size);
+    if (buffer_ == nullptr) {
+      size_ = 0;
+      return false;
+    }
+    size_ = size;
+    return true;
+  }
+
+ private:
+  AlignedUniquePtr<T> buffer_;
+  size_t size_ = 0;
+};
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_UTILS_DYNAMIC_BUFFER_H_
diff --git a/libgav1/src/utils/entropy_decoder.cc b/libgav1/src/utils/entropy_decoder.cc
index acf0b08..dfe3bba 100644
--- a/libgav1/src/utils/entropy_decoder.cc
+++ b/libgav1/src/utils/entropy_decoder.cc
@@ -15,10 +15,34 @@
 #include "src/utils/entropy_decoder.h"
 
 #include <cassert>
+#include <cstring>
 
 #include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
 #include "src/utils/constants.h"
 
+#if defined(__ARM_NEON__) || defined(__aarch64__) || \
+    (defined(_MSC_VER) && defined(_M_ARM))
+#define LIBGAV1_ENTROPY_DECODER_ENABLE_NEON 1
+#else
+#define LIBGAV1_ENTROPY_DECODER_ENABLE_NEON 0
+#endif
+
+#if LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
+#include <arm_neon.h>
+#endif
+
+#if defined(__SSE4_1__) || defined(LIBGAV1_X86_MSVC)
+#define LIBGAV1_ENTROPY_DECODER_ENABLE_SSE4 1
+#else
+#define LIBGAV1_ENTROPY_DECODER_ENABLE_SSE4 0
+#endif
+
+#if LIBGAV1_ENTROPY_DECODER_ENABLE_SSE4
+#include <smmintrin.h>
+#endif
+
+namespace libgav1 {
 namespace {
 
 constexpr uint32_t kReadBitMask = ~255;
@@ -33,13 +57,13 @@
 // loop in Section 8.2.6 of the spec. This function is monotonically
 // decreasing as the values of index increases (note that the |cdf| array is
 // sorted in decreasing order).
-uint32_t ScaleCdf(uint16_t values_in_range_shifted, const uint16_t* const cdf,
+uint32_t ScaleCdf(uint32_t values_in_range_shifted, const uint16_t* const cdf,
                   int index, int symbol_count) {
   return ((values_in_range_shifted * (cdf[index] >> kCdfPrecision)) >> 1) +
          (kMinimumProbabilityPerSymbol * (symbol_count - index));
 }
 
-void UpdateCdf(uint16_t* const cdf, int symbol_count, int symbol) {
+void UpdateCdf(uint16_t* const cdf, const int symbol_count, const int symbol) {
   const uint16_t count = cdf[symbol_count];
   // rate is computed in the spec as:
   //  3 + ( cdf[N] > 15 ) + ( cdf[N] > 31 ) + Min(FloorLog2(N), 2)
@@ -68,9 +92,9 @@
   // 2. The for loop can be rewritten in the following form, which would enable
   // clang to vectorize the loop with width 8:
   //
-  //   const int mask = (1 << rate) - 1;
+  //   const int rounding = (1 << rate) - 1;
   //   for (int i = 0; i < symbol_count - 1; ++i) {
-  //     const uint16_t a = (i < symbol) ? kCdfMaxProbability : mask;
+  //     const uint16_t a = (i < symbol) ? kCdfMaxProbability : rounding;
   //     cdf[i] += static_cast<int16_t>(a - cdf[i]) >> rate;
   //   }
   //
@@ -81,7 +105,7 @@
   // Visual C++.
   for (int i = 0; i < symbol_count - 1; ++i) {
     if (i < symbol) {
-      cdf[i] += (libgav1::kCdfMaxProbability - cdf[i]) >> rate;
+      cdf[i] += (kCdfMaxProbability - cdf[i]) >> rate;
     } else {
       cdf[i] -= cdf[i] >> rate;
     }
@@ -89,11 +113,409 @@
   cdf[symbol_count] += static_cast<uint16_t>(count < 32);
 }
 
+// Define the UpdateCdfN functions. UpdateCdfN is a specialized implementation
+// of UpdateCdf based on the fact that symbol_count == N. UpdateCdfN uses the
+// SIMD instruction sets if available.
+
+#if LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
+
+// The UpdateCdf() method contains the following for loop:
+//
+//   for (int i = 0; i < symbol_count - 1; ++i) {
+//     if (i < symbol) {
+//       cdf[i] += (kCdfMaxProbability - cdf[i]) >> rate;
+//     } else {
+//       cdf[i] -= cdf[i] >> rate;
+//     }
+//   }
+//
+// It can be rewritten in the following two forms, which are amenable to SIMD
+// implementations:
+//
+//   const int rounding = (1 << rate) - 1;
+//   for (int i = 0; i < symbol_count - 1; ++i) {
+//     const uint16_t a = (i < symbol) ? kCdfMaxProbability : rounding;
+//     cdf[i] += static_cast<int16_t>(a - cdf[i]) >> rate;
+//   }
+//
+// or:
+//
+//   const int rounding = (1 << rate) - 1;
+//   for (int i = 0; i < symbol_count - 1; ++i) {
+//     const uint16_t a = (i < symbol) ? (kCdfMaxProbability - rounding) : 0;
+//     cdf[i] -= static_cast<int16_t>(cdf[i] - a) >> rate;
+//   }
+//
+// The following ARM NEON implementations use the second form, which seems
+// slightly faster.
+//
+// The cdf array has symbol_count + 1 elements. The first symbol_count elements
+// are the CDF. The last element is a count that is initialized to 0 and may
+// grow up to 32. The for loop in UpdateCdf updates the CDF in the array. Since
+// cdf[symbol_count - 1] is always 0, the for loop does not update
+// cdf[symbol_count - 1]. However, it would be correct to have the for loop
+// update cdf[symbol_count - 1] anyway: since symbol_count - 1 >= symbol, the
+// for loop would take the else branch when i is symbol_count - 1:
+//      cdf[i] -= cdf[i] >> rate;
+// Since cdf[symbol_count - 1] is 0, cdf[symbol_count - 1] would still be 0
+// after the update. The ARM NEON implementations take advantage of this in the
+// following two cases:
+// 1. When symbol_count is 8 or 16, the vectorized code updates the first
+//    symbol_count elements in the array.
+// 2. When symbol_count is 7, the vectorized code updates all the 8 elements in
+//    the cdf array. Since an invalid CDF value is written into cdf[7], the
+//    count in cdf[7] needs to be fixed up after the vectorized code.
+
+void UpdateCdf5(uint16_t* const cdf, const int symbol) {
+  uint16x4_t cdf_vec = vld1_u16(cdf);
+  const uint16_t count = cdf[5];
+  const int rate = (4 | (count >> 4)) + 1;
+  const uint16x4_t zero = vdup_n_u16(0);
+  const uint16x4_t cdf_max_probability =
+      vdup_n_u16(kCdfMaxProbability + 1 - (1 << rate));
+  const uint16x4_t index = vcreate_u16(0x0003000200010000);
+  const uint16x4_t symbol_vec = vdup_n_u16(symbol);
+  const uint16x4_t mask = vclt_u16(index, symbol_vec);
+  const uint16x4_t a = vbsl_u16(mask, cdf_max_probability, zero);
+  const int16x4_t diff = vreinterpret_s16_u16(vsub_u16(cdf_vec, a));
+  const int16x4_t negative_rate = vdup_n_s16(-rate);
+  const uint16x4_t delta = vreinterpret_u16_s16(vshl_s16(diff, negative_rate));
+  cdf_vec = vsub_u16(cdf_vec, delta);
+  vst1_u16(cdf, cdf_vec);
+  cdf[5] = count + static_cast<uint16_t>(count < 32);
+}
+
+// This version works for |symbol_count| = 7, 8, or 9.
+template <int symbol_count>
+void UpdateCdf7To9(uint16_t* const cdf, const int symbol) {
+  static_assert(symbol_count >= 7 && symbol_count <= 9, "");
+  uint16x8_t cdf_vec = vld1q_u16(cdf);
+  const uint16_t count = cdf[symbol_count];
+  const int rate = (4 | (count >> 4)) + 1;
+  const uint16x8_t zero = vdupq_n_u16(0);
+  const uint16x8_t cdf_max_probability =
+      vdupq_n_u16(kCdfMaxProbability + 1 - (1 << rate));
+  const uint16x8_t index = vcombine_u16(vcreate_u16(0x0003000200010000),
+                                        vcreate_u16(0x0007000600050004));
+  const uint16x8_t symbol_vec = vdupq_n_u16(symbol);
+  const uint16x8_t mask = vcltq_u16(index, symbol_vec);
+  const uint16x8_t a = vbslq_u16(mask, cdf_max_probability, zero);
+  const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(cdf_vec, a));
+  const int16x8_t negative_rate = vdupq_n_s16(-rate);
+  const uint16x8_t delta =
+      vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
+  cdf_vec = vsubq_u16(cdf_vec, delta);
+  vst1q_u16(cdf, cdf_vec);
+  cdf[symbol_count] = count + static_cast<uint16_t>(count < 32);
+}
+
+void UpdateCdf7(uint16_t* const cdf, const int symbol) {
+  UpdateCdf7To9<7>(cdf, symbol);
+}
+
+void UpdateCdf8(uint16_t* const cdf, const int symbol) {
+  UpdateCdf7To9<8>(cdf, symbol);
+}
+
+void UpdateCdf11(uint16_t* const cdf, const int symbol) {
+  uint16x8_t cdf_vec = vld1q_u16(cdf + 2);
+  const uint16_t count = cdf[11];
+  cdf[11] = count + static_cast<uint16_t>(count < 32);
+  const int rate = (4 | (count >> 4)) + 1;
+  if (symbol > 1) {
+    cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
+    cdf[1] += (kCdfMaxProbability - cdf[1]) >> rate;
+    const uint16x8_t zero = vdupq_n_u16(0);
+    const uint16x8_t cdf_max_probability =
+        vdupq_n_u16(kCdfMaxProbability + 1 - (1 << rate));
+    const uint16x8_t symbol_vec = vdupq_n_u16(symbol);
+    const int16x8_t negative_rate = vdupq_n_s16(-rate);
+    const uint16x8_t index = vcombine_u16(vcreate_u16(0x0005000400030002),
+                                          vcreate_u16(0x0009000800070006));
+    const uint16x8_t mask = vcltq_u16(index, symbol_vec);
+    const uint16x8_t a = vbslq_u16(mask, cdf_max_probability, zero);
+    const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(cdf_vec, a));
+    const uint16x8_t delta =
+        vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
+    cdf_vec = vsubq_u16(cdf_vec, delta);
+    vst1q_u16(cdf + 2, cdf_vec);
+  } else {
+    if (symbol != 0) {
+      cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
+      cdf[1] -= cdf[1] >> rate;
+    } else {
+      cdf[0] -= cdf[0] >> rate;
+      cdf[1] -= cdf[1] >> rate;
+    }
+    const int16x8_t negative_rate = vdupq_n_s16(-rate);
+    const uint16x8_t delta = vshlq_u16(cdf_vec, negative_rate);
+    cdf_vec = vsubq_u16(cdf_vec, delta);
+    vst1q_u16(cdf + 2, cdf_vec);
+  }
+}
+
+void UpdateCdf13(uint16_t* const cdf, const int symbol) {
+  uint16x8_t cdf_vec0 = vld1q_u16(cdf);
+  uint16x8_t cdf_vec1 = vld1q_u16(cdf + 4);
+  const uint16_t count = cdf[13];
+  const int rate = (4 | (count >> 4)) + 1;
+  const uint16x8_t zero = vdupq_n_u16(0);
+  const uint16x8_t cdf_max_probability =
+      vdupq_n_u16(kCdfMaxProbability + 1 - (1 << rate));
+  const uint16x8_t symbol_vec = vdupq_n_u16(symbol);
+  const int16x8_t negative_rate = vdupq_n_s16(-rate);
+
+  uint16x8_t index = vcombine_u16(vcreate_u16(0x0003000200010000),
+                                  vcreate_u16(0x0007000600050004));
+  uint16x8_t mask = vcltq_u16(index, symbol_vec);
+  uint16x8_t a = vbslq_u16(mask, cdf_max_probability, zero);
+  int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(cdf_vec0, a));
+  uint16x8_t delta = vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
+  cdf_vec0 = vsubq_u16(cdf_vec0, delta);
+  vst1q_u16(cdf, cdf_vec0);
+
+  index = vcombine_u16(vcreate_u16(0x0007000600050004),
+                       vcreate_u16(0x000b000a00090008));
+  mask = vcltq_u16(index, symbol_vec);
+  a = vbslq_u16(mask, cdf_max_probability, zero);
+  diff = vreinterpretq_s16_u16(vsubq_u16(cdf_vec1, a));
+  delta = vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
+  cdf_vec1 = vsubq_u16(cdf_vec1, delta);
+  vst1q_u16(cdf + 4, cdf_vec1);
+
+  cdf[13] = count + static_cast<uint16_t>(count < 32);
+}
+
+void UpdateCdf16(uint16_t* const cdf, const int symbol) {
+  uint16x8_t cdf_vec = vld1q_u16(cdf);
+  const uint16_t count = cdf[16];
+  const int rate = (4 | (count >> 4)) + 1;
+  const uint16x8_t zero = vdupq_n_u16(0);
+  const uint16x8_t cdf_max_probability =
+      vdupq_n_u16(kCdfMaxProbability + 1 - (1 << rate));
+  const uint16x8_t symbol_vec = vdupq_n_u16(symbol);
+  const int16x8_t negative_rate = vdupq_n_s16(-rate);
+
+  uint16x8_t index = vcombine_u16(vcreate_u16(0x0003000200010000),
+                                  vcreate_u16(0x0007000600050004));
+  uint16x8_t mask = vcltq_u16(index, symbol_vec);
+  uint16x8_t a = vbslq_u16(mask, cdf_max_probability, zero);
+  int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(cdf_vec, a));
+  uint16x8_t delta = vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
+  cdf_vec = vsubq_u16(cdf_vec, delta);
+  vst1q_u16(cdf, cdf_vec);
+
+  cdf_vec = vld1q_u16(cdf + 8);
+  index = vcombine_u16(vcreate_u16(0x000b000a00090008),
+                       vcreate_u16(0x000f000e000d000c));
+  mask = vcltq_u16(index, symbol_vec);
+  a = vbslq_u16(mask, cdf_max_probability, zero);
+  diff = vreinterpretq_s16_u16(vsubq_u16(cdf_vec, a));
+  delta = vreinterpretq_u16_s16(vshlq_s16(diff, negative_rate));
+  cdf_vec = vsubq_u16(cdf_vec, delta);
+  vst1q_u16(cdf + 8, cdf_vec);
+
+  cdf[16] = count + static_cast<uint16_t>(count < 32);
+}
+
+#else  // !LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
+
+#if LIBGAV1_ENTROPY_DECODER_ENABLE_SSE4
+
+inline __m128i LoadLo8(const void* a) {
+  return _mm_loadl_epi64(static_cast<const __m128i*>(a));
+}
+
+inline __m128i LoadUnaligned16(const void* a) {
+  return _mm_loadu_si128(static_cast<const __m128i*>(a));
+}
+
+inline void StoreLo8(void* a, const __m128i v) {
+  _mm_storel_epi64(static_cast<__m128i*>(a), v);
+}
+
+inline void StoreUnaligned16(void* a, const __m128i v) {
+  _mm_storeu_si128(static_cast<__m128i*>(a), v);
+}
+
+void UpdateCdf5(uint16_t* const cdf, const int symbol) {
+  __m128i cdf_vec = LoadLo8(cdf);
+  const uint16_t count = cdf[5];
+  const int rate = (4 | (count >> 4)) + 1;
+  const __m128i zero = _mm_setzero_si128();
+  const __m128i cdf_max_probability = _mm_shufflelo_epi16(
+      _mm_cvtsi32_si128(kCdfMaxProbability + 1 - (1 << rate)), 0);
+  const __m128i index = _mm_set_epi32(0x0, 0x0, 0x00030002, 0x00010000);
+  const __m128i symbol_vec = _mm_shufflelo_epi16(_mm_cvtsi32_si128(symbol), 0);
+  const __m128i mask = _mm_cmplt_epi16(index, symbol_vec);
+  const __m128i a = _mm_blendv_epi8(zero, cdf_max_probability, mask);
+  const __m128i diff = _mm_sub_epi16(cdf_vec, a);
+  const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
+  cdf_vec = _mm_sub_epi16(cdf_vec, delta);
+  StoreLo8(cdf, cdf_vec);
+  cdf[5] = count + static_cast<uint16_t>(count < 32);
+}
+
+// This version works for |symbol_count| = 7, 8, or 9.
+template <int symbol_count>
+void UpdateCdf7To9(uint16_t* const cdf, const int symbol) {
+  static_assert(symbol_count >= 7 && symbol_count <= 9, "");
+  __m128i cdf_vec = LoadUnaligned16(cdf);
+  const uint16_t count = cdf[symbol_count];
+  const int rate = (4 | (count >> 4)) + 1;
+  const __m128i zero = _mm_setzero_si128();
+  const __m128i cdf_max_probability =
+      _mm_set1_epi16(kCdfMaxProbability + 1 - (1 << rate));
+  const __m128i index =
+      _mm_set_epi32(0x00070006, 0x00050004, 0x00030002, 0x00010000);
+  const __m128i symbol_vec = _mm_set1_epi16(symbol);
+  const __m128i mask = _mm_cmplt_epi16(index, symbol_vec);
+  const __m128i a = _mm_blendv_epi8(zero, cdf_max_probability, mask);
+  const __m128i diff = _mm_sub_epi16(cdf_vec, a);
+  const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
+  cdf_vec = _mm_sub_epi16(cdf_vec, delta);
+  StoreUnaligned16(cdf, cdf_vec);
+  cdf[symbol_count] = count + static_cast<uint16_t>(count < 32);
+}
+
+void UpdateCdf7(uint16_t* const cdf, const int symbol) {
+  UpdateCdf7To9<7>(cdf, symbol);
+}
+
+void UpdateCdf8(uint16_t* const cdf, const int symbol) {
+  UpdateCdf7To9<8>(cdf, symbol);
+}
+
+void UpdateCdf11(uint16_t* const cdf, const int symbol) {
+  __m128i cdf_vec = LoadUnaligned16(cdf + 2);
+  const uint16_t count = cdf[11];
+  cdf[11] = count + static_cast<uint16_t>(count < 32);
+  const int rate = (4 | (count >> 4)) + 1;
+  if (symbol > 1) {
+    cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
+    cdf[1] += (kCdfMaxProbability - cdf[1]) >> rate;
+    const __m128i zero = _mm_setzero_si128();
+    const __m128i cdf_max_probability =
+        _mm_set1_epi16(kCdfMaxProbability + 1 - (1 << rate));
+    const __m128i index =
+        _mm_set_epi32(0x00090008, 0x00070006, 0x00050004, 0x00030002);
+    const __m128i symbol_vec = _mm_set1_epi16(symbol);
+    const __m128i mask = _mm_cmplt_epi16(index, symbol_vec);
+    const __m128i a = _mm_blendv_epi8(zero, cdf_max_probability, mask);
+    const __m128i diff = _mm_sub_epi16(cdf_vec, a);
+    const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
+    cdf_vec = _mm_sub_epi16(cdf_vec, delta);
+    StoreUnaligned16(cdf + 2, cdf_vec);
+  } else {
+    if (symbol != 0) {
+      cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
+      cdf[1] -= cdf[1] >> rate;
+    } else {
+      cdf[0] -= cdf[0] >> rate;
+      cdf[1] -= cdf[1] >> rate;
+    }
+    const __m128i delta = _mm_sra_epi16(cdf_vec, _mm_cvtsi32_si128(rate));
+    cdf_vec = _mm_sub_epi16(cdf_vec, delta);
+    StoreUnaligned16(cdf + 2, cdf_vec);
+  }
+}
+
+void UpdateCdf13(uint16_t* const cdf, const int symbol) {
+  __m128i cdf_vec0 = LoadUnaligned16(cdf);
+  __m128i cdf_vec1 = LoadUnaligned16(cdf + 4);
+  const uint16_t count = cdf[13];
+  const int rate = (4 | (count >> 4)) + 1;
+  const __m128i zero = _mm_setzero_si128();
+  const __m128i cdf_max_probability =
+      _mm_set1_epi16(kCdfMaxProbability + 1 - (1 << rate));
+  const __m128i symbol_vec = _mm_set1_epi16(symbol);
+
+  const __m128i index =
+      _mm_set_epi32(0x00070006, 0x00050004, 0x00030002, 0x00010000);
+  const __m128i mask = _mm_cmplt_epi16(index, symbol_vec);
+  const __m128i a = _mm_blendv_epi8(zero, cdf_max_probability, mask);
+  const __m128i diff = _mm_sub_epi16(cdf_vec0, a);
+  const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
+  cdf_vec0 = _mm_sub_epi16(cdf_vec0, delta);
+  StoreUnaligned16(cdf, cdf_vec0);
+
+  const __m128i index1 =
+      _mm_set_epi32(0x000b000a, 0x00090008, 0x00070006, 0x00050004);
+  const __m128i mask1 = _mm_cmplt_epi16(index1, symbol_vec);
+  const __m128i a1 = _mm_blendv_epi8(zero, cdf_max_probability, mask1);
+  const __m128i diff1 = _mm_sub_epi16(cdf_vec1, a1);
+  const __m128i delta1 = _mm_sra_epi16(diff1, _mm_cvtsi32_si128(rate));
+  cdf_vec1 = _mm_sub_epi16(cdf_vec1, delta1);
+  StoreUnaligned16(cdf + 4, cdf_vec1);
+
+  cdf[13] = count + static_cast<uint16_t>(count < 32);
+}
+
+void UpdateCdf16(uint16_t* const cdf, const int symbol) {
+  __m128i cdf_vec0 = LoadUnaligned16(cdf);
+  const uint16_t count = cdf[16];
+  const int rate = (4 | (count >> 4)) + 1;
+  const __m128i zero = _mm_setzero_si128();
+  const __m128i cdf_max_probability =
+      _mm_set1_epi16(kCdfMaxProbability + 1 - (1 << rate));
+  const __m128i symbol_vec = _mm_set1_epi16(symbol);
+
+  const __m128i index =
+      _mm_set_epi32(0x00070006, 0x00050004, 0x00030002, 0x00010000);
+  const __m128i mask = _mm_cmplt_epi16(index, symbol_vec);
+  const __m128i a = _mm_blendv_epi8(zero, cdf_max_probability, mask);
+  const __m128i diff = _mm_sub_epi16(cdf_vec0, a);
+  const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
+  cdf_vec0 = _mm_sub_epi16(cdf_vec0, delta);
+  StoreUnaligned16(cdf, cdf_vec0);
+
+  __m128i cdf_vec1 = LoadUnaligned16(cdf + 8);
+  const __m128i index1 =
+      _mm_set_epi32(0x000f000e, 0x000d000c, 0x000b000a, 0x00090008);
+  const __m128i mask1 = _mm_cmplt_epi16(index1, symbol_vec);
+  const __m128i a1 = _mm_blendv_epi8(zero, cdf_max_probability, mask1);
+  const __m128i diff1 = _mm_sub_epi16(cdf_vec1, a1);
+  const __m128i delta1 = _mm_sra_epi16(diff1, _mm_cvtsi32_si128(rate));
+  cdf_vec1 = _mm_sub_epi16(cdf_vec1, delta1);
+  StoreUnaligned16(cdf + 8, cdf_vec1);
+
+  cdf[16] = count + static_cast<uint16_t>(count < 32);
+}
+
+#else  // !LIBGAV1_ENTROPY_DECODER_ENABLE_SSE4
+
+void UpdateCdf5(uint16_t* const cdf, const int symbol) {
+  UpdateCdf(cdf, 5, symbol);
+}
+
+void UpdateCdf7(uint16_t* const cdf, const int symbol) {
+  UpdateCdf(cdf, 7, symbol);
+}
+
+void UpdateCdf8(uint16_t* const cdf, const int symbol) {
+  UpdateCdf(cdf, 8, symbol);
+}
+
+void UpdateCdf11(uint16_t* const cdf, const int symbol) {
+  UpdateCdf(cdf, 11, symbol);
+}
+
+void UpdateCdf13(uint16_t* const cdf, const int symbol) {
+  UpdateCdf(cdf, 13, symbol);
+}
+
+void UpdateCdf16(uint16_t* const cdf, const int symbol) {
+  UpdateCdf(cdf, 16, symbol);
+}
+
+#endif  // LIBGAV1_ENTROPY_DECODER_ENABLE_SSE4
+#endif  // LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
+
 }  // namespace
 
-namespace libgav1 {
-
-constexpr uint32_t DaalaBitReader::kWindowSize;  // static.
+#if !LIBGAV1_CXX17
+constexpr int DaalaBitReader::kWindowSize;  // static.
+#endif
 
 DaalaBitReader::DaalaBitReader(const uint8_t* data, size_t size,
                                bool allow_update_cdf)
@@ -131,10 +553,21 @@
 
 int64_t DaalaBitReader::ReadLiteral(int num_bits) {
   assert(num_bits <= 32);
+  assert(num_bits > 0);
   uint32_t literal = 0;
-  for (int bit = num_bits - 1; bit >= 0; --bit) {
-    literal |= static_cast<uint32_t>(ReadBit()) << bit;
-  }
+  int bit = num_bits - 1;
+  do {
+    // ARM can combine a shift operation with a constant number of bits with
+    // some other operations, such as the OR operation.
+    // Here is an ARM disassembly example:
+    // orr w1, w0, w1, lsl #1
+    // which left shifts register w1 by 1 bit and OR the shift result with
+    // register w0.
+    // The next 2 lines are equivalent to:
+    // literal |= static_cast<uint32_t>(ReadBit()) << bit;
+    literal <<= 1;
+    literal |= static_cast<uint32_t>(ReadBit());
+  } while (--bit >= 0);
   return literal;
 }
 
@@ -158,7 +591,7 @@
     // 32). So using that information:
     //  count >> 4 is 0 for count from 0 to 15.
     //  count >> 4 is 1 for count from 16 to 31.
-    //  count >> 4 is 2 for count == 31.
+    //  count >> 4 is 2 for count == 32.
     // Now, the equation becomes:
     //  4 + (count >> 4).
     // Since (count >> 4) can only be 0 or 1 or 2, the addition can be replaced
@@ -182,11 +615,33 @@
 template <int symbol_count>
 int DaalaBitReader::ReadSymbol(uint16_t* const cdf) {
   static_assert(symbol_count >= 3 && symbol_count <= 16, "");
-  const int symbol = (symbol_count <= 13)
-                         ? ReadSymbolImpl(cdf, symbol_count)
-                         : ReadSymbolImplBinarySearch(cdf, symbol_count);
+  if (symbol_count == 4) {
+    return ReadSymbol4(cdf);
+  }
+  int symbol;
+  if (symbol_count == 8) {
+    symbol = ReadSymbolImpl8(cdf);
+  } else if (symbol_count <= 13) {
+    symbol = ReadSymbolImpl(cdf, symbol_count);
+  } else {
+    symbol = ReadSymbolImplBinarySearch(cdf, symbol_count);
+  }
   if (allow_update_cdf_) {
-    UpdateCdf(cdf, symbol_count, symbol);
+    if (symbol_count == 5) {
+      UpdateCdf5(cdf, symbol);
+    } else if (symbol_count == 7) {
+      UpdateCdf7(cdf, symbol);
+    } else if (symbol_count == 8) {
+      UpdateCdf8(cdf, symbol);
+    } else if (symbol_count == 11) {
+      UpdateCdf11(cdf, symbol);
+    } else if (symbol_count == 13) {
+      UpdateCdf13(cdf, symbol);
+    } else if (symbol_count == 16) {
+      UpdateCdf16(cdf, symbol);
+    } else {
+      UpdateCdf(cdf, symbol_count, symbol);
+    }
   }
   return symbol;
 }
@@ -238,7 +693,7 @@
   // and |curr| is the scaled cdf value for |symbol|.
   uint32_t prev = values_in_range_;
   uint32_t curr = 0;
-  const uint16_t values_in_range_shifted = values_in_range_ >> 8;
+  const uint32_t values_in_range_shifted = values_in_range_ >> 8;
   do {
     const int mid = DivideBy2(low + high);
     const uint32_t scaled_cdf =
@@ -275,15 +730,279 @@
   return symbol;
 }
 
+// Equivalent to ReadSymbol(cdf, 4), with the ReadSymbolImpl and UpdateCdf
+// calls inlined.
+int DaalaBitReader::ReadSymbol4(uint16_t* const cdf) {
+  assert(cdf[3] == 0);
+  uint32_t curr = values_in_range_;
+  uint32_t prev;
+  const auto symbol_value =
+      static_cast<uint32_t>(window_diff_ >> (kWindowSize - 16));
+  uint32_t delta = kMinimumProbabilityPerSymbol * 3;
+  const uint32_t values_in_range_shifted = values_in_range_ >> 8;
+
+  // Search through the |cdf| array to determine where the scaled cdf value and
+  // |symbol_value| cross over. If allow_update_cdf_ is true, update the |cdf|
+  // array.
+  //
+  // The original code is:
+  //
+  //  int symbol = -1;
+  //  do {
+  //    prev = curr;
+  //    curr =
+  //        ((values_in_range_shifted * (cdf[++symbol] >> kCdfPrecision)) >> 1)
+  //        + delta;
+  //    delta -= kMinimumProbabilityPerSymbol;
+  //  } while (symbol_value < curr);
+  //  if (allow_update_cdf_) {
+  //    UpdateCdf(cdf, 4, symbol);
+  //  }
+  //
+  // The do-while loop is unrolled with four iterations, and the UpdateCdf call
+  // is inlined and merged into the four iterations.
+  int symbol = 0;
+  // Iteration 0.
+  prev = curr;
+  curr =
+      ((values_in_range_shifted * (cdf[symbol] >> kCdfPrecision)) >> 1) + delta;
+  if (symbol_value >= curr) {
+    // symbol == 0.
+    if (allow_update_cdf_) {
+      // Inlined version of UpdateCdf(cdf, 4, /*symbol=*/0).
+      const uint16_t count = cdf[4];
+      cdf[4] += static_cast<uint16_t>(count < 32);
+      const int rate = (4 | (count >> 4)) + 1;
+#if LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
+      // 1. On Motorola Moto G5 Plus (running 32-bit Android 8.1.0), the ARM
+      // NEON code is slower. Consider using the C version if __arm__ is
+      // defined.
+      // 2. The ARM NEON code (compiled for arm64) is slightly slower on
+      // Samsung Galaxy S8+ (SM-G955FD).
+      uint16x4_t cdf_vec = vld1_u16(cdf);
+      const int16x4_t negative_rate = vdup_n_s16(-rate);
+      const uint16x4_t delta = vshl_u16(cdf_vec, negative_rate);
+      cdf_vec = vsub_u16(cdf_vec, delta);
+      vst1_u16(cdf, cdf_vec);
+#elif LIBGAV1_ENTROPY_DECODER_ENABLE_SSE4
+      __m128i cdf_vec = LoadLo8(cdf);
+      const __m128i delta = _mm_sra_epi16(cdf_vec, _mm_cvtsi32_si128(rate));
+      cdf_vec = _mm_sub_epi16(cdf_vec, delta);
+      StoreLo8(cdf, cdf_vec);
+#else  // !LIBGAV1_ENTROPY_DECODER_ENABLE_SSE4
+      cdf[0] -= cdf[0] >> rate;
+      cdf[1] -= cdf[1] >> rate;
+      cdf[2] -= cdf[2] >> rate;
+#endif
+    }
+    goto found;
+  }
+  ++symbol;
+  delta -= kMinimumProbabilityPerSymbol;
+  // Iteration 1.
+  prev = curr;
+  curr =
+      ((values_in_range_shifted * (cdf[symbol] >> kCdfPrecision)) >> 1) + delta;
+  if (symbol_value >= curr) {
+    // symbol == 1.
+    if (allow_update_cdf_) {
+      // Inlined version of UpdateCdf(cdf, 4, /*symbol=*/1).
+      const uint16_t count = cdf[4];
+      cdf[4] += static_cast<uint16_t>(count < 32);
+      const int rate = (4 | (count >> 4)) + 1;
+      cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
+      cdf[1] -= cdf[1] >> rate;
+      cdf[2] -= cdf[2] >> rate;
+    }
+    goto found;
+  }
+  ++symbol;
+  delta -= kMinimumProbabilityPerSymbol;
+  // Iteration 2.
+  prev = curr;
+  curr =
+      ((values_in_range_shifted * (cdf[symbol] >> kCdfPrecision)) >> 1) + delta;
+  if (symbol_value >= curr) {
+    // symbol == 2.
+    if (allow_update_cdf_) {
+      // Inlined version of UpdateCdf(cdf, 4, /*symbol=*/2).
+      const uint16_t count = cdf[4];
+      cdf[4] += static_cast<uint16_t>(count < 32);
+      const int rate = (4 | (count >> 4)) + 1;
+      cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
+      cdf[1] += (kCdfMaxProbability - cdf[1]) >> rate;
+      cdf[2] -= cdf[2] >> rate;
+    }
+    goto found;
+  }
+  ++symbol;
+  // |delta| is 0 for the last iteration.
+  // Iteration 3.
+  prev = curr;
+  // Since cdf[3] is 0 and |delta| is 0, |curr| is also 0.
+  curr = 0;
+  // symbol == 3.
+  if (allow_update_cdf_) {
+    // Inlined version of UpdateCdf(cdf, 4, /*symbol=*/3).
+    const uint16_t count = cdf[4];
+    cdf[4] += static_cast<uint16_t>(count < 32);
+    const int rate = (4 | (count >> 4)) + 1;
+#if LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
+    // On Motorola Moto G5 Plus (running 32-bit Android 8.1.0), the ARM NEON
+    // code is a tiny bit slower. Consider using the C version if __arm__ is
+    // defined.
+    uint16x4_t cdf_vec = vld1_u16(cdf);
+    const uint16x4_t cdf_max_probability = vdup_n_u16(kCdfMaxProbability);
+    const int16x4_t diff =
+        vreinterpret_s16_u16(vsub_u16(cdf_max_probability, cdf_vec));
+    const int16x4_t negative_rate = vdup_n_s16(-rate);
+    const uint16x4_t delta =
+        vreinterpret_u16_s16(vshl_s16(diff, negative_rate));
+    cdf_vec = vadd_u16(cdf_vec, delta);
+    vst1_u16(cdf, cdf_vec);
+    cdf[3] = 0;
+#elif LIBGAV1_ENTROPY_DECODER_ENABLE_SSE4
+    __m128i cdf_vec = LoadLo8(cdf);
+    const __m128i cdf_max_probability =
+        _mm_shufflelo_epi16(_mm_cvtsi32_si128(kCdfMaxProbability), 0);
+    const __m128i diff = _mm_sub_epi16(cdf_max_probability, cdf_vec);
+    const __m128i delta = _mm_sra_epi16(diff, _mm_cvtsi32_si128(rate));
+    cdf_vec = _mm_add_epi16(cdf_vec, delta);
+    StoreLo8(cdf, cdf_vec);
+    cdf[3] = 0;
+#else  // !LIBGAV1_ENTROPY_DECODER_ENABLE_SSE4
+    cdf[0] += (kCdfMaxProbability - cdf[0]) >> rate;
+    cdf[1] += (kCdfMaxProbability - cdf[1]) >> rate;
+    cdf[2] += (kCdfMaxProbability - cdf[2]) >> rate;
+#endif
+  }
+found:
+  // End of unrolled do-while loop.
+
+  values_in_range_ = prev - curr;
+  window_diff_ -= static_cast<WindowSize>(curr) << (kWindowSize - 16);
+  NormalizeRange();
+  return symbol;
+}
+
+int DaalaBitReader::ReadSymbolImpl8(const uint16_t* const cdf) {
+  assert(cdf[7] == 0);
+  uint32_t curr = values_in_range_;
+  uint32_t prev;
+  const auto symbol_value =
+      static_cast<uint32_t>(window_diff_ >> (kWindowSize - 16));
+  uint32_t delta = kMinimumProbabilityPerSymbol * 7;
+  // Search through the |cdf| array to determine where the scaled cdf value and
+  // |symbol_value| cross over.
+  //
+  // The original code is:
+  //
+  // int symbol = -1;
+  // do {
+  //   prev = curr;
+  //   curr =
+  //       (((values_in_range_ >> 8) * (cdf[++symbol] >> kCdfPrecision)) >> 1)
+  //       + delta;
+  //   delta -= kMinimumProbabilityPerSymbol;
+  // } while (symbol_value < curr);
+  //
+  // The do-while loop is unrolled with eight iterations.
+  int symbol = 0;
+
+#define READ_SYMBOL_ITERATION                                                \
+  prev = curr;                                                               \
+  curr = (((values_in_range_ >> 8) * (cdf[symbol] >> kCdfPrecision)) >> 1) + \
+         delta;                                                              \
+  if (symbol_value >= curr) goto found;                                      \
+  ++symbol;                                                                  \
+  delta -= kMinimumProbabilityPerSymbol
+
+  READ_SYMBOL_ITERATION;  // Iteration 0.
+  READ_SYMBOL_ITERATION;  // Iteration 1.
+  READ_SYMBOL_ITERATION;  // Iteration 2.
+  READ_SYMBOL_ITERATION;  // Iteration 3.
+  READ_SYMBOL_ITERATION;  // Iteration 4.
+  READ_SYMBOL_ITERATION;  // Iteration 5.
+
+  // The last two iterations can be simplified, so they don't use the
+  // READ_SYMBOL_ITERATION macro.
+#undef READ_SYMBOL_ITERATION
+
+  // Iteration 6.
+  prev = curr;
+  curr =
+      (((values_in_range_ >> 8) * (cdf[symbol] >> kCdfPrecision)) >> 1) + delta;
+  if (symbol_value >= curr) goto found;  // symbol == 6.
+  ++symbol;
+  // |delta| is 0 for the last iteration.
+  // Iteration 7.
+  prev = curr;
+  // Since cdf[7] is 0 and |delta| is 0, |curr| is also 0.
+  curr = 0;
+  // symbol == 7.
+found:
+  // End of unrolled do-while loop.
+
+  values_in_range_ = prev - curr;
+  window_diff_ -= static_cast<WindowSize>(curr) << (kWindowSize - 16);
+  NormalizeRange();
+  return symbol;
+}
+
 void DaalaBitReader::PopulateBits() {
-  int shift = kWindowSize - 9 - (bits_ + 15);
-  for (; shift >= 0 && data_index_ < size_; shift -= 8) {
-    window_diff_ ^= static_cast<WindowSize>(data_[data_index_++]) << shift;
-    bits_ += 8;
+#if defined(__aarch64__)
+  // Fast path: read eight bytes and add the first six bytes to window_diff_.
+  // This fast path makes the following assumptions.
+  // 1. We assume that unaligned load of uint64_t is fast.
+  // 2. When there are enough bytes in data_, the for loop below reads 6 or 7
+  //    bytes depending on the value of bits_. This fast path always reads 6
+  //    bytes, which results in more calls to PopulateBits(). We assume that
+  //    making more calls to a faster PopulateBits() is overall a win.
+  // NOTE: Although this fast path could also be used on x86_64, it hurts
+  // performance (measured on Lenovo ThinkStation P920 running Linux). (The
+  // reason is still unknown.) Therefore this fast path is only used on arm64.
+  static_assert(kWindowSize == 64, "");
+  if (size_ - data_index_ >= 8) {
+    uint64_t value;
+    // arm64 supports unaligned loads, so this memcpy call is compiled to a
+    // single ldr instruction.
+    memcpy(&value, &data_[data_index_], 8);
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+    value = __builtin_bswap64(value);
+#endif
+    value &= 0xffffffffffff0000;
+    window_diff_ ^= static_cast<WindowSize>(value) >> (bits_ + 16);
+    data_index_ += 6;
+    bits_ += 6 * 8;
+    return;
   }
-  if (data_index_ >= size_) {
-    bits_ = kLargeBitCount;
+#endif
+
+  size_t data_index = data_index_;
+  int bits = bits_;
+  WindowSize window_diff = window_diff_;
+
+  int shift = kWindowSize - 9 - (bits + 15);
+  // The fast path above, if compiled, would cause clang 8.0.7 to vectorize
+  // this loop. Since -15 <= bits_ <= -1, this loop has at most 6 or 7
+  // iterations when WindowSize is 64 bits. So it is not profitable to
+  // vectorize this loop. Note that clang 8.0.7 does not vectorize this loop if
+  // the fast path above is not compiled.
+
+#ifdef __clang__
+#pragma clang loop vectorize(disable) interleave(disable)
+#endif
+  for (; shift >= 0 && data_index < size_; shift -= 8) {
+    window_diff ^= static_cast<WindowSize>(data_[data_index++]) << shift;
+    bits += 8;
   }
+  if (data_index >= size_) {
+    bits = kLargeBitCount;
+  }
+
+  data_index_ = data_index;
+  bits_ = bits;
+  window_diff_ = window_diff;
 }
 
 void DaalaBitReader::NormalizeRange() {
@@ -300,8 +1019,10 @@
 template int DaalaBitReader::ReadSymbol<5>(uint16_t* cdf);
 template int DaalaBitReader::ReadSymbol<7>(uint16_t* cdf);
 template int DaalaBitReader::ReadSymbol<8>(uint16_t* cdf);
+template int DaalaBitReader::ReadSymbol<10>(uint16_t* cdf);
 template int DaalaBitReader::ReadSymbol<11>(uint16_t* cdf);
 template int DaalaBitReader::ReadSymbol<13>(uint16_t* cdf);
+template int DaalaBitReader::ReadSymbol<14>(uint16_t* cdf);
 template int DaalaBitReader::ReadSymbol<16>(uint16_t* cdf);
 
 }  // namespace libgav1
diff --git a/libgav1/src/utils/entropy_decoder.h b/libgav1/src/utils/entropy_decoder.h
index 5867a61..75c633b 100644
--- a/libgav1/src/utils/entropy_decoder.h
+++ b/libgav1/src/utils/entropy_decoder.h
@@ -21,6 +21,7 @@
 #include <cstdint>
 
 #include "src/utils/bit_reader.h"
+#include "src/utils/compiler_attributes.h"
 
 namespace libgav1 {
 
@@ -33,7 +34,7 @@
   DaalaBitReader(DaalaBitReader&& rhs) noexcept;
   DaalaBitReader& operator=(DaalaBitReader&& rhs) noexcept;
 
-  int ReadBit() override;
+  int ReadBit() final;
   int64_t ReadLiteral(int num_bits) override;
   // ReadSymbol() calls for which the |symbol_count| is only known at runtime
   // will use this variant.
@@ -49,9 +50,12 @@
   int ReadSymbol(uint16_t* cdf);
 
  private:
-  using WindowSize = uint32_t;
-  static constexpr uint32_t kWindowSize =
-      static_cast<uint32_t>(sizeof(WindowSize)) * 8;
+  // WindowSize must be an unsigned integer type with at least 32 bits. Use the
+  // largest type with fast arithmetic. size_t should meet these requirements.
+  static_assert(sizeof(size_t) == sizeof(void*), "");
+  using WindowSize = size_t;
+  static constexpr int kWindowSize = static_cast<int>(sizeof(WindowSize)) * 8;
+  static_assert(kWindowSize >= 32, "");
 
   // Reads a symbol using the |cdf| table which contains the probabilities of
   // each symbol. On a high level, this function does the following:
@@ -61,33 +65,50 @@
   //   3) That index is the symbol that has been decoded.
   //   4) Update |window_diff_| and |values_in_range_| based on the symbol that
   //   has been decoded.
-  int ReadSymbolImpl(const uint16_t* cdf, int symbol_count);
+  inline int ReadSymbolImpl(const uint16_t* cdf, int symbol_count);
   // Similar to ReadSymbolImpl but it uses binary search to perform step 2 in
   // the comment above. As of now, this function is called when |symbol_count|
-  // is greater than or equal to 8.
-  int ReadSymbolImplBinarySearch(const uint16_t* cdf, int symbol_count);
+  // is greater than or equal to 14.
+  inline int ReadSymbolImplBinarySearch(const uint16_t* cdf, int symbol_count);
   // Specialized implementation of ReadSymbolImpl based on the fact that
   // symbol_count == 2.
-  int ReadSymbolImpl(const uint16_t* cdf);
-  void PopulateBits();
+  inline int ReadSymbolImpl(const uint16_t* cdf);
+  // ReadSymbolN is a specialization of ReadSymbol for symbol_count == N.
+  LIBGAV1_ALWAYS_INLINE int ReadSymbol4(uint16_t* cdf);
+  // ReadSymbolImplN is a specialization of ReadSymbolImpl for
+  // symbol_count == N.
+  LIBGAV1_ALWAYS_INLINE int ReadSymbolImpl8(const uint16_t* cdf);
+  inline void PopulateBits();
   // Normalizes the range so that 32768 <= |values_in_range_| < 65536. Also
   // calls PopulateBits() if necessary.
-  void NormalizeRange();
+  inline void NormalizeRange();
 
-  const uint8_t* data_;
+  const uint8_t* const data_;
   const size_t size_;
   size_t data_index_;
   const bool allow_update_cdf_;
   // Number of bits of data in the current value.
   int bits_;
-  // Number of values in the current range.
-  uint16_t values_in_range_;
+  // Number of values in the current range. Declared as uint32_t for better
+  // performance but only the lower 16 bits are used.
+  uint32_t values_in_range_;
   // The difference between the high end of the current range and the coded
-  // value minus 1. The 16 least significant bits of this variable is used to
+  // value minus 1. The 16 most significant bits of this variable is used to
   // decode the next symbol. It is filled in whenever |bits_| is less than 0.
   WindowSize window_diff_;
 };
 
+extern template int DaalaBitReader::ReadSymbol<3>(uint16_t* cdf);
+extern template int DaalaBitReader::ReadSymbol<4>(uint16_t* cdf);
+extern template int DaalaBitReader::ReadSymbol<5>(uint16_t* cdf);
+extern template int DaalaBitReader::ReadSymbol<7>(uint16_t* cdf);
+extern template int DaalaBitReader::ReadSymbol<8>(uint16_t* cdf);
+extern template int DaalaBitReader::ReadSymbol<10>(uint16_t* cdf);
+extern template int DaalaBitReader::ReadSymbol<11>(uint16_t* cdf);
+extern template int DaalaBitReader::ReadSymbol<13>(uint16_t* cdf);
+extern template int DaalaBitReader::ReadSymbol<14>(uint16_t* cdf);
+extern template int DaalaBitReader::ReadSymbol<16>(uint16_t* cdf);
+
 }  // namespace libgav1
 
 #endif  // LIBGAV1_SRC_UTILS_ENTROPY_DECODER_H_
diff --git a/libgav1/src/utils/libgav1_utils.cmake b/libgav1/src/utils/libgav1_utils.cmake
new file mode 100644
index 0000000..8b6ec4b
--- /dev/null
+++ b/libgav1/src/utils/libgav1_utils.cmake
@@ -0,0 +1,72 @@
+# Copyright 2019 The libgav1 Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(LIBGAV1_UTILS_LIBGAV1_UTILS_CMAKE_)
+  return()
+endif() # LIBGAV1_UTILS_LIBGAV1_UTILS_CMAKE_
+set(LIBGAV1_UTILS_LIBGAV1_UTILS_CMAKE_ 1)
+
+list(APPEND libgav1_utils_sources
+            "${libgav1_source}/utils/array_2d.h"
+            "${libgav1_source}/utils/bit_mask_set.h"
+            "${libgav1_source}/utils/bit_reader.cc"
+            "${libgav1_source}/utils/bit_reader.h"
+            "${libgav1_source}/utils/block_parameters_holder.cc"
+            "${libgav1_source}/utils/block_parameters_holder.h"
+            "${libgav1_source}/utils/blocking_counter.h"
+            "${libgav1_source}/utils/common.h"
+            "${libgav1_source}/utils/compiler_attributes.h"
+            "${libgav1_source}/utils/constants.cc"
+            "${libgav1_source}/utils/constants.h"
+            "${libgav1_source}/utils/cpu.cc"
+            "${libgav1_source}/utils/cpu.h"
+            "${libgav1_source}/utils/dynamic_buffer.h"
+            "${libgav1_source}/utils/entropy_decoder.cc"
+            "${libgav1_source}/utils/entropy_decoder.h"
+            "${libgav1_source}/utils/executor.cc"
+            "${libgav1_source}/utils/executor.h"
+            "${libgav1_source}/utils/logging.cc"
+            "${libgav1_source}/utils/logging.h"
+            "${libgav1_source}/utils/memory.h"
+            "${libgav1_source}/utils/parameter_tree.cc"
+            "${libgav1_source}/utils/parameter_tree.h"
+            "${libgav1_source}/utils/queue.h"
+            "${libgav1_source}/utils/raw_bit_reader.cc"
+            "${libgav1_source}/utils/raw_bit_reader.h"
+            "${libgav1_source}/utils/reference_info.h"
+            "${libgav1_source}/utils/segmentation.cc"
+            "${libgav1_source}/utils/segmentation.h"
+            "${libgav1_source}/utils/segmentation_map.cc"
+            "${libgav1_source}/utils/segmentation_map.h"
+            "${libgav1_source}/utils/stack.h"
+            "${libgav1_source}/utils/threadpool.cc"
+            "${libgav1_source}/utils/threadpool.h"
+            "${libgav1_source}/utils/types.h"
+            "${libgav1_source}/utils/unbounded_queue.h"
+            "${libgav1_source}/utils/vector.h")
+
+macro(libgav1_add_utils_targets)
+  libgav1_add_library(NAME
+                      libgav1_utils
+                      TYPE
+                      OBJECT
+                      SOURCES
+                      ${libgav1_utils_sources}
+                      DEFINES
+                      ${libgav1_defines}
+                      INCLUDES
+                      ${libgav1_include_paths}
+                      ${libgav1_gtest_include_paths})
+
+endmacro()
diff --git a/libgav1/src/utils/logging.cc b/libgav1/src/utils/logging.cc
index 26e3e15..9a43c22 100644
--- a/libgav1/src/utils/logging.cc
+++ b/libgav1/src/utils/logging.cc
@@ -56,7 +56,7 @@
   va_end(ap);
   fprintf(stderr, "\n");
 }
-#else   // !LIBGAV1_ENABLE_LOGGING
+#else  // !LIBGAV1_ENABLE_LOGGING
 void Log(LogSeverity /*severity*/, const char* /*file*/, int /*line*/,
          const char* /*format*/, ...) {}
 #endif  // LIBGAV1_ENABLE_LOGGING
diff --git a/libgav1/src/utils/logging.h b/libgav1/src/utils/logging.h
index 378d369..48928db 100644
--- a/libgav1/src/utils/logging.h
+++ b/libgav1/src/utils/logging.h
@@ -22,7 +22,7 @@
 #include "src/utils/compiler_attributes.h"
 
 #if !defined(LIBGAV1_ENABLE_LOGGING)
-#if defined(NDEBUG)
+#if defined(NDEBUG) || defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION)
 #define LIBGAV1_ENABLE_LOGGING 0
 #else
 #define LIBGAV1_ENABLE_LOGGING 1
diff --git a/libgav1/src/utils/memory.h b/libgav1/src/utils/memory.h
index f9d921a..80c1d8c 100644
--- a/libgav1/src/utils/memory.h
+++ b/libgav1/src/utils/memory.h
@@ -70,7 +70,7 @@
   // more convenient to use memalign(). Unlike glibc, Android does not consider
   // memalign() an obsolete function.
   return memalign(alignment, size);
-#else   // !defined(__ANDROID__)
+#else  // !defined(__ANDROID__)
   void* ptr = nullptr;
   // posix_memalign requires that the requested alignment be at least
   // sizeof(void*). In this case, fall back on malloc which should return
@@ -139,10 +139,48 @@
   // Class-specific non-throwing allocation functions
   static void* operator new(size_t size, const std::nothrow_t& tag) noexcept {
     if (size > 0x40000000) return nullptr;
+    return ::operator new(size, tag);
+  }
+  static void* operator new[](size_t size, const std::nothrow_t& tag) noexcept {
+    if (size > 0x40000000) return nullptr;
+    return ::operator new[](size, tag);
+  }
+
+  // Class-specific deallocation functions.
+  static void operator delete(void* ptr) noexcept { ::operator delete(ptr); }
+  static void operator delete[](void* ptr) noexcept {
+    ::operator delete[](ptr);
+  }
+
+  // Only called if new (std::nothrow) is used and the constructor throws an
+  // exception.
+  static void operator delete(void* ptr, const std::nothrow_t& tag) noexcept {
+    ::operator delete(ptr, tag);
+  }
+  // Only called if new[] (std::nothrow) is used and the constructor throws an
+  // exception.
+  static void operator delete[](void* ptr, const std::nothrow_t& tag) noexcept {
+    ::operator delete[](ptr, tag);
+  }
+};
+
+// A variant of Allocable that forces allocations to be aligned to
+// kMaxAlignment bytes. This is intended for use with classes that use
+// alignas() with this value. C++17 aligned new/delete are used if available,
+// otherwise we use AlignedAlloc/Free.
+struct MaxAlignedAllocable {
+  // Class-specific allocation functions.
+  static void* operator new(size_t size) = delete;
+  static void* operator new[](size_t size) = delete;
+
+  // Class-specific non-throwing allocation functions
+  static void* operator new(size_t size, const std::nothrow_t& tag) noexcept {
+    if (size > 0x40000000) return nullptr;
 #ifdef __cpp_aligned_new
     return ::operator new(size, std::align_val_t(kMaxAlignment), tag);
 #else
-    return ::operator new(size, tag);
+    static_cast<void>(tag);
+    return AlignedAlloc(kMaxAlignment, size);
 #endif
   }
   static void* operator new[](size_t size, const std::nothrow_t& tag) noexcept {
@@ -150,7 +188,8 @@
 #ifdef __cpp_aligned_new
     return ::operator new[](size, std::align_val_t(kMaxAlignment), tag);
 #else
-    return ::operator new[](size, tag);
+    static_cast<void>(tag);
+    return AlignedAlloc(kMaxAlignment, size);
 #endif
   }
 
@@ -159,14 +198,14 @@
 #ifdef __cpp_aligned_new
     ::operator delete(ptr, std::align_val_t(kMaxAlignment));
 #else
-    ::operator delete(ptr);
+    AlignedFree(ptr);
 #endif
   }
   static void operator delete[](void* ptr) noexcept {
 #ifdef __cpp_aligned_new
     ::operator delete[](ptr, std::align_val_t(kMaxAlignment));
 #else
-    ::operator delete[](ptr);
+    AlignedFree(ptr);
 #endif
   }
 
@@ -176,7 +215,8 @@
 #ifdef __cpp_aligned_new
     ::operator delete(ptr, std::align_val_t(kMaxAlignment), tag);
 #else
-    ::operator delete(ptr, tag);
+    static_cast<void>(tag);
+    AlignedFree(ptr);
 #endif
   }
   // Only called if new[] (std::nothrow) is used and the constructor throws an
@@ -185,7 +225,8 @@
 #ifdef __cpp_aligned_new
     ::operator delete[](ptr, std::align_val_t(kMaxAlignment), tag);
 #else
-    ::operator delete[](ptr, tag);
+    static_cast<void>(tag);
+    AlignedFree(ptr);
 #endif
   }
 };
diff --git a/libgav1/src/utils/queue.h b/libgav1/src/utils/queue.h
index 41ecda9..cffb9ca 100644
--- a/libgav1/src/utils/queue.h
+++ b/libgav1/src/utils/queue.h
@@ -26,8 +26,7 @@
 
 namespace libgav1 {
 
-// A FIFO queue of a fixed capacity. The elements are copied, so the element
-// type T should be small.
+// A FIFO queue of a fixed capacity.
 //
 // WARNING: No error checking is performed.
 template <typename T>
@@ -42,21 +41,43 @@
 
   // Pushes the element |value| to the end of the queue. It is an error to call
   // Push() when the queue is full.
-  void Push(T value) {
+  void Push(T&& value) {
     assert(size_ < capacity_);
-    elements_[back_++] = value;
-    if (back_ == capacity_) back_ = 0;
+    elements_[end_++] = std::move(value);
+    if (end_ == capacity_) end_ = 0;
     ++size_;
   }
 
-  // Returns the element at the front of the queue and removes it from the
-  // queue. It is an error to call Pop() when the queue is empty.
-  T Pop() {
+  // Removes the element at the front of the queue. It is an error to call Pop()
+  // when the queue is empty.
+  void Pop() {
     assert(size_ != 0);
-    const T front_element = elements_[front_++];
-    if (front_ == capacity_) front_ = 0;
+    const T element = std::move(elements_[begin_++]);
+    static_cast<void>(element);
+    if (begin_ == capacity_) begin_ = 0;
     --size_;
-    return front_element;
+  }
+
+  // Returns a reference to the element at the front of the queue. It is an
+  // error to call Front() when the queue is empty.
+  T& Front() {
+    assert(size_ != 0);
+    return elements_[begin_];
+  }
+
+  // Returns a reference to the element at the back of the queue. It is an error
+  // to call Back() when the queue is empty.
+  T& Back() {
+    assert(size_ != 0);
+    const size_t back = ((end_ == 0) ? capacity_ : end_) - 1;
+    return elements_[back];
+  }
+
+  // Clears the queue.
+  void Clear() {
+    while (!Empty()) {
+      Pop();
+    }
   }
 
   // Returns true if the queue is empty.
@@ -73,9 +94,9 @@
   std::unique_ptr<T[]> elements_;
   size_t capacity_ = 0;
   // The index of the element to be removed by Pop().
-  size_t front_ = 0;
+  size_t begin_ = 0;
   // The index where the new element is inserted by Push().
-  size_t back_ = 0;
+  size_t end_ = 0;
   size_t size_ = 0;
 };
 
diff --git a/libgav1/src/utils/raw_bit_reader.cc b/libgav1/src/utils/raw_bit_reader.cc
index 271d1b6..15e980d 100644
--- a/libgav1/src/utils/raw_bit_reader.cc
+++ b/libgav1/src/utils/raw_bit_reader.cc
@@ -64,13 +64,22 @@
 int64_t RawBitReader::ReadLiteral(int num_bits) {
   assert(num_bits <= 32);
   if (!CanReadLiteral(num_bits)) return -1;
-  uint32_t value = 0;
-  // We can now call ReadBitImpl() since we've made sure that there are enough
-  // bits to be read.
-  for (int i = num_bits - 1; i >= 0; --i) {
-    value |= static_cast<uint32_t>(ReadBitImpl()) << i;
-  }
-  return value;
+  assert(num_bits > 0);
+  uint32_t literal = 0;
+  int bit = num_bits - 1;
+  do {
+    // ARM can combine a shift operation with a constant number of bits with
+    // some other operations, such as the OR operation.
+    // Here is an ARM disassembly example:
+    // orr w1, w0, w1, lsl #1
+    // which left shifts register w1 by 1 bit and OR the shift result with
+    // register w0.
+    // The next 2 lines are equivalent to:
+    // literal |= static_cast<uint32_t>(ReadBitImpl()) << bit;
+    literal <<= 1;
+    literal |= static_cast<uint32_t>(ReadBitImpl());
+  } while (--bit >= 0);
+  return literal;
 }
 
 bool RawBitReader::ReadInverseSignedLiteral(int num_bits, int* const value) {
@@ -155,12 +164,18 @@
       return false;
     }
   }
-  const int literal = static_cast<int>(ReadLiteral(leading_zeros));
-  if (literal == -1) {
-    LIBGAV1_DLOG(ERROR, "Not enough bits to read uvlc value.");
-    return false;
+  int literal;
+  if (leading_zeros != 0) {
+    literal = static_cast<int>(ReadLiteral(leading_zeros));
+    if (literal == -1) {
+      LIBGAV1_DLOG(ERROR, "Not enough bits to read uvlc value.");
+      return false;
+    }
+    literal += (1U << leading_zeros) - 1;
+  } else {
+    literal = 0;
   }
-  *value = literal + (1U << leading_zeros) - 1;
+  *value = literal;
   return true;
 }
 
diff --git a/libgav1/src/utils/reference_info.h b/libgav1/src/utils/reference_info.h
new file mode 100644
index 0000000..a660791
--- /dev/null
+++ b/libgav1/src/utils/reference_info.h
@@ -0,0 +1,92 @@
+/*
+ * Copyright 2020 The libgav1 Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_SRC_UTILS_REFERENCE_INFO_H_
+#define LIBGAV1_SRC_UTILS_REFERENCE_INFO_H_
+
+#include <array>
+#include <cstdint>
+
+#include "src/utils/array_2d.h"
+#include "src/utils/constants.h"
+#include "src/utils/types.h"
+
+namespace libgav1 {
+
+// This struct collects some members related to reference frames in one place to
+// make it easier to pass them as parameters to some dsp functions.
+struct ReferenceInfo {
+  // Initialize |motion_field_reference_frame| so that
+  // Tile::StoreMotionFieldMvsIntoCurrentFrame() can skip some updates when
+  // the updates are the same as the initialized value.
+  // Set to kReferenceFrameIntra instead of kReferenceFrameNone to simplify
+  // branch conditions in motion field projection.
+  // The following memory initialization of contiguous memory is very fast. It
+  // is not recommended to make the initialization multi-threaded, unless the
+  // memory which needs to be initialized in each thread is still contiguous.
+  LIBGAV1_MUST_USE_RESULT bool Reset(int rows, int columns) {
+    return motion_field_reference_frame.Reset(rows, columns,
+                                              /*zero_initialize=*/true) &&
+           motion_field_mv.Reset(
+               rows, columns,
+#if LIBGAV1_MSAN
+               // It is set in Tile::StoreMotionFieldMvsIntoCurrentFrame() only
+               // for qualified blocks. In MotionFieldProjectionKernel() dsp
+               // optimizations, it is read no matter it was set or not.
+               /*zero_initialize=*/true
+#else
+               /*zero_initialize=*/false
+#endif
+           );
+  }
+
+  // All members are used by inter frames only.
+  // For intra frames, they are not initialized.
+
+  std::array<uint8_t, kNumReferenceFrameTypes> order_hint;
+
+  // An example when |relative_distance_from| does not equal
+  // -|relative_distance_to|:
+  // |relative_distance_from| = GetRelativeDistance(7, 71, 25) = -64
+  // -|relative_distance_to| = -GetRelativeDistance(71, 7, 25) = 64
+  // This is why we need both |relative_distance_from| and
+  // |relative_distance_to|.
+  // |relative_distance_from|: Relative distances from reference frames to this
+  // frame.
+  std::array<int8_t, kNumReferenceFrameTypes> relative_distance_from;
+  // |relative_distance_to|: Relative distances to reference frames.
+  std::array<int8_t, kNumReferenceFrameTypes> relative_distance_to;
+
+  // Skip motion field projection of specific types of frames if their
+  // |relative_distance_to| is negative or too large.
+  std::array<bool, kNumReferenceFrameTypes> skip_references;
+  // Lookup table to get motion field projection division multiplier of specific
+  // types of frames. Derived from kProjectionMvDivisionLookup.
+  std::array<int16_t, kNumReferenceFrameTypes> projection_divisions;
+
+  // The current frame's |motion_field_reference_frame| and |motion_field_mv_|
+  // are guaranteed to be allocated only when refresh_frame_flags is not 0.
+  // Array of size (rows4x4 / 2) x (columns4x4 / 2). Entry at i, j corresponds
+  // to MfRefFrames[i * 2 + 1][j * 2 + 1] in the spec.
+  Array2D<ReferenceFrameType> motion_field_reference_frame;
+  // Array of size (rows4x4 / 2) x (columns4x4 / 2). Entry at i, j corresponds
+  // to MfMvs[i * 2 + 1][j * 2 + 1] in the spec.
+  Array2D<MotionVector> motion_field_mv;
+};
+
+}  // namespace libgav1
+
+#endif  // LIBGAV1_SRC_UTILS_REFERENCE_INFO_H_
diff --git a/libgav1/src/utils/segmentation.h b/libgav1/src/utils/segmentation.h
index 0467c85..67ff74c 100644
--- a/libgav1/src/utils/segmentation.h
+++ b/libgav1/src/utils/segmentation.h
@@ -20,60 +20,13 @@
 #include <cstdint>
 
 #include "src/utils/constants.h"
+#include "src/utils/types.h"
 
 namespace libgav1 {
 
-// The corresponding segment feature constants in the AV1 spec are named
-// SEG_LVL_xxx.
-enum SegmentFeature : uint8_t {
-  kSegmentFeatureQuantizer,
-  kSegmentFeatureLoopFilterYVertical,
-  kSegmentFeatureLoopFilterYHorizontal,
-  kSegmentFeatureLoopFilterU,
-  kSegmentFeatureLoopFilterV,
-  kSegmentFeatureReferenceFrame,
-  kSegmentFeatureSkip,
-  kSegmentFeatureGlobalMv,
-  kSegmentFeatureMax
-};
-
 extern const int8_t kSegmentationFeatureBits[kSegmentFeatureMax];
 extern const int kSegmentationFeatureMaxValues[kSegmentFeatureMax];
 
-struct Segmentation {
-  // 5.11.14.
-  // Returns true if the feature is enabled in the segment.
-  bool FeatureActive(int segment_id, SegmentFeature feature) const {
-    return enabled && segment_id < kMaxSegments &&
-           feature_enabled[segment_id][feature];
-  }
-
-  // Returns true if the feature is signed.
-  static bool FeatureSigned(SegmentFeature feature) {
-    // Only the first five segment features are signed, so this comparison
-    // suffices.
-    return feature <= kSegmentFeatureLoopFilterV;
-  }
-
-  bool enabled;
-  bool update_map;
-  bool update_data;
-  bool temporal_update;
-  // True if the segment id will be read before the skip syntax element. False
-  // if the skip syntax element will be read first.
-  bool segment_id_pre_skip;
-  // The highest numbered segment id that has some enabled feature. Used as
-  // the upper bound for decoding segment ids.
-  int8_t last_active_segment_id;
-
-  bool feature_enabled[kMaxSegments][kSegmentFeatureMax];
-  int16_t feature_data[kMaxSegments][kSegmentFeatureMax];
-  bool lossless[kMaxSegments];
-  // Cached values of get_qindex(1, segmentId), to be consumed by
-  // Tile::ReadTransformType(). The values are in the range [0, 255].
-  uint8_t qindex[kMaxSegments];
-};
-
 }  // namespace libgav1
 
 #endif  // LIBGAV1_SRC_UTILS_SEGMENTATION_H_
diff --git a/libgav1/src/utils/threadpool.cc b/libgav1/src/utils/threadpool.cc
index 6f9d2f5..8c8f4fe 100644
--- a/libgav1/src/utils/threadpool.cc
+++ b/libgav1/src/utils/threadpool.cc
@@ -24,6 +24,7 @@
 #include <sys/types.h>
 #include <unistd.h>
 #endif
+#include <algorithm>
 #include <cassert>
 #include <cinttypes>
 #include <cstddef>
@@ -132,8 +133,14 @@
   void Join();
 
  private:
+#if defined(_MSC_VER)
+  static unsigned int __stdcall ThreadBody(void* arg);
+#else
   static void* ThreadBody(void* arg);
+#endif
+
   void SetupName();
+  void Run();
 
   ThreadPool* pool_;
 #if defined(_MSC_VER)
@@ -153,13 +160,8 @@
   // created using CreateThread calls the CRT, the CRT may terminate the
   // process in low-memory conditions."
   uintptr_t handle = _beginthreadex(
-      /*security=*/nullptr, /*stack_size=*/0,
-      static_cast<unsigned int(__stdcall*)(void*)>(
-          [](void* arg) -> unsigned int {
-            ThreadBody(arg);
-            return 0;
-          }),
-      this, /*initflag=*/CREATE_SUSPENDED, /*thrdaddr=*/nullptr);
+      /*security=*/nullptr, /*stack_size=*/0, ThreadBody, this,
+      /*initflag=*/CREATE_SUSPENDED, /*thrdaddr=*/nullptr);
   if (handle == 0) return false;
   handle_ = reinterpret_cast<HANDLE>(handle);
   ResumeThread(handle_);
@@ -171,6 +173,12 @@
   CloseHandle(handle_);
 }
 
+unsigned int ThreadPool::WorkerThread::ThreadBody(void* arg) {
+  auto* thread = static_cast<WorkerThread*>(arg);
+  thread->Run();
+  return 0;
+}
+
 void ThreadPool::WorkerThread::SetupName() {
   // Not currently supported on Windows.
 }
@@ -183,6 +191,12 @@
 
 void ThreadPool::WorkerThread::Join() { pthread_join(thread_, nullptr); }
 
+void* ThreadPool::WorkerThread::ThreadBody(void* arg) {
+  auto* thread = static_cast<WorkerThread*>(arg);
+  thread->Run();
+  return nullptr;
+}
+
 void ThreadPool::WorkerThread::SetupName() {
   if (pool_->name_prefix_[0] != '\0') {
 #if defined(__APPLE__)
@@ -215,11 +229,9 @@
 
 #endif  // defined(_MSC_VER)
 
-void* ThreadPool::WorkerThread::ThreadBody(void* arg) {
-  auto* thread = static_cast<WorkerThread*>(arg);
-  thread->SetupName();
-  thread->pool_->WorkerFunction();
-  return nullptr;
+void ThreadPool::WorkerThread::Run() {
+  SetupName();
+  pool_->WorkerFunction();
 }
 
 bool ThreadPool::StartWorkers() {
diff --git a/libgav1/src/utils/threadpool.h b/libgav1/src/utils/threadpool.h
index 79e7d12..fac875e 100644
--- a/libgav1/src/utils/threadpool.h
+++ b/libgav1/src/utils/threadpool.h
@@ -20,17 +20,28 @@
 #include <functional>
 #include <memory>
 
-#if defined(__ANDROID__)
-#include <condition_variable>  // NOLINT (unapproved c++11 header)
-#include <mutex>               // NOLINT (unapproved c++11 header)
+#if defined(__APPLE__)
+#include <TargetConditionals.h>
+#endif
+
+#if !defined(LIBGAV1_THREADPOOL_USE_STD_MUTEX)
+#if defined(__ANDROID__) || (defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE)
 #define LIBGAV1_THREADPOOL_USE_STD_MUTEX 1
 #else
-// absl::Mutex & absl::CondVar are significantly faster than the pthread
-// variants on platforms other than Android.
-#include "third_party/absl/base/thread_annotations.h"
-#include "third_party/absl/synchronization/mutex.h"
 #define LIBGAV1_THREADPOOL_USE_STD_MUTEX 0
 #endif
+#endif
+
+#if LIBGAV1_THREADPOOL_USE_STD_MUTEX
+#include <condition_variable>  // NOLINT (unapproved c++11 header)
+#include <mutex>               // NOLINT (unapproved c++11 header)
+#else
+// absl::Mutex & absl::CondVar are significantly faster than the pthread
+// variants on platforms other than Android. iOS may deadlock on Shutdown()
+// using absl, see b/142251739.
+#include "absl/base/thread_annotations.h"
+#include "absl/synchronization/mutex.h"
+#endif
 
 #include "src/utils/compiler_attributes.h"
 #include "src/utils/executor.h"
diff --git a/libgav1/src/utils/types.h b/libgav1/src/utils/types.h
index af22fd8..c0ac76c 100644
--- a/libgav1/src/utils/types.h
+++ b/libgav1/src/utils/types.h
@@ -17,6 +17,7 @@
 #ifndef LIBGAV1_SRC_UTILS_TYPES_H_
 #define LIBGAV1_SRC_UTILS_TYPES_H_
 
+#include <array>
 #include <cstdint>
 #include <memory>
 
@@ -27,25 +28,54 @@
 namespace libgav1 {
 
 struct MotionVector : public Allocable {
-  static const int kRow = 0;
-  static const int kColumn = 1;
+  static constexpr int kRow = 0;
+  static constexpr int kColumn = 1;
 
-  bool operator==(const MotionVector& rhs) {
-    return mv[0] == rhs.mv[0] && mv[1] == rhs.mv[1];
+  MotionVector() = default;
+  MotionVector(const MotionVector& mv) = default;
+
+  MotionVector& operator=(const MotionVector& rhs) {
+    mv32 = rhs.mv32;
+    return *this;
   }
 
-  int mv[2];
+  bool operator==(const MotionVector& rhs) const { return mv32 == rhs.mv32; }
+
+  union {
+    // Motion vectors will always fit in int16_t and using int16_t here instead
+    // of int saves significant memory since some of the frame sized structures
+    // store motion vectors.
+    int16_t mv[2];
+    // A uint32_t view into the |mv| array. Useful for cases where both the
+    // motion vectors have to be copied or compared with a single 32 bit
+    // instruction.
+    uint32_t mv32;
+  };
 };
 
-struct CandidateMotionVector {
+union CompoundMotionVector {
+  CompoundMotionVector() = default;
+  CompoundMotionVector(const CompoundMotionVector& mv) = default;
+
+  CompoundMotionVector& operator=(const CompoundMotionVector& rhs) {
+    mv64 = rhs.mv64;
+    return *this;
+  }
+
+  bool operator==(const CompoundMotionVector& rhs) const {
+    return mv64 == rhs.mv64;
+  }
+
   MotionVector mv[2];
-  int weight;
+  // A uint64_t view into the |mv| array. Useful for cases where all the motion
+  // vectors have to be copied or compared with a single 64 bit instruction.
+  uint64_t mv64;
 };
 
-// Stores the motion vector used for motion field estimation.
-struct TemporalMotionVector : public Allocable {
-  MotionVector mv;
-  int reference_offset;
+// Stores the motion information used for motion field estimation.
+struct TemporalMotionField : public Allocable {
+  Array2D<MotionVector> mv;
+  Array2D<int8_t> reference_offset;
 };
 
 // MvContexts contains the contexts used to decode portions of an inter block
@@ -73,6 +103,24 @@
 // the block itself (for example, some of the variables in BlockParameters are
 // used to compute the context for reading elements in the subsequent blocks).
 struct PredictionParameters : public Allocable {
+  // Restore the index in the unsorted mv stack from the least 3 bits of sorted
+  // |weight_index_stack|.
+  const MotionVector& reference_mv(int stack_index) const {
+    return ref_mv_stack[7 - (weight_index_stack[stack_index] & 7)];
+  }
+  const MotionVector& reference_mv(int stack_index, int mv_index) const {
+    return compound_ref_mv_stack[7 - (weight_index_stack[stack_index] & 7)]
+        .mv[mv_index];
+  }
+
+  void IncreaseWeight(ptrdiff_t index, int weight) {
+    weight_index_stack[index] += weight << 3;
+  }
+
+  void SetWeightIndexStackEntry(int index, int weight) {
+    weight_index_stack[index] = (weight << 3) + 7 - index;
+  }
+
   bool use_filter_intra;
   FilterIntraPredictor filter_intra_mode;
   int angle_delta[kNumPlaneTypes];
@@ -89,7 +137,26 @@
   bool mask_is_inverse;
   MotionMode motion_mode;
   CompoundPredictionType compound_prediction_type;
-  CandidateMotionVector ref_mv_stack[kMaxRefMvStackSize];
+  union {
+    // |ref_mv_stack| and |compound_ref_mv_stack| are not sorted after
+    // construction. reference_mv() must be called to get the correct element.
+    MotionVector ref_mv_stack[kMaxRefMvStackSize];
+    CompoundMotionVector compound_ref_mv_stack[kMaxRefMvStackSize];
+  };
+  // The least 3 bits of |weight_index_stack| store the index information, and
+  // the other bits store the weight. The index information is actually 7 -
+  // index to make the descending order sort stable (preserves the original
+  // order for elements with the same weight). Sorting an int16_t array is much
+  // faster than sorting a struct array with weight and index stored separately.
+  int16_t weight_index_stack[kMaxRefMvStackSize];
+  // In the spec, the weights of all the nearest mvs are incremented by a bonus
+  // weight which is larger than any natural weight, and later the weights of
+  // the mvs are compared with this bonus weight to determine their contexts. We
+  // replace this procedure by introducing |nearest_mv_count|, which records the
+  // count of the nearest mvs. Since all the nearest mvs are in the beginning of
+  // the mv stack, the index of a mv in the mv stack can be compared with
+  // |nearest_mv_count| to get that mv's context.
+  int nearest_mv_count;
   int ref_mv_count;
   int ref_mv_index;
   MotionVector global_mv[2];
@@ -102,31 +169,31 @@
 // their types are large enough.
 struct BlockParameters : public Allocable {
   BlockSize size;
-  // segment_id is in the range [0, 7].
-  int8_t segment_id;
-  bool use_predicted_segment_id;  // only valid with temporal update enabled.
   bool skip;
   // True means that this block will use some default settings (that
   // correspond to compound prediction) and so most of the mode info is
   // skipped. False means that the mode info is not skipped.
   bool skip_mode;
   bool is_inter;
+  bool is_explicit_compound_type;  // comp_group_idx in the spec.
+  bool is_compound_type_average;   // compound_idx in the spec.
+  bool is_global_mv_block;
+  bool use_predicted_segment_id;  // only valid with temporal update enabled.
+  int8_t segment_id;              // segment_id is in the range [0, 7].
   PredictionMode y_mode;
   PredictionMode uv_mode;
   TransformSize transform_size;
   TransformSize uv_transform_size;
-  PaletteModeInfo palette_mode_info;
-  ReferenceFrameType reference_frame[2];
-  MotionVector mv[2];
-  bool is_explicit_compound_type;  // comp_group_idx in the spec.
-  bool is_compound_type_average;   // compound_idx in the spec.
   InterpolationFilter interpolation_filter[2];
+  ReferenceFrameType reference_frame[2];
   // The index of this array is as follows:
   //  0 - Y plane vertical filtering.
   //  1 - Y plane horizontal filtering.
   //  2 - U plane (both directions).
   //  3 - V plane (both directions).
   uint8_t deblock_filter_level[kFrameLfCount];
+  CompoundMotionVector mv;
+  PaletteModeInfo palette_mode_info;
   // When |Tile::split_parse_and_decode_| is true, each block gets its own
   // instance of |prediction_parameters|. When it is false, all the blocks point
   // to |Tile::prediction_parameters_|. This field is valid only as long as the
@@ -135,5 +202,324 @@
   std::unique_ptr<PredictionParameters> prediction_parameters;
 };
 
+// A five dimensional array used to store the wedge masks. The dimensions are:
+//   - block_size_index (returned by GetWedgeBlockSizeIndex() in prediction.cc).
+//   - flip_sign (0 or 1).
+//   - wedge_index (0 to 15).
+//   - each of those three dimensions is a 2d array of block_width by
+//     block_height.
+using WedgeMaskArray =
+    std::array<std::array<std::array<Array2D<uint8_t>, 16>, 2>, 9>;
+
+enum GlobalMotionTransformationType : uint8_t {
+  kGlobalMotionTransformationTypeIdentity,
+  kGlobalMotionTransformationTypeTranslation,
+  kGlobalMotionTransformationTypeRotZoom,
+  kGlobalMotionTransformationTypeAffine,
+  kNumGlobalMotionTransformationTypes
+};
+
+// Global motion and warped motion parameters. See the paper for more info:
+// S. Parker, Y. Chen, D. Barker, P. de Rivaz, D. Mukherjee, "Global and locally
+// adaptive warped motion compensation in video compression", Proc. IEEE
+// International Conference on Image Processing (ICIP), pp. 275-279, Sep. 2017.
+struct GlobalMotion {
+  GlobalMotionTransformationType type;
+  int32_t params[6];
+
+  // Represent two shearing operations. Computed from |params| by SetupShear().
+  //
+  // The least significant six (= kWarpParamRoundingBits) bits are all zeros.
+  // (This means alpha, beta, gamma, and delta could be represented by a 10-bit
+  // signed integer.) The minimum value is INT16_MIN (= -32768) and the maximum
+  // value is 32704 = 0x7fc0, the largest int16_t value whose least significant
+  // six bits are all zeros.
+  //
+  // Valid warp parameters (as validated by SetupShear()) have smaller ranges.
+  // Their absolute values are less than 2^14 (= 16384). (This follows from
+  // the warpValid check at the end of Section 7.11.3.6.)
+  //
+  // NOTE: Section 7.11.3.6 of the spec allows a maximum value of 32768, which
+  // is outside the range of int16_t. When cast to int16_t, 32768 becomes
+  // -32768. This potential int16_t overflow does not matter because either
+  // 32768 or -32768 causes SetupShear() to return false,
+  int16_t alpha;
+  int16_t beta;
+  int16_t gamma;
+  int16_t delta;
+};
+
+// Loop filter parameters:
+//
+// If level[0] and level[1] are both equal to 0, the loop filter process is
+// not invoked.
+//
+// |sharpness| and |delta_enabled| are only used by the loop filter process.
+//
+// The |ref_deltas| and |mode_deltas| arrays are used not only by the loop
+// filter process but also by the reference frame update and loading
+// processes. The loop filter process uses |ref_deltas| and |mode_deltas| only
+// when |delta_enabled| is true.
+struct LoopFilter {
+  // Contains loop filter strength values in the range of [0, 63].
+  std::array<int8_t, kFrameLfCount> level;
+  // Indicates the sharpness level in the range of [0, 7].
+  int8_t sharpness;
+  // Whether the filter level depends on the mode and reference frame used to
+  // predict a block.
+  bool delta_enabled;
+  // Whether additional syntax elements were read that specify which mode and
+  // reference frame deltas are to be updated. loop_filter_delta_update field in
+  // Section 5.9.11 of the spec.
+  bool delta_update;
+  // Contains the adjustment needed for the filter level based on the chosen
+  // reference frame, in the range of [-64, 63].
+  std::array<int8_t, kNumReferenceFrameTypes> ref_deltas;
+  // Contains the adjustment needed for the filter level based on the chosen
+  // mode, in the range of [-64, 63].
+  std::array<int8_t, kLoopFilterMaxModeDeltas> mode_deltas;
+};
+
+struct Delta {
+  bool present;
+  uint8_t scale;
+  bool multi;
+};
+
+struct Cdef {
+  uint8_t damping;  // damping value from the spec + (bitdepth - 8).
+  uint8_t bits;
+  // All the strength values are the values from the spec and left shifted by
+  // (bitdepth - 8).
+  uint8_t y_primary_strength[kMaxCdefStrengths];
+  uint8_t y_secondary_strength[kMaxCdefStrengths];
+  uint8_t uv_primary_strength[kMaxCdefStrengths];
+  uint8_t uv_secondary_strength[kMaxCdefStrengths];
+};
+
+struct TileInfo {
+  bool uniform_spacing;
+  int sb_rows;
+  int sb_columns;
+  int tile_count;
+  int tile_columns_log2;
+  int tile_columns;
+  int tile_column_start[kMaxTileColumns + 1];
+  // This field is not used by libgav1, but is populated for use by some
+  // hardware decoders. So it must not be removed.
+  int tile_column_width_in_superblocks[kMaxTileColumns + 1];
+  int tile_rows_log2;
+  int tile_rows;
+  int tile_row_start[kMaxTileRows + 1];
+  // This field is not used by libgav1, but is populated for use by some
+  // hardware decoders. So it must not be removed.
+  int tile_row_height_in_superblocks[kMaxTileRows + 1];
+  int16_t context_update_id;
+  uint8_t tile_size_bytes;
+};
+
+struct LoopRestoration {
+  LoopRestorationType type[kMaxPlanes];
+  int unit_size[kMaxPlanes];
+};
+
+// Stores the quantization parameters of Section 5.9.12.
+struct QuantizerParameters {
+  // base_index is in the range [0, 255].
+  uint8_t base_index;
+  int8_t delta_dc[kMaxPlanes];
+  // delta_ac[kPlaneY] is always 0.
+  int8_t delta_ac[kMaxPlanes];
+  bool use_matrix;
+  // The |matrix_level| array is used only when |use_matrix| is true.
+  // matrix_level[plane] specifies the level in the quantizer matrix that
+  // should be used for decoding |plane|. The quantizer matrix has 15 levels,
+  // from 0 to 14. The range of matrix_level[plane] is [0, 15]. If
+  // matrix_level[plane] is 15, the quantizer matrix is not used.
+  int8_t matrix_level[kMaxPlanes];
+};
+
+// The corresponding segment feature constants in the AV1 spec are named
+// SEG_LVL_xxx.
+enum SegmentFeature : uint8_t {
+  kSegmentFeatureQuantizer,
+  kSegmentFeatureLoopFilterYVertical,
+  kSegmentFeatureLoopFilterYHorizontal,
+  kSegmentFeatureLoopFilterU,
+  kSegmentFeatureLoopFilterV,
+  kSegmentFeatureReferenceFrame,
+  kSegmentFeatureSkip,
+  kSegmentFeatureGlobalMv,
+  kSegmentFeatureMax
+};
+
+struct Segmentation {
+  // 5.11.14.
+  // Returns true if the feature is enabled in the segment.
+  bool FeatureActive(int segment_id, SegmentFeature feature) const {
+    return enabled && segment_id < kMaxSegments &&
+           feature_enabled[segment_id][feature];
+  }
+
+  // Returns true if the feature is signed.
+  static bool FeatureSigned(SegmentFeature feature) {
+    // Only the first five segment features are signed, so this comparison
+    // suffices.
+    return feature <= kSegmentFeatureLoopFilterV;
+  }
+
+  bool enabled;
+  bool update_map;
+  bool update_data;
+  bool temporal_update;
+  // True if the segment id will be read before the skip syntax element. False
+  // if the skip syntax element will be read first.
+  bool segment_id_pre_skip;
+  // The highest numbered segment id that has some enabled feature. Used as
+  // the upper bound for decoding segment ids.
+  int8_t last_active_segment_id;
+
+  bool feature_enabled[kMaxSegments][kSegmentFeatureMax];
+  int16_t feature_data[kMaxSegments][kSegmentFeatureMax];
+  bool lossless[kMaxSegments];
+  // Cached values of get_qindex(1, segmentId), to be consumed by
+  // Tile::ReadTransformType(). The values are in the range [0, 255].
+  uint8_t qindex[kMaxSegments];
+};
+
+// Section 6.8.20.
+// Note: In spec, film grain section uses YCbCr to denote variable names,
+// such as num_cb_points, num_cr_points. To keep it consistent with other
+// parts of code, we use YUV, i.e., num_u_points, num_v_points, etc.
+struct FilmGrainParams {
+  bool apply_grain;
+  bool update_grain;
+  bool chroma_scaling_from_luma;
+  bool overlap_flag;
+  bool clip_to_restricted_range;
+
+  uint8_t num_y_points;  // [0, 14].
+  uint8_t num_u_points;  // [0, 10].
+  uint8_t num_v_points;  // [0, 10].
+  // Must be [0, 255]. 10/12 bit /= 4 or 16. Must be in increasing order.
+  uint8_t point_y_value[14];
+  uint8_t point_y_scaling[14];
+  uint8_t point_u_value[10];
+  uint8_t point_u_scaling[10];
+  uint8_t point_v_value[10];
+  uint8_t point_v_scaling[10];
+
+  uint8_t chroma_scaling;              // [8, 11].
+  uint8_t auto_regression_coeff_lag;   // [0, 3].
+  int8_t auto_regression_coeff_y[24];  // [-128, 127]
+  int8_t auto_regression_coeff_u[25];  // [-128, 127]
+  int8_t auto_regression_coeff_v[25];  // [-128, 127]
+  // Shift value: auto regression coeffs range
+  // 6: [-2, 2)
+  // 7: [-1, 1)
+  // 8: [-0.5, 0.5)
+  // 9: [-0.25, 0.25)
+  uint8_t auto_regression_shift;
+
+  uint16_t grain_seed;
+  int reference_index;
+  int grain_scale_shift;
+  // These multipliers are encoded as nonnegative values by adding 128 first.
+  // The 128 is subtracted during parsing.
+  int8_t u_multiplier;       // [-128, 127]
+  int8_t u_luma_multiplier;  // [-128, 127]
+  // These offsets are encoded as nonnegative values by adding 256 first. The
+  // 256 is subtracted during parsing.
+  int16_t u_offset;          // [-256, 255]
+  int8_t v_multiplier;       // [-128, 127]
+  int8_t v_luma_multiplier;  // [-128, 127]
+  int16_t v_offset;          // [-256, 255]
+};
+
+struct ObuFrameHeader {
+  uint16_t display_frame_id;
+  uint16_t current_frame_id;
+  int64_t frame_offset;
+  uint16_t expected_frame_id[kNumInterReferenceFrameTypes];
+  int32_t width;
+  int32_t height;
+  int32_t columns4x4;
+  int32_t rows4x4;
+  // The render size (render_width and render_height) is a hint to the
+  // application about the desired display size. It has no effect on the
+  // decoding process.
+  int32_t render_width;
+  int32_t render_height;
+  int32_t upscaled_width;
+  LoopRestoration loop_restoration;
+  uint32_t buffer_removal_time[kMaxOperatingPoints];
+  uint32_t frame_presentation_time;
+  // Note: global_motion[0] (for kReferenceFrameIntra) is not used.
+  std::array<GlobalMotion, kNumReferenceFrameTypes> global_motion;
+  TileInfo tile_info;
+  QuantizerParameters quantizer;
+  Segmentation segmentation;
+  bool show_existing_frame;
+  // frame_to_show is in the range [0, 7]. Only used if show_existing_frame is
+  // true.
+  int8_t frame_to_show;
+  FrameType frame_type;
+  bool show_frame;
+  bool showable_frame;
+  bool error_resilient_mode;
+  bool enable_cdf_update;
+  bool frame_size_override_flag;
+  // The order_hint syntax element in the uncompressed header. If
+  // show_existing_frame is false, the OrderHint variable in the spec is equal
+  // to this field, and so this field can be used in place of OrderHint when
+  // show_existing_frame is known to be false, such as during tile decoding.
+  uint8_t order_hint;
+  int8_t primary_reference_frame;
+  bool render_and_frame_size_different;
+  bool use_superres;
+  uint8_t superres_scale_denominator;
+  bool allow_screen_content_tools;
+  bool allow_intrabc;
+  bool frame_refs_short_signaling;
+  // A bitmask that specifies which reference frame slots will be updated with
+  // the current frame after it is decoded.
+  uint8_t refresh_frame_flags;
+  static_assert(sizeof(ObuFrameHeader::refresh_frame_flags) * 8 ==
+                    kNumReferenceFrameTypes,
+                "");
+  bool found_reference;
+  int8_t force_integer_mv;
+  bool allow_high_precision_mv;
+  InterpolationFilter interpolation_filter;
+  bool is_motion_mode_switchable;
+  bool use_ref_frame_mvs;
+  bool enable_frame_end_update_cdf;
+  // True if all segments are losslessly encoded at the coded resolution.
+  bool coded_lossless;
+  // True if all segments are losslessly encoded at the upscaled resolution.
+  bool upscaled_lossless;
+  TxMode tx_mode;
+  // True means that the mode info for inter blocks contains the syntax
+  // element comp_mode that indicates whether to use single or compound
+  // prediction. False means that all inter blocks will use single prediction.
+  bool reference_mode_select;
+  // The frames to use for compound prediction when skip_mode is true.
+  ReferenceFrameType skip_mode_frame[2];
+  bool skip_mode_present;
+  bool reduced_tx_set;
+  bool allow_warped_motion;
+  Delta delta_q;
+  Delta delta_lf;
+  // A valid value of reference_frame_index[i] is in the range [0, 7]. -1
+  // indicates an invalid value.
+  int8_t reference_frame_index[kNumInterReferenceFrameTypes];
+  // The ref_order_hint[ i ] syntax element in the uncompressed header.
+  // Specifies the expected output order hint for each reference frame.
+  uint8_t reference_order_hint[kNumReferenceFrameTypes];
+  LoopFilter loop_filter;
+  Cdef cdef;
+  FilmGrainParams film_grain_params;
+};
+
 }  // namespace libgav1
 #endif  // LIBGAV1_SRC_UTILS_TYPES_H_
diff --git a/libgav1/src/utils/unbounded_queue.h b/libgav1/src/utils/unbounded_queue.h
index d8cc0fb..fa0d303 100644
--- a/libgav1/src/utils/unbounded_queue.h
+++ b/libgav1/src/utils/unbounded_queue.h
@@ -235,8 +235,10 @@
   size_t back_ = 0;
 };
 
+#if !LIBGAV1_CXX17
 template <typename T>
 constexpr size_t UnboundedQueue<T>::kBlockCapacity;
+#endif
 
 }  // namespace libgav1
 
diff --git a/libgav1/src/version.cc b/libgav1/src/version.cc
new file mode 100644
index 0000000..8d1e5a9
--- /dev/null
+++ b/libgav1/src/version.cc
@@ -0,0 +1,39 @@
+// Copyright 2019 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/gav1/version.h"
+
+#define LIBGAV1_TOSTRING(x) #x
+#define LIBGAV1_STRINGIFY(x) LIBGAV1_TOSTRING(x)
+#define LIBGAV1_DOT_SEPARATED(M, m, p) M##.##m##.##p
+#define LIBGAV1_DOT_SEPARATED_VERSION(M, m, p) LIBGAV1_DOT_SEPARATED(M, m, p)
+#define LIBGAV1_DOT_VERSION                                                   \
+  LIBGAV1_DOT_SEPARATED_VERSION(LIBGAV1_MAJOR_VERSION, LIBGAV1_MINOR_VERSION, \
+                                LIBGAV1_PATCH_VERSION)
+
+#define LIBGAV1_VERSION_STRING LIBGAV1_STRINGIFY(LIBGAV1_DOT_VERSION)
+
+extern "C" {
+
+int Libgav1GetVersion() { return LIBGAV1_VERSION; }
+const char* Libgav1GetVersionString() { return LIBGAV1_VERSION_STRING; }
+
+const char* Libgav1GetBuildConfiguration() {
+  // TODO(jzern): cmake can generate the detail or in other cases we could
+  // produce one based on the known defines along with the defaults based on
+  // the toolchain, e.g., LIBGAV1_ENABLE_NEON from cpu.h.
+  return "Not available.";
+}
+
+}  // extern "C"
diff --git a/libgav1/src/warp_prediction.cc b/libgav1/src/warp_prediction.cc
index 79556f0..dd06317 100644
--- a/libgav1/src/warp_prediction.cc
+++ b/libgav1/src/warp_prediction.cc
@@ -21,6 +21,7 @@
 #include "src/tile.h"
 #include "src/utils/block_parameters_holder.h"
 #include "src/utils/common.h"
+#include "src/utils/constants.h"
 #include "src/utils/logging.h"
 
 namespace libgav1 {
@@ -79,25 +80,21 @@
 int LeastSquareProduct(int a, int b) { return ((a * b) >> 2) + a + b; }
 
 // 7.11.3.8.
-int DiagonalClamp(int64_t value, int16_t division_factor,
-                  int16_t division_shift) {
-  return Clip3(
-      RightShiftWithRoundingSigned(value * division_factor, division_shift),
-      (1 << kWarpedModelPrecisionBits) - kWarpModelAffineClamp + 1,
-      (1 << kWarpedModelPrecisionBits) + kWarpModelAffineClamp - 1);
+int DiagonalClamp(int32_t value) {
+  return Clip3(value,
+               (1 << kWarpedModelPrecisionBits) - kWarpModelAffineClamp + 1,
+               (1 << kWarpedModelPrecisionBits) + kWarpModelAffineClamp - 1);
 }
 
 // 7.11.3.8.
-int NonDiagonalClamp(int64_t value, int16_t division_factor,
-                     int16_t division_shift) {
-  return Clip3(
-      RightShiftWithRoundingSigned(value * division_factor, division_shift),
-      -kWarpModelAffineClamp + 1, kWarpModelAffineClamp - 1);
+int NonDiagonalClamp(int32_t value) {
+  return Clip3(value, -kWarpModelAffineClamp + 1, kWarpModelAffineClamp - 1);
 }
 
 int16_t GetShearParameter(int value) {
   return static_cast<int16_t>(
-      LeftShift(RightShiftWithRoundingSigned(value, kWarpParamRoundingBits),
+      LeftShift(RightShiftWithRoundingSigned(Clip3(value, INT16_MIN, INT16_MAX),
+                                             kWarpParamRoundingBits),
                 kWarpParamRoundingBits));
 }
 
@@ -109,19 +106,16 @@
   const auto* const params = warp_params->params;
   GenerateApproximateDivisor<int32_t>(params[2], &division_factor,
                                       &division_shift);
-  const int alpha =
-      Clip3(params[2] - (1 << kWarpedModelPrecisionBits), INT16_MIN, INT16_MAX);
-  const int beta = Clip3(params[3], INT16_MIN, INT16_MAX);
+  const int alpha = params[2] - (1 << kWarpedModelPrecisionBits);
+  const int beta = params[3];
   const int64_t v = LeftShift(params[4], kWarpedModelPrecisionBits);
   const int gamma =
-      Clip3(RightShiftWithRoundingSigned(v * division_factor, division_shift),
-            INT16_MIN, INT16_MAX);
+      RightShiftWithRoundingSigned(v * division_factor, division_shift);
   const int64_t w = static_cast<int64_t>(params[3]) * params[4];
-  const int delta = Clip3(
+  const int delta =
       params[5] -
-          RightShiftWithRoundingSigned(w * division_factor, division_shift) -
-          (1 << kWarpedModelPrecisionBits),
-      INT16_MIN, INT16_MAX);
+      RightShiftWithRoundingSigned(w * division_factor, division_shift) -
+      (1 << kWarpedModelPrecisionBits);
 
   warp_params->alpha = GetShearParameter(alpha);
   warp_params->beta = GetShearParameter(beta);
@@ -200,20 +194,34 @@
                                       &division_shift);
 
   division_shift -= kWarpedModelPrecisionBits;
-  if (division_shift < 0) {
+
+  const int64_t params_2 = a[1][1] * bx[0] - a[0][1] * bx[1];
+  const int64_t params_3 = -a[0][1] * bx[0] + a[0][0] * bx[1];
+  const int64_t params_4 = a[1][1] * by[0] - a[0][1] * by[1];
+  const int64_t params_5 = -a[0][1] * by[0] + a[0][0] * by[1];
+  auto* const params = warp_params->params;
+
+  if (division_shift <= 0) {
     division_factor <<= -division_shift;
-    division_shift = 0;
+    params[2] = static_cast<int32_t>(params_2) * division_factor;
+    params[3] = static_cast<int32_t>(params_3) * division_factor;
+    params[4] = static_cast<int32_t>(params_4) * division_factor;
+    params[5] = static_cast<int32_t>(params_5) * division_factor;
+  } else {
+    params[2] = RightShiftWithRoundingSigned(params_2 * division_factor,
+                                             division_shift);
+    params[3] = RightShiftWithRoundingSigned(params_3 * division_factor,
+                                             division_shift);
+    params[4] = RightShiftWithRoundingSigned(params_4 * division_factor,
+                                             division_shift);
+    params[5] = RightShiftWithRoundingSigned(params_5 * division_factor,
+                                             division_shift);
   }
 
-  auto* const params = warp_params->params;
-  params[2] = DiagonalClamp(a[1][1] * bx[0] - a[0][1] * bx[1], division_factor,
-                            division_shift);
-  params[3] = NonDiagonalClamp(-a[0][1] * bx[0] + a[0][0] * bx[1],
-                               division_factor, division_shift);
-  params[4] = NonDiagonalClamp(a[1][1] * by[0] - a[0][1] * by[1],
-                               division_factor, division_shift);
-  params[5] = DiagonalClamp(-a[0][1] * by[0] + a[0][0] * by[1], division_factor,
-                            division_shift);
+  params[2] = DiagonalClamp(params[2]);
+  params[3] = NonDiagonalClamp(params[3]);
+  params[4] = NonDiagonalClamp(params[4]);
+  params[5] = DiagonalClamp(params[5]);
 
   const int vx =
       mv.mv[MotionVector::kColumn] * (1 << (kWarpedModelPrecisionBits - 3)) -
diff --git a/libgav1/src/yuv_buffer.cc b/libgav1/src/yuv_buffer.cc
index 2cb2346..c74e140 100644
--- a/libgav1/src/yuv_buffer.cc
+++ b/libgav1/src/yuv_buffer.cc
@@ -16,181 +16,184 @@
 
 #include <cassert>
 #include <cstddef>
+#include <new>
 
+#include "src/frame_buffer_utils.h"
 #include "src/utils/common.h"
 #include "src/utils/logging.h"
-#include "src/utils/memory.h"
 
 namespace libgav1 {
-namespace {
-
-// |align| must be a power of 2.
-uint8_t* AlignAddr(uint8_t* const addr, const size_t align) {
-  const auto value = reinterpret_cast<size_t>(addr);
-  return reinterpret_cast<uint8_t*>(Align(value, align));
-}
-
-}  // namespace
-
-YuvBuffer::~YuvBuffer() { AlignedFree(buffer_alloc_); }
 
 // Size conventions:
 // * Widths, heights, and border sizes are in pixels.
 // * Strides and plane sizes are in bytes.
+//
+// YuvBuffer objects may be reused through the BufferPool. Realloc() must
+// assume that data members (except buffer_alloc_ and buffer_alloc_size_) may
+// contain stale values from the previous use, and must set all data members
+// from scratch. In particular, Realloc() must not rely on the initial values
+// of data members set by the YuvBuffer constructor.
 bool YuvBuffer::Realloc(int bitdepth, bool is_monochrome, int width, int height,
-                        int8_t subsampling_x, int8_t subsampling_y, int border,
-                        int byte_alignment,
+                        int8_t subsampling_x, int8_t subsampling_y,
+                        int left_border, int right_border, int top_border,
+                        int bottom_border,
                         GetFrameBufferCallback get_frame_buffer,
-                        void* private_data, FrameBuffer* frame_buffer) {
-  // Only support allocating buffers that have a border that's a multiple of
-  // 32. The border restriction is required to get 16-byte alignment of the
-  // start of the chroma rows.
-  if ((border & 31) != 0) return false;
+                        void* callback_private_data,
+                        void** buffer_private_data) {
+  // Only support allocating buffers that have borders that are a multiple of
+  // 2. The border restriction is required because we may subsample the
+  // borders in the chroma planes.
+  if (((left_border | right_border | top_border | bottom_border) & 1) != 0) {
+    LIBGAV1_DLOG(ERROR,
+                 "Borders must be a multiple of 2: left_border = %d, "
+                 "right_border = %d, top_border = %d, bottom_border = %d.",
+                 left_border, right_border, top_border, bottom_border);
+    return false;
+  }
 
-  assert(byte_alignment == 0 || byte_alignment >= 16);
-  const int byte_align = (byte_alignment == 0) ? 1 : byte_alignment;
-  // byte_align must be a power of 2.
-  assert((byte_align & (byte_align - 1)) == 0);
-
-  // aligned_width and aligned_height are width and height padded to a
-  // multiple of 8 pixels.
-  const int aligned_width = Align(width, 8);
-  const int aligned_height = Align(height, 8);
-
-  // Calculate y_stride (in bytes). It is padded to a multiple of 16 bytes.
-  int y_stride = aligned_width + 2 * border;
-#if LIBGAV1_MAX_BITDEPTH >= 10
-  if (bitdepth > 8) y_stride *= sizeof(uint16_t);
-#endif
-  y_stride = Align(y_stride, 16);
-  // Size of the Y plane in bytes.
-  const uint64_t y_plane_size =
-      (aligned_height + 2 * border) * static_cast<uint64_t>(y_stride) +
-      byte_alignment;
-  assert((y_plane_size & 15) == 0);
-
-  const int uv_width = aligned_width >> subsampling_x;
-  const int uv_height = aligned_height >> subsampling_y;
-  const int uv_border_width = border >> subsampling_x;
-  const int uv_border_height = border >> subsampling_y;
-
-  // Calculate uv_stride (in bytes). It is padded to a multiple of 16 bytes.
-  int uv_stride = uv_width + 2 * uv_border_width;
-#if LIBGAV1_MAX_BITDEPTH >= 10
-  if (bitdepth > 8) uv_stride *= sizeof(uint16_t);
-#endif
-  uv_stride = Align(uv_stride, 16);
-  // Size of the U or V plane in bytes.
-  const uint64_t uv_plane_size =
-      (uv_height + 2 * uv_border_height) * static_cast<uint64_t>(uv_stride) +
-      byte_alignment;
-  assert((uv_plane_size & 15) == 0);
-
-  const uint64_t frame_size = y_plane_size + 2 * uv_plane_size;
-
-  // Allocate y_buffer, u_buffer, and v_buffer with 16-byte alignment.
-  uint8_t* y_buffer = nullptr;
-  uint8_t* u_buffer = nullptr;
-  uint8_t* v_buffer = nullptr;
+  // Every row in the plane buffers needs to be kFrameBufferRowAlignment-byte
+  // aligned. Since the strides are multiples of kFrameBufferRowAlignment bytes,
+  // it suffices to just make the plane buffers kFrameBufferRowAlignment-byte
+  // aligned.
+  const int plane_align = kFrameBufferRowAlignment;
+  const int uv_width =
+      is_monochrome ? 0 : SubsampledValue(width, subsampling_x);
+  const int uv_height =
+      is_monochrome ? 0 : SubsampledValue(height, subsampling_y);
+  const int uv_left_border = is_monochrome ? 0 : left_border >> subsampling_x;
+  const int uv_right_border = is_monochrome ? 0 : right_border >> subsampling_x;
+  const int uv_top_border = is_monochrome ? 0 : top_border >> subsampling_y;
+  const int uv_bottom_border =
+      is_monochrome ? 0 : bottom_border >> subsampling_y;
 
   if (get_frame_buffer != nullptr) {
-    // |get_frame_buffer| allocates unaligned memory. Ask it to allocate 15
-    // extra bytes so we can align the buffers to 16-byte boundaries.
-    const int align_addr_extra_size = 15;
-    const uint64_t external_y_plane_size = y_plane_size + align_addr_extra_size;
-    const uint64_t external_uv_plane_size =
-        uv_plane_size + align_addr_extra_size;
+    assert(buffer_private_data != nullptr);
 
-    assert(frame_buffer != nullptr);
-
-    if (external_y_plane_size != static_cast<size_t>(external_y_plane_size) ||
-        external_uv_plane_size != static_cast<size_t>(external_uv_plane_size)) {
+    const Libgav1ImageFormat image_format =
+        ComposeImageFormat(is_monochrome, subsampling_x, subsampling_y);
+    FrameBuffer frame_buffer;
+    if (get_frame_buffer(callback_private_data, bitdepth, image_format, width,
+                         height, left_border, right_border, top_border,
+                         bottom_border, kFrameBufferRowAlignment,
+                         &frame_buffer) != kStatusOk) {
       return false;
     }
 
-    // Allocation to hold larger frame, or first allocation.
-    if (get_frame_buffer(
-            private_data, static_cast<size_t>(external_y_plane_size),
-            static_cast<size_t>(external_uv_plane_size), frame_buffer) < 0) {
-      return false;
-    }
-
-    if (frame_buffer->data[0] == nullptr ||
-        frame_buffer->size[0] < external_y_plane_size ||
-        frame_buffer->data[1] == nullptr ||
-        frame_buffer->size[1] < external_uv_plane_size ||
-        frame_buffer->data[2] == nullptr ||
-        frame_buffer->size[2] < external_uv_plane_size) {
-      assert(0 && "The get_frame_buffer callback malfunctioned.");
+    if (frame_buffer.plane[0] == nullptr ||
+        (!is_monochrome && frame_buffer.plane[1] == nullptr) ||
+        (!is_monochrome && frame_buffer.plane[2] == nullptr)) {
+      assert(false && "The get_frame_buffer callback malfunctioned.");
       LIBGAV1_DLOG(ERROR, "The get_frame_buffer callback malfunctioned.");
       return false;
     }
 
-    y_buffer = AlignAddr(frame_buffer->data[0], 16);
-    u_buffer = AlignAddr(frame_buffer->data[1], 16);
-    v_buffer = AlignAddr(frame_buffer->data[2], 16);
+    stride_[kPlaneY] = frame_buffer.stride[0];
+    stride_[kPlaneU] = frame_buffer.stride[1];
+    stride_[kPlaneV] = frame_buffer.stride[2];
+    buffer_[kPlaneY] = frame_buffer.plane[0];
+    buffer_[kPlaneU] = frame_buffer.plane[1];
+    buffer_[kPlaneV] = frame_buffer.plane[2];
+    *buffer_private_data = frame_buffer.private_data;
   } else {
-    assert(private_data == nullptr);
-    assert(frame_buffer == nullptr);
+    assert(callback_private_data == nullptr);
+    assert(buffer_private_data == nullptr);
 
+    // Calculate y_stride (in bytes). It is padded to a multiple of
+    // kFrameBufferRowAlignment bytes.
+    int y_stride = width + left_border + right_border;
+#if LIBGAV1_MAX_BITDEPTH >= 10
+    if (bitdepth > 8) y_stride *= sizeof(uint16_t);
+#endif
+    y_stride = Align(y_stride, kFrameBufferRowAlignment);
+    // Size of the Y plane in bytes.
+    const uint64_t y_plane_size = (height + top_border + bottom_border) *
+                                      static_cast<uint64_t>(y_stride) +
+                                  (plane_align - 1);
+
+    // Calculate uv_stride (in bytes). It is padded to a multiple of
+    // kFrameBufferRowAlignment bytes.
+    int uv_stride = uv_width + uv_left_border + uv_right_border;
+#if LIBGAV1_MAX_BITDEPTH >= 10
+    if (bitdepth > 8) uv_stride *= sizeof(uint16_t);
+#endif
+    uv_stride = Align(uv_stride, kFrameBufferRowAlignment);
+    // Size of the U or V plane in bytes.
+    const uint64_t uv_plane_size =
+        is_monochrome ? 0
+                      : (uv_height + uv_top_border + uv_bottom_border) *
+                                static_cast<uint64_t>(uv_stride) +
+                            (plane_align - 1);
+
+    // Allocate unaligned y_buffer, u_buffer, and v_buffer.
+    uint8_t* y_buffer = nullptr;
+    uint8_t* u_buffer = nullptr;
+    uint8_t* v_buffer = nullptr;
+
+    const uint64_t frame_size = y_plane_size + 2 * uv_plane_size;
     if (frame_size > buffer_alloc_size_) {
       // Allocation to hold larger frame, or first allocation.
-      AlignedFree(buffer_alloc_);
-      buffer_alloc_ = nullptr;
-
       if (frame_size != static_cast<size_t>(frame_size)) return false;
 
-      buffer_alloc_ = static_cast<uint8_t*>(
-          AlignedAlloc(16, static_cast<size_t>(frame_size)));
-      if (buffer_alloc_ == nullptr) return false;
+      buffer_alloc_.reset(new (std::nothrow)
+                              uint8_t[static_cast<size_t>(frame_size)]);
+      if (buffer_alloc_ == nullptr) {
+        buffer_alloc_size_ = 0;
+        return false;
+      }
 
       buffer_alloc_size_ = static_cast<size_t>(frame_size);
     }
 
-    y_buffer = buffer_alloc_;
-    u_buffer = buffer_alloc_ + y_plane_size;
-    v_buffer = buffer_alloc_ + y_plane_size + uv_plane_size;
+    y_buffer = buffer_alloc_.get();
+    if (!is_monochrome) {
+      u_buffer = y_buffer + y_plane_size;
+      v_buffer = u_buffer + uv_plane_size;
+    }
+
+    stride_[kPlaneY] = y_stride;
+    stride_[kPlaneU] = stride_[kPlaneV] = uv_stride;
+
+    int left_border_bytes = left_border;
+    int uv_left_border_bytes = uv_left_border;
+#if LIBGAV1_MAX_BITDEPTH >= 10
+    if (bitdepth > 8) {
+      left_border_bytes *= sizeof(uint16_t);
+      uv_left_border_bytes *= sizeof(uint16_t);
+    }
+#endif
+    buffer_[kPlaneY] = AlignAddr(
+        y_buffer + (top_border * y_stride) + left_border_bytes, plane_align);
+    buffer_[kPlaneU] =
+        AlignAddr(u_buffer + (uv_top_border * uv_stride) + uv_left_border_bytes,
+                  plane_align);
+    buffer_[kPlaneV] =
+        AlignAddr(v_buffer + (uv_top_border * uv_stride) + uv_left_border_bytes,
+                  plane_align);
   }
 
-  y_crop_width_ = width;
-  y_crop_height_ = height;
-  y_width_ = aligned_width;
-  y_height_ = aligned_height;
-  stride_[kPlaneY] = y_stride;
-  left_border_[kPlaneY] = right_border_[kPlaneY] = top_border_[kPlaneY] =
-      bottom_border_[kPlaneY] = border;
+  y_width_ = width;
+  y_height_ = height;
+  left_border_[kPlaneY] = left_border;
+  right_border_[kPlaneY] = right_border;
+  top_border_[kPlaneY] = top_border;
+  bottom_border_[kPlaneY] = bottom_border;
 
-  uv_crop_width_ = (width + subsampling_x) >> subsampling_x;
-  uv_crop_height_ = (height + subsampling_y) >> subsampling_y;
   uv_width_ = uv_width;
   uv_height_ = uv_height;
-  stride_[kPlaneU] = stride_[kPlaneV] = uv_stride;
-  left_border_[kPlaneU] = right_border_[kPlaneU] = uv_border_width;
-  top_border_[kPlaneU] = bottom_border_[kPlaneU] = uv_border_height;
-  left_border_[kPlaneV] = right_border_[kPlaneV] = uv_border_width;
-  top_border_[kPlaneV] = bottom_border_[kPlaneV] = uv_border_height;
+  left_border_[kPlaneU] = left_border_[kPlaneV] = uv_left_border;
+  right_border_[kPlaneU] = right_border_[kPlaneV] = uv_right_border;
+  top_border_[kPlaneU] = top_border_[kPlaneV] = uv_top_border;
+  bottom_border_[kPlaneU] = bottom_border_[kPlaneV] = uv_bottom_border;
 
   subsampling_x_ = subsampling_x;
   subsampling_y_ = subsampling_y;
 
   bitdepth_ = bitdepth;
   is_monochrome_ = is_monochrome;
-  int border_bytes = border;
-  int uv_border_width_bytes = uv_border_width;
-#if LIBGAV1_MAX_BITDEPTH >= 10
-  if (bitdepth > 8) {
-    border_bytes *= sizeof(uint16_t);
-    uv_border_width_bytes *= sizeof(uint16_t);
-  }
-#endif
-  buffer_[kPlaneY] =
-      AlignAddr(y_buffer + (border * y_stride) + border_bytes, byte_align);
-  buffer_[kPlaneU] = AlignAddr(
-      u_buffer + (uv_border_height * uv_stride) + uv_border_width_bytes,
-      byte_align);
-  buffer_[kPlaneV] = AlignAddr(
-      v_buffer + (uv_border_height * uv_stride) + uv_border_width_bytes,
-      byte_align);
+  assert(!is_monochrome || stride_[kPlaneU] == 0);
+  assert(!is_monochrome || stride_[kPlaneV] == 0);
+  assert(!is_monochrome || buffer_[kPlaneU] == nullptr);
+  assert(!is_monochrome || buffer_[kPlaneV] == nullptr);
 
   return true;
 }
diff --git a/libgav1/src/yuv_buffer.h b/libgav1/src/yuv_buffer.h
index e0ea980..b9e8cd3 100644
--- a/libgav1/src/yuv_buffer.h
+++ b/libgav1/src/yuv_buffer.h
@@ -17,90 +17,78 @@
 #ifndef LIBGAV1_SRC_YUV_BUFFER_H_
 #define LIBGAV1_SRC_YUV_BUFFER_H_
 
+#include <cassert>
 #include <cstddef>
 #include <cstdint>
+#include <memory>
+#include <type_traits>
 
-#include "src/frame_buffer.h"
+#include "src/gav1/frame_buffer.h"
 #include "src/utils/constants.h"
 
 namespace libgav1 {
 
 class YuvBuffer {
  public:
-  // If the memory was allocated by YuvBuffer directly, the memory is freed.
-  ~YuvBuffer();
-
   // Allocates the buffer. Returns true on success. Returns false on failure.
   //
   // * |width| and |height| are the image dimensions in pixels.
   // * |subsampling_x| and |subsampling_y| (either 0 or 1) specify the
   //   subsampling of the width and height of the chroma planes, respectively.
-  // * |border| is the size of the borders (on all four sides) in pixels.
-  // * |byte_alignment| specifies the additional alignment requirement of the
-  //   data buffers of the Y, U, and V planes. If |byte_alignment| is 0, there
-  //   is no additional alignment requirement. Otherwise, |byte_alignment|
-  //   must be a power of 2 and greater than or equal to 16.
-  //   NOTE: The strides are a multiple of 16. Therefore only the first row in
-  //   each plane is aligned to |byte_alignment|. Subsequent rows are only
-  //   16-byte aligned.
+  // * |left_border|, |right_border|, |top_border|, and |bottom_border| are
+  //   the sizes (in pixels) of the borders on the left, right, top, and
+  //   bottom sides, respectively. The four border sizes must all be a
+  //   multiple of 2.
   // * If |get_frame_buffer| is not null, it is invoked to allocate the memory.
   //   If |get_frame_buffer| is null, YuvBuffer allocates the memory directly
-  //   and ignores the |private_data| and |frame_buffer| parameters, which
-  //   should be null.
+  //   and ignores the |callback_private_data| and |buffer_private_data|
+  //   parameters, which should be null.
   //
-  // Example: bitdepth=8 width=20 height=6 border=2. The diagram below shows
-  // how Realloc() allocates the data buffer for the Y plane.
+  // NOTE: The strides are a multiple of 16. Since the first row in each plane
+  // is 16-byte aligned, subsequent rows are also 16-byte aligned.
+  //
+  // Example: bitdepth=8 width=20 height=6 left/right/top/bottom_border=2. The
+  // diagram below shows how Realloc() allocates the data buffer for the Y
+  // plane.
   //
   //   16-byte aligned
   //          |
   //          v
-  //        BBBBBBBBBBBBBBBBBBBBBBBBBBBBpppp
-  //        BBBBBBBBBBBBBBBBBBBBBBBBBBBBpppp
-  //        BB01234567890123456789....BBpppp
-  //        BB11234567890123456789....BBpppp
-  //        BB21234567890123456789....BBpppp
-  //        BB31234567890123456789....BBpppp
-  //        BB41234567890123456789....BBpppp
-  //        BB51234567890123456789....BBpppp
-  //        BB........................BBpppp
-  //        BB........................BBpppp
-  //        BBBBBBBBBBBBBBBBBBBBBBBBBBBBpppp
-  //        BBBBBBBBBBBBBBBBBBBBBBBBBBBBpppp
+  //        ++++++++++++++++++++++++pppppppp
+  //        ++++++++++++++++++++++++pppppppp
+  //        ++01234567890123456789++pppppppp
+  //        ++11234567890123456789++pppppppp
+  //        ++21234567890123456789++pppppppp
+  //        ++31234567890123456789++pppppppp
+  //        ++41234567890123456789++pppppppp
+  //        ++51234567890123456789++pppppppp
+  //        ++++++++++++++++++++++++pppppppp
+  //        ++++++++++++++++++++++++pppppppp
   //        |                              |
   //        |<-- stride (multiple of 16) ->|
   //
   // The video frame has 6 rows of 20 pixels each. Each row is shown as the
   // pattern r1234567890123456789, where |r| is 0, 1, 2, 3, 4, 5.
   //
-  // Realloc() first aligns |width| and |height| to multiples of 8 pixels. The
-  // pixels added in this step are shown as dots ('.'). In this example, the
-  // aligned width is 24 pixels and the aligned height is 8 pixels. NOTE: The
-  // purpose of this step is unknown. We should be able to remove this step.
+  // Realloc() first adds a border of 2 pixels around the video frame. The
+  // border pixels are shown as '+'.
   //
-  // Realloc() then adds a border of 2 pixels around this region. The border
-  // pixels are shown as capital 'B'. NOTE: This example uses a tiny border of
-  // 2 pixels to keep the diagram small. The current implementation of
-  // Realloc() actually requires that |border| be a multiple of 32. We should
-  // be able to only require that |border| be a multiple of 2.
-  //
-  // Each row is now padded to a multiple of the default alignment in bytes,
+  // Each row is then padded to a multiple of the default alignment in bytes,
   // which is 16. The padding bytes are shown as lowercase 'p'. (Since
   // |bitdepth| is 8 in this example, each pixel is one byte.) The padded size
   // in bytes is the stride. In this example, the stride is 32 bytes.
   //
   // Finally, Realloc() aligns the first byte of frame data, which is the '0'
   // pixel/byte in the upper left corner of the frame, to the default (16-byte)
-  // alignemnt boundary and also the |byte_alignment| boundary, if
-  // |byte_alignment| is nonzero.
+  // alignment boundary.
   //
-  // TODO(wtc): We don't need to allocate the U and V plane buffers if
-  // |monochrome| is true.
   // TODO(wtc): Add a check for width and height limits to defend against
   // invalid bitstreams.
   bool Realloc(int bitdepth, bool is_monochrome, int width, int height,
-               int8_t subsampling_x, int8_t subsampling_y, int border,
-               int byte_alignment, GetFrameBufferCallback get_frame_buffer,
-               void* private_data, FrameBuffer* frame_buffer);
+               int8_t subsampling_x, int8_t subsampling_y, int left_border,
+               int right_border, int top_border, int bottom_border,
+               GetFrameBufferCallback get_frame_buffer,
+               void* callback_private_data, void** buffer_private_data);
 
   int bitdepth() const { return bitdepth_; }
 
@@ -116,13 +104,6 @@
     return (plane == kPlaneY) ? y_height_ : uv_height_;
   }
 
-  int displayed_width(int plane) const {
-    return (plane == kPlaneY) ? y_crop_width_ : uv_crop_width_;
-  }
-  int displayed_height(int plane) const {
-    return (plane == kPlaneY) ? y_crop_height_ : uv_crop_height_;
-  }
-
   // Returns border sizes in pixels.
   int left_border(int plane) const { return left_border_[plane]; }
   int right_border(int plane) const { return right_border_[plane]; }
@@ -130,79 +111,55 @@
   int bottom_border(int plane) const { return bottom_border_[plane]; }
 
   // Returns the alignment of frame buffer row in bytes.
-  int alignment() const { return 16; }
+  int alignment() const { return kFrameBufferRowAlignment; }
 
-  // Returns whether shifing frame buffer is successful.
-  // |vertical_shift| and |horizontal_shift| are in pixels.
-  // TODO(chengchen):
-  // Warning: this implementation doesn't handle the byte_alignment requirement.
-  // For example, if the frame is required to be 4K-byte aligned, this method
-  // fails. Figure out alternative solutions if the feature of
-  // byte_alignment is required in practice.
-  bool ShiftBuffer(int plane, int horizontal_shift, int vertical_shift) {
-    if (!ValidHorizontalShift(plane, horizontal_shift) ||
-        !ValidVerticalShift(plane, vertical_shift)) {
-      return false;
-    }
-    left_border_[plane] += horizontal_shift;
-    right_border_[plane] -= horizontal_shift;
-    top_border_[plane] += vertical_shift;
-    bottom_border_[plane] -= vertical_shift;
-    const int pixel_size =
-        static_cast<int>((bitdepth_ == 8) ? sizeof(uint8_t) : sizeof(uint16_t));
-    buffer_[plane] +=
-        vertical_shift * stride_[plane] + horizontal_shift * pixel_size;
-    return true;
-  }
-
+  // Backup the current set of warnings and disable -Warray-bounds for the
+  // following three functions as the compiler cannot, in all cases, determine
+  // whether |plane| is within [0, kMaxPlanes), e.g., with a variable based for
+  // loop.
+#ifdef __GNUC__
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Warray-bounds"
+#endif
   // Returns the data buffer for |plane|.
-  uint8_t* data(int plane) { return buffer_[plane]; }
-  const uint8_t* data(int plane) const { return buffer_[plane]; }
+  uint8_t* data(int plane) {
+    assert(plane >= 0);
+    assert(static_cast<size_t>(plane) < std::extent<decltype(buffer_)>::value);
+    return buffer_[plane];
+  }
+  const uint8_t* data(int plane) const {
+    assert(plane >= 0);
+    assert(static_cast<size_t>(plane) < std::extent<decltype(buffer_)>::value);
+    return buffer_[plane];
+  }
 
   // Returns the stride in bytes for |plane|.
-  int stride(int plane) const { return stride_[plane]; }
+  int stride(int plane) const {
+    assert(plane >= 0);
+    assert(static_cast<size_t>(plane) < std::extent<decltype(stride_)>::value);
+    return stride_[plane];
+  }
+  // Restore the previous set of compiler warnings.
+#ifdef __GNUC__
+#pragma GCC diagnostic pop
+#endif
 
  private:
-  // Frame buffer pointer, i.e., |buffer_[plane]| can only be shifted in loop
-  // restoration. If loop restoration is applied on plane, |buffer_[plane]|
-  // will be shifted kRestorationBorder rows above, and
-  // kFrameBufferRowAlignment columns left.
-  // |shift| is in pixels.
-  // Positive vertical shift is a down shift. Negative vertical shift is an
-  // up shift.
-  bool ValidVerticalShift(int plane, int shift) const {
-    return (shift >= 0) ? bottom_border_[plane] >= shift
-                        : top_border_[plane] + shift >= 0;
-  }
-  // Positive horizontal shift is a right shift. Negative horizontal shift is
-  // a left shift.
-  bool ValidHorizontalShift(int plane, int shift) const {
-    return (shift >= 0) ? right_border_[plane] >= shift
-                        : left_border_[plane] + shift >= 0;
-  }
-
+  static constexpr int kFrameBufferRowAlignment = 16;
   int bitdepth_ = 0;
   bool is_monochrome_ = false;
 
-  // y_crop_width_ and y_crop_height_ are the original width and height (the
-  // |width| and |height| arguments passed to the Realloc() method). y_width_
-  // and y_height_ are the original width and height padded to a multiple of
-  // 8.
+  // y_width_ and y_height_ are the |width| and |height| arguments passed to the
+  // Realloc() method.
   //
-  // The UV widths and heights are computed from Y widths and heights as
+  // uv_width_ and uv_height_ are computed from y_width_ and y_height_ as
   // follows:
-  //   uv_crop_width_ = (y_crop_width_ + subsampling_x_) >> subsampling_x_
-  //   uv_crop_height_ = (y_crop_height_ + subsampling_y_) >> subsampling_y_
-  //   uv_width_ = y_width_ >> subsampling_x_
-  //   uv_height_ = y_height_ >> subsampling_y_
+  //   uv_width_ = (y_width_ + subsampling_x_) >> subsampling_x_
+  //   uv_height_ = (y_height_ + subsampling_y_) >> subsampling_y_
   int y_width_ = 0;
   int uv_width_ = 0;
   int y_height_ = 0;
   int uv_height_ = 0;
-  int y_crop_width_ = 0;
-  int uv_crop_width_ = 0;
-  int y_crop_height_ = 0;
-  int uv_crop_height_ = 0;
 
   int left_border_[kMaxPlanes] = {};
   int right_border_[kMaxPlanes] = {};
@@ -214,7 +171,7 @@
 
   // buffer_alloc_ and buffer_alloc_size_ are only used if the
   // get_frame_buffer callback is null and we allocate the buffer ourselves.
-  uint8_t* buffer_alloc_ = nullptr;
+  std::unique_ptr<uint8_t[]> buffer_alloc_;
   size_t buffer_alloc_size_ = 0;
 
   int8_t subsampling_x_ = 0;  // 0 or 1.
diff --git a/libgav1/tests/fuzzer/decoder_fuzzer.cc b/libgav1/tests/fuzzer/decoder_fuzzer.cc
new file mode 100644
index 0000000..236fd3c
--- /dev/null
+++ b/libgav1/tests/fuzzer/decoder_fuzzer.cc
@@ -0,0 +1,87 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "examples/file_reader.h"
+#include "examples/file_reader_constants.h"
+#include "examples/file_reader_interface.h"
+#include "src/gav1/decoder.h"
+#include "tests/fuzzer/fuzzer_temp_file.h"
+
+namespace {
+
+#if defined(LIBGAV1_EXHAUSTIVE_FUZZING)
+// Set a large upper bound to give more coverage of a single input; this value
+// should be larger than most of the frame counts in the corpus.
+constexpr int kMaxFrames = 100;
+constexpr size_t kMaxDataSize = 400 * 1024;
+#else
+// Restrict the number of frames to improve fuzzer throughput.
+constexpr int kMaxFrames = 5;
+constexpr size_t kMaxDataSize = 200 * 1024;
+#endif
+
+void Decode(const uint8_t* const data, const size_t size,
+            libgav1::Decoder* const decoder) {
+  decoder->EnqueueFrame(data, size, /*user_private_data=*/0,
+                        /*buffer_private_data=*/nullptr);
+  const libgav1::DecoderBuffer* buffer;
+  decoder->DequeueFrame(&buffer);
+}
+
+}  // namespace
+
+// Always returns 0. Nonzero return values are reserved by libFuzzer for future
+// use.
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+  // Reject large chunks of data to improve fuzzer throughput.
+  if (size > kMaxDataSize) return 0;
+
+  libgav1::Decoder decoder;
+  libgav1::DecoderSettings settings = {};
+  // Use the low byte of the width to seed the number of threads.
+  // We use both nibbles of the lower byte as this results in values != 1 much
+  // more quickly than using the lower nibble alone.
+  settings.threads = (size >= 13) ? ((data[12] >> 4 | data[12]) & 0xF) + 1 : 1;
+  if (decoder.Init(&settings) != libgav1::kStatusOk) return 0;
+
+  // Treat the input as a raw OBU stream.
+  Decode(data, size, &decoder);
+
+  // Use the first frame from an IVF to bypass any read errors from the parser.
+  static constexpr size_t kIvfHeaderSize =
+      libgav1::kIvfFileHeaderSize + libgav1::kIvfFrameHeaderSize;
+  if (size >= kIvfHeaderSize) {
+    Decode(data + kIvfHeaderSize, size - kIvfHeaderSize, &decoder);
+  }
+
+  FuzzerTemporaryFile tempfile(data, size);
+  auto file_reader =
+      libgav1::FileReader::Open(tempfile.filename(), /*error_tolerant=*/true);
+  if (file_reader == nullptr) return 0;
+
+  std::vector<uint8_t> buffer;
+  int decoded_frames = 0;
+  do {
+    if (!file_reader->ReadTemporalUnit(&buffer, nullptr)) break;
+    Decode(buffer.data(), buffer.size(), &decoder);
+    if (++decoded_frames >= kMaxFrames) break;
+  } while (!file_reader->IsEndOfFile());
+
+  return 0;
+}
diff --git a/libgav1/tests/fuzzer/decoder_fuzzer_frame_parallel.cc b/libgav1/tests/fuzzer/decoder_fuzzer_frame_parallel.cc
new file mode 100644
index 0000000..6e8b6a0
--- /dev/null
+++ b/libgav1/tests/fuzzer/decoder_fuzzer_frame_parallel.cc
@@ -0,0 +1,141 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <cstddef>
+#include <cstdint>
+#include <deque>
+#include <memory>
+#include <vector>
+
+#include "examples/file_reader.h"
+#include "examples/file_reader_constants.h"
+#include "examples/file_reader_interface.h"
+#include "src/gav1/decoder.h"
+#include "src/gav1/status_code.h"
+#include "tests/fuzzer/fuzzer_temp_file.h"
+
+namespace {
+
+#if defined(LIBGAV1_EXHAUSTIVE_FUZZING)
+// Set a large upper bound to give more coverage of a single input; this value
+// should be larger than most of the frame counts in the corpus.
+constexpr size_t kMaxDataSize = 400 * 1024;
+#else
+constexpr size_t kMaxDataSize = 200 * 1024;
+#endif
+
+using InputBuffer = std::vector<uint8_t>;
+
+struct InputBuffers {
+  ~InputBuffers() {
+    for (auto& buffer : free_buffers) {
+      delete buffer;
+    }
+  }
+  std::deque<InputBuffer*> free_buffers;
+};
+
+void ReleaseInputBuffer(void* callback_private_data,
+                        void* buffer_private_data) {
+  auto* const test = static_cast<InputBuffers*>(callback_private_data);
+  test->free_buffers.push_back(static_cast<InputBuffer*>(buffer_private_data));
+}
+
+}  // namespace
+
+// Always returns 0. Nonzero return values are reserved by libFuzzer for future
+// use.
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+  // Reject large chunks of data to improve fuzzer throughput.
+  if (size > kMaxDataSize) return 0;
+
+  // Note that |input_buffers| has to outlive the |decoder| object since the
+  // |release_input_buffer| callback could be called on the |decoder|'s
+  // destructor.
+  InputBuffers input_buffers;
+
+  libgav1::Decoder decoder;
+  libgav1::DecoderSettings settings = {};
+  // Use the 33 + low byte of the width to seed the number of threads. This
+  // ensures that we will trigger the frame parallel path in most cases.
+  // We use both nibbles of the lower byte as this results in values != 1 much
+  // more quickly than using the lower nibble alone.
+  settings.threads =
+      33 + ((size >= 13) ? ((data[12] >> 4 | data[12]) & 0xF) + 1 : 1);
+
+  settings.frame_parallel = true;
+  settings.blocking_dequeue = true;
+  settings.callback_private_data = &input_buffers;
+  settings.release_input_buffer = ReleaseInputBuffer;
+  if (decoder.Init(&settings) != libgav1::kStatusOk) return 0;
+
+  FuzzerTemporaryFile tempfile(data, size);
+  auto file_reader =
+      libgav1::FileReader::Open(tempfile.filename(), /*error_tolerant=*/true);
+  if (file_reader == nullptr) return 0;
+
+  InputBuffer* input_buffer = nullptr;
+  bool dequeue_finished = false;
+
+  do {
+    if (input_buffer == nullptr && !file_reader->IsEndOfFile()) {
+      if (input_buffers.free_buffers.empty()) {
+        auto* const buffer = new (std::nothrow) InputBuffer();
+        if (buffer == nullptr) {
+          break;
+        }
+        input_buffers.free_buffers.push_back(buffer);
+      }
+      input_buffer = input_buffers.free_buffers.front();
+      input_buffers.free_buffers.pop_front();
+      if (!file_reader->ReadTemporalUnit(input_buffer, nullptr)) {
+        break;
+      }
+    }
+
+    if (input_buffer != nullptr) {
+      libgav1::StatusCode status =
+          decoder.EnqueueFrame(input_buffer->data(), input_buffer->size(),
+                               /*user_private_data=*/0,
+                               /*buffer_private_data=*/input_buffer);
+      if (status == libgav1::kStatusOk) {
+        input_buffer = nullptr;
+        // Continue to enqueue frames until we get a kStatusTryAgain status.
+        continue;
+      }
+      if (status != libgav1::kStatusTryAgain) {
+        break;
+      }
+    }
+
+    const libgav1::DecoderBuffer* buffer;
+    libgav1::StatusCode status = decoder.DequeueFrame(&buffer);
+    if (status != libgav1::kStatusOk &&
+        status != libgav1::kStatusNothingToDequeue) {
+      break;
+    }
+    if (buffer == nullptr) {
+      dequeue_finished = status == libgav1::kStatusNothingToDequeue;
+    } else {
+      dequeue_finished = false;
+    }
+  } while (input_buffer != nullptr || !file_reader->IsEndOfFile() ||
+           !dequeue_finished);
+
+  if (input_buffer != nullptr) {
+    input_buffers.free_buffers.push_back(input_buffer);
+  }
+
+  return 0;
+}
diff --git a/libgav1/tests/fuzzer/fuzzer_temp_file.h b/libgav1/tests/fuzzer/fuzzer_temp_file.h
new file mode 100644
index 0000000..5d12bbe
--- /dev/null
+++ b/libgav1/tests/fuzzer/fuzzer_temp_file.h
@@ -0,0 +1,148 @@
+/*
+ * Copyright 2020 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBGAV1_TESTS_FUZZER_FUZZER_TEMP_FILE_H_
+#define LIBGAV1_TESTS_FUZZER_FUZZER_TEMP_FILE_H_
+
+// Adapter utility from fuzzer input to a temporary file, for fuzzing APIs that
+// require a file instead of an input buffer.
+
+#include <limits.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <unistd.h>
+
+// Pure-C interface for creating and cleaning up temporary files.
+
+static char* fuzzer_get_tmpfile_with_suffix(const uint8_t* data, size_t size,
+                                            const char* suffix) {
+  if (suffix == NULL) {  // NOLINT (this could be a C compilation unit)
+    suffix = "";
+  }
+  const size_t suffix_len = strlen(suffix);
+  if (suffix_len > INT_MAX) {  // mkstemps takes int for suffixlen param
+    perror("Suffix too long");
+    abort();
+  }
+
+#ifdef __ANDROID__
+  const char* leading_temp_path =
+      "/data/local/tmp/generate_temporary_file.XXXXXX";
+#else
+  const char* leading_temp_path = "/tmp/generate_temporary_file.XXXXXX";
+#endif
+  const size_t buffer_sz = strlen(leading_temp_path) + suffix_len + 1;
+  char* filename_buffer =
+      (char*)malloc(buffer_sz);  // NOLINT (this could be a C compilation unit)
+  if (!filename_buffer) {
+    perror("Failed to allocate file name buffer.");
+    abort();
+  }
+
+  if (snprintf(filename_buffer, buffer_sz, "%s%s", leading_temp_path, suffix) >=
+      buffer_sz) {
+    perror("File name buffer too short.");
+    abort();
+  }
+
+  const int file_descriptor = mkstemps(filename_buffer, suffix_len);
+  if (file_descriptor < 0) {
+    perror("Failed to make temporary file.");
+    abort();
+  }
+  FILE* file = fdopen(file_descriptor, "wb");
+  if (!file) {
+    perror("Failed to open file descriptor.");
+    close(file_descriptor);
+    abort();
+  }
+  const size_t bytes_written = fwrite(data, sizeof(uint8_t), size, file);
+  if (bytes_written < size) {
+    close(file_descriptor);
+    fprintf(stderr, "Failed to write all bytes to file (%zu out of %zu)",
+            bytes_written, size);
+    abort();
+  }
+  fclose(file);
+  return filename_buffer;
+}
+
+static char* fuzzer_get_tmpfile(
+    const uint8_t* data,
+    size_t size) {  // NOLINT (people include this .inc file directly)
+  return fuzzer_get_tmpfile_with_suffix(data, size, NULL);  // NOLINT
+}
+
+static void fuzzer_release_tmpfile(char* filename) {
+  if (unlink(filename) != 0) {
+    perror("WARNING: Failed to delete temporary file.");
+  }
+  free(filename);
+}
+
+// C++ RAII object for creating temporary files.
+
+#ifdef __cplusplus
+class FuzzerTemporaryFile {
+ public:
+  FuzzerTemporaryFile(const uint8_t* data, size_t size)
+      : original_filename_(fuzzer_get_tmpfile(data, size)) {
+    filename_ = strdup(original_filename_);
+    if (!filename_) {
+      perror("Failed to allocate file name copy.");
+      abort();
+    }
+  }
+
+  FuzzerTemporaryFile(const uint8_t* data, size_t size, const char* suffix)
+      : original_filename_(fuzzer_get_tmpfile_with_suffix(data, size, suffix)) {
+    filename_ = strdup(original_filename_);
+    if (!filename_) {
+      perror("Failed to allocate file name copy.");
+      abort();
+    }
+  }
+
+  ~FuzzerTemporaryFile() {
+    free(filename_);
+    fuzzer_release_tmpfile(original_filename_);
+  }
+
+  FuzzerTemporaryFile(const FuzzerTemporaryFile& other) = delete;
+  FuzzerTemporaryFile operator=(const FuzzerTemporaryFile& other) = delete;
+
+  FuzzerTemporaryFile(const FuzzerTemporaryFile&& other) = delete;
+  FuzzerTemporaryFile operator=(const FuzzerTemporaryFile&& other) = delete;
+
+  const char* filename() const { return filename_; }
+
+  // Returns a mutable pointer to the file name. Should be used sparingly, only
+  // in case the fuzzed API demands it or when making a mutable copy is
+  // inconvenient (e.g., in auto-generated code).
+  char* mutable_filename() const { return filename_; }
+
+ private:
+  char* original_filename_;
+
+  // A mutable copy of the original filename, returned by the accessor. This
+  // guarantees that the original filename can always be used to release the
+  // temporary path.
+  char* filename_;
+};
+#endif  // __cplusplus
+#endif  // LIBGAV1_TESTS_FUZZER_FUZZER_TEMP_FILE_H_
diff --git a/libgav1/tests/fuzzer/obu_parser_fuzzer.cc b/libgav1/tests/fuzzer/obu_parser_fuzzer.cc
new file mode 100644
index 0000000..634a802
--- /dev/null
+++ b/libgav1/tests/fuzzer/obu_parser_fuzzer.cc
@@ -0,0 +1,89 @@
+// Copyright 2020 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "examples/file_reader.h"
+#include "examples/file_reader_constants.h"
+#include "examples/file_reader_interface.h"
+#include "src/buffer_pool.h"
+#include "src/decoder_impl.h"
+#include "src/decoder_state.h"
+#include "src/internal_frame_buffer_list.h"
+#include "src/obu_parser.h"
+#include "tests/fuzzer/fuzzer_temp_file.h"
+
+namespace {
+
+#if defined(LIBGAV1_EXHAUSTIVE_FUZZING)
+// Set a large upper bound to give more coverage of a single input; this value
+// should be larger than most of the frame counts in the corpus.
+constexpr int kMaxFrames = 100;
+constexpr size_t kMaxDataSize = 400 * 1024;
+#else
+// Restrict the number of frames and obus to improve fuzzer throughput.
+constexpr int kMaxFrames = 5;
+constexpr size_t kMaxDataSize = 200 * 1024;
+#endif
+
+inline void ParseObu(const uint8_t* const data, size_t size) {
+  libgav1::InternalFrameBufferList buffer_list;
+  libgav1::BufferPool buffer_pool(libgav1::OnInternalFrameBufferSizeChanged,
+                                  libgav1::GetInternalFrameBuffer,
+                                  libgav1::ReleaseInternalFrameBuffer,
+                                  &buffer_list);
+  libgav1::DecoderState decoder_state;
+  libgav1::ObuParser parser(data, size, 0, &buffer_pool, &decoder_state);
+  libgav1::RefCountedBufferPtr current_frame;
+  int parsed_frames = 0;
+  while (parser.HasData()) {
+    if (parser.ParseOneFrame(&current_frame) != libgav1::kStatusOk) break;
+    if (++parsed_frames >= kMaxFrames) break;
+  }
+}
+
+}  // namespace
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+  // Reject large chunks of data to improve fuzzer throughput.
+  if (size > kMaxDataSize) return 0;
+
+  // Treat the input as a raw OBU stream.
+  ParseObu(data, size);
+
+  // Use the first frame from an IVF to bypass any read errors from the parser.
+  static constexpr size_t kIvfHeaderSize =
+      libgav1::kIvfFileHeaderSize + libgav1::kIvfFrameHeaderSize;
+  if (size >= kIvfHeaderSize) {
+    ParseObu(data + kIvfHeaderSize, size - kIvfHeaderSize);
+  }
+
+  FuzzerTemporaryFile tempfile(data, size);
+  auto file_reader =
+      libgav1::FileReader::Open(tempfile.filename(), /*error_tolerant=*/true);
+  if (file_reader == nullptr) return 0;
+
+  std::vector<uint8_t> buffer;
+  int parsed_frames = 0;
+  do {
+    if (!file_reader->ReadTemporalUnit(&buffer, nullptr)) break;
+    ParseObu(buffer.data(), buffer.size());
+    if (++parsed_frames >= kMaxFrames) break;
+  } while (!file_reader->IsEndOfFile());
+
+  return 0;
+}