Alternate between two SkBitmaps in SkAnimatedImage

Bug: 78866720

The client in Android calls newPictureSnapshot, which results in copying
the mutable SkBitmap into a newly allocated one in each frame. Avoid
this by calling SkMakeImageFromRasterBitmap with
kNever_SkCopyPixelsMode. Make SkAnimatedImage copy on write, by copying
before decoding if the bitmap's pixel ref is not unique.

Android's AnimatedImageDrawable's current architecture only decodes one
frame in advance, so it will never need to perform the copy on write.
This will save one bitmap allocation per GIF frame.

Add a test to verify that copy on write works as expected.

Change-Id: I87eb6e84089096cd2d618b91fb627fc58677e66a
Reviewed-on: https://skia-review.googlesource.com/129841
Reviewed-by: Leon Scroggins <scroggo@google.com>
Commit-Queue: Leon Scroggins <scroggo@google.com>
Auto-Submit: Leon Scroggins <scroggo@google.com>
(cherry picked from commit 4aafb3a8d12015067fae1301c2f5951f398eb25b)
Reviewed-on: https://skia-review.googlesource.com/129942
diff --git a/include/android/SkAnimatedImage.h b/include/android/SkAnimatedImage.h
index 51f0e5b..983a57b 100644
--- a/include/android/SkAnimatedImage.h
+++ b/include/android/SkAnimatedImage.h
@@ -107,7 +107,20 @@
         int      fIndex;
         SkCodecAnimation::DisposalMethod fDisposalMethod;
 
+        // init() may have to create a new SkPixelRef, if the
+        // current one is already in use by another owner (e.g.
+        // an SkPicture). This determines whether to copy the
+        // existing one to the new one.
+        enum class OnInit {
+            // Restore the image from the old SkPixelRef to the
+            // new one.
+            kRestoreIfNecessary,
+            // No need to restore.
+            kNoRestore,
+        };
+
         Frame();
+        bool init(const SkImageInfo& info, OnInit);
         bool copyTo(Frame*) const;
     };
 
@@ -122,7 +135,8 @@
 
     bool                            fFinished;
     int                             fCurrentFrameDuration;
-    Frame                           fActiveFrame;
+    Frame                           fDisplayFrame;
+    Frame                           fDecodingFrame;
     Frame                           fRestoreFrame;
     int                             fRepetitionCount;
     int                             fRepetitionsCompleted;
diff --git a/src/android/SkAnimatedImage.cpp b/src/android/SkAnimatedImage.cpp
index de25f30..daeec81 100644
--- a/src/android/SkAnimatedImage.cpp
+++ b/src/android/SkAnimatedImage.cpp
@@ -10,8 +10,10 @@
 #include "SkCanvas.h"
 #include "SkCodec.h"
 #include "SkCodecPriv.h"
+#include "SkImagePriv.h"
 #include "SkPicture.h"
 #include "SkPictureRecorder.h"
+#include "SkPixelRef.h"
 
 sk_sp<SkAnimatedImage> SkAnimatedImage::Make(std::unique_ptr<SkAndroidCodec> codec,
         SkISize scaledSize, SkIRect cropRect, sk_sp<SkPicture> postProcess) {
@@ -30,7 +32,7 @@
 
     auto image = sk_sp<SkAnimatedImage>(new SkAnimatedImage(std::move(codec), scaledSize,
                 decodeInfo, cropRect, std::move(postProcess)));
-    if (!image->fActiveFrame.fBitmap.getPixels()) {
+    if (!image->fDisplayFrame.fBitmap.getPixels()) {
         // tryAllocPixels failed.
         return nullptr;
     }
@@ -49,7 +51,7 @@
     auto image = sk_sp<SkAnimatedImage>(new SkAnimatedImage(std::move(codec), scaledSize,
                 decodeInfo, cropRect, nullptr));
 
-    if (!image->fActiveFrame.fBitmap.getPixels()) {
+    if (!image->fDisplayFrame.fBitmap.getPixels()) {
         // tryAllocPixels failed.
         return nullptr;
     }
@@ -72,7 +74,7 @@
     , fRepetitionCount(fCodec->codec()->getRepetitionCount())
     , fRepetitionsCompleted(0)
 {
-    if (!fActiveFrame.fBitmap.tryAllocPixels(fDecodeInfo)) {
+    if (!fDecodingFrame.fBitmap.tryAllocPixels(fDecodeInfo)) {
         return;
     }
 
@@ -95,10 +97,33 @@
     : fIndex(SkCodec::kNone)
 {}
 
+bool SkAnimatedImage::Frame::init(const SkImageInfo& info, OnInit onInit) {
+    if (fBitmap.getPixels()) {
+        if (fBitmap.pixelRef()->unique()) {
+            SkAssertResult(fBitmap.setAlphaType(info.alphaType()));
+            return true;
+        }
+
+        // An SkCanvas provided to onDraw is still holding a reference.
+        // Copy before we decode to ensure that we don't overwrite the
+        // expected contents of the image.
+        if (OnInit::kRestoreIfNecessary == onInit) {
+            SkBitmap tmp;
+            if (!tmp.tryAllocPixels(info)) {
+                return false;
+            }
+
+            memcpy(tmp.getPixels(), fBitmap.getPixels(), fBitmap.computeByteSize());
+            SkTSwap(tmp, fBitmap);
+            return true;
+        }
+    }
+
+    return fBitmap.tryAllocPixels(info);
+}
+
 bool SkAnimatedImage::Frame::copyTo(Frame* dst) const {
-    if (dst->fBitmap.getPixels()) {
-        dst->fBitmap.setAlphaType(fBitmap.alphaType());
-    } else if (!dst->fBitmap.tryAllocPixels(fBitmap.info())) {
+    if (!dst->init(fBitmap.info(), OnInit::kNoRestore)) {
         return false;
     }
 
@@ -111,19 +136,10 @@
 void SkAnimatedImage::reset() {
     fFinished = false;
     fRepetitionsCompleted = 0;
-    if (fActiveFrame.fIndex == 0) {
-        // Already showing the first frame.
-        return;
+    if (fDisplayFrame.fIndex != 0) {
+        fDisplayFrame.fIndex = SkCodec::kNone;
+        this->decodeNextFrame();
     }
-
-    if (fRestoreFrame.fIndex == 0) {
-        SkTSwap(fActiveFrame, fRestoreFrame);
-        // Now we're showing the first frame.
-        return;
-    }
-
-    fActiveFrame.fIndex = SkCodec::kNone;
-    this->decodeNextFrame();
 }
 
 static bool is_restore_previous(SkCodecAnimation::DisposalMethod dispose) {
@@ -160,7 +176,7 @@
     }
 
     bool animationEnded = false;
-    int frameToDecode = this->computeNextFrame(fActiveFrame.fIndex, &animationEnded);
+    int frameToDecode = this->computeNextFrame(fDisplayFrame.fIndex, &animationEnded);
 
     SkCodec::FrameInfo frameInfo;
     if (fCodec->codec()->getFrameInfo(frameToDecode, &frameInfo)) {
@@ -188,19 +204,21 @@
         }
     }
 
-    if (frameToDecode == fActiveFrame.fIndex) {
+    if (frameToDecode == fDisplayFrame.fIndex) {
         if (animationEnded) {
             return this->finish();
         }
         return fCurrentFrameDuration;
     }
 
-    if (frameToDecode == fRestoreFrame.fIndex) {
-        SkTSwap(fActiveFrame, fRestoreFrame);
-        if (animationEnded) {
-            return this->finish();
+    for (Frame* frame : { &fRestoreFrame, &fDecodingFrame }) {
+        if (frameToDecode == frame->fIndex) {
+            SkTSwap(fDisplayFrame, *frame);
+            if (animationEnded) {
+                return this->finish();
+            }
+            return fCurrentFrameDuration;
         }
-        return fCurrentFrameDuration;
     }
 
     // The following code makes an effort to avoid overwriting a frame that will
@@ -216,9 +234,9 @@
             // frameToDecode will be discarded immediately after drawing, so
             // do not overwrite a frame which could possibly be used in the
             // future.
-            if (fActiveFrame.fIndex != SkCodec::kNone &&
-                    !is_restore_previous(fActiveFrame.fDisposalMethod)) {
-                SkTSwap(fActiveFrame, fRestoreFrame);
+            if (fDecodingFrame.fIndex != SkCodec::kNone &&
+                    !is_restore_previous(fDecodingFrame.fDisposalMethod)) {
+                SkTSwap(fDecodingFrame, fRestoreFrame);
             }
         }
     } else {
@@ -229,34 +247,36 @@
 
             return frame.fIndex >= frameInfo.fRequiredFrame && frame.fIndex < frameToDecode;
         };
-        if (validPriorFrame(fActiveFrame)) {
+        if (validPriorFrame(fDecodingFrame)) {
             if (is_restore_previous(frameInfo.fDisposalMethod)) {
-                // fActiveFrame is a good frame to use for this one, but we
+                // fDecodingFrame is a good frame to use for this one, but we
                 // don't want to overwrite it.
-                fActiveFrame.copyTo(&fRestoreFrame);
+                fDecodingFrame.copyTo(&fRestoreFrame);
             }
-            options.fPriorFrame = fActiveFrame.fIndex;
+            options.fPriorFrame = fDecodingFrame.fIndex;
+        } else if (validPriorFrame(fDisplayFrame)) {
+            if (!fDisplayFrame.copyTo(&fDecodingFrame)) {
+                SkCodecPrintf("Failed to allocate pixels for frame\n");
+                return this->finish();
+            }
+            options.fPriorFrame = fDecodingFrame.fIndex;
         } else if (validPriorFrame(fRestoreFrame)) {
             if (!is_restore_previous(frameInfo.fDisposalMethod)) {
-                SkTSwap(fActiveFrame, fRestoreFrame);
-            } else if (!fRestoreFrame.copyTo(&fActiveFrame)) {
+                SkTSwap(fDecodingFrame, fRestoreFrame);
+            } else if (!fRestoreFrame.copyTo(&fDecodingFrame)) {
                 SkCodecPrintf("Failed to restore frame\n");
                 return this->finish();
             }
-            options.fPriorFrame = fActiveFrame.fIndex;
+            options.fPriorFrame = fDecodingFrame.fIndex;
         }
     }
 
     auto alphaType = kOpaque_SkAlphaType == frameInfo.fAlphaType ?
                      kOpaque_SkAlphaType : kPremul_SkAlphaType;
-    SkBitmap* dst = &fActiveFrame.fBitmap;
-    if (dst->getPixels()) {
-        SkAssertResult(dst->setAlphaType(alphaType));
-    } else {
-        auto info = fDecodeInfo.makeAlphaType(alphaType);
-        if (!dst->tryAllocPixels(info)) {
-            return this->finish();
-        }
+    auto info = fDecodeInfo.makeAlphaType(alphaType);
+    SkBitmap* dst = &fDecodingFrame.fBitmap;
+    if (!fDecodingFrame.init(info, Frame::OnInit::kRestoreIfNecessary)) {
+        return this->finish();
     }
 
     auto result = fCodec->codec()->getPixels(dst->info(), dst->getPixels(), dst->rowBytes(),
@@ -266,8 +286,11 @@
         return this->finish();
     }
 
-    fActiveFrame.fIndex = frameToDecode;
-    fActiveFrame.fDisposalMethod = frameInfo.fDisposalMethod;
+    fDecodingFrame.fIndex = frameToDecode;
+    fDecodingFrame.fDisposalMethod = frameInfo.fDisposalMethod;
+
+    SkTSwap(fDecodingFrame, fDisplayFrame);
+    fDisplayFrame.fBitmap.notifyPixelsChanged();
 
     if (animationEnded) {
         return this->finish();
@@ -276,8 +299,11 @@
 }
 
 void SkAnimatedImage::onDraw(SkCanvas* canvas) {
+    auto image = SkMakeImageFromRasterBitmap(fDisplayFrame.fBitmap,
+                                             kNever_SkCopyPixelsMode);
+
     if (fSimple) {
-        canvas->drawBitmap(fActiveFrame.fBitmap, 0, 0);
+        canvas->drawImage(image, 0, 0);
         return;
     }
 
@@ -290,7 +316,7 @@
         canvas->concat(fMatrix);
         SkPaint paint;
         paint.setFilterQuality(kLow_SkFilterQuality);
-        canvas->drawBitmap(fActiveFrame.fBitmap, 0, 0, &paint);
+        canvas->drawImage(image, 0, 0, &paint);
     }
     if (fPostProcess) {
         canvas->drawPicture(fPostProcess);
diff --git a/tests/AnimatedImageTest.cpp b/tests/AnimatedImageTest.cpp
index 6c7b0e5..13ea808 100644
--- a/tests/AnimatedImageTest.cpp
+++ b/tests/AnimatedImageTest.cpp
@@ -72,6 +72,102 @@
     }
 }
 
+static bool compare_bitmaps(skiatest::Reporter* r,
+                            const char* file,
+                            int expectedFrame,
+                            const SkBitmap& expectedBm,
+                            const SkBitmap& actualBm) {
+    REPORTER_ASSERT(r, expectedBm.colorType() == actualBm.colorType());
+    REPORTER_ASSERT(r, expectedBm.dimensions() == actualBm.dimensions());
+    for (int i = 0; i < actualBm.width();  ++i)
+    for (int j = 0; j < actualBm.height(); ++j) {
+        SkColor expected = SkUnPreMultiply::PMColorToColor(*expectedBm.getAddr32(i, j));
+        SkColor actual   = SkUnPreMultiply::PMColorToColor(*actualBm  .getAddr32(i, j));
+        if (expected != actual) {
+            ERRORF(r, "frame %i of %s does not match at pixel %i, %i!"
+                            " expected %x\tactual: %x",
+                            expectedFrame, file, i, j, expected, actual);
+            SkString expected_name = SkStringPrintf("expected_%c", '0' + expectedFrame);
+            SkString actual_name   = SkStringPrintf("actual_%c",   '0' + expectedFrame);
+            write_bm(expected_name.c_str(), expectedBm);
+            write_bm(actual_name.c_str(),   actualBm);
+            return false;
+        }
+    }
+    return true;
+}
+
+DEF_TEST(AnimatedImage_copyOnWrite, r) {
+    if (GetResourcePath().isEmpty()) {
+        return;
+    }
+    for (const char* file : { "images/alphabetAnim.gif",
+                              "images/colorTables.gif",
+                              "images/webp-animated.webp",
+                              "images/required.webp",
+                              }) {
+        auto data = GetResourceAsData(file);
+        if (!data) {
+            ERRORF(r, "Could not get %s", file);
+            continue;
+        }
+
+        auto codec = SkCodec::MakeFromData(data);
+        if (!codec) {
+            ERRORF(r, "Could not create codec for %s", file);
+            continue;
+        }
+
+        const auto imageInfo = codec->getInfo().makeAlphaType(kPremul_SkAlphaType);
+        const int frameCount = codec->getFrameCount();
+        auto androidCodec = SkAndroidCodec::MakeFromCodec(std::move(codec));
+        if (!androidCodec) {
+            ERRORF(r, "Could not create androidCodec for %s", file);
+            continue;
+        }
+
+        auto animatedImage = SkAnimatedImage::Make(std::move(androidCodec));
+        if (!animatedImage) {
+            ERRORF(r, "Could not create animated image for %s", file);
+            continue;
+        }
+        animatedImage->setRepetitionCount(0);
+
+        std::vector<SkBitmap> expected(frameCount);
+        std::vector<sk_sp<SkPicture>> pictures(frameCount);
+        for (int i = 0; i < frameCount; i++) {
+            SkBitmap& bm = expected[i];
+            bm.allocPixels(imageInfo);
+            bm.eraseColor(SK_ColorTRANSPARENT);
+            SkCanvas canvas(bm);
+
+            pictures[i].reset(animatedImage->newPictureSnapshot());
+            canvas.drawPicture(pictures[i]);
+
+            const auto duration = animatedImage->decodeNextFrame();
+            // We're attempting to decode i + 1, so decodeNextFrame will return
+            // kFinished if that is the last frame (or we attempt to decode one
+            // more).
+            if (i >= frameCount - 2) {
+                REPORTER_ASSERT(r, duration == SkAnimatedImage::kFinished);
+            } else {
+                REPORTER_ASSERT(r, duration != SkAnimatedImage::kFinished);
+            }
+        }
+
+        for (int i = 0; i < frameCount; i++) {
+            SkBitmap test;
+            test.allocPixels(imageInfo);
+            test.eraseColor(SK_ColorTRANSPARENT);
+            SkCanvas canvas(test);
+
+            canvas.drawPicture(pictures[i]);
+
+            compare_bitmaps(r, file, i, expected[i], test);
+        }
+    }
+}
+
 DEF_TEST(AnimatedImage, r) {
     if (GetResourcePath().isEmpty()) {
         return;
@@ -147,24 +243,7 @@
             animatedImage->draw(&c);
 
             const SkBitmap& frame = frames[expectedFrame];
-            REPORTER_ASSERT(r, frame.colorType() == test.colorType());
-            REPORTER_ASSERT(r, frame.dimensions() == test.dimensions());
-            for (int i = 0; i < test.width();  ++i)
-            for (int j = 0; j < test.height(); ++j) {
-                SkColor expected = SkUnPreMultiply::PMColorToColor(*frame.getAddr32(i, j));
-                SkColor actual   = SkUnPreMultiply::PMColorToColor(*test .getAddr32(i, j));
-                if (expected != actual) {
-                    ERRORF(r, "frame %i of %s does not match at pixel %i, %i!"
-                            " expected %x\tactual: %x",
-                            expectedFrame, file, i, j, expected, actual);
-                    SkString expected_name = SkStringPrintf("expected_%c", '0' + expectedFrame);
-                    SkString actual_name   = SkStringPrintf("actual_%c",   '0' + expectedFrame);
-                    write_bm(expected_name.c_str(), frame);
-                    write_bm(actual_name.c_str(),   test);
-                    return false;
-                }
-            }
-            return true;
+            return compare_bitmaps(r, file, expectedFrame, frame, test);
         };
 
         REPORTER_ASSERT(r, animatedImage->currentFrameDuration() == frameInfos[0].fDuration);