SkJpegCodec: Detect multi-picture format gainmaps

Add support for Multi-Picture format gainmaps produced by iPhone
cameras.

To determine if the gainmap is present, first check to see if the
image has MPF metadata. Instruct the decoder to retain MPF tags (as
it does for ICC and EXIF tags already), and early-out if it does not
find MPF tags or if they cannot be parsed.

Once it is know that there are MPF tags, extract SkStreams for the
MP images. Then, for each MP image, use SkJpegSegmentScan to search
for XMP metadata. Search that XMP metadata for a node that indicates
that the MP image is a recognized gainmap image.

Bug: skia: 14031
Change-Id: I3c2aa2c95e02bf4e4539692bc052fd05d2b92537
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/625925
Reviewed-by: Brian Osman <brianosman@google.com>
Commit-Queue: Christopher Cameron <ccameron@google.com>
diff --git a/BUILD.gn b/BUILD.gn
index 981bcc3..53beb07 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -1208,10 +1208,17 @@
   sources = [
     "src/codec/SkJpegCodec.cpp",
     "src/codec/SkJpegDecoderMgr.cpp",
-    "src/codec/SkJpegMultiPicture.cpp",
-    "src/codec/SkJpegSegmentScan.cpp",
     "src/codec/SkJpegUtility.cpp",
   ]
+  if (skia_use_jpeg_gainmaps) {
+    defines = [ "SK_CODEC_DECODES_JPEG_GAINMAPS" ]
+    deps += [ ":xml" ]
+    sources += [
+      "src/codec/SkJpegGainmap.cpp",
+      "src/codec/SkJpegMultiPicture.cpp",
+      "src/codec/SkJpegSegmentScan.cpp",
+    ]
+  }
 }
 
 optional("jpeg_encode") {
@@ -1388,7 +1395,7 @@
 }
 
 optional("xml") {
-  enabled = skia_use_expat
+  enabled = skia_use_expat || skia_use_jpeg_gainmaps
   public_defines = [ "SK_XML" ]
 
   deps = [ "//third_party/expat" ]
@@ -2242,6 +2249,9 @@
       cflags_objcc = [ "-fobjc-arc" ]
       frameworks += [ "MetalKit.framework" ]
     }
+    if (skia_use_jpeg_gainmaps) {
+      sources += jpeg_gainmap_tests_sources
+    }
     if (skia_use_gl) {
       sources += gl_tests_sources
     }
diff --git a/gn/skia.gni b/gn/skia.gni
index 9d7f394..29cb98b 100644
--- a/gn/skia.gni
+++ b/gn/skia.gni
@@ -53,6 +53,7 @@
   skia_use_libavif = false
   skia_use_libheif = is_skia_dev_build
   skia_use_jpegr = false
+  skia_use_jpeg_gainmaps = false
   skia_use_libjpeg_turbo_decode = true
   skia_use_libjpeg_turbo_encode = true
   skia_use_libjxl_decode = false
diff --git a/gn/tests.gni b/gn/tests.gni
index c8164b5..cff0fa1 100644
--- a/gn/tests.gni
+++ b/gn/tests.gni
@@ -462,3 +462,5 @@
 ]
 
 tests_sources += ganesh_tests_sources
+
+jpeg_gainmap_tests_sources = [ "$_tests/JpegGainmapTest.cpp" ]
diff --git a/include/codec/SkCodec.h b/include/codec/SkCodec.h
index 6d924f8..29231c7 100644
--- a/include/codec/SkCodec.h
+++ b/include/codec/SkCodec.h
@@ -771,6 +771,8 @@
         return fSrcXformFormat;
     }
 
+    sk_sp<const SkData> getXmpMetadata() const { return fXmpMetadata; }
+
     virtual bool onGetGainmapInfo(SkGainmapInfo*, std::unique_ptr<SkStream>*) { return false; }
 
     virtual SkISize onGetScaledDimensions(float /*desiredScale*/) const {
diff --git a/src/codec/BUILD.bazel b/src/codec/BUILD.bazel
index d4fe17c..bf8cac9 100644
--- a/src/codec/BUILD.bazel
+++ b/src/codec/BUILD.bazel
@@ -86,6 +86,8 @@
     "SkJpegCodec.h",
     "SkJpegDecoderMgr.cpp",
     "SkJpegDecoderMgr.h",
+    "SkJpegGainmap.cpp",
+    "SkJpegGainmap.h",
     "SkJpegUtility.cpp",
     "SkJpegUtility.h",
     "SkJpegMultiPicture.cpp",
diff --git a/src/codec/SkJpegCodec.cpp b/src/codec/SkJpegCodec.cpp
index 53481ed..e254bed 100644
--- a/src/codec/SkJpegCodec.cpp
+++ b/src/codec/SkJpegCodec.cpp
@@ -18,7 +18,6 @@
 #include "include/core/SkStream.h"
 #include "include/core/SkTypes.h"
 #include "include/core/SkYUVAInfo.h"
-#include "include/private/SkGainmapInfo.h"
 #include "include/private/SkTemplates.h"
 #include "include/private/base/SkAlign.h"
 #include "include/private/base/SkMalloc.h"
@@ -30,12 +29,17 @@
 #include "src/codec/SkParseEncodedOrigin.h"
 #include "src/codec/SkSwizzler.h"
 
+#ifdef SK_CODEC_DECODES_JPEG_GAINMAPS
+#include "src/codec/SkJpegGainmap.h"
+#endif  // SK_CODEC_DECODES_JPEG_GAINMAPS
+
 #include <array>
 #include <csetjmp>
 #include <cstring>
 #include <utility>
 
 class SkSampler;
+struct SkGainmapInfo;
 
 // This warning triggers false postives way too often in here.
 #if defined(__GNUC__) && !defined(__clang__)
@@ -202,6 +206,7 @@
     if (codecOut) {
         jpeg_save_markers(dinfo, kExifMarker, 0xFFFF);
         jpeg_save_markers(dinfo, kICCMarker, 0xFFFF);
+        jpeg_save_markers(dinfo, kMpfMarker, 0xFFFF);
     }
 
     // Read the jpeg header
@@ -1020,6 +1025,7 @@
     jpeg_decompress_struct* dinfo = decoderMgr.dinfo();
     jpeg_save_markers(dinfo, kExifMarker, 0xFFFF);
     jpeg_save_markers(dinfo, kICCMarker, 0xFFFF);
+    jpeg_save_markers(dinfo, kMpfMarker, 0xFFFF);
     if (JPEG_HEADER_OK != jpeg_read_header(dinfo, true)) {
         return false;
     }
@@ -1041,8 +1047,14 @@
 
 bool SkJpegCodec::onGetGainmapInfo(SkGainmapInfo* info,
                                    std::unique_ptr<SkStream>* gainmapImageStream) {
-    // TODO(ccameron): Parse gainmap here.
-    *info = SkGainmapInfo();
+#ifdef SK_CODEC_DECODES_JPEG_GAINMAPS
+    // Attempt to extract Multi-Picture Format gainmap formats.
+    auto mpfMetadata =
+            read_metadata_marker(fDecoderMgr->dinfo(), kMpfMarker, kMpfSig, sizeof(kMpfSig));
+    if (SkJpegGetMultiPictureGainmap(mpfMetadata, stream(), info, gainmapImageStream)) {
+        return true;
+    }
+#endif  // SK_CODEC_DECODES_JPEG_GAINMAPS
     return false;
 }
 
diff --git a/src/codec/SkJpegGainmap.cpp b/src/codec/SkJpegGainmap.cpp
new file mode 100644
index 0000000..131ce4b
--- /dev/null
+++ b/src/codec/SkJpegGainmap.cpp
@@ -0,0 +1,304 @@
+/*
+ * Copyright 2023 Google Inc.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/codec/SkJpegGainmap.h"
+
+#include "include/core/SkColor.h"
+#include "include/core/SkData.h"
+#include "include/core/SkStream.h"
+#include "include/private/SkFloatingPoint.h"
+#include "include/private/SkGainmapInfo.h"
+#include "include/utils/SkParse.h"
+#include "src/codec/SkCodecPriv.h"
+#include "src/codec/SkJpegMultiPicture.h"
+#include "src/codec/SkJpegPriv.h"
+#include "src/codec/SkJpegSegmentScan.h"
+#include "src/xml/SkDOM.h"
+
+#include <cstdint>
+#include <cstring>
+#include <utility>
+#include <vector>
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// SkStream helpers.
+
+/*
+ * Class that will will rewind an SkStream, and then restore it to its original position when it
+ * goes out of scope. If the SkStream is not seekable, then the stream will not be altered at all,
+ * and will return false from canRestore.
+ */
+
+class ScopedSkStreamRestorer {
+public:
+    ScopedSkStreamRestorer(SkStream* stream)
+            : fStream(stream), fPosition(stream->hasPosition() ? stream->getPosition() : 0) {
+        if (canRestore()) {
+            if (!fStream->rewind()) {
+                SkCodecPrintf("Failed to rewind decoder stream.\n");
+            }
+        }
+    }
+    ~ScopedSkStreamRestorer() {
+        if (canRestore()) {
+            if (!fStream->seek(fPosition)) {
+                SkCodecPrintf("Failed to restore decoder stream.\n");
+            }
+        }
+    }
+    bool canRestore() const { return fStream->hasPosition(); }
+
+private:
+    SkStream* const fStream;
+    const size_t fPosition;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// SkDOM and XMP helpers.
+
+/*
+ * Build an SkDOM from an SkData. Return true on success and false on failure (including the input
+ * data being nullptr).
+ */
+bool SkDataToSkDOM(sk_sp<const SkData> data, SkDOM* dom) {
+    if (!data) {
+        return false;
+    }
+    auto stream = SkMemoryStream::MakeDirect(data->data(), data->size());
+    if (!stream) {
+        return false;
+    }
+    return dom->build(*stream) != nullptr;
+}
+
+/*
+ * Given an SkDOM, verify that the dom is XMP, and find the first rdf:Description node that matches
+ * the specified namespaces to the specified URIs. The XML structure that this function matches is
+ * as follows (with NAMESPACEi and URIi being the parameters specified to this function):
+ *
+ *   <x:xmpmeta ...>
+ *     <rdf:RDF ...>
+ *       <rdf:Description NAMESPACE0="URI0" NAMESPACE1="URI1" .../>
+ *     </rdf:RDF>
+ *   </x:xmpmeta>
+ */
+const SkDOM::Node* FindXmpNamespaceUriMatch(const SkDOM& dom,
+                                            const char* namespaces[],
+                                            const char* uris[],
+                                            size_t count) {
+    const SkDOM::Node* root = dom.getRootNode();
+    if (!root) {
+        return nullptr;
+    }
+    const char* rootName = dom.getName(root);
+    if (!rootName || strcmp(rootName, "x:xmpmeta") != 0) {
+        return nullptr;
+    }
+
+    const char* kRdf = "rdf:RDF";
+    for (const auto* rdf = dom.getFirstChild(root, kRdf); rdf;
+         rdf = dom.getNextSibling(rdf, kRdf)) {
+        const char* kDesc = "rdf:Description";
+        for (const auto* desc = dom.getFirstChild(rdf, kDesc); desc;
+             desc = dom.getNextSibling(desc, kDesc)) {
+            bool allNamespaceURIsMatch = true;
+            for (size_t i = 0; i < count; ++i) {
+                if (!dom.hasAttr(desc, namespaces[i], uris[i])) {
+                    allNamespaceURIsMatch = false;
+                    break;
+                }
+            }
+            if (allNamespaceURIsMatch) {
+                return desc;
+            }
+        }
+    }
+    return nullptr;
+}
+
+/*
+ * Given a node, see if that node has only one child with the indicated name. If so, see if that
+ * child has only a single child of its own, and that child is text. If all of that is the case
+ * then return the text, otherwise return nullptr.
+ *
+ * In the following example, innerText will be returned.
+ *    <node><childName>innerText</childName></node>
+ *
+ * In the following examples, nullptr will be returned (because there are multiple children with
+ * childName in the first case, and because the child has children of its own in the second).
+ *    <node><childName>innerTextA</childName><childName>innerTextB</childName></node>
+ *    <node><childName>innerText<otherGrandChild/></childName></node>
+ */
+static const char* GetUniqueChildText(const SkDOM& dom,
+                                      const SkDOM::Node* node,
+                                      const char* childName) {
+    // Fail if there are multiple children with childName.
+    if (dom.countChildren(node, childName) != 1) {
+        return nullptr;
+    }
+    const auto* child = dom.getFirstChild(node, childName);
+    if (!child) {
+        return nullptr;
+    }
+    // Fail if the child has any children besides text.
+    if (dom.countChildren(child) != 1) {
+        return nullptr;
+    }
+    const auto* grandChild = dom.getFirstChild(child);
+    if (dom.getType(grandChild) != SkDOM::kText_Type) {
+        return nullptr;
+    }
+    // Return the text.
+    return dom.getName(grandChild);
+}
+
+// Helper function that builds on GetUniqueChildText, returning true if the unique child with
+// childName has inner text that matches an expected text.
+static bool UniqueChildTextMatches(const SkDOM& dom,
+                                   const SkDOM::Node* node,
+                                   const char* childName,
+                                   const char* expectedText) {
+    const char* text = GetUniqueChildText(dom, node, childName);
+    if (text && !strcmp(text, expectedText)) {
+        return true;
+    }
+    return false;
+}
+
+// Helper function that builds on GetUniqueChildText, returning true if the unique child with
+// childName has inner text that matches an expected integer.
+static bool UniqueChildTextMatches(const SkDOM& dom,
+                                   const SkDOM::Node* node,
+                                   const char* childName,
+                                   int32_t expectedValue) {
+    const char* text = GetUniqueChildText(dom, node, childName);
+    int32_t actualValue = 0;
+    if (text && SkParse::FindS32(text, &actualValue)) {
+        return actualValue == expectedValue;
+    }
+    return false;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// Multi-PictureFormat Gainmap Functions
+
+// Return true if the specified XMP metadata identifies this image as an HDR gainmap.
+static bool XmpIsHDRGainMap(const sk_sp<const SkData>& xmpMetadata) {
+    // Parse the XMP.
+    SkDOM dom;
+    if (!SkDataToSkDOM(xmpMetadata, &dom)) {
+        return false;
+    }
+
+    // Find a node that matches the requested namespaces and URIs.
+    const char* namespaces[2] = {"xmlns:apdi", "xmlns:HDRGainMap"};
+    const char* uris[2] = {"http://ns.apple.com/pixeldatainfo/1.0/",
+                           "http://ns.apple.com/HDRGainMap/1.0/"};
+    const SkDOM::Node* node = FindXmpNamespaceUriMatch(dom, namespaces, uris, 2);
+    if (!node) {
+        return false;
+    }
+    if (!UniqueChildTextMatches(
+                dom, node, "apdi:AuxiliaryImageType", "urn:com:apple:photo:2020:aux:hdrgainmap")) {
+        SkCodecPrintf("Did not find auxiliary image type.\n");
+        return false;
+    }
+    if (!UniqueChildTextMatches(dom, node, "HDRGainMap:HDRGainMapVersion", 65536)) {
+        SkCodecPrintf("HDRGainMapVersion absent or not 65536.\n");
+        return false;
+    }
+
+    // This node will often have StoredFormat and NativeFormat children that have inner text that
+    // specifies the integer 'L008' (also known as kCVPixelFormatType_OneComponent8).
+    return true;
+}
+
+bool SkJpegGetMultiPictureGainmap(sk_sp<const SkData> decoderMpfMetadata,
+                                  SkStream* decoderStream,
+                                  SkGainmapInfo* outInfo,
+                                  std::unique_ptr<SkStream>* outGainmapImageStream) {
+    // The decoder has already scanned for MPF metadata. If it doesn't exist, or it doesn't parse,
+    // then early-out.
+    if (!decoderMpfMetadata || !SkJpegParseMultiPicture(decoderMpfMetadata)) {
+        return false;
+    }
+
+    // The implementation of Multi-Picture images requires a seekable stream. Save the position so
+    // that it can be restored before returning.
+    ScopedSkStreamRestorer streamRestorer(decoderStream);
+    if (!streamRestorer.canRestore()) {
+        SkCodecPrintf("Multi-Picture gainmap extraction requires a seekable stream.\n");
+        return false;
+    }
+
+    // Scan the original decoder stream.
+    auto scan = SkJpegSegmentScan::Create(decoderStream, SkJpegSegmentScan::Options());
+    if (!scan) {
+        SkCodecPrintf("Failed to scan decoder stream.\n");
+        return false;
+    }
+
+    // Extract the Multi-Picture image streams in the original decoder stream (we needed the scan to
+    // find the offsets of the MP images within the original decoder stream).
+    auto mpStreams = SkJpegExtractMultiPictureStreams(scan.get());
+    if (!mpStreams) {
+        SkCodecPrintf("Failed to extract MP image streams.\n");
+        return false;
+    }
+
+    // Iterate over the MP image streams.
+    for (auto& mpImage : mpStreams->images) {
+        if (!mpImage.stream) {
+            continue;
+        }
+
+        // Create a scan of this MP image.
+        auto mpImageScan =
+                SkJpegSegmentScan::Create(mpImage.stream.get(), SkJpegSegmentScan::Options());
+        if (!mpImageScan) {
+            SkCodecPrintf("Failed to can MP image.\n");
+            continue;
+        }
+
+        // Search for the XMP metadata in the MP image's scan.
+        for (const auto& segment : mpImageScan->segments()) {
+            if (segment.marker != kXMPMarker) {
+                continue;
+            }
+            auto xmpMetadata = mpImageScan->copyParameters(segment, kXMPSig, sizeof(kXMPSig));
+            if (!xmpMetadata) {
+                continue;
+            }
+
+            // If this XMP does not indicate that the image is an HDR gainmap, then continue.
+            if (!XmpIsHDRGainMap(xmpMetadata)) {
+                continue;
+            }
+
+            // This MP image is the gainmap image. Populate its stream and the rendering parameters
+            // for its format.
+            if (outGainmapImageStream) {
+                if (!mpImage.stream->rewind()) {
+                    SkCodecPrintf("Failed to rewind gainmap image stream.\n");
+                    return false;
+                }
+                *outGainmapImageStream = std::move(mpImage.stream);
+            }
+            constexpr float kLogRatioMin = 0.f;
+            constexpr float kLogRatioMax = 1.f;
+            outInfo->fLogRatioMin = {kLogRatioMin, kLogRatioMin, kLogRatioMin, 1.f};
+            outInfo->fLogRatioMax = {kLogRatioMax, kLogRatioMax, kLogRatioMax, 1.f};
+            outInfo->fGainmapGamma = {1.f, 1.f, 1.f, 1.f};
+            outInfo->fEpsilonSdr = 1 / 128.f;
+            outInfo->fEpsilonHdr = 1 / 128.f;
+            outInfo->fHdrRatioMin = 1.f;
+            outInfo->fHdrRatioMax = sk_float_exp(kLogRatioMax);
+            return true;
+        }
+    }
+    return false;
+}
diff --git a/src/codec/SkJpegGainmap.h b/src/codec/SkJpegGainmap.h
new file mode 100644
index 0000000..c72b8fa
--- /dev/null
+++ b/src/codec/SkJpegGainmap.h
@@ -0,0 +1,27 @@
+/*
+ * Copyright 2023 Google Inc.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SkJpegGainmap_codec_DEFINED
+#define SkJpegGainmap_codec_DEFINED
+
+#include "include/core/SkRefCnt.h"  // IWYU pragma: keep
+
+class SkData;
+class SkStream;
+struct SkGainmapInfo;
+
+#include <memory>
+
+/*
+ * Implementation of onGetGainmap that detects Multi-Picture Format based gainmaps.
+ */
+bool SkJpegGetMultiPictureGainmap(sk_sp<const SkData> decoderMpfMetadata,
+                                  SkStream* decoderStream,
+                                  SkGainmapInfo* outInfo,
+                                  std::unique_ptr<SkStream>* outGainmapImageStream);
+
+#endif
diff --git a/tests/AndroidCodecTest.cpp b/tests/AndroidCodecTest.cpp
index 79ecdad..708f89b 100644
--- a/tests/AndroidCodecTest.cpp
+++ b/tests/AndroidCodecTest.cpp
@@ -15,6 +15,7 @@
 #include "include/core/SkSize.h"
 #include "include/core/SkString.h"
 #include "include/core/SkTypes.h"
+#include "include/private/SkGainmapInfo.h"  // IWYU pragma: keep
 #include "modules/skcms/skcms.h"
 #include "src/core/SkMD5.h"
 #include "tests/Test.h"
diff --git a/tests/BUILD.bazel b/tests/BUILD.bazel
index 9c8408a..357535a 100644
--- a/tests/BUILD.bazel
+++ b/tests/BUILD.bazel
@@ -15,6 +15,7 @@
     "ExifTest.cpp",
     "GifTest.cpp",
     "IndexedPngOverflowTest.cpp",
+    "JpegGainmapTest.cpp",
     "WebpTest.cpp",
     "YUVTest.cpp",
 ]
diff --git a/tests/CodecTest.cpp b/tests/CodecTest.cpp
index a7979ff..6ab0fdc 100644
--- a/tests/CodecTest.cpp
+++ b/tests/CodecTest.cpp
@@ -45,11 +45,6 @@
 #include "tools/Resources.h"
 #include "tools/ToolUtils.h"
 
-#ifdef SK_CODEC_DECODES_JPEG
-#include "src/codec/SkJpegMultiPicture.h"
-#include "src/codec/SkJpegSegmentScan.h"
-#endif
-
 #ifdef SK_ENABLE_ANDROID_UTILS
 #include "client_utils/android/FrontBufferedStream.h"
 #endif
@@ -1942,104 +1937,3 @@
         REPORTER_ASSERT(r, bm.getColor(0, 0) == rec.color);
     }
 }
-
-#ifdef SK_CODEC_DECODES_JPEG
-DEF_TEST(Codec_jpegSegmentScan, r) {
-    const struct Rec {
-        const char* path;
-        size_t sosSegmentCount;
-        size_t eoiSegmentCount;
-        size_t testSegmentIndex;
-        uint8_t testSegmentMarker;
-        size_t testSegmentOffset;
-        uint16_t testSegmentParameterLength;
-    } recs[] = {
-            {"images/wide_gamut_yellow_224_224_64.jpeg", 11, 15, 10, 0xda, 9768, 12},
-            {"images/CMYK.jpg", 7, 8, 1, 0xee, 2, 14},
-            {"images/b78329453.jpeg", 10, 23, 3, 0xe2, 154, 540},
-            {"images/brickwork-texture.jpg", 8, 28, 12, 0xc4, 34183, 42},
-            {"images/brickwork_normal-map.jpg", 8, 28, 27, 0xd9, 180612, 0},
-            {"images/cmyk_yellow_224_224_32.jpg", 19, 23, 2, 0xed, 854, 2828},
-            {"images/color_wheel.jpg", 10, 11, 2, 0xdb, 20, 67},
-            {"images/cropped_mandrill.jpg", 10, 11, 4, 0xc0, 158, 17},
-            {"images/dog.jpg", 10, 11, 5, 0xc4, 177, 28},
-            {"images/ducky.jpg", 12, 13, 10, 0xc4, 3718, 181},
-            {"images/exif-orientation-2-ur.jpg", 11, 12, 2, 0xe1, 20, 130},
-            {"images/flutter_logo.jpg", 9, 27, 21, 0xda, 5731, 8},
-            {"images/grayscale.jpg", 6, 16, 9, 0xda, 327, 8},
-            {"images/icc-v2-gbr.jpg", 12, 25, 24, 0xd9, 43832, 0},
-            {"images/mandrill_512_q075.jpg", 10, 11, 7, 0xc4, 393, 31},
-            {"images/mandrill_cmyk.jpg", 19, 35, 16, 0xdd, 574336, 4},
-            {"images/mandrill_h1v1.jpg", 10, 11, 1, 0xe0, 2, 16},
-            {"images/mandrill_h2v1.jpg", 10, 11, 0, 0xd8, 0, 0},
-            {"images/randPixels.jpg", 10, 11, 6, 0xc4, 200, 30},
-            {"images/wide_gamut_yellow_224_224_64.jpeg", 11, 15, 10, 0xda, 9768, 12},
-    };
-
-    for (const auto& rec : recs) {
-        auto stream = GetResourceAsStream(rec.path);
-        if (!stream) {
-            continue;
-        }
-
-        // Ensure that we get the expected number of segments for a scan that stops at StartOfScan.
-        SkJpegSegmentScan::Options options;
-        auto sosSegmentScan = SkJpegSegmentScan::Create(stream.get(), options);
-        REPORTER_ASSERT(r, rec.sosSegmentCount == sosSegmentScan->segments().size());
-
-        // Rewind and now go all the way to EndOfImage.
-        stream->rewind();
-        options.stopOnStartOfScan = false;
-        auto eoiSegmentScan = SkJpegSegmentScan::Create(stream.get(), options);
-        REPORTER_ASSERT(r, rec.eoiSegmentCount == eoiSegmentScan->segments().size());
-
-        // Verify the values for a randomly pre-selected segment index.
-        const auto& segment = eoiSegmentScan->segments()[rec.testSegmentIndex];
-        REPORTER_ASSERT(r, rec.testSegmentMarker == segment.marker);
-        REPORTER_ASSERT(r, rec.testSegmentOffset == segment.offset);
-        REPORTER_ASSERT(r, rec.testSegmentParameterLength == segment.parameterLength);
-    }
-}
-
-DEF_TEST(Codec_jpegMultiPicture, r) {
-    const char* path = "images/iphone_13_pro.jpeg";
-    auto stream = GetResourceAsStream(path);
-    REPORTER_ASSERT(r, stream);
-
-    auto segmentScan = SkJpegSegmentScan::Create(stream.get(), SkJpegSegmentScan::Options());
-    REPORTER_ASSERT(r, segmentScan);
-
-    // Extract the streams for the MultiPicture images.
-    auto mpStreams = SkJpegExtractMultiPictureStreams(segmentScan.get());
-    REPORTER_ASSERT(r, mpStreams);
-    size_t numberOfImages = mpStreams->images.size();
-
-    // Decode them into bitmaps.
-    std::vector<SkBitmap> bitmaps(numberOfImages);
-    for (size_t i = 0; i < numberOfImages; ++i) {
-        auto imageStream = std::move(mpStreams->images[i].stream);
-        if (i == 0) {
-            REPORTER_ASSERT(r, !imageStream);
-            continue;
-        }
-        REPORTER_ASSERT(r, imageStream);
-
-        std::unique_ptr<SkCodec> codec = SkCodec::MakeFromStream(std::move(imageStream));
-        REPORTER_ASSERT(r, codec);
-
-        SkBitmap bm;
-        bm.allocPixels(codec->getInfo());
-        REPORTER_ASSERT(
-                r, SkCodec::kSuccess == codec->getPixels(bm.info(), bm.getPixels(), bm.rowBytes()));
-        bitmaps[i] = bm;
-    }
-
-    // Spot-check the image size and pixels.
-    REPORTER_ASSERT(r, bitmaps[1].dimensions() == SkISize::Make(1512, 2016));
-    REPORTER_ASSERT(r, bitmaps[1].getColor(0, 0) == 0xFF3B3B3B);
-    REPORTER_ASSERT(r, bitmaps[1].getColor(1511, 2015) == 0xFF101010);
-    REPORTER_ASSERT(r, bitmaps[2].dimensions() == SkISize::Make(576, 768));
-    REPORTER_ASSERT(r, bitmaps[2].getColor(0, 0) == 0xFF010101);
-    REPORTER_ASSERT(r, bitmaps[2].getColor(575, 767) == 0xFFB5B5B5);
-}
-#endif
diff --git a/tests/JpegGainmapTest.cpp b/tests/JpegGainmapTest.cpp
new file mode 100644
index 0000000..1c53dd3
--- /dev/null
+++ b/tests/JpegGainmapTest.cpp
@@ -0,0 +1,168 @@
+/*
+ * Copyright 2023 Google Inc.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "include/codec/SkAndroidCodec.h"
+#include "include/codec/SkCodec.h"
+#include "include/core/SkBitmap.h"
+#include "include/core/SkColor.h"
+#include "include/core/SkSize.h"
+#include "include/core/SkStream.h"
+#include "include/core/SkTypes.h"
+#include "include/private/SkGainmapInfo.h"
+#include "src/codec/SkJpegMultiPicture.h"
+#include "src/codec/SkJpegSegmentScan.h"
+#include "tests/Test.h"
+#include "tools/Resources.h"
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <utility>
+#include <vector>
+
+DEF_TEST(Codec_jpegSegmentScan, r) {
+    const struct Rec {
+        const char* path;
+        size_t sosSegmentCount;
+        size_t eoiSegmentCount;
+        size_t testSegmentIndex;
+        uint8_t testSegmentMarker;
+        size_t testSegmentOffset;
+        uint16_t testSegmentParameterLength;
+    } recs[] = {
+            {"images/wide_gamut_yellow_224_224_64.jpeg", 11, 15, 10, 0xda, 9768, 12},
+            {"images/CMYK.jpg", 7, 8, 1, 0xee, 2, 14},
+            {"images/b78329453.jpeg", 10, 23, 3, 0xe2, 154, 540},
+            {"images/brickwork-texture.jpg", 8, 28, 12, 0xc4, 34183, 42},
+            {"images/brickwork_normal-map.jpg", 8, 28, 27, 0xd9, 180612, 0},
+            {"images/cmyk_yellow_224_224_32.jpg", 19, 23, 2, 0xed, 854, 2828},
+            {"images/color_wheel.jpg", 10, 11, 2, 0xdb, 20, 67},
+            {"images/cropped_mandrill.jpg", 10, 11, 4, 0xc0, 158, 17},
+            {"images/dog.jpg", 10, 11, 5, 0xc4, 177, 28},
+            {"images/ducky.jpg", 12, 13, 10, 0xc4, 3718, 181},
+            {"images/exif-orientation-2-ur.jpg", 11, 12, 2, 0xe1, 20, 130},
+            {"images/flutter_logo.jpg", 9, 27, 21, 0xda, 5731, 8},
+            {"images/grayscale.jpg", 6, 16, 9, 0xda, 327, 8},
+            {"images/icc-v2-gbr.jpg", 12, 25, 24, 0xd9, 43832, 0},
+            {"images/mandrill_512_q075.jpg", 10, 11, 7, 0xc4, 393, 31},
+            {"images/mandrill_cmyk.jpg", 19, 35, 16, 0xdd, 574336, 4},
+            {"images/mandrill_h1v1.jpg", 10, 11, 1, 0xe0, 2, 16},
+            {"images/mandrill_h2v1.jpg", 10, 11, 0, 0xd8, 0, 0},
+            {"images/randPixels.jpg", 10, 11, 6, 0xc4, 200, 30},
+            {"images/wide_gamut_yellow_224_224_64.jpeg", 11, 15, 10, 0xda, 9768, 12},
+    };
+
+    for (const auto& rec : recs) {
+        auto stream = GetResourceAsStream(rec.path);
+        if (!stream) {
+            continue;
+        }
+
+        // Ensure that we get the expected number of segments for a scan that stops at StartOfScan.
+        SkJpegSegmentScan::Options options;
+        auto sosSegmentScan = SkJpegSegmentScan::Create(stream.get(), options);
+        REPORTER_ASSERT(r, rec.sosSegmentCount == sosSegmentScan->segments().size());
+
+        // Rewind and now go all the way to EndOfImage.
+        stream->rewind();
+        options.stopOnStartOfScan = false;
+        auto eoiSegmentScan = SkJpegSegmentScan::Create(stream.get(), options);
+        REPORTER_ASSERT(r, rec.eoiSegmentCount == eoiSegmentScan->segments().size());
+
+        // Verify the values for a randomly pre-selected segment index.
+        const auto& segment = eoiSegmentScan->segments()[rec.testSegmentIndex];
+        REPORTER_ASSERT(r, rec.testSegmentMarker == segment.marker);
+        REPORTER_ASSERT(r, rec.testSegmentOffset == segment.offset);
+        REPORTER_ASSERT(r, rec.testSegmentParameterLength == segment.parameterLength);
+    }
+}
+
+DEF_TEST(Codec_jpegMultiPicture, r) {
+    const char* path = "images/iphone_13_pro.jpeg";
+    auto stream = GetResourceAsStream(path);
+    REPORTER_ASSERT(r, stream);
+
+    auto segmentScan = SkJpegSegmentScan::Create(stream.get(), SkJpegSegmentScan::Options());
+    REPORTER_ASSERT(r, segmentScan);
+
+    // Extract the streams for the MultiPicture images.
+    auto mpStreams = SkJpegExtractMultiPictureStreams(segmentScan.get());
+    REPORTER_ASSERT(r, mpStreams);
+    size_t numberOfImages = mpStreams->images.size();
+
+    // Decode them into bitmaps.
+    std::vector<SkBitmap> bitmaps(numberOfImages);
+    for (size_t i = 0; i < numberOfImages; ++i) {
+        auto imageStream = std::move(mpStreams->images[i].stream);
+        if (i == 0) {
+            REPORTER_ASSERT(r, !imageStream);
+            continue;
+        }
+        REPORTER_ASSERT(r, imageStream);
+
+        std::unique_ptr<SkCodec> codec = SkCodec::MakeFromStream(std::move(imageStream));
+        REPORTER_ASSERT(r, codec);
+
+        SkBitmap bm;
+        bm.allocPixels(codec->getInfo());
+        REPORTER_ASSERT(
+                r, SkCodec::kSuccess == codec->getPixels(bm.info(), bm.getPixels(), bm.rowBytes()));
+        bitmaps[i] = bm;
+    }
+
+    // Spot-check the image size and pixels.
+    REPORTER_ASSERT(r, bitmaps[1].dimensions() == SkISize::Make(1512, 2016));
+    REPORTER_ASSERT(r, bitmaps[1].getColor(0, 0) == 0xFF3B3B3B);
+    REPORTER_ASSERT(r, bitmaps[1].getColor(1511, 2015) == 0xFF101010);
+    REPORTER_ASSERT(r, bitmaps[2].dimensions() == SkISize::Make(576, 768));
+    REPORTER_ASSERT(r, bitmaps[2].getColor(0, 0) == 0xFF010101);
+    REPORTER_ASSERT(r, bitmaps[2].getColor(575, 767) == 0xFFB5B5B5);
+}
+
+DEF_TEST(AndroidCodec_jpegGainmap, r) {
+    const struct Rec {
+        const char* path;
+        SkISize dimensions;
+        SkColor originColor;
+        SkColor farCornerColor;
+    } recs[] = {
+            {"images/iphone_13_pro.jpeg", SkISize::Make(1512, 2016), 0xFF3B3B3B, 0xFF101010},
+    };
+
+    for (const auto& rec : recs) {
+        auto stream = GetResourceAsStream(rec.path);
+        REPORTER_ASSERT(r, stream);
+
+        std::unique_ptr<SkCodec> codec = SkCodec::MakeFromStream(std::move(stream));
+        REPORTER_ASSERT(r, codec);
+
+        std::unique_ptr<SkAndroidCodec> androidCodec =
+                SkAndroidCodec::MakeFromCodec(std::move(codec));
+        REPORTER_ASSERT(r, androidCodec);
+
+        SkGainmapInfo gainmapInfo;
+        std::unique_ptr<SkStream> gainmapStream;
+        REPORTER_ASSERT(r, androidCodec->getAndroidGainmap(&gainmapInfo, &gainmapStream));
+        REPORTER_ASSERT(r, gainmapStream);
+
+        std::unique_ptr<SkCodec> gainmapCodec = SkCodec::MakeFromStream(std::move(gainmapStream));
+        REPORTER_ASSERT(r, gainmapCodec);
+
+        SkBitmap bm;
+        bm.allocPixels(gainmapCodec->getInfo());
+        REPORTER_ASSERT(r,
+                        SkCodec::kSuccess ==
+                                gainmapCodec->getPixels(bm.info(), bm.getPixels(), bm.rowBytes()));
+
+        // Spot-check the image size and pixels.
+        REPORTER_ASSERT(r, bm.dimensions() == rec.dimensions);
+        REPORTER_ASSERT(r, bm.getColor(0, 0) == rec.originColor);
+        REPORTER_ASSERT(r,
+                        bm.getColor(rec.dimensions.fWidth - 1, rec.dimensions.fHeight - 1) ==
+                                rec.farCornerColor);
+    }
+}