WindowInsetsAnimationControllerTests: Test for synchronous progress invocation

Also verify callback order and consistency.

Bug: 152617481
Test: atest WindowInsetsAnimationControllerTests
Change-Id: I697d346c01dc0aeb07dc92220fb32664eff5861d
(cherry picked from commit b0fbdd87fb0ffe15d7cd5a77800b001f9b742ab9)
diff --git a/tests/framework/base/windowmanager/src/android/server/wm/WindowInsetsAnimationControllerTests.java b/tests/framework/base/windowmanager/src/android/server/wm/WindowInsetsAnimationControllerTests.java
index 62c81e5..85f31b4 100644
--- a/tests/framework/base/windowmanager/src/android/server/wm/WindowInsetsAnimationControllerTests.java
+++ b/tests/framework/base/windowmanager/src/android/server/wm/WindowInsetsAnimationControllerTests.java
@@ -25,7 +25,11 @@
 
 import static androidx.test.internal.runner.junit4.statement.UiThreadStatement.runOnUiThread;
 
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasItem;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.notNullValue;
@@ -65,8 +69,10 @@
 import org.junit.runners.Parameterized.Parameter;
 import org.junit.runners.Parameterized.Parameters;
 
+import java.util.HashSet;
 import java.util.List;
 import java.util.Locale;
+import java.util.Set;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 
@@ -78,7 +84,6 @@
  *     atest CtsWindowManagerDeviceTestCases:WindowInsetsAnimationControllerTests
  */
 @Presubmit
-@FlakyTest(detail = "Promote once confirmed non-flaky")
 @RunWith(Parameterized.class)
 public class WindowInsetsAnimationControllerTests extends WindowManagerTestBase {
 
@@ -87,9 +92,11 @@
     ControlListener mListener;
     CancellationSignal mCancellationSignal = new CancellationSignal();
     Interpolator mInterpolator;
+    boolean mOnProgressCalled;
+    private ValueAnimator mAnimator;
 
     @Rule
-    public ErrorCollector mErrorCollector = new ErrorCollector();
+    public LimitedErrorCollector mErrorCollector = new LimitedErrorCollector();
 
     @Parameter(0)
     public int mType;
@@ -117,6 +124,7 @@
     @Test
     public void testControl_andCancel() throws Throwable {
         runOnUiThread(() -> {
+            setupAnimationListener();
             mRootView.getWindowInsetsController().controlWindowInsetsAnimation(mType, 0,
                     null, mCancellationSignal, mListener);
         });
@@ -134,6 +142,7 @@
     @Test
     public void testControl_andImmediatelyCancel() throws Throwable {
         runOnUiThread(() -> {
+            setupAnimationListener();
             mRootView.getWindowInsetsController().controlWindowInsetsAnimation(mType, 0,
                     null, mCancellationSignal, mListener);
             mCancellationSignal.cancel();
@@ -149,6 +158,7 @@
         setVisibilityAndWait(mType, false);
 
         runOnUiThread(() -> {
+            setupAnimationListener();
             mRootView.getWindowInsetsController().controlWindowInsetsAnimation(mType, 0,
                     null, null, mListener);
         });
@@ -168,6 +178,7 @@
         setVisibilityAndWait(mType, true);
 
         runOnUiThread(() -> {
+            setupAnimationListener();
             mRootView.getWindowInsetsController().controlWindowInsetsAnimation(mType, 0,
                     null, null, mListener);
         });
@@ -187,6 +198,7 @@
         setVisibilityAndWait(mType, false);
 
         runOnUiThread(() -> {
+            setupAnimationListener();
             mRootView.getWindowInsetsController().controlWindowInsetsAnimation(mType, 0,
                     null, null, mListener);
         });
@@ -204,6 +216,7 @@
         setVisibilityAndWait(mType, true);
 
         runOnUiThread(() -> {
+            setupAnimationListener();
             mRootView.getWindowInsetsController().controlWindowInsetsAnimation(mType, 0,
                     null, null, mListener);
         });
@@ -222,6 +235,7 @@
         setVisibilityAndWait(mType, false);
 
         runOnUiThread(() -> {
+            setupAnimationListener();
             mRootView.getWindowInsetsController().controlWindowInsetsAnimation(mType, 0,
                     mInterpolator, null, mListener);
         });
@@ -240,6 +254,7 @@
         setVisibilityAndWait(mType, true);
 
         runOnUiThread(() -> {
+            setupAnimationListener();
             mRootView.getWindowInsetsController().controlWindowInsetsAnimation(mType, 0,
                     mInterpolator, null, mListener);
         });
@@ -252,24 +267,79 @@
         mListener.assertWasNotCalled(CANCELLED);
     }
 
-    public void runTransition(boolean show) throws Throwable {
-        runOnUiThread(() -> {
-            WindowInsets initialInsets = mActivity.mLastWindowInsets;
+    private void setupAnimationListener() {
+        WindowInsets initialInsets = mActivity.mLastWindowInsets;
+        mRootView.setWindowInsetsAnimationCallback(new VerifyingCallback(
+                new Callback(Callback.DISPATCH_MODE_STOP) {
+            @Override
+            public void onPrepare(@NonNull WindowInsetsAnimation animation) {
+                mErrorCollector.checkThat("onPrepare",
+                        mActivity.mLastWindowInsets.getInsets(mType),
+                        equalTo(initialInsets.getInsets(mType)));
+            }
 
-            final ValueAnimator animator = ValueAnimator.ofObject(
+            @NonNull
+            @Override
+            public WindowInsetsAnimation.Bounds onStart(
+                    @NonNull WindowInsetsAnimation animation,
+                    @NonNull WindowInsetsAnimation.Bounds bounds) {
+                mErrorCollector.checkThat("onStart",
+                        mActivity.mLastWindowInsets, not(equalTo(initialInsets)));
+                mErrorCollector.checkThat("onStart",
+                        animation.getInterpolator(), sameInstance(mInterpolator));
+                return bounds;
+            }
+
+            @NonNull
+            @Override
+            public WindowInsets onProgress(@NonNull WindowInsets insets,
+                    @NonNull List<WindowInsetsAnimation> runningAnimations) {
+                mOnProgressCalled = true;
+                if (mAnimator != null) {
+                    float fraction = runningAnimations.get(0).getFraction();
+                    mErrorCollector.checkThat(
+                            String.format(Locale.US, "onProgress(%.2f)", fraction),
+                            insets.getInsets(mType), equalTo(mAnimator.getAnimatedValue()));
+                    mErrorCollector.checkThat("onProgress",
+                            fraction, equalTo(mAnimator.getAnimatedFraction()));
+
+                    Interpolator interpolator =
+                            mInterpolator != null ? mInterpolator : new LinearInterpolator();
+                    mErrorCollector.checkThat("onProgress",
+                            runningAnimations.get(0).getInterpolatedFraction(),
+                            equalTo(interpolator.getInterpolation(
+                                    mAnimator.getAnimatedFraction())));
+                }
+                return insets;
+            }
+
+            @Override
+            public void onEnd(@NonNull WindowInsetsAnimation animation) {
+                mRootView.setOnApplyWindowInsetsListener(null);
+            }
+        }));
+    }
+
+    private void runTransition(boolean show) throws Throwable {
+        runOnUiThread(() -> {
+            mAnimator = ValueAnimator.ofObject(
                     sInsetsEvaluator,
                     show ? mListener.mController.getHiddenStateInsets()
                             : mListener.mController.getShownStateInsets(),
                     show ? mListener.mController.getShownStateInsets()
                             : mListener.mController.getHiddenStateInsets()
             );
-            animator.setDuration(1000);
-            animator.addUpdateListener((animator1) -> {
-                Insets insets = (Insets) animator.getAnimatedValue();
+            mAnimator.setDuration(1000);
+            mAnimator.addUpdateListener((animator1) -> {
+                Insets insets = (Insets) mAnimator.getAnimatedValue();
+                mOnProgressCalled = false;
                 mListener.mController.setInsetsAndAlpha(insets, 1.0f,
-                        animator.getAnimatedFraction());
+                        mAnimator.getAnimatedFraction());
+                mErrorCollector.checkThat(
+                        "setInsetsAndAlpha() must synchronously call onProgress() but didn't",
+                        mOnProgressCalled, is(true));
             });
-            animator.addListener(new AnimatorListenerAdapter() {
+            mAnimator.addListener(new AnimatorListenerAdapter() {
                 @Override
                 public void onAnimationEnd(Animator animation) {
                     if (!mListener.mController.isCancelled()) {
@@ -277,59 +347,12 @@
                     }
                 }
             });
-            mRootView.setWindowInsetsAnimationCallback(new Callback(Callback.DISPATCH_MODE_STOP) {
-                @Override
-                public void onPrepare(@NonNull WindowInsetsAnimation animation) {
-                    mErrorCollector.checkThat("onPrepare",
-                            mActivity.mLastWindowInsets, equalTo(initialInsets));
-                    mErrorCollector.checkThat("onPrepare",
-                            mActivity.mLastWindowInsets.getInsets(mType),
-                            equalTo(mListener.mController.getHiddenStateInsets()));
-                }
 
-                @NonNull
-                @Override
-                public WindowInsetsAnimation.Bounds onStart(
-                        @NonNull WindowInsetsAnimation animation,
-                        @NonNull WindowInsetsAnimation.Bounds bounds) {
-                    mErrorCollector.checkThat("onStart",
-                            mActivity.mLastWindowInsets.getInsets(mType),
-                            equalTo(mListener.mController.getShownStateInsets()));
-                    mErrorCollector.checkThat("onStart",
-                            animation.getInterpolator(), sameInstance(mInterpolator));
-                    return bounds;
-                }
-
-                @NonNull
-                @Override
-                public WindowInsets onProgress(@NonNull WindowInsets insets,
-                        @NonNull List<WindowInsetsAnimation> runningAnimations) {
-                    float fraction = runningAnimations.get(0).getFraction();
-                    mErrorCollector.checkThat(
-                            String.format(Locale.US, "onProgress(%.2f)", fraction),
-                            insets.getInsets(mType), equalTo(animator.getAnimatedValue()));
-                    mErrorCollector.checkThat("onProgress",
-                            fraction, equalTo(animator.getAnimatedFraction()));
-
-                    Interpolator interpolator =
-                            mInterpolator != null ? mInterpolator : new LinearInterpolator();
-                    mErrorCollector.checkThat("onProgress",
-                            runningAnimations.get(0).getInterpolatedFraction(),
-                            equalTo(interpolator.getInterpolation(animator.getAnimatedFraction())));
-                    return insets;
-                }
-
-                @Override
-                public void onEnd(@NonNull WindowInsetsAnimation animation) {
-                    mRootView.setOnApplyWindowInsetsListener(null);
-                }
-            });
-
-            animator.start();
+            mAnimator.start();
         });
     }
 
-    public void setVisibilityAndWait(int type, boolean visible) throws Throwable {
+    private void setVisibilityAndWait(int type, boolean visible) throws Throwable {
         runOnUiThread(() -> {
             if (visible) {
                 mRootView.getWindowInsetsController().show(type);
@@ -421,4 +444,74 @@
                     (int) (startValue.top + fraction * (endValue.top - startValue.top)),
                     (int) (startValue.right + fraction * (endValue.right - startValue.right)),
                     (int) (startValue.bottom + fraction * (endValue.bottom - startValue.bottom)));
+
+
+    private class VerifyingCallback extends Callback {
+        private final Callback mInner;
+        private final Set<WindowInsetsAnimation> mPreparedAnimations = new HashSet<>();
+        private final Set<WindowInsetsAnimation> mRunningAnimations = new HashSet<>();
+
+        public VerifyingCallback(Callback callback) {
+            super(callback.getDispatchMode());
+            mInner = callback;
+        }
+
+        @Override
+        public void onPrepare(@NonNull WindowInsetsAnimation animation) {
+            mErrorCollector.checkThat("onPrepare", mPreparedAnimations, not(hasItem(animation)));
+            mPreparedAnimations.add(animation);
+            mInner.onPrepare(animation);
+        }
+
+        @NonNull
+        @Override
+        public WindowInsetsAnimation.Bounds onStart(@NonNull WindowInsetsAnimation animation,
+                @NonNull WindowInsetsAnimation.Bounds bounds) {
+            mErrorCollector.checkThat("onStart: mPreparedAnimations",
+                    mPreparedAnimations, hasItem(animation));
+            mErrorCollector.checkThat("onStart: mRunningAnimations",
+                    mRunningAnimations, not(hasItem(animation)));
+            mRunningAnimations.add(animation);
+            mPreparedAnimations.remove(animation);
+            return mInner.onStart(animation, bounds);
+        }
+
+        @NonNull
+        @Override
+        public WindowInsets onProgress(@NonNull WindowInsets insets,
+                @NonNull List<WindowInsetsAnimation> runningAnimations) {
+            mErrorCollector.checkThat("onProgress", new HashSet<>(runningAnimations),
+                    is(equalTo(mRunningAnimations)));
+            return mInner.onProgress(insets, runningAnimations);
+        }
+
+        @Override
+        public void onEnd(@NonNull WindowInsetsAnimation animation) {
+            mErrorCollector.checkThat("onEnd: mRunningAnimations",
+                    mRunningAnimations, hasItem(animation));
+            mRunningAnimations.remove(animation);
+            mPreparedAnimations.remove(animation);
+            mInner.onEnd(animation);
+        }
+    }
+
+    public static final class LimitedErrorCollector extends ErrorCollector {
+        private static final int LIMIT = 1;
+        private int mCount = 0;
+
+        @Override
+        public void addError(Throwable error) {
+            if (mCount++ < LIMIT) {
+                super.addError(error);
+            }
+        }
+
+        @Override
+        protected void verify() throws Throwable {
+            if (mCount >= LIMIT) {
+                super.addError(new AssertionError(mCount + " errors skipped."));
+            }
+            super.verify();
+        }
+    }
 }