diff --git a/include/android/SkAnimatedImage.h b/include/android/SkAnimatedImage.h
index 51f0e5b..983a57b 100644
--- a/include/android/SkAnimatedImage.h
+++ b/include/android/SkAnimatedImage.h
@@ -107,7 +107,20 @@
         int      fIndex;
         SkCodecAnimation::DisposalMethod fDisposalMethod;
 
+        // init() may have to create a new SkPixelRef, if the
+        // current one is already in use by another owner (e.g.
+        // an SkPicture). This determines whether to copy the
+        // existing one to the new one.
+        enum class OnInit {
+            // Restore the image from the old SkPixelRef to the
+            // new one.
+            kRestoreIfNecessary,
+            // No need to restore.
+            kNoRestore,
+        };
+
         Frame();
+        bool init(const SkImageInfo& info, OnInit);
         bool copyTo(Frame*) const;
     };
 
@@ -122,7 +135,8 @@
 
     bool                            fFinished;
     int                             fCurrentFrameDuration;
-    Frame                           fActiveFrame;
+    Frame                           fDisplayFrame;
+    Frame                           fDecodingFrame;
     Frame                           fRestoreFrame;
     int                             fRepetitionCount;
     int                             fRepetitionsCompleted;
diff --git a/src/android/SkAnimatedImage.cpp b/src/android/SkAnimatedImage.cpp
index de25f30..daeec81 100644
--- a/src/android/SkAnimatedImage.cpp
+++ b/src/android/SkAnimatedImage.cpp
@@ -10,8 +10,10 @@
 #include "SkCanvas.h"
 #include "SkCodec.h"
 #include "SkCodecPriv.h"
+#include "SkImagePriv.h"
 #include "SkPicture.h"
 #include "SkPictureRecorder.h"
+#include "SkPixelRef.h"
 
 sk_sp<SkAnimatedImage> SkAnimatedImage::Make(std::unique_ptr<SkAndroidCodec> codec,
         SkISize scaledSize, SkIRect cropRect, sk_sp<SkPicture> postProcess) {
@@ -30,7 +32,7 @@
 
     auto image = sk_sp<SkAnimatedImage>(new SkAnimatedImage(std::move(codec), scaledSize,
                 decodeInfo, cropRect, std::move(postProcess)));
-    if (!image->fActiveFrame.fBitmap.getPixels()) {
+    if (!image->fDisplayFrame.fBitmap.getPixels()) {
         // tryAllocPixels failed.
         return nullptr;
     }
@@ -49,7 +51,7 @@
     auto image = sk_sp<SkAnimatedImage>(new SkAnimatedImage(std::move(codec), scaledSize,
                 decodeInfo, cropRect, nullptr));
 
-    if (!image->fActiveFrame.fBitmap.getPixels()) {
+    if (!image->fDisplayFrame.fBitmap.getPixels()) {
         // tryAllocPixels failed.
         return nullptr;
     }
@@ -72,7 +74,7 @@
     , fRepetitionCount(fCodec->codec()->getRepetitionCount())
     , fRepetitionsCompleted(0)
 {
-    if (!fActiveFrame.fBitmap.tryAllocPixels(fDecodeInfo)) {
+    if (!fDecodingFrame.fBitmap.tryAllocPixels(fDecodeInfo)) {
         return;
     }
 
@@ -95,10 +97,33 @@
     : fIndex(SkCodec::kNone)
 {}
 
+bool SkAnimatedImage::Frame::init(const SkImageInfo& info, OnInit onInit) {
+    if (fBitmap.getPixels()) {
+        if (fBitmap.pixelRef()->unique()) {
+            SkAssertResult(fBitmap.setAlphaType(info.alphaType()));
+            return true;
+        }
+
+        // An SkCanvas provided to onDraw is still holding a reference.
+        // Copy before we decode to ensure that we don't overwrite the
+        // expected contents of the image.
+        if (OnInit::kRestoreIfNecessary == onInit) {
+            SkBitmap tmp;
+            if (!tmp.tryAllocPixels(info)) {
+                return false;
+            }
+
+            memcpy(tmp.getPixels(), fBitmap.getPixels(), fBitmap.computeByteSize());
+            SkTSwap(tmp, fBitmap);
+            return true;
+        }
+    }
+
+    return fBitmap.tryAllocPixels(info);
+}
+
 bool SkAnimatedImage::Frame::copyTo(Frame* dst) const {
-    if (dst->fBitmap.getPixels()) {
-        dst->fBitmap.setAlphaType(fBitmap.alphaType());
-    } else if (!dst->fBitmap.tryAllocPixels(fBitmap.info())) {
+    if (!dst->init(fBitmap.info(), OnInit::kNoRestore)) {
         return false;
     }
 
@@ -111,19 +136,10 @@
 void SkAnimatedImage::reset() {
     fFinished = false;
     fRepetitionsCompleted = 0;
-    if (fActiveFrame.fIndex == 0) {
-        // Already showing the first frame.
-        return;
+    if (fDisplayFrame.fIndex != 0) {
+        fDisplayFrame.fIndex = SkCodec::kNone;
+        this->decodeNextFrame();
     }
-
-    if (fRestoreFrame.fIndex == 0) {
-        SkTSwap(fActiveFrame, fRestoreFrame);
-        // Now we're showing the first frame.
-        return;
-    }
-
-    fActiveFrame.fIndex = SkCodec::kNone;
-    this->decodeNextFrame();
 }
 
 static bool is_restore_previous(SkCodecAnimation::DisposalMethod dispose) {
@@ -160,7 +176,7 @@
     }
 
     bool animationEnded = false;
-    int frameToDecode = this->computeNextFrame(fActiveFrame.fIndex, &animationEnded);
+    int frameToDecode = this->computeNextFrame(fDisplayFrame.fIndex, &animationEnded);
 
     SkCodec::FrameInfo frameInfo;
     if (fCodec->codec()->getFrameInfo(frameToDecode, &frameInfo)) {
@@ -188,19 +204,21 @@
         }
     }
 
-    if (frameToDecode == fActiveFrame.fIndex) {
+    if (frameToDecode == fDisplayFrame.fIndex) {
         if (animationEnded) {
             return this->finish();
         }
         return fCurrentFrameDuration;
     }
 
-    if (frameToDecode == fRestoreFrame.fIndex) {
-        SkTSwap(fActiveFrame, fRestoreFrame);
-        if (animationEnded) {
-            return this->finish();
+    for (Frame* frame : { &fRestoreFrame, &fDecodingFrame }) {
+        if (frameToDecode == frame->fIndex) {
+            SkTSwap(fDisplayFrame, *frame);
+            if (animationEnded) {
+                return this->finish();
+            }
+            return fCurrentFrameDuration;
         }
-        return fCurrentFrameDuration;
     }
 
     // The following code makes an effort to avoid overwriting a frame that will
@@ -216,9 +234,9 @@
             // frameToDecode will be discarded immediately after drawing, so
             // do not overwrite a frame which could possibly be used in the
             // future.
-            if (fActiveFrame.fIndex != SkCodec::kNone &&
-                    !is_restore_previous(fActiveFrame.fDisposalMethod)) {
-                SkTSwap(fActiveFrame, fRestoreFrame);
+            if (fDecodingFrame.fIndex != SkCodec::kNone &&
+                    !is_restore_previous(fDecodingFrame.fDisposalMethod)) {
+                SkTSwap(fDecodingFrame, fRestoreFrame);
             }
         }
     } else {
@@ -229,34 +247,36 @@
 
             return frame.fIndex >= frameInfo.fRequiredFrame && frame.fIndex < frameToDecode;
         };
-        if (validPriorFrame(fActiveFrame)) {
+        if (validPriorFrame(fDecodingFrame)) {
             if (is_restore_previous(frameInfo.fDisposalMethod)) {
-                // fActiveFrame is a good frame to use for this one, but we
+                // fDecodingFrame is a good frame to use for this one, but we
                 // don't want to overwrite it.
-                fActiveFrame.copyTo(&fRestoreFrame);
+                fDecodingFrame.copyTo(&fRestoreFrame);
             }
-            options.fPriorFrame = fActiveFrame.fIndex;
+            options.fPriorFrame = fDecodingFrame.fIndex;
+        } else if (validPriorFrame(fDisplayFrame)) {
+            if (!fDisplayFrame.copyTo(&fDecodingFrame)) {
+                SkCodecPrintf("Failed to allocate pixels for frame\n");
+                return this->finish();
+            }
+            options.fPriorFrame = fDecodingFrame.fIndex;
         } else if (validPriorFrame(fRestoreFrame)) {
             if (!is_restore_previous(frameInfo.fDisposalMethod)) {
-                SkTSwap(fActiveFrame, fRestoreFrame);
-            } else if (!fRestoreFrame.copyTo(&fActiveFrame)) {
+                SkTSwap(fDecodingFrame, fRestoreFrame);
+            } else if (!fRestoreFrame.copyTo(&fDecodingFrame)) {
                 SkCodecPrintf("Failed to restore frame\n");
                 return this->finish();
             }
-            options.fPriorFrame = fActiveFrame.fIndex;
+            options.fPriorFrame = fDecodingFrame.fIndex;
         }
     }
 
     auto alphaType = kOpaque_SkAlphaType == frameInfo.fAlphaType ?
                      kOpaque_SkAlphaType : kPremul_SkAlphaType;
-    SkBitmap* dst = &fActiveFrame.fBitmap;
-    if (dst->getPixels()) {
-        SkAssertResult(dst->setAlphaType(alphaType));
-    } else {
-        auto info = fDecodeInfo.makeAlphaType(alphaType);
-        if (!dst->tryAllocPixels(info)) {
-            return this->finish();
-        }
+    auto info = fDecodeInfo.makeAlphaType(alphaType);
+    SkBitmap* dst = &fDecodingFrame.fBitmap;
+    if (!fDecodingFrame.init(info, Frame::OnInit::kRestoreIfNecessary)) {
+        return this->finish();
     }
 
     auto result = fCodec->codec()->getPixels(dst->info(), dst->getPixels(), dst->rowBytes(),
@@ -266,8 +286,11 @@
         return this->finish();
     }
 
-    fActiveFrame.fIndex = frameToDecode;
-    fActiveFrame.fDisposalMethod = frameInfo.fDisposalMethod;
+    fDecodingFrame.fIndex = frameToDecode;
+    fDecodingFrame.fDisposalMethod = frameInfo.fDisposalMethod;
+
+    SkTSwap(fDecodingFrame, fDisplayFrame);
+    fDisplayFrame.fBitmap.notifyPixelsChanged();
 
     if (animationEnded) {
         return this->finish();
@@ -276,8 +299,11 @@
 }
 
 void SkAnimatedImage::onDraw(SkCanvas* canvas) {
+    auto image = SkMakeImageFromRasterBitmap(fDisplayFrame.fBitmap,
+                                             kNever_SkCopyPixelsMode);
+
     if (fSimple) {
-        canvas->drawBitmap(fActiveFrame.fBitmap, 0, 0);
+        canvas->drawImage(image, 0, 0);
         return;
     }
 
@@ -290,7 +316,7 @@
         canvas->concat(fMatrix);
         SkPaint paint;
         paint.setFilterQuality(kLow_SkFilterQuality);
-        canvas->drawBitmap(fActiveFrame.fBitmap, 0, 0, &paint);
+        canvas->drawImage(image, 0, 0, &paint);
     }
     if (fPostProcess) {
         canvas->drawPicture(fPostProcess);
diff --git a/tests/AnimatedImageTest.cpp b/tests/AnimatedImageTest.cpp
index 6c7b0e5..13ea808 100644
--- a/tests/AnimatedImageTest.cpp
+++ b/tests/AnimatedImageTest.cpp
@@ -72,6 +72,102 @@
     }
 }
 
+static bool compare_bitmaps(skiatest::Reporter* r,
+                            const char* file,
+                            int expectedFrame,
+                            const SkBitmap& expectedBm,
+                            const SkBitmap& actualBm) {
+    REPORTER_ASSERT(r, expectedBm.colorType() == actualBm.colorType());
+    REPORTER_ASSERT(r, expectedBm.dimensions() == actualBm.dimensions());
+    for (int i = 0; i < actualBm.width();  ++i)
+    for (int j = 0; j < actualBm.height(); ++j) {
+        SkColor expected = SkUnPreMultiply::PMColorToColor(*expectedBm.getAddr32(i, j));
+        SkColor actual   = SkUnPreMultiply::PMColorToColor(*actualBm  .getAddr32(i, j));
+        if (expected != actual) {
+            ERRORF(r, "frame %i of %s does not match at pixel %i, %i!"
+                            " expected %x\tactual: %x",
+                            expectedFrame, file, i, j, expected, actual);
+            SkString expected_name = SkStringPrintf("expected_%c", '0' + expectedFrame);
+            SkString actual_name   = SkStringPrintf("actual_%c",   '0' + expectedFrame);
+            write_bm(expected_name.c_str(), expectedBm);
+            write_bm(actual_name.c_str(),   actualBm);
+            return false;
+        }
+    }
+    return true;
+}
+
+DEF_TEST(AnimatedImage_copyOnWrite, r) {
+    if (GetResourcePath().isEmpty()) {
+        return;
+    }
+    for (const char* file : { "images/alphabetAnim.gif",
+                              "images/colorTables.gif",
+                              "images/webp-animated.webp",
+                              "images/required.webp",
+                              }) {
+        auto data = GetResourceAsData(file);
+        if (!data) {
+            ERRORF(r, "Could not get %s", file);
+            continue;
+        }
+
+        auto codec = SkCodec::MakeFromData(data);
+        if (!codec) {
+            ERRORF(r, "Could not create codec for %s", file);
+            continue;
+        }
+
+        const auto imageInfo = codec->getInfo().makeAlphaType(kPremul_SkAlphaType);
+        const int frameCount = codec->getFrameCount();
+        auto androidCodec = SkAndroidCodec::MakeFromCodec(std::move(codec));
+        if (!androidCodec) {
+            ERRORF(r, "Could not create androidCodec for %s", file);
+            continue;
+        }
+
+        auto animatedImage = SkAnimatedImage::Make(std::move(androidCodec));
+        if (!animatedImage) {
+            ERRORF(r, "Could not create animated image for %s", file);
+            continue;
+        }
+        animatedImage->setRepetitionCount(0);
+
+        std::vector<SkBitmap> expected(frameCount);
+        std::vector<sk_sp<SkPicture>> pictures(frameCount);
+        for (int i = 0; i < frameCount; i++) {
+            SkBitmap& bm = expected[i];
+            bm.allocPixels(imageInfo);
+            bm.eraseColor(SK_ColorTRANSPARENT);
+            SkCanvas canvas(bm);
+
+            pictures[i].reset(animatedImage->newPictureSnapshot());
+            canvas.drawPicture(pictures[i]);
+
+            const auto duration = animatedImage->decodeNextFrame();
+            // We're attempting to decode i + 1, so decodeNextFrame will return
+            // kFinished if that is the last frame (or we attempt to decode one
+            // more).
+            if (i >= frameCount - 2) {
+                REPORTER_ASSERT(r, duration == SkAnimatedImage::kFinished);
+            } else {
+                REPORTER_ASSERT(r, duration != SkAnimatedImage::kFinished);
+            }
+        }
+
+        for (int i = 0; i < frameCount; i++) {
+            SkBitmap test;
+            test.allocPixels(imageInfo);
+            test.eraseColor(SK_ColorTRANSPARENT);
+            SkCanvas canvas(test);
+
+            canvas.drawPicture(pictures[i]);
+
+            compare_bitmaps(r, file, i, expected[i], test);
+        }
+    }
+}
+
 DEF_TEST(AnimatedImage, r) {
     if (GetResourcePath().isEmpty()) {
         return;
@@ -147,24 +243,7 @@
             animatedImage->draw(&c);
 
             const SkBitmap& frame = frames[expectedFrame];
-            REPORTER_ASSERT(r, frame.colorType() == test.colorType());
-            REPORTER_ASSERT(r, frame.dimensions() == test.dimensions());
-            for (int i = 0; i < test.width();  ++i)
-            for (int j = 0; j < test.height(); ++j) {
-                SkColor expected = SkUnPreMultiply::PMColorToColor(*frame.getAddr32(i, j));
-                SkColor actual   = SkUnPreMultiply::PMColorToColor(*test .getAddr32(i, j));
-                if (expected != actual) {
-                    ERRORF(r, "frame %i of %s does not match at pixel %i, %i!"
-                            " expected %x\tactual: %x",
-                            expectedFrame, file, i, j, expected, actual);
-                    SkString expected_name = SkStringPrintf("expected_%c", '0' + expectedFrame);
-                    SkString actual_name   = SkStringPrintf("actual_%c",   '0' + expectedFrame);
-                    write_bm(expected_name.c_str(), frame);
-                    write_bm(actual_name.c_str(),   test);
-                    return false;
-                }
-            }
-            return true;
+            return compare_bitmaps(r, file, expectedFrame, frame, test);
         };
 
         REPORTER_ASSERT(r, animatedImage->currentFrameDuration() == frameInfos[0].fDuration);
