SetFlagsRule Annotations Update

Update AnnotationsRetriever
* Output is effectively Map<String, Boolean> which is clearer to test
* Retriever will assert if the annotations on a method conflict
* Retriever will assert if the annotations on a class conflict
* Retriever will assert if the annotations on a method conflict with the class
* Retriever no longer ignores enabled requirements of the class when the method defines enabled requirements, while leaving class disabled requirements in place; and vice versa
note: at the time of authoring, there are no examples outside of AnnotationsRetrieverTest where these annotations are defined on both the class and method for the same test.

Update SetFlagsRule
* Will now set flags based on @EnableFlags / @DisableFlags annotations
* Supports FlagParameterization, including skipping test parameterizations that are inconsistent with @EnableFlags / @DisableFlags annotations
* Fails if any inconsistencies are detected among FlagsParameterization / @RequiresFlagsEnabled / @RequiresFlagsDisabled / @EnableFlags / @DisableFlags
* Throws FlagSetException when calling enableFlags or disableFlags on any flag that is set by annotation or parameterization, or required by annotation.

Update CheckFlagsRule
* Fails if any inconsistencies are detected among @RequiresFlagsEnabled / @RequiresFlagsDisabled / @EnableFlags / @DisableFlags

Fixes: 309625167
Fixes: 309522666
Test: atest AnnotationsRetrieverTest
Flag: NA
Change-Id: I2e03868b614cf19f4b09b333d523944934080cfa
diff --git a/libraries/annotations/src/android/platform/test/annotations/DisableFlags.java b/libraries/annotations/src/android/platform/test/annotations/DisableFlags.java
new file mode 100644
index 0000000..08b50dd
--- /dev/null
+++ b/libraries/annotations/src/android/platform/test/annotations/DisableFlags.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.platform.test.annotations;
+
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+/**
+ * Indicates that {@code SetFlagsRule} should disable all the given feature flags before running the
+ * annotated test or class.
+ *
+ * <p>This annotation works together with {@link EnableFlags} to define the value of the flag that
+ * needs to be set for the test to run. It is an error for either a method or class to declare that
+ * a flag is set to be both enabled and disabled.
+ *
+ * <p>If the value for a particular flag is defined (by either {@code EnableFlags} or {@code
+ * DisableFlags}) by both the class and test method, then the values must be consistent.
+ *
+ * <p>If the value of a one flag is required by an annotation on the class, and the value of a
+ * different flag is required by an annotation of the method, then both requirements apply.
+ *
+ * <p>With {@code SetFlagsRule}, the flag will be disabled within the test process for the duration
+ * of the test(s). When being run with {@code FlagsParameterization} that enables the flag, then the
+ * test will be skipped with 'assumption failed'.
+ *
+ * <p>Both {@code SetFlagsRule} and {@code CheckFlagsRule} will fail the test if a particular flag
+ * is both set (with {@code EnableFlags} or {@code DisableFlags}) and required (with {@code
+ * RequiresFlagsEnabled} or {@code RequiresFlagsDisabled}).
+ */
+@Retention(RetentionPolicy.RUNTIME)
+@Target({ElementType.METHOD, ElementType.TYPE})
+public @interface DisableFlags {
+    /**
+     * The list of the feature flags to be disabled. Each item is the full flag name with the format
+     * {package_name}.{flag_name}.
+     */
+    String[] value();
+}
diff --git a/libraries/annotations/src/android/platform/test/annotations/EnableFlags.java b/libraries/annotations/src/android/platform/test/annotations/EnableFlags.java
new file mode 100644
index 0000000..320eb19
--- /dev/null
+++ b/libraries/annotations/src/android/platform/test/annotations/EnableFlags.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.platform.test.annotations;
+
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+/**
+ * Indicates that {@code SetFlagsRule} should enable all the given feature flags before running the
+ * annotated test or class.
+ *
+ * <p>This annotation works together with {@link DisableFlags} to define the value of the flag that
+ * needs to be set for the test to run. It is an error for either a method or class to declare that
+ * a flag is set to be both enabled and disabled.
+ *
+ * <p>If the value for a particular flag is defined (by either {@code EnableFlags} or {@code
+ * DisableFlags}) by both the class and test method, then the values must be consistent.
+ *
+ * <p>If the value of a one flag is required by an annotation on the class, and the value of a
+ * different flag is required by an annotation of the method, then both requirements apply.
+ *
+ * <p>With {@code SetFlagsRule}, the flag will be enabled within the test process for the duration
+ * of the test(s). When being run with {@code FlagsParameterization} that disables the flag, then
+ * the test will be skipped with 'assumption failed'.
+ *
+ * <p>Both {@code SetFlagsRule} and {@code CheckFlagsRule} will fail the test if a particular flag
+ * is both set (with {@code EnableFlags} or {@code DisableFlags}) and required (with {@code
+ * RequiresFlagsEnabled} or {@code RequiresFlagsDisabled}).
+ */
+@Retention(RetentionPolicy.RUNTIME)
+@Target({ElementType.METHOD, ElementType.TYPE})
+public @interface EnableFlags {
+    /**
+     * The list of the feature flags to be enabled. Each item is the full flag name with the format
+     * {package_name}.{flag_name}.
+     */
+    String[] value();
+}
diff --git a/libraries/annotations/src/android/platform/test/annotations/RequiresFlagsDisabled.java b/libraries/annotations/src/android/platform/test/annotations/RequiresFlagsDisabled.java
index b67c454..5d2b9d3 100644
--- a/libraries/annotations/src/android/platform/test/annotations/RequiresFlagsDisabled.java
+++ b/libraries/annotations/src/android/platform/test/annotations/RequiresFlagsDisabled.java
@@ -22,13 +22,26 @@
 import java.lang.annotation.Target;
 
 /**
- * Indicates that a specific test or class should be run on certain feature flag disabled.
+ * Indicates that a specific test or class should be run only if all of the given feature flags are
+ * disabled in the device's current state. Enforced by the {@code CheckFlagsRule}.
+ *
+ * <p>This annotation works together with {@link RequiresFlagsEnabled} to define the value that is
+ * required of the flag by the test for the test to run. It is an error for either a method or class
+ * to require that a particular flag be both enabled and disabled.
+ *
+ * <p>If the value of a particular flag is required (by either {@code RequiresFlagsEnabled} or
+ * {@code RequiresFlagsDisabled}) by both the class and test method, then the values must be
+ * consistent.
+ *
+ * <p>If the value of a one flag is required by an annotation on the class, and the value of a
+ * different flag is required by an annotation of the method, then both requirements apply.
  *
  * <p>With {@code CheckFlagsRule}, test(s) will be skipped with 'assumption failed' when any of the
  * required flag on the target Android platform is enabled.
  *
- * <p>If {@code RequiresFlagsDisabled} is applied at both the class and test method, the test method
- * annotation takes precedence, and the class level {@code RequiresFlagsDisabled} is ignored.
+ * <p>Both {@code SetFlagsRule} and {@code CheckFlagsRule} will fail the test if a particular flag
+ * is both set (with {@code EnableFlags} or {@code DisableFlags}) and required (with {@code
+ * RequiresFlagsEnabled} or {@code RequiresFlagsDisabled}).
  */
 @Retention(RetentionPolicy.RUNTIME)
 @Target({ElementType.METHOD, ElementType.TYPE})
diff --git a/libraries/annotations/src/android/platform/test/annotations/RequiresFlagsEnabled.java b/libraries/annotations/src/android/platform/test/annotations/RequiresFlagsEnabled.java
index e25243b..5b5271d 100644
--- a/libraries/annotations/src/android/platform/test/annotations/RequiresFlagsEnabled.java
+++ b/libraries/annotations/src/android/platform/test/annotations/RequiresFlagsEnabled.java
@@ -22,13 +22,26 @@
 import java.lang.annotation.Target;
 
 /**
- * Indicates that a specific test or class should be run on certain feature flag enabled.
+ * Indicates that a specific test or class should be run only if all of the given feature flags are
+ * enabled in the device's current state. Enforced by the {@code CheckFlagsRule}.
+ *
+ * <p>This annotation works together with {@link RequiresFlagsDisabled} to define the value that is
+ * required of the flag by the test for the test to run. It is an error for either a method or class
+ * to require that a particular flag be both enabled and disabled.
+ *
+ * <p>If the value of a particular flag is required (by either {@code RequiresFlagsEnabled} or
+ * {@code RequiresFlagsDisabled}) by both the class and test method, then the values must be
+ * consistent.
+ *
+ * <p>If the value of a one flag is required by an annotation on the class, and the value of a
+ * different flag is required by an annotation of the method, then both requirements apply.
  *
  * <p>With {@code CheckFlagsRule}, test(s) will be skipped with 'assumption failed' when any of the
  * required flag on the target Android platform is disabled.
  *
- * <p>If {@code RequiresFlagsEnabled} is applied at both the class and test method, the test method
- * annotation takes precedence, and the class level {@code RequiresFlagsEnabled} is ignored.
+ * <p>Both {@code SetFlagsRule} and {@code CheckFlagsRule} will fail the test if a particular flag
+ * is both set (with {@code EnableFlags} or {@code DisableFlags}) and required (with {@code
+ * RequiresFlagsEnabled} or {@code RequiresFlagsDisabled}).
  */
 @Retention(RetentionPolicy.RUNTIME)
 @Target({ElementType.METHOD, ElementType.TYPE})
diff --git a/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/AnnotationsRetriever.java b/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/AnnotationsRetriever.java
index a88fccc..84caa49 100644
--- a/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/AnnotationsRetriever.java
+++ b/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/AnnotationsRetriever.java
@@ -16,9 +16,16 @@
 
 package android.platform.test.flag.junit;
 
+import static org.junit.Assume.assumeFalse;
+import static org.junit.Assume.assumeTrue;
+
+import android.platform.test.annotations.DisableFlags;
+import android.platform.test.annotations.EnableFlags;
 import android.platform.test.annotations.RequiresFlagsDisabled;
 import android.platform.test.annotations.RequiresFlagsEnabled;
 
+import com.google.common.collect.Sets;
+
 import org.junit.Test;
 import org.junit.runner.Description;
 
@@ -29,12 +36,15 @@
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
+import java.util.Objects;
 import java.util.Queue;
 import java.util.Set;
 
-import javax.annotation.Nullable;
+import javax.annotation.Nonnull;
 
 /**
  * Retrieves feature flag related annotations from a given {@code Description}.
@@ -50,27 +60,99 @@
 
     /** Gets all feature flag related annotations. */
     public static FlagAnnotations getFlagAnnotations(Description description) {
-        return new FlagAnnotations(
-                getAnnotation(RequiresFlagsEnabled.class, description),
-                getAnnotation(RequiresFlagsDisabled.class, description));
-    }
+        final Map<String, Boolean> requiresFlagValues =
+                getMergedFlagValues(sRequiresFlagsEnabled, sRequiresFlagsDisabled, description);
+        final Map<String, Boolean> setsFlagValues =
+                getMergedFlagValues(sEnableFlags, sDisableFlags, description);
 
-    @Nullable
-    private static <T extends Annotation> T getAnnotation(
-            Class<T> annotationType, Description description) {
-        T annotation = getAnnotation(annotationType, description.getAnnotations());
-        if (annotation != null) {
-            return annotation;
+        // Assert that no flag is defined in both maps
+        Set<String> inconsistentFlags =
+                Sets.intersection(requiresFlagValues.keySet(), setsFlagValues.keySet());
+        if (!inconsistentFlags.isEmpty()) {
+            throw new AssertionError(
+                    "The following flags are both required and set: " + inconsistentFlags);
         }
-        Class<?> testClass = description.getTestClass();
-        return testClass == null
-                ? null
-                : getAnnotation(annotationType, List.of(testClass.getAnnotations()));
+
+        return new FlagAnnotations(requiresFlagValues, setsFlagValues);
     }
 
-    @Nullable
-    private static <T extends Annotation> T getAnnotation(
-            Class<T> annotationType, Collection<Annotation> annotations) {
+    private static Map<String, Boolean> getMergedFlagValues(
+            FlagsAnnotation<? extends Annotation> enabledAnnotation,
+            FlagsAnnotation<? extends Annotation> disabledAnnotation,
+            Description description) {
+        final Map<String, Boolean> methodFlagValues =
+                getFlagValues(
+                        description.getMethodName(),
+                        enabledAnnotation,
+                        disabledAnnotation,
+                        description.getAnnotations());
+        Class<?> testClass = description.getTestClass();
+        final Map<String, Boolean> classFlagValues =
+                testClass == null
+                        ? new HashMap<>()
+                        : getFlagValues(
+                                testClass.getName(),
+                                enabledAnnotation,
+                                disabledAnnotation,
+                                List.of(testClass.getAnnotations()));
+        Sets.SetView<String> doublyDefinedFlags =
+                Sets.intersection(classFlagValues.keySet(), methodFlagValues.keySet());
+        if (!doublyDefinedFlags.isEmpty()) {
+            List<String> mismatchedFlags =
+                    doublyDefinedFlags.stream()
+                            .filter(
+                                    flag ->
+                                            !Objects.equals(
+                                                    classFlagValues.get(flag),
+                                                    methodFlagValues.get(flag)))
+                            .toList();
+            if (!mismatchedFlags.isEmpty()) {
+                throw new AssertionError(
+                        "The following flags are required by "
+                                + description.getMethodName()
+                                + " and "
+                                + description.getClassName()
+                                + " to be both enabled and disabled: "
+                                + mismatchedFlags);
+            }
+        }
+        // Now override the class values with the method values to produce a merged map
+        classFlagValues.putAll(methodFlagValues);
+        return classFlagValues;
+    }
+
+    private static Map<String, Boolean> getFlagValues(
+            @Nonnull String annotationTarget,
+            @Nonnull FlagsAnnotation<? extends Annotation> enabledAnnotation,
+            @Nonnull FlagsAnnotation<? extends Annotation> disabledAnnotation,
+            @Nonnull Collection<Annotation> annotations) {
+        Set<String> enabledFlags = getFlagsForAnnotation(enabledAnnotation, annotations);
+        Set<String> disabledFlags = getFlagsForAnnotation(disabledAnnotation, annotations);
+        if (enabledFlags.isEmpty() && disabledFlags.isEmpty()) {
+            return new HashMap<>();
+        }
+        Set<String> inconsistentFlags = Sets.intersection(enabledFlags, disabledFlags);
+        if (!inconsistentFlags.isEmpty()) {
+            throw new AssertionError(
+                    "The following flags are required by "
+                            + annotationTarget
+                            + " to be both enabled and disabled: "
+                            + inconsistentFlags);
+        }
+        HashMap<String, Boolean> result = new HashMap<>();
+        for (String enabledFlag : enabledFlags) {
+            result.put(enabledFlag, true);
+        }
+        for (String disabledFlag : disabledFlags) {
+            result.put(disabledFlag, false);
+        }
+        return result;
+    }
+
+    @Nonnull
+    private static <T extends Annotation> Set<String> getFlagsForAnnotation(
+            FlagsAnnotation<T> flagsAnnotation, Collection<Annotation> annotations) {
+        Class<T> annotationType = flagsAnnotation.mAnnotationType;
         List<T> results = new ArrayList<>();
         Queue<Annotation> annotationQueue = new ArrayDeque<>();
         Set<Class<? extends Annotation>> visitedAnnotations = new HashSet<>();
@@ -93,22 +175,121 @@
                             "Annotation %s has been specified multiple time: %s",
                             annotationType, results));
         }
-        return results.isEmpty() ? null : results.get(0);
+        return results.isEmpty() ? Set.of() : flagsAnnotation.getFlagsSet(results.get(0));
     }
 
     /** Contains all feature flag related annotations. */
     public static class FlagAnnotations {
-        /** Annotation for the flags that requires to be enabled. */
-        @Nullable public final RequiresFlagsEnabled mRequiresFlagsEnabled;
 
-        /** Annotation for the flags that requires to be disabled. */
-        @Nullable public final RequiresFlagsDisabled mRequiresFlagsDisabled;
+        /** The flag names which have required values, mapped to the value they require */
+        public @Nonnull Map<String, Boolean> mRequiredFlagValues;
+
+        /** The flag names which have values to be set, mapped to the value they set */
+        public @Nonnull Map<String, Boolean> mSetFlagValues;
 
         FlagAnnotations(
-                RequiresFlagsEnabled requiresFlagsEnabled,
-                RequiresFlagsDisabled requiresFlagsDisabled) {
-            mRequiresFlagsEnabled = requiresFlagsEnabled;
-            mRequiresFlagsDisabled = requiresFlagsDisabled;
+                @Nonnull Map<String, Boolean> requiredFlagValues,
+                @Nonnull Map<String, Boolean> setFlagValues) {
+            mRequiredFlagValues = requiredFlagValues;
+            mSetFlagValues = setFlagValues;
+        }
+
+        /**
+         * Check that all @RequiresFlagsEnabled and @RequiresFlagsDisabled annotations match the
+         * values from the provider, and if this is not true, throw {@link
+         * org.junit.AssumptionViolatedException}
+         *
+         * @param valueProvider the value provider
+         */
+        public void assumeAllRequiredFlagsMatchProvider(IFlagsValueProvider valueProvider) {
+            for (Map.Entry<String, Boolean> required : mRequiredFlagValues.entrySet()) {
+                final String flag = required.getKey();
+                if (required.getValue()) {
+                    assumeTrue(
+                            String.format("Flag %s required to be enabled, but is disabled", flag),
+                            valueProvider.getBoolean(flag));
+                } else {
+                    assumeFalse(
+                            String.format("Flag %s required to be disabled, but is enabled", flag),
+                            valueProvider.getBoolean(flag));
+                }
+            }
+        }
+
+        /**
+         * Check that all @EnableFlags and @DisableFlags annotations match the values contained in
+         * the parameterization (if present), and if this is not true, throw {@link
+         * org.junit.AssumptionViolatedException}
+         *
+         * @param params the parameterization to evaluate against (optional)
+         */
+        public void assumeAllSetFlagsMatchParameterization(FlagsParameterization params) {
+            if (params == null) return;
+            for (Map.Entry<String, Boolean> toSet : mSetFlagValues.entrySet()) {
+                final String flag = toSet.getKey();
+                final Boolean paramValue = params.mOverrides.get(flag);
+                if (paramValue == null) continue;
+                if (toSet.getValue()) {
+                    assumeTrue(
+                            String.format(
+                                    "Flag %s is enabled by annotation but disabled by the current"
+                                            + " FlagsParameterization; skipping test",
+                                    flag),
+                            paramValue);
+                } else {
+                    assumeFalse(
+                            String.format(
+                                    "Flag %s is disabled by annotation but enabled by the current"
+                                            + " FlagsParameterization; skipping test",
+                                    flag),
+                            paramValue);
+                }
+            }
         }
     }
+
+    private abstract static class FlagsAnnotation<T extends Annotation> {
+        Class<T> mAnnotationType;
+
+        FlagsAnnotation(Class<T> type) {
+            mAnnotationType = type;
+        }
+
+        protected abstract String[] getFlags(T annotation);
+
+        @Nonnull
+        Set<String> getFlagsSet(T annotation) {
+            String[] flags = getFlags(annotation);
+            return flags == null ? Set.of() : Set.of(flags);
+        }
+    }
+
+    private static final FlagsAnnotation<RequiresFlagsEnabled> sRequiresFlagsEnabled =
+            new FlagsAnnotation<>(RequiresFlagsEnabled.class) {
+                @Override
+                protected String[] getFlags(RequiresFlagsEnabled annotation) {
+                    return annotation.value();
+                }
+            };
+    private static final FlagsAnnotation<RequiresFlagsDisabled> sRequiresFlagsDisabled =
+            new FlagsAnnotation<>(RequiresFlagsDisabled.class) {
+                @Override
+                protected String[] getFlags(RequiresFlagsDisabled annotation) {
+                    return annotation.value();
+                }
+            };
+    private static final FlagsAnnotation<EnableFlags> sEnableFlags =
+            new FlagsAnnotation<>(EnableFlags.class) {
+                @Override
+                protected String[] getFlags(EnableFlags annotation) {
+                    return annotation.value();
+                }
+            };
+    private static final FlagsAnnotation<DisableFlags> sDisableFlags =
+            new FlagsAnnotation<>(DisableFlags.class) {
+                @Override
+                protected String[] getFlags(DisableFlags annotation) {
+                    return annotation.value();
+                }
+            };
 }
diff --git a/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/CheckFlagsRule.java b/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/CheckFlagsRule.java
index c061440..0e9adcb 100644
--- a/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/CheckFlagsRule.java
+++ b/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/CheckFlagsRule.java
@@ -16,11 +16,6 @@
 
 package android.platform.test.flag.junit;
 
-import static org.junit.Assume.assumeTrue;
-
-import android.platform.test.annotations.RequiresFlagsDisabled;
-import android.platform.test.annotations.RequiresFlagsEnabled;
-
 import org.junit.rules.TestRule;
 import org.junit.runner.Description;
 import org.junit.runners.model.Statement;
@@ -48,29 +43,9 @@
             public void evaluate() throws Throwable {
                 AnnotationsRetriever.FlagAnnotations flagAnnotations =
                         AnnotationsRetriever.getFlagAnnotations(description);
-                RequiresFlagsEnabled requiresFlagsEnabled = flagAnnotations.mRequiresFlagsEnabled;
-                RequiresFlagsDisabled requiresFlagsDisabled =
-                        flagAnnotations.mRequiresFlagsDisabled;
                 mFlagsValueProvider.setUp();
                 try {
-                    if (requiresFlagsEnabled != null) {
-                        for (String flag : requiresFlagsEnabled.value()) {
-                            assumeTrue(
-                                    String.format(
-                                            "Flag %s required to be enabled, but is disabled",
-                                            flag),
-                                    mFlagsValueProvider.getBoolean(flag));
-                        }
-                    }
-                    if (requiresFlagsDisabled != null) {
-                        for (String flag : requiresFlagsDisabled.value()) {
-                            assumeTrue(
-                                    String.format(
-                                            "Flag %s required to be disabled, but is enabled",
-                                            flag),
-                                    !mFlagsValueProvider.getBoolean(flag));
-                        }
-                    }
+                    flagAnnotations.assumeAllRequiredFlagsMatchProvider(mFlagsValueProvider);
                 } finally {
                     mFlagsValueProvider.tearDownBeforeTest();
                 }
diff --git a/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/FlagsParameterization.java b/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/FlagsParameterization.java
new file mode 100644
index 0000000..3bcb8c2
--- /dev/null
+++ b/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/FlagsParameterization.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.platform.test.flag.junit;
+
+import static java.util.Objects.requireNonNull;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+
+import javax.annotation.Nonnull;
+
+/** An object which holds aconfig flags values, and can be used for parameterized testing. */
+public final class FlagsParameterization {
+    public final Map<String, Boolean> mOverrides;
+
+    /** Construct a values wrapper class */
+    public FlagsParameterization(Map<String, Boolean> overrides) {
+        mOverrides = Map.copyOf(overrides);
+    }
+
+    @Override
+    public String toString() {
+        if (mOverrides.isEmpty()) {
+            return "EMPTY";
+        }
+        StringBuilder sb = new StringBuilder();
+        for (Map.Entry<String, Boolean> entry : new TreeMap<>(mOverrides).entrySet()) {
+            if (sb.length() != 0) {
+                sb.append(',');
+            }
+            sb.append(entry.getKey()).append('=').append(entry.getValue());
+        }
+        return sb.toString();
+    }
+
+    /**
+     * Determines whether the dependency <code>alpha dependsOn beta</code> is met for the defined
+     * values.
+     *
+     * @param alpha a flag which must be defined in this object
+     * @param beta a flag which must be defined in this object
+     * @return true in all cases except when alpha is enabled but beta is disabled.
+     */
+    public boolean isDependencyMet(String alpha, String beta) {
+        boolean alphaEnabled = requireNonNull(mOverrides.get(alpha), alpha + " is not defined");
+        boolean betaEnabled = requireNonNull(mOverrides.get(beta), beta + " is not defined");
+        return betaEnabled || !alphaEnabled;
+    }
+
+    @Override
+    public boolean equals(Object other) {
+        if (other == null) return false;
+        if (other == this) return true;
+        if (!(other instanceof FlagsParameterization)) return false;
+        return mOverrides.equals(((FlagsParameterization) other).mOverrides);
+    }
+
+    @Override
+    public int hashCode() {
+        return mOverrides.hashCode();
+    }
+
+    /**
+     * Produces a list containing every combination of boolean values for the given flags.
+     *
+     * @return a list of size 2^N for N provided flags.
+     */
+    @Nonnull
+    public static List<FlagsParameterization> allCombinationsOf(@Nonnull String... flagNames) {
+        List<Map<String, Boolean>> currentList = List.of(new HashMap<>());
+        for (String flagName : flagNames) {
+            List<Map<String, Boolean>> next = new ArrayList<>(currentList.size() * 2);
+            for (Map<String, Boolean> current : currentList) {
+                // copy the current map and add this flag as disabled
+                Map<String, Boolean> plusDisabled = new HashMap<>(current);
+                plusDisabled.put(flagName, false);
+                next.add(plusDisabled);
+                // re-use the current map and add this flag as enabled
+                current.put(flagName, true);
+                next.add(current);
+            }
+            currentList = next;
+        }
+        List<FlagsParameterization> result = new ArrayList<>();
+        for (Map<String, Boolean> valuesMap : currentList) {
+            result.add(new FlagsParameterization(valuesMap));
+        }
+        return result;
+    }
+
+    /**
+     * Produces a list containing the flag parameterizations where each flag is turned on in the
+     * given sequence.
+     *
+     * <p><code>progressionOf("a", "b", "c")</code> produces the following parameterizations:
+     *
+     * <ul>
+     *   <li><code>{"a": false, "b": false, "c": false}</code>
+     *   <li><code>{"a": true, "b": false, "c": false}</code>
+     *   <li><code>{"a": true, "b": true, "c": false}</code>
+     *   <li><code>{"a": true, "b": true, "c": true}</code>
+     * </ul>
+     *
+     * @return a list of size N+1 for N provided flags.
+     */
+    @Nonnull
+    public static List<FlagsParameterization> progressionOf(@Nonnull String... flagNames) {
+        final List<FlagsParameterization> result = new ArrayList<>();
+        final Map<String, Boolean> currentMap = new HashMap<>();
+        for (String flagName : flagNames) {
+            currentMap.put(flagName, false);
+        }
+        result.add(new FlagsParameterization(currentMap));
+        for (String flagName : flagNames) {
+            currentMap.put(flagName, true);
+            result.add(new FlagsParameterization(currentMap));
+        }
+        return result;
+    }
+}
diff --git a/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/SetFlagsRule.java b/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/SetFlagsRule.java
index e814b33..0dd4fff 100644
--- a/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/SetFlagsRule.java
+++ b/libraries/flag-helpers/junit/src_base/android/platform/test/flag/junit/SetFlagsRule.java
@@ -21,6 +21,7 @@
 import android.platform.test.flag.util.FlagSetException;
 
 import com.google.common.base.CaseFormat;
+import com.google.common.collect.Sets;
 
 import org.junit.rules.TestRule;
 import org.junit.runner.Description;
@@ -28,9 +29,13 @@
 
 import java.lang.reflect.Field;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
 
 import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 
 /** A {@link TestRule} that helps to set flag values in unit test. */
 public final class SetFlagsRule implements TestRule {
@@ -48,7 +53,12 @@
     // Store value for the scope of each test method
     private final Map<Class<?>, Map<Flag, Boolean>> mFlagsClassToFlagDefaultMap = new HashMap<>();
 
+    // Any flags added to this list cannot be set imperatively (i.e. with enableFlags/disableFlags)
+    private final Set<String> mLockedFlagNames = new HashSet<>();
+
     private boolean mIsInitWithDefault = false;
+    private FlagsParameterization mFlagsParameterization;
+    private boolean mIsRuleEvaluating = false;
 
     /**
      * Enable default value for flags
@@ -91,6 +101,12 @@
     }
 
     public SetFlagsRule(DefaultInitValueType defaultType) {
+        this(defaultType, null);
+    }
+
+    public SetFlagsRule(
+            DefaultInitValueType defaultType,
+            @Nullable FlagsParameterization flagsParameterization) {
         switch (defaultType) {
             case DEVICE_DEFAULT:
                 mIsInitWithDefault = true;
@@ -98,6 +114,27 @@
             default:
                 break;
         }
+        mFlagsParameterization = flagsParameterization;
+        if (flagsParameterization != null) {
+            mLockedFlagNames.addAll(flagsParameterization.mOverrides.keySet());
+        }
+    }
+
+    /**
+     * Set the FlagsParameterization to be used during this test. This cannot be used to override a
+     * previous call, and cannot be called once the rule has been evaluated.
+     */
+    public void setFlagsParameterization(@Nonnull FlagsParameterization flagsParameterization) {
+        Objects.requireNonNull(flagsParameterization, "FlagsParameterization cannot be cleared");
+        if (mFlagsParameterization != null) {
+            throw new AssertionError("FlagsParameterization cannot be overridden");
+        }
+        if (mIsRuleEvaluating) {
+            throw new AssertionError("Cannot set FlagsParameterization once the rule is running");
+        }
+        ensureFlagsAreUnset();
+        mFlagsParameterization = flagsParameterization;
+        mLockedFlagNames.addAll(flagsParameterization.mOverrides.keySet());
     }
 
     /**
@@ -108,6 +145,9 @@
      */
     public void enableFlags(String... fullFlagNames) {
         for (String fullFlagName : fullFlagNames) {
+            if (mLockedFlagNames.contains(fullFlagName)) {
+                throw new FlagSetException(fullFlagName, "Not allowed to change locked flags");
+            }
             setFlagValue(fullFlagName, true);
         }
     }
@@ -120,6 +160,9 @@
      */
     public void disableFlags(String... fullFlagNames) {
         for (String fullFlagName : fullFlagNames) {
+            if (mLockedFlagNames.contains(fullFlagName)) {
+                throw new FlagSetException(fullFlagName, "Not allowed to change locked flags");
+            }
             setFlagValue(fullFlagName, false);
         }
     }
@@ -155,6 +198,12 @@
         return featureFlagsClass.cast(fakeFlagsImplInstance);
     }
 
+    private void ensureFlagsAreUnset() {
+        if (!mFlagsClassToFakeFlagsImpl.isEmpty()) {
+            throw new AssertionError("Some flags were set before the rule was initialized");
+        }
+    }
+
     @Override
     public Statement apply(Statement base, Description description) {
         return new Statement() {
@@ -162,10 +211,29 @@
             public void evaluate() throws Throwable {
                 Throwable throwable = null;
                 try {
+                    AnnotationsRetriever.FlagAnnotations flagAnnotations =
+                            AnnotationsRetriever.getFlagAnnotations(description);
+                    assertAnnotationsMatchParameterization(flagAnnotations, mFlagsParameterization);
+                    flagAnnotations.assumeAllSetFlagsMatchParameterization(mFlagsParameterization);
+                    if (mFlagsParameterization != null) {
+                        ensureFlagsAreUnset();
+                        for (Map.Entry<String, Boolean> pair :
+                                mFlagsParameterization.mOverrides.entrySet()) {
+                            setFlagValue(pair.getKey(), pair.getValue());
+                        }
+                    }
+                    for (Map.Entry<String, Boolean> pair :
+                            flagAnnotations.mSetFlagValues.entrySet()) {
+                        setFlagValue(pair.getKey(), pair.getValue());
+                    }
+                    mLockedFlagNames.addAll(flagAnnotations.mRequiredFlagValues.keySet());
+                    mLockedFlagNames.addAll(flagAnnotations.mSetFlagValues.keySet());
+                    mIsRuleEvaluating = true;
                     base.evaluate();
                 } catch (Throwable t) {
                     throwable = t;
                 } finally {
+                    mIsRuleEvaluating = false;
                     try {
                         resetFlags();
                     } catch (Throwable t) {
@@ -180,6 +248,24 @@
         };
     }
 
+    private static void assertAnnotationsMatchParameterization(
+            AnnotationsRetriever.FlagAnnotations flagAnnotations,
+            FlagsParameterization parameterization) {
+        if (parameterization == null) return;
+        Set<String> parameterizedFlags = parameterization.mOverrides.keySet();
+        Set<String> requiredFlags = flagAnnotations.mRequiredFlagValues.keySet();
+        // Assert that NO Annotation-Required flag is in the parameterization
+        Set<String> parameterizedAndRequiredFlags =
+                Sets.intersection(parameterizedFlags, requiredFlags);
+        if (!parameterizedAndRequiredFlags.isEmpty()) {
+            throw new AssertionError(
+                    "The following flags have required values (per @RequiresFlagsEnabled or"
+                            + " @RequiresFlagsDisabled) but they are part of the"
+                            + " FlagParameterization: "
+                            + parameterizedAndRequiredFlags);
+        }
+    }
+
     private void setFlagValue(String fullFlagName, boolean value) {
         if (!fullFlagName.contains(".")) {
             throw new FlagSetException(
diff --git a/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/AnnotationTestRuleHelper.java b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/AnnotationTestRuleHelper.java
new file mode 100644
index 0000000..596f951
--- /dev/null
+++ b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/AnnotationTestRuleHelper.java
@@ -0,0 +1,215 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.platform.test.flag.junit;
+
+import android.platform.test.annotations.DisableFlags;
+import android.platform.test.annotations.EnableFlags;
+import android.platform.test.annotations.RequiresFlagsDisabled;
+import android.platform.test.annotations.RequiresFlagsEnabled;
+
+import com.google.auto.value.AutoAnnotation;
+
+import org.junit.Assert;
+import org.junit.AssumptionViolatedException;
+import org.junit.rules.TestRule;
+import org.junit.runner.Description;
+import org.junit.runners.model.Statement;
+
+import java.lang.annotation.Annotation;
+import java.util.ArrayList;
+import java.util.List;
+
+class AnnotationTestRuleHelper {
+    private Class<?> mTestClass = null;
+    private TestRule mTestRule = null;
+    private TestCode mTestCode = null;
+    private List<Annotation> mMethodAnnotations = new ArrayList<>();
+
+    AnnotationTestRuleHelper(TestRule testRule) {
+        mTestRule = testRule;
+    }
+
+    /** set test class with annotations for the example test */
+    AnnotationTestRuleHelper setTestClass(Class<?> testClass) {
+        mTestClass = testClass;
+        return this;
+    }
+
+    /** Add {@code @RequiresFlagsEnabled} with the given arguments to the test method. */
+    AnnotationTestRuleHelper addRequiresFlagsEnabled(String... values) {
+        mMethodAnnotations.add(createRequiresFlagsEnabled(values));
+        return this;
+    }
+
+    /** Add {@code @RequiresFlagsDisabled} with the given arguments to the test method. */
+    AnnotationTestRuleHelper addRequiresFlagsDisabled(String... values) {
+        mMethodAnnotations.add(createRequiresFlagsDisabled(values));
+        return this;
+    }
+
+    /** Add {@code @EnableFlags} with the given arguments to the test method. */
+    AnnotationTestRuleHelper addEnableFlags(String... values) {
+        mMethodAnnotations.add(createEnableFlags(values));
+        return this;
+    }
+
+    /** Add {@code @DisableFlags} with the given arguments to the test method. */
+    AnnotationTestRuleHelper addDisableFlags(String... values) {
+        mMethodAnnotations.add(createDisableFlags(values));
+        return this;
+    }
+
+    /** produce a Description */
+    Description buildDescription() {
+        Annotation[] methodAnnotations = mMethodAnnotations.toArray(new Annotation[0]);
+        if (mTestClass == null) {
+            return Description.createTestDescription("testClass", "testMethod", methodAnnotations);
+        } else {
+            return Description.createTestDescription(mTestClass, "testMethod", methodAnnotations);
+        }
+    }
+
+    AnnotationTestRuleHelper setTestCode(TestCode testCode) {
+        mTestCode = testCode;
+        return this;
+    }
+
+    PreparedTest prepareTest() {
+        Statement testStatement =
+                new Statement() {
+                    public void evaluate() throws Throwable {
+                        if (mTestCode != null) {
+                            mTestCode.evaluate();
+                        }
+                    }
+                };
+        return new PreparedTest(mTestRule.apply(testStatement, buildDescription()));
+    }
+
+    static class PreparedTest {
+        private final Statement mStatement;
+
+        private PreparedTest(Statement statement) {
+            mStatement = statement;
+        }
+
+        Throwable runAndReturnFailure() {
+            try {
+                mStatement.evaluate();
+                return null;
+            } catch (Throwable throwable) {
+                return throwable;
+            }
+        }
+
+        void assertFails() {
+            assertFailsWithTypeAndMessage(AssertionError.class, null);
+        }
+
+        void assertFailsWithMessage(String expectedMessage) {
+            assertFailsWithTypeAndMessage(AssertionError.class, expectedMessage);
+        }
+
+        void assertFailsWithType(Class<? extends Throwable> expectedError) {
+            assertFailsWithTypeAndMessage(expectedError, null);
+        }
+
+        void assertFailsWithTypeAndMessage(
+                Class<? extends Throwable> expectedError, String expectedMessage) {
+            Throwable failure = runAndReturnFailure();
+            Assert.assertNotNull("Test was expected to fail", failure);
+            if (failure.getClass() != expectedError) {
+                throw new AssertionError(
+                        "Wrong error type; expected " + expectedError + " but was: " + failure,
+                        failure);
+            }
+            if (expectedMessage != null) {
+                String failureMessage = failure.getMessage();
+                if (failureMessage == null || !failureMessage.contains(expectedMessage)) {
+                    throw new AssertionError(
+                            "Failure message should contain \""
+                                    + expectedMessage
+                                    + "\" but was: "
+                                    + failureMessage,
+                            failure);
+                }
+            }
+        }
+
+        void assertSkipped() {
+            assertSkippedWithMessage(null);
+        }
+
+        void assertSkippedWithMessage(String expectedMessage) {
+            Throwable failure = runAndReturnFailure();
+            Assert.assertNotNull("Test was expected to be skipped but it ran and passed", failure);
+            if (failure.getClass() != AssumptionViolatedException.class) {
+                throw new AssertionError(
+                        "Test was expected to be skipped but it ran and failed: " + failure,
+                        failure);
+            }
+            if (expectedMessage != null) {
+                String skippedMessage = failure.getMessage();
+                if (skippedMessage == null || !skippedMessage.contains(expectedMessage)) {
+                    throw new AssertionError(
+                            "Test skip message should contain \""
+                                    + expectedMessage
+                                    + "\" but was: "
+                                    + skippedMessage,
+                            failure);
+                }
+            }
+        }
+
+        void assertPasses() {
+            try {
+                mStatement.evaluate();
+            } catch (Throwable failure) {
+                throw new AssertionError(
+                        "Test was expected to pass, but failed with error: " + failure, failure);
+            }
+        }
+    }
+
+    @AutoAnnotation
+    private static RequiresFlagsEnabled createRequiresFlagsEnabled(String[] value) {
+        return new AutoAnnotation_AnnotationTestRuleHelper_createRequiresFlagsEnabled(value);
+    }
+
+    @AutoAnnotation
+    private static RequiresFlagsDisabled createRequiresFlagsDisabled(String[] value) {
+        return new AutoAnnotation_AnnotationTestRuleHelper_createRequiresFlagsDisabled(value);
+    }
+
+    @AutoAnnotation
+    private static EnableFlags createEnableFlags(String[] value) {
+        return new AutoAnnotation_AnnotationTestRuleHelper_createEnableFlags(value);
+    }
+
+    @AutoAnnotation
+    private static DisableFlags createDisableFlags(String[] value) {
+        return new AutoAnnotation_AnnotationTestRuleHelper_createDisableFlags(value);
+    }
+
+    /**
+     * A variant of Junit's Statement class that is an interface, so that it can be implemented with
+     * a lambda
+     */
+    public interface TestCode {
+        void evaluate() throws Throwable;
+    }
+}
diff --git a/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/AnnotationsRetrieverTest.java b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/AnnotationsRetrieverTest.java
index 9c8eab9..88e6916 100644
--- a/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/AnnotationsRetrieverTest.java
+++ b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/AnnotationsRetrieverTest.java
@@ -17,7 +17,6 @@
 package android.platform.test.flag.junit;
 
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNull;
 
 import android.platform.test.annotations.RequiresFlagsDisabled;
 import android.platform.test.annotations.RequiresFlagsEnabled;
@@ -34,6 +33,7 @@
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
 import java.lang.annotation.Target;
+import java.util.Map;
 
 @RunWith(JUnit4.class)
 public class AnnotationsRetrieverTest {
@@ -57,6 +57,10 @@
     @RequiresFlagsDisabled({"flag3", "flag4"})
     static class TestClassHasAllAnnotations {}
 
+    @RequiresFlagsEnabled({"flag1"})
+    @RequiresFlagsDisabled({"flag1"})
+    static class TestClassHasConflictingAnnotations {}
+
     private final RequiresFlagsEnabled mRequiresFlagsEnabled =
             createRequiresFlagsEnabled(new String[]{"flag5"});
 
@@ -64,29 +68,37 @@
             createRequiresFlagsDisabled(new String[]{"flag6"});
 
     @Test
-    public void getFlagAnnotations_noAnnotation() {
+    public void noAnnotation() {
         AnnotationsRetriever.FlagAnnotations flagAnnotations =
                 getFlagAnnotations(TestClassHasNoAnnotation.class);
 
-        assertNull(flagAnnotations.mRequiresFlagsEnabled);
-        assertNull(flagAnnotations.mRequiresFlagsDisabled);
+        assertEquals(Map.of(), flagAnnotations.mRequiredFlagValues);
     }
 
     @Test
-    public void getFlagAnnotations_oneAnnotationFromMethod() {
+    public void oneAnnotationFromMethod() {
+        AnnotationsRetriever.FlagAnnotations flagAnnotations1 =
+                getFlagAnnotations(TestClassHasNoAnnotation.class, mRequiresFlagsEnabled);
+        AnnotationsRetriever.FlagAnnotations flagAnnotations2 =
+                getFlagAnnotations(TestClassHasNoAnnotation.class, mRequiresFlagsDisabled);
+
+        assertEquals(Map.of("flag5", true), flagAnnotations1.mRequiredFlagValues);
+        assertEquals(Map.of("flag6", false), flagAnnotations2.mRequiredFlagValues);
+    }
+
+    @Test
+    public void methodAnnotationsMergeWithClass() {
         AnnotationsRetriever.FlagAnnotations flagAnnotations1 =
                 getFlagAnnotations(TestClassHasRequiresFlagsEnabled.class, mRequiresFlagsEnabled);
         AnnotationsRetriever.FlagAnnotations flagAnnotations2 =
                 getFlagAnnotations(TestClassHasRequiresFlagsDisabled.class, mRequiresFlagsDisabled);
 
-        assertNull(flagAnnotations1.mRequiresFlagsDisabled);
         assertEquals(
-                flagAnnotations1.mRequiresFlagsEnabled,
-                createRequiresFlagsEnabled(new String[]{"flag5"}));
-        assertNull(flagAnnotations2.mRequiresFlagsEnabled);
+                Map.of("flag1", true, "flag2", true, "flag5", true),
+                flagAnnotations1.mRequiredFlagValues);
         assertEquals(
-                flagAnnotations2.mRequiresFlagsDisabled,
-                createRequiresFlagsDisabled(new String[]{"flag6"}));
+                Map.of("flag3", false, "flag4", false, "flag6", false),
+                flagAnnotations2.mRequiredFlagValues);
     }
 
     @Test
@@ -96,18 +108,23 @@
         AnnotationsRetriever.FlagAnnotations flagAnnotations2 =
                 getFlagAnnotations(TestClassHasRequiresFlagsDisabled.class);
 
-        assertNull(flagAnnotations1.mRequiresFlagsDisabled);
-        assertEquals(
-                flagAnnotations1.mRequiresFlagsEnabled,
-                createRequiresFlagsEnabled(new String[]{"flag1", "flag2"}));
-        assertNull(flagAnnotations2.mRequiresFlagsEnabled);
-        assertEquals(
-                flagAnnotations2.mRequiresFlagsDisabled,
-                createRequiresFlagsDisabled(new String[]{"flag3", "flag4"}));
+        assertEquals(Map.of("flag1", true, "flag2", true), flagAnnotations1.mRequiredFlagValues);
+        assertEquals(Map.of("flag3", false, "flag4", false), flagAnnotations2.mRequiredFlagValues);
     }
 
     @Test
-    public void getFlagAnnotations_twoAnnotationsFromMethod() {
+    public void bothAnnotationsFromMethod() {
+        AnnotationsRetriever.FlagAnnotations flagAnnotations =
+                getFlagAnnotations(
+                        TestClassHasNoAnnotation.class,
+                        mRequiresFlagsEnabled,
+                        mRequiresFlagsDisabled);
+
+        assertEquals(Map.of("flag5", true, "flag6", false), flagAnnotations.mRequiredFlagValues);
+    }
+
+    @Test
+    public void bothAnnotationsFromMethodMergesWithClass() {
         AnnotationsRetriever.FlagAnnotations flagAnnotations =
                 getFlagAnnotations(
                         TestClassHasAllAnnotations.class,
@@ -115,51 +132,76 @@
                         mRequiresFlagsDisabled);
 
         assertEquals(
-                flagAnnotations.mRequiresFlagsEnabled,
-                createRequiresFlagsEnabled(new String[]{"flag5"}));
-        assertEquals(
-                flagAnnotations.mRequiresFlagsDisabled,
-                createRequiresFlagsDisabled(new String[]{"flag6"}));
+                Map.of(
+                        "flag1", true, "flag2", true, "flag3", false, "flag4", false, "flag5", true,
+                        "flag6", false),
+                flagAnnotations.mRequiredFlagValues);
     }
 
     @Test
-    public void getFlagAnnotations_twoAnnotationsFromClass() {
+    public void bothAnnotationsFromClass() {
         AnnotationsRetriever.FlagAnnotations flagAnnotations =
                 getFlagAnnotations(TestClassHasAllAnnotations.class);
 
         assertEquals(
-                flagAnnotations.mRequiresFlagsEnabled,
-                createRequiresFlagsEnabled(new String[]{"flag1", "flag2"}));
-        assertEquals(
-                flagAnnotations.mRequiresFlagsDisabled,
-                createRequiresFlagsDisabled(new String[]{"flag3", "flag4"}));
+                Map.of("flag1", true, "flag2", true, "flag3", false, "flag4", false),
+                flagAnnotations.mRequiredFlagValues);
     }
 
     @Test
-    public void getFlagAnnotations_twoAnnotationsFromMethodAndClass() {
+    public void bothAnnotationsFromClassAndOneFromMethod() {
         AnnotationsRetriever.FlagAnnotations flagAnnotations =
                 getFlagAnnotations(TestClassHasAllAnnotations.class, mRequiresFlagsEnabled);
 
         assertEquals(
-                flagAnnotations.mRequiresFlagsEnabled,
-                createRequiresFlagsEnabled(new String[]{"flag5"}));
-        assertEquals(
-                flagAnnotations.mRequiresFlagsDisabled,
-                createRequiresFlagsDisabled(new String[]{"flag3", "flag4"}));
+                Map.of("flag1", true, "flag2", true, "flag3", false, "flag4", false, "flag5", true),
+                flagAnnotations.mRequiredFlagValues);
+    }
+
+    @Test(expected = AssertionError.class)
+    public void conflictingClassAnnotationsThrows() {
+        getFlagAnnotations(TestClassHasConflictingAnnotations.class);
+    }
+
+    @Test(expected = AssertionError.class)
+    public void conflictingMethodAnnotationsThrows() {
+        getFlagAnnotations(
+                TestClassHasNoAnnotation.class,
+                createRequiresFlagsEnabled(new String[] {"flag1"}),
+                createRequiresFlagsDisabled(new String[] {"flag1"}));
+    }
+
+    @Test(expected = AssertionError.class)
+    public void methodValuesFailsOnOverrideClassValues() {
+        getFlagAnnotations(
+                TestClassHasAllAnnotations.class,
+                createRequiresFlagsEnabled(new String[] {"flag3"}),
+                createRequiresFlagsDisabled(new String[] {"flag1"}));
     }
 
     @Test
-    public void getFlagAnnotations_recursively() {
+    public void getFlagAnnotationsRecursively() {
         AnnotationsRetriever.FlagAnnotations flagAnnotations =
                 getFlagAnnotations(
-                        TestClassHasAllAnnotations.class, createCompositeFlagRequirements());
+                        TestClassHasNoAnnotation.class, createCompositeFlagRequirements());
+
+        assertEquals(Map.of("flag1", true, "flag2", false), flagAnnotations.mRequiredFlagValues);
+    }
+
+    @Test(expected = AssertionError.class)
+    public void getFlagAnnotationsRecursivelyFailsOnOverride() {
+        getFlagAnnotations(TestClassHasAllAnnotations.class, createCompositeFlagRequirements());
+    }
+
+    @Test
+    public void getFlagAnnotationsRecursivelyMergesWithClass() {
+        AnnotationsRetriever.FlagAnnotations flagAnnotations =
+                getFlagAnnotations(
+                        TestClassHasRequiresFlagsDisabled.class, createCompositeFlagRequirements());
 
         assertEquals(
-                flagAnnotations.mRequiresFlagsEnabled,
-                createRequiresFlagsEnabled(new String[] {"flag1"}));
-        assertEquals(
-                flagAnnotations.mRequiresFlagsDisabled,
-                createRequiresFlagsDisabled(new String[] {"flag2"}));
+                Map.of("flag1", true, "flag2", false, "flag3", false, "flag4", false),
+                flagAnnotations.mRequiredFlagValues);
     }
 
     private AnnotationsRetriever.FlagAnnotations getFlagAnnotations(
diff --git a/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/CheckFlagsRuleTest.java b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/CheckFlagsRuleTest.java
index 256fbf2..00d82b4 100644
--- a/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/CheckFlagsRuleTest.java
+++ b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/CheckFlagsRuleTest.java
@@ -16,102 +16,244 @@
 
 package android.platform.test.flag.junit;
 
-import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.fail;
 
+import android.platform.test.annotations.DisableFlags;
+import android.platform.test.annotations.EnableFlags;
 import android.platform.test.annotations.RequiresFlagsDisabled;
 import android.platform.test.annotations.RequiresFlagsEnabled;
 import android.platform.test.flag.util.FlagReadException;
 
-import org.junit.BeforeClass;
-import org.junit.FixMethodOrder;
-import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
-import org.junit.runners.MethodSorters;
 
-import java.lang.reflect.Method;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.stream.Collectors;
-
-/**
- * Tests for {@code CheckFlagsRule}. Test MUST be ended with '_execute' if it is not going to be
- * skipped.
- */
+/** Tests for {@code CheckFlagsRule}. */
 @RunWith(JUnit4.class)
-@RequiresFlagsEnabled("flag1")
-@FixMethodOrder(MethodSorters.NAME_ASCENDING)
-public class CheckFlagsRuleTest {
-    private static final List<String> EXPECTED_TESTS_EXECUTED =
-            Arrays.stream(CheckFlagsRuleTest.class.getDeclaredMethods())
-                    .map(Method::getName)
-                    .filter(methodName -> methodName.endsWith("_execute"))
-                    .sorted()
-                    .collect(Collectors.toList());
+public final class CheckFlagsRuleTest {
+    private final CheckFlagsRule mRule =
+            new CheckFlagsRule(
+                    new IFlagsValueProvider() {
+                        @Override
+                        public boolean getBoolean(String flag) throws FlagReadException {
+                            switch (flag) {
+                                case "flag0":
+                                    return false;
+                                case "flag1":
+                                case "flag2":
+                                    return true;
+                                default:
+                                    throw new FlagReadException(flag, "flag not defined");
+                            }
+                        }
+                    });
 
-    private static final List<String> ACTUAL_TESTS_EXECUTED = new ArrayList<>();
-
-    private final IFlagsValueProvider mFlagsValueProvider =
-            new IFlagsValueProvider() {
-                @Override
-                public boolean getBoolean(String flag) throws FlagReadException {
-                    switch (flag) {
-                        case "flag1":
-                            return true;
-                        case "flag2":
-                            return false;
-                        default:
-                            throw new FlagReadException("flag3", "expected boolean but a String");
-                    }
-                }
-            };
-
-    @Rule public final CheckFlagsRule mCheckFlagsRule = new CheckFlagsRule(mFlagsValueProvider);
-
-    @BeforeClass
-    public static void clearList() {
-        ACTUAL_TESTS_EXECUTED.clear();
+    @Test
+    public void emptyTestWithoutAnnotationsPasses() {
+        new AnnotationTestRuleHelper(mRule).prepareTest().assertPasses();
     }
 
     @Test
-    public void noAnnotation_execute() {
-        ACTUAL_TESTS_EXECUTED.add("noAnnotation_execute");
+    public void emptyFailingWithoutAnnotationsFails() {
+        new AnnotationTestRuleHelper(mRule)
+                .setTestCode(
+                        () -> {
+                            fail();
+                        })
+                .prepareTest()
+                .assertFails();
     }
 
     @Test
-    @RequiresFlagsEnabled("flag2")
-    public void methodAnnotationOverrideClassAnnotation_skip() {
-        // Should be skipped.
-        fail();
+    public void usingEnableFlagsOnClassForDifferentFlagPasses() {
+        @EnableFlags("flag0")
+        class SomeClass {}
+        new AnnotationTestRuleHelper(mRule)
+                .setTestClass(SomeClass.class)
+                .addRequiresFlagsEnabled("flag1")
+                .prepareTest()
+                .assertPasses();
     }
 
     @Test
-    @RequiresFlagsEnabled("flag1")
-    @RequiresFlagsDisabled("flag2")
-    public void requireBothEnabledAndDisabledFlags_execute() {
-        ACTUAL_TESTS_EXECUTED.add("requireBothEnabledAndDisabledFlags_execute");
+    public void usingDisableFlagsOnClassForDifferentFlagPasses() {
+        @DisableFlags("flag0")
+        class SomeClass {}
+        new AnnotationTestRuleHelper(mRule)
+                .setTestClass(SomeClass.class)
+                .addRequiresFlagsEnabled("flag1")
+                .prepareTest()
+                .assertPasses();
     }
 
     @Test
-    @RequiresFlagsDisabled("flag1")
-    public void requiredEnabledFlagDisabled_skip() {
-        // Should be skipped.
-        fail();
+    public void usingEnableFlagsOnMethodForDifferentFlagPasses() {
+        new AnnotationTestRuleHelper(mRule)
+                .addEnableFlags("flag0")
+                .addRequiresFlagsEnabled("flag1")
+                .prepareTest()
+                .assertPasses();
     }
 
     @Test
-    @RequiresFlagsEnabled({"flag1", "flag2"})
-    public void requiredDisabledFlagEnabled_skip() {
-        // Should be skipped.
-        fail();
+    public void usingDisableFlagsOnMethodForDifferentFlagPasses() {
+        new AnnotationTestRuleHelper(mRule)
+                .addDisableFlags("flag0")
+                .addRequiresFlagsEnabled("flag1")
+                .prepareTest()
+                .assertPasses();
     }
 
     @Test
-    public void zLastTest_checkExecutedTests() { // Starts the method name with 'z' so that
-        // it will be the last test to get executed.
-        assertArrayEquals(EXPECTED_TESTS_EXECUTED.toArray(), ACTUAL_TESTS_EXECUTED.toArray());
+    public void usingEnableFlagsOnClassForSameFlagFails() {
+        @EnableFlags("flag1")
+        class SomeClass {}
+        new AnnotationTestRuleHelper(mRule)
+                .setTestClass(SomeClass.class)
+                .addRequiresFlagsEnabled("flag1")
+                .prepareTest()
+                .assertFails();
+    }
+
+    @Test
+    public void usingDisableFlagsOnClassForSameFlagFails() {
+        @DisableFlags("flag1")
+        class SomeClass {}
+        new AnnotationTestRuleHelper(mRule)
+                .setTestClass(SomeClass.class)
+                .addRequiresFlagsEnabled("flag1")
+                .prepareTest()
+                .assertFails();
+    }
+
+    @Test
+    public void usingEnableFlagsOnMethodForSameFlagFails() {
+        new AnnotationTestRuleHelper(mRule)
+                .addEnableFlags("flag1")
+                .addRequiresFlagsEnabled("flag1")
+                .prepareTest()
+                .assertFails();
+    }
+
+    @Test
+    public void usingDisableFlagsOnMethodForSameFlagFails() {
+        new AnnotationTestRuleHelper(mRule)
+                .addDisableFlags("flag1")
+                .addRequiresFlagsEnabled("flag1")
+                .prepareTest()
+                .assertFails();
+    }
+
+    @Test
+    public void usingRequiresFlagsEnabledFlag1OnClassPasses() {
+        @RequiresFlagsEnabled("flag1")
+        class SomeClass {}
+        new AnnotationTestRuleHelper(mRule)
+                .setTestClass(SomeClass.class)
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void usingRequiresFlagsDisabledFlag1OnClassSkipped() {
+        @RequiresFlagsDisabled("flag1")
+        class SomeClass {}
+        new AnnotationTestRuleHelper(mRule)
+                .setTestClass(SomeClass.class)
+                .prepareTest()
+                .assertSkipped();
+    }
+
+    @Test
+    public void usingRequiresFlagsEnabledFlag1OnMethodPasses() {
+        new AnnotationTestRuleHelper(mRule)
+                .addRequiresFlagsEnabled("flag1")
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void usingRequiresFlagsDisabledFlag1OnMethodSkipped() {
+        new AnnotationTestRuleHelper(mRule)
+                .addRequiresFlagsDisabled("flag1")
+                .prepareTest()
+                .assertSkipped();
+    }
+
+    @Test
+    public void requiringNonexistentFlagEnabledFails() {
+        new AnnotationTestRuleHelper(mRule)
+                .addRequiresFlagsEnabled("nonexistent")
+                .prepareTest()
+                .assertFailsWithTypeAndMessage(FlagReadException.class, "nonexistent");
+    }
+
+    @Test
+    public void requiringNonexistentFlagDisabledFails() {
+        @RequiresFlagsDisabled("nonexistent")
+        class SomeClass {}
+        new AnnotationTestRuleHelper(mRule)
+                .setTestClass(SomeClass.class)
+                .prepareTest()
+                .assertFailsWithTypeAndMessage(FlagReadException.class, "nonexistent");
+    }
+
+    @Test
+    public void conflictingClassAnnotationsFails() {
+        @RequiresFlagsEnabled({"flag1", "flag0"})
+        @RequiresFlagsDisabled("flag1")
+        class SomeClass {}
+        new AnnotationTestRuleHelper(mRule)
+                .setTestClass(SomeClass.class)
+                .prepareTest()
+                .assertFails();
+    }
+
+    @Test
+    public void conflictingMethodAnnotationsFails() {
+        new AnnotationTestRuleHelper(mRule)
+                .addRequiresFlagsEnabled("flag1", "flag0")
+                .addRequiresFlagsDisabled("flag1")
+                .prepareTest()
+                .assertFails();
+    }
+
+    @Test
+    public void conflictingAnnotationsAcrossMethodAndClassFails() {
+        @RequiresFlagsDisabled("flag1")
+        class SomeClass {}
+        new AnnotationTestRuleHelper(mRule)
+                .setTestClass(SomeClass.class)
+                .addRequiresFlagsEnabled("flag1", "flag0")
+                .prepareTest()
+                .assertFails();
+    }
+
+    @Test
+    public void canDuplicateFlagAcrossMethodAndClassAnnotations() {
+        @RequiresFlagsEnabled("flag1")
+        class SomeClass {}
+        new AnnotationTestRuleHelper(mRule)
+                .setTestClass(SomeClass.class)
+                .addRequiresFlagsEnabled("flag1", "flag2")
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void requiringAllFlagsEnabledSkipped() {
+        new AnnotationTestRuleHelper(mRule)
+                .addRequiresFlagsEnabled("flag0", "flag1", "flag2")
+                .prepareTest()
+                .assertSkipped();
+    }
+
+    @Test
+    public void mixedRequirementsWithOneMissedSkipped() {
+        new AnnotationTestRuleHelper(mRule)
+                .addRequiresFlagsEnabled("flag1")
+                .addRequiresFlagsDisabled("flag0", "flag2")
+                .prepareTest()
+                .assertSkipped();
     }
 }
diff --git a/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/FlagsParameterizationTest.java b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/FlagsParameterizationTest.java
new file mode 100644
index 0000000..589c652
--- /dev/null
+++ b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/FlagsParameterizationTest.java
@@ -0,0 +1,156 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.platform.test.flag.junit;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertTrue;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+@RunWith(JUnit4.class)
+public class FlagsParameterizationTest {
+    @Test
+    public void toStringIsOrderIndependent() {
+        FlagsParameterization params1 = new FlagsParameterization(Map.of("a", true, "b", false));
+        FlagsParameterization params2 = new FlagsParameterization(Map.of("b", false, "a", true));
+        assertEquals(params1.toString(), params2.toString());
+        assertEquals(params1, params2);
+    }
+
+    @Test
+    public void toStringIsDifferent() {
+        FlagsParameterization params1 = new FlagsParameterization(Map.of("a", true, "b", false));
+        FlagsParameterization params2 = new FlagsParameterization(Map.of("a", true, "b", true));
+        assertNotEquals(params1.toString(), params2.toString());
+        assertNotEquals(params1, params2);
+    }
+
+    @Test
+    public void toStringIsNotEmpty() {
+        FlagsParameterization params = new FlagsParameterization(Map.of());
+        assertNotEquals("", params.toString());
+    }
+
+    @Test(expected = UnsupportedOperationException.class)
+    public void overridesCannotBeChanged() {
+        Map<String, Boolean> original = new HashMap<>();
+        original.put("foo", true);
+        FlagsParameterization params = new FlagsParameterization(original);
+        params.mOverrides.put("foo", false);
+    }
+
+    @Test
+    public void overridesIsACopy() {
+        Map<String, Boolean> original = new HashMap<>();
+        original.put("foo", true);
+        FlagsParameterization params = new FlagsParameterization(original);
+        original.put("foo", false);
+        assertTrue(params.mOverrides.get("foo"));
+    }
+
+    @Test
+    public void allCombinationsWith0Flags() {
+        List<FlagsParameterization> actual = FlagsParameterization.allCombinationsOf();
+        List<FlagsParameterization> expected = List.of(new FlagsParameterization(Map.of()));
+        assertEquals(expected, actual);
+    }
+
+    @Test
+    public void allCombinationsWith1Flags() {
+        List<FlagsParameterization> actual = FlagsParameterization.allCombinationsOf("a");
+        List<FlagsParameterization> expected =
+                List.of(
+                        new FlagsParameterization(Map.of("a", false)),
+                        new FlagsParameterization(Map.of("a", true)));
+        assertEquals(expected, actual);
+    }
+
+    @Test
+    public void allCombinationsWith2Flags() {
+        List<FlagsParameterization> actual = FlagsParameterization.allCombinationsOf("a", "b");
+        List<FlagsParameterization> expected =
+                List.of(
+                        new FlagsParameterization(Map.of("a", false, "b", false)),
+                        new FlagsParameterization(Map.of("a", false, "b", true)),
+                        new FlagsParameterization(Map.of("a", true, "b", false)),
+                        new FlagsParameterization(Map.of("a", true, "b", true)));
+        assertEquals(expected, actual);
+    }
+
+    @Test
+    public void allCombinationsWith3Flags() {
+        List<FlagsParameterization> actual = FlagsParameterization.allCombinationsOf("a", "b", "c");
+        List<FlagsParameterization> expected =
+                List.of(
+                        new FlagsParameterization(Map.of("a", false, "b", false, "c", false)),
+                        new FlagsParameterization(Map.of("a", false, "b", false, "c", true)),
+                        new FlagsParameterization(Map.of("a", false, "b", true, "c", false)),
+                        new FlagsParameterization(Map.of("a", false, "b", true, "c", true)),
+                        new FlagsParameterization(Map.of("a", true, "b", false, "c", false)),
+                        new FlagsParameterization(Map.of("a", true, "b", false, "c", true)),
+                        new FlagsParameterization(Map.of("a", true, "b", true, "c", false)),
+                        new FlagsParameterization(Map.of("a", true, "b", true, "c", true)));
+        assertEquals(expected, actual);
+    }
+
+    @Test
+    public void progressionWith0Flags() {
+        List<FlagsParameterization> actual = FlagsParameterization.progressionOf();
+        List<FlagsParameterization> expected = List.of(new FlagsParameterization(Map.of()));
+        assertEquals(expected, actual);
+    }
+
+    @Test
+    public void progressionWith1Flags() {
+        List<FlagsParameterization> actual = FlagsParameterization.progressionOf("a");
+        List<FlagsParameterization> expected =
+                List.of(
+                        new FlagsParameterization(Map.of("a", false)),
+                        new FlagsParameterization(Map.of("a", true)));
+        assertEquals(expected, actual);
+    }
+
+    @Test
+    public void progressionWith2Flags() {
+        List<FlagsParameterization> actual = FlagsParameterization.progressionOf("a", "b");
+        List<FlagsParameterization> expected =
+                List.of(
+                        new FlagsParameterization(Map.of("a", false, "b", false)),
+                        new FlagsParameterization(Map.of("a", true, "b", false)),
+                        new FlagsParameterization(Map.of("a", true, "b", true)));
+        assertEquals(expected, actual);
+    }
+
+    @Test
+    public void progressionWith3Flags() {
+        List<FlagsParameterization> actual = FlagsParameterization.progressionOf("a", "b", "c");
+        List<FlagsParameterization> expected =
+                List.of(
+                        new FlagsParameterization(Map.of("a", false, "b", false, "c", false)),
+                        new FlagsParameterization(Map.of("a", true, "b", false, "c", false)),
+                        new FlagsParameterization(Map.of("a", true, "b", true, "c", false)),
+                        new FlagsParameterization(Map.of("a", true, "b", true, "c", true)));
+        assertEquals(expected, actual);
+    }
+}
diff --git a/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/SetFlagsRuleAnnotationsTest.java b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/SetFlagsRuleAnnotationsTest.java
new file mode 100644
index 0000000..bc91058
--- /dev/null
+++ b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/SetFlagsRuleAnnotationsTest.java
@@ -0,0 +1,489 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.platform.test.flag.junit;
+
+import static android.platform.test.flag.junit.SetFlagsRule.DefaultInitValueType.DEVICE_DEFAULT;
+import static android.platform.test.flag.junit.SetFlagsRule.DefaultInitValueType.NULL_DEFAULT;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import android.platform.test.annotations.DisableFlags;
+import android.platform.test.annotations.EnableFlags;
+import android.platform.test.annotations.RequiresFlagsDisabled;
+import android.platform.test.annotations.RequiresFlagsEnabled;
+import android.platform.test.flag.util.FlagSetException;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.io.IOException;
+import java.util.Map;
+
+/** Unit tests for {@code SetFlagsRule} being used with annotations. */
+@RunWith(JUnit4.class)
+public final class SetFlagsRuleAnnotationsTest {
+
+    @Test
+    public void emptyTestWithoutAnnotationsPasses() {
+        new AnnotationTestRuleHelper(new SetFlagsRule()).prepareTest().assertPasses();
+    }
+
+    @Test
+    public void throwingTestWithoutAnnotationsThrows() {
+        new AnnotationTestRuleHelper(new SetFlagsRule())
+                .setTestCode(
+                        () -> {
+                            throw new IOException("a test about foo");
+                        })
+                .prepareTest()
+                .assertFailsWithTypeAndMessage(IOException.class, "foo");
+    }
+
+    @Test
+    public void unsetFlagsWithNullDefaultPassIfNoFlagsInPackageAreSet() {
+        new AnnotationTestRuleHelper(new SetFlagsRule(NULL_DEFAULT))
+                .setTestCode(
+                        () -> {
+                            assertFalse(Flags.flagName3());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void setFlagsWithNullDefaultCanBeRead() {
+        new AnnotationTestRuleHelper(new SetFlagsRule(NULL_DEFAULT))
+                .addEnableFlags(Flags.FLAG_FLAG_NAME3)
+                .setTestCode(
+                        () -> {
+                            assertTrue(Flags.flagName3());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void unsetFlagsWithNullDefaultFailToRead() {
+        new AnnotationTestRuleHelper(new SetFlagsRule(NULL_DEFAULT))
+                .addEnableFlags(Flags.FLAG_FLAG_NAME4)
+                .setTestCode(
+                        () -> {
+                            Flags.flagName3();
+                        })
+                .prepareTest()
+                .assertFailsWithType(NullPointerException.class);
+    }
+
+    @Test
+    public void classAnnotationsAreHandled() {
+        @EnableFlags(Flags.FLAG_FLAG_NAME3)
+        class SomeClass {}
+        new AnnotationTestRuleHelper(new SetFlagsRule(NULL_DEFAULT))
+                .setTestClass(SomeClass.class)
+                .setTestCode(
+                        () -> {
+                            assertTrue(Flags.flagName3());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void enablingFlagEnabledByAnnotationFails() {
+        SetFlagsRule setFlagsRule = new SetFlagsRule();
+        new AnnotationTestRuleHelper(setFlagsRule)
+                .addEnableFlags(Flags.FLAG_FLAG_NAME3)
+                .setTestCode(
+                        () -> {
+                            setFlagsRule.enableFlags(Flags.FLAG_FLAG_NAME3);
+                        })
+                .prepareTest()
+                .assertFailsWithType(FlagSetException.class);
+    }
+
+    @Test
+    public void disablingFlagEnabledByAnnotationFails() {
+        SetFlagsRule setFlagsRule = new SetFlagsRule();
+        new AnnotationTestRuleHelper(setFlagsRule)
+                .addEnableFlags(Flags.FLAG_FLAG_NAME3)
+                .setTestCode(
+                        () -> {
+                            setFlagsRule.disableFlags(Flags.FLAG_FLAG_NAME3);
+                        })
+                .prepareTest()
+                .assertFailsWithType(FlagSetException.class);
+    }
+
+    @Test
+    public void enablingFlagDisabledByAnnotationFails() {
+        SetFlagsRule setFlagsRule = new SetFlagsRule();
+        new AnnotationTestRuleHelper(setFlagsRule)
+                .addDisableFlags(Flags.FLAG_FLAG_NAME3)
+                .setTestCode(
+                        () -> {
+                            setFlagsRule.enableFlags(Flags.FLAG_FLAG_NAME3);
+                        })
+                .prepareTest()
+                .assertFailsWithType(FlagSetException.class);
+    }
+
+    @Test
+    public void disablingFlagDisabledByAnnotationFails() {
+        SetFlagsRule setFlagsRule = new SetFlagsRule();
+        new AnnotationTestRuleHelper(setFlagsRule)
+                .addDisableFlags(Flags.FLAG_FLAG_NAME3)
+                .setTestCode(
+                        () -> {
+                            setFlagsRule.disableFlags(Flags.FLAG_FLAG_NAME3);
+                        })
+                .prepareTest()
+                .assertFailsWithType(FlagSetException.class);
+    }
+
+    @Test
+    public void enablingDifferentFlagThanAnnotationPasses() {
+        SetFlagsRule setFlagsRule = new SetFlagsRule();
+        new AnnotationTestRuleHelper(setFlagsRule)
+                .addDisableFlags(Flags.FLAG_FLAG_NAME3)
+                .setTestCode(
+                        () -> {
+                            setFlagsRule.enableFlags(Flags.FLAG_FLAG_NAME4);
+                            assertFalse(Flags.flagName3());
+                            assertTrue(Flags.flagName4());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void disablingDifferentFlagThanAnnotationPasses() {
+        SetFlagsRule setFlagsRule = new SetFlagsRule();
+        new AnnotationTestRuleHelper(setFlagsRule)
+                .addDisableFlags(Flags.FLAG_FLAG_NAME3)
+                .setTestCode(
+                        () -> {
+                            setFlagsRule.disableFlags(Flags.FLAG_FLAG_NAME4);
+                            assertFalse(Flags.flagName3());
+                            assertFalse(Flags.flagName4());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void conflictingClassAnnotationsFails() {
+        @EnableFlags({Flags.FLAG_FLAG_NAME3, Flags.FLAG_FLAG_NAME4})
+        @DisableFlags(Flags.FLAG_FLAG_NAME3)
+        class SomeClass {}
+        new AnnotationTestRuleHelper(new SetFlagsRule())
+                .setTestClass(SomeClass.class)
+                .prepareTest()
+                .assertFails();
+    }
+
+    @Test
+    public void conflictingMethodAnnotationsFails() {
+        new AnnotationTestRuleHelper(new SetFlagsRule())
+                .addEnableFlags(Flags.FLAG_FLAG_NAME3, Flags.FLAG_FLAG_NAME4)
+                .addDisableFlags(Flags.FLAG_FLAG_NAME3)
+                .prepareTest()
+                .assertFails();
+    }
+
+    @Test
+    public void conflictingAnnotationsAcrossMethodAndClassFails() {
+        @DisableFlags(Flags.FLAG_FLAG_NAME3)
+        class SomeClass {}
+        new AnnotationTestRuleHelper(new SetFlagsRule())
+                .setTestClass(SomeClass.class)
+                .addEnableFlags(Flags.FLAG_FLAG_NAME3, Flags.FLAG_FLAG_NAME4)
+                .prepareTest()
+                .assertFails();
+    }
+
+    @Test
+    public void canDuplicateFlagAcrossMethodAndClassAnnotations() {
+        @EnableFlags(Flags.FLAG_FLAG_NAME3)
+        class SomeClass {}
+        new AnnotationTestRuleHelper(new SetFlagsRule())
+                .setTestClass(SomeClass.class)
+                .addEnableFlags(Flags.FLAG_FLAG_NAME3, Flags.FLAG_FLAG_NAME4)
+                .setTestCode(
+                        () -> {
+                            assertTrue(Flags.flagName3());
+                            assertTrue(Flags.flagName4());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void canSetFlagsWithClass() {
+        @EnableFlags(Flags.FLAG_FLAG_NAME3)
+        @DisableFlags(Flags.FLAG_FLAG_NAME4)
+        class SomeClass {}
+        new AnnotationTestRuleHelper(new SetFlagsRule())
+                .setTestClass(SomeClass.class)
+                .setTestCode(
+                        () -> {
+                            assertTrue(Flags.flagName3());
+                            assertFalse(Flags.flagName4());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void canSetFlagsWithMethod() {
+        new AnnotationTestRuleHelper(new SetFlagsRule())
+                .addEnableFlags(Flags.FLAG_FLAG_NAME3)
+                .addDisableFlags(Flags.FLAG_FLAG_NAME4)
+                .setTestCode(
+                        () -> {
+                            assertTrue(Flags.flagName3());
+                            assertFalse(Flags.flagName4());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void canSetFlagsWithClassAndMethod() {
+        @EnableFlags(Flags.FLAG_FLAG_NAME3)
+        @DisableFlags(Flags.FLAG_FLAG_NAME4)
+        class SomeClass {}
+        new AnnotationTestRuleHelper(new SetFlagsRule())
+                .setTestClass(SomeClass.class)
+                .addEnableFlags(Flags.FLAG_FLAG_NAME3)
+                .addDisableFlags(Flags.FLAG_FLAG_NAME4)
+                .setTestCode(
+                        () -> {
+                            assertTrue(Flags.flagName3());
+                            assertFalse(Flags.flagName4());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void requiresFlagsOnClassPasses() {
+        @RequiresFlagsEnabled(Flags.FLAG_FLAG_NAME3)
+        @RequiresFlagsDisabled(Flags.FLAG_FLAG_NAME4)
+        class SomeClass {}
+        new AnnotationTestRuleHelper(new SetFlagsRule())
+                .setTestClass(SomeClass.class)
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void requiresFlagsOnMethodPasses() {
+        new AnnotationTestRuleHelper(new SetFlagsRule())
+                .addRequiresFlagsEnabled(Flags.FLAG_FLAG_NAME3)
+                .addRequiresFlagsDisabled(Flags.FLAG_FLAG_NAME4)
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void settingByAnnotationRequiredFlagsFails() {
+        SetFlagsRule rule = new SetFlagsRule();
+        new AnnotationTestRuleHelper(rule)
+                .addRequiresFlagsEnabled(Flags.FLAG_FLAG_NAME3)
+                .addEnableFlags(Flags.FLAG_FLAG_NAME3)
+                .prepareTest()
+                .assertFails();
+    }
+
+    @Test
+    public void settingDirectlyRequiredFlagsFails() {
+        SetFlagsRule rule = new SetFlagsRule();
+        new AnnotationTestRuleHelper(rule)
+                .addRequiresFlagsEnabled(Flags.FLAG_FLAG_NAME3)
+                .setTestCode(
+                        () -> {
+                            rule.enableFlags(Flags.FLAG_FLAG_NAME3);
+                        })
+                .prepareTest()
+                .assertFailsWithType(FlagSetException.class);
+    }
+
+    @Test
+    public void paramEnabledFlagsGetEnabled() {
+        FlagsParameterization params =
+                new FlagsParameterization(Map.of(Flags.FLAG_FLAG_NAME3, true));
+        new AnnotationTestRuleHelper(new SetFlagsRule(DEVICE_DEFAULT, params))
+                .setTestCode(
+                        () -> {
+                            assertTrue(Flags.flagName3());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void paramDisabledFlagsGetDisabled() {
+        FlagsParameterization params =
+                new FlagsParameterization(Map.of(Flags.FLAG_FLAG_NAME3, false));
+        new AnnotationTestRuleHelper(new SetFlagsRule(DEVICE_DEFAULT, params))
+                .setTestCode(
+                        () -> {
+                            assertFalse(Flags.flagName3());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void paramSetFlagsGetSet() {
+        FlagsParameterization params =
+                new FlagsParameterization(
+                        Map.of(Flags.FLAG_FLAG_NAME3, true, Flags.FLAG_FLAG_NAME4, false));
+        new AnnotationTestRuleHelper(new SetFlagsRule(DEVICE_DEFAULT, params))
+                .setTestCode(
+                        () -> {
+                            assertTrue(Flags.flagName3());
+                            assertFalse(Flags.flagName4());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void settingAnyFlagBeforeParameterizationIsAppliedFails() {
+        FlagsParameterization params =
+                new FlagsParameterization(Map.of(Flags.FLAG_FLAG_NAME3, true));
+        SetFlagsRule setFlagsRule = new SetFlagsRule(DEVICE_DEFAULT, params);
+        setFlagsRule.enableFlags(Flags.FLAG_FLAG_NAME4);
+        new AnnotationTestRuleHelper(setFlagsRule).prepareTest().assertFails();
+    }
+
+    @Test
+    public void settingParameterizedFlagFails() {
+        FlagsParameterization params =
+                new FlagsParameterization(Map.of(Flags.FLAG_FLAG_NAME3, true));
+        SetFlagsRule setFlagsRule = new SetFlagsRule(DEVICE_DEFAULT, params);
+        new AnnotationTestRuleHelper(setFlagsRule)
+                .setTestCode(
+                        () -> {
+                            setFlagsRule.enableFlags(Flags.FLAG_FLAG_NAME3);
+                        })
+                .prepareTest()
+                .assertFailsWithType(FlagSetException.class);
+    }
+
+    @Test
+    public void settingNonParameterizedFlagDirectlyWorks() {
+        FlagsParameterization params =
+                new FlagsParameterization(Map.of(Flags.FLAG_FLAG_NAME3, true));
+        SetFlagsRule setFlagsRule = new SetFlagsRule(DEVICE_DEFAULT, params);
+        new AnnotationTestRuleHelper(setFlagsRule)
+                .setTestCode(
+                        () -> {
+                            setFlagsRule.enableFlags(Flags.FLAG_FLAG_NAME4);
+                            assertTrue(Flags.flagName3());
+                            assertTrue(Flags.flagName4());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void settingNonParameterizedFlagByAnnotationWorks() {
+        FlagsParameterization params =
+                new FlagsParameterization(Map.of(Flags.FLAG_FLAG_NAME3, true));
+        new AnnotationTestRuleHelper(new SetFlagsRule(DEVICE_DEFAULT, params))
+                .addEnableFlags(Flags.FLAG_FLAG_NAME4)
+                .setTestCode(
+                        () -> {
+                            assertTrue(Flags.flagName3());
+                            assertTrue(Flags.flagName4());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void paramEnabledFlagsCantBeRequiredEnabledByAnnotation() {
+        FlagsParameterization params =
+                new FlagsParameterization(Map.of(Flags.FLAG_FLAG_NAME3, true));
+        new AnnotationTestRuleHelper(new SetFlagsRule(DEVICE_DEFAULT, params))
+                .addRequiresFlagsEnabled(Flags.FLAG_FLAG_NAME3)
+                .prepareTest()
+                .assertFails();
+    }
+
+    @Test
+    public void paramDisabledFlagsCantBeRequiredEnabledByAnnotation() {
+        FlagsParameterization params =
+                new FlagsParameterization(Map.of(Flags.FLAG_FLAG_NAME3, false));
+        new AnnotationTestRuleHelper(new SetFlagsRule(DEVICE_DEFAULT, params))
+                .addRequiresFlagsEnabled(Flags.FLAG_FLAG_NAME3)
+                .prepareTest()
+                .assertFails();
+    }
+
+    @Test
+    public void paramEnabledFlagsRunWhenAnnotationEnablesFlag() {
+        FlagsParameterization params =
+                new FlagsParameterization(Map.of(Flags.FLAG_FLAG_NAME3, true));
+        new AnnotationTestRuleHelper(new SetFlagsRule(DEVICE_DEFAULT, params))
+                .addEnableFlags(Flags.FLAG_FLAG_NAME3)
+                .setTestCode(
+                        () -> {
+                            assertTrue(Flags.flagName3());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void paramDisabledFlagsRunWhenAnnotationDisablesFlag() {
+        FlagsParameterization params =
+                new FlagsParameterization(Map.of(Flags.FLAG_FLAG_NAME3, false));
+        new AnnotationTestRuleHelper(new SetFlagsRule(DEVICE_DEFAULT, params))
+                .addDisableFlags(Flags.FLAG_FLAG_NAME3)
+                .setTestCode(
+                        () -> {
+                            assertFalse(Flags.flagName3());
+                        })
+                .prepareTest()
+                .assertPasses();
+    }
+
+    @Test
+    public void paramDisabledFlagsSkipWhenAnnotationEnablesFlag() {
+        FlagsParameterization params =
+                new FlagsParameterization(Map.of(Flags.FLAG_FLAG_NAME3, false));
+        new AnnotationTestRuleHelper(new SetFlagsRule(DEVICE_DEFAULT, params))
+                .addEnableFlags(Flags.FLAG_FLAG_NAME3)
+                .prepareTest()
+                .assertSkipped();
+    }
+
+    @Test
+    public void paramEnabledFlagsSkipWhenAnnotationDisablesFlag() {
+        FlagsParameterization params =
+                new FlagsParameterization(Map.of(Flags.FLAG_FLAG_NAME3, true));
+        new AnnotationTestRuleHelper(new SetFlagsRule(DEVICE_DEFAULT, params))
+                .addDisableFlags(Flags.FLAG_FLAG_NAME3)
+                .prepareTest()
+                .assertSkipped();
+    }
+}
diff --git a/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/SetFlagsRuleTest.java b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/SetFlagsRuleTest.java
index 6197dae..352299c 100644
--- a/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/SetFlagsRuleTest.java
+++ b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/SetFlagsRuleTest.java
@@ -27,7 +27,7 @@
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
-/** Unit tests for {@code ResetFlagsRule}. */
+/** Unit tests for {@code SetFlagsRule}. */
 @RunWith(Parameterized.class)
 public final class SetFlagsRuleTest {
 
diff --git a/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/example/ExampleCheckFlagsRuleTest.java b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/example/ExampleCheckFlagsRuleTest.java
new file mode 100644
index 0000000..7c548ca
--- /dev/null
+++ b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/example/ExampleCheckFlagsRuleTest.java
@@ -0,0 +1,86 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.platform.test.flag.junit.example;
+
+import static org.junit.Assert.fail;
+
+import android.platform.test.annotations.RequiresFlagsDisabled;
+import android.platform.test.annotations.RequiresFlagsEnabled;
+import android.platform.test.flag.junit.CheckFlagsRule;
+import android.platform.test.flag.junit.IFlagsValueProvider;
+import android.platform.test.flag.util.FlagReadException;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Example for how to write a test using {@code CheckFlagsRule}. */
+@RunWith(JUnit4.class)
+public class ExampleCheckFlagsRuleTest {
+
+    /**
+     * NOTE: A real test would use the following: @Rule public final CheckFlagsRule mCheckFlagsRule
+     * = DeviceFlagsValueProvider.createCheckFlagsRule();
+     */
+    @Rule
+    public final CheckFlagsRule mCheckFlagsRule =
+            new CheckFlagsRule(
+                    new IFlagsValueProvider() {
+                        @Override
+                        public boolean getBoolean(String flag) throws FlagReadException {
+                            switch (flag) {
+                                case "flag0":
+                                    return false;
+                                case "flag1":
+                                    return true;
+                                default:
+                                    throw new FlagReadException(flag, "unknown flag");
+                            }
+                        }
+                    });
+
+    @Test
+    public void noAnnotation_execute() {
+        // Test passes
+    }
+
+    @Test
+    @RequiresFlagsEnabled("flag0")
+    public void requiredDisabledFlagEnabled_skip() {
+        fail("Test should be skipped");
+    }
+
+    @Test
+    @RequiresFlagsEnabled("flag1")
+    @RequiresFlagsDisabled("flag0")
+    public void requireBothEnabledAndDisabledFlags_execute() {
+        // Test passes
+    }
+
+    @Test
+    @RequiresFlagsDisabled("flag1")
+    public void requiredEnabledFlagDisabled_skip() {
+        fail("Test should be skipped");
+    }
+
+    @Test
+    @RequiresFlagsEnabled({"flag0", "flag1"})
+    public void requiredDisabledFlagEnabledWithOthers_skip() {
+        fail("Test should be skipped");
+    }
+}
diff --git a/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/example/ExampleFlagsParameterizedTest.java b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/example/ExampleFlagsParameterizedTest.java
new file mode 100644
index 0000000..b645110
--- /dev/null
+++ b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/example/ExampleFlagsParameterizedTest.java
@@ -0,0 +1,95 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.platform.test.flag.junit.example;
+
+import static android.platform.test.flag.junit.SetFlagsRule.DefaultInitValueType.NULL_DEFAULT;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import android.platform.test.annotations.DisableFlags;
+import android.platform.test.annotations.EnableFlags;
+import android.platform.test.flag.junit.Flags;
+import android.platform.test.flag.junit.FlagsParameterization;
+import android.platform.test.flag.junit.SetFlagsRule;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+import java.util.List;
+
+/**
+ * Example for how to write a test using {@link SetFlagsRule}, {@link FlagsParameterization} and the
+ * annotations {@link EnableFlags}, {@link DisableFlags}.
+ */
+@RunWith(Parameterized.class)
+public class ExampleFlagsParameterizedTest {
+
+    @Parameters(name = "{0}")
+    public static List<FlagsParameterization> getFlags() {
+        return FlagsParameterization.allCombinationsOf(
+                Flags.FLAG_FLAG_NAME3, Flags.FLAG_FLAG_NAME4);
+    }
+
+    public ExampleFlagsParameterizedTest(FlagsParameterization flags) {
+        mSetFlagsRule = new SetFlagsRule(NULL_DEFAULT, flags);
+    }
+
+    @Rule public final SetFlagsRule mSetFlagsRule;
+
+    // assertNotNull is used to call out when a flag is accessible
+    // but will have different values depending on the parameterization.
+
+    @Test
+    public void runTestWithAllFlagCombinations() {
+        assertNotNull(Flags.flagName3());
+        assertNotNull(Flags.flagName4());
+    }
+
+    @Test
+    @EnableFlags(Flags.FLAG_FLAG_NAME3)
+    public void runTestWithFlag3Enabled() {
+        assertTrue(Flags.flagName3());
+        assertNotNull(Flags.flagName4());
+    }
+
+    @Test
+    @DisableFlags(Flags.FLAG_FLAG_NAME4)
+    public void runTestWithFlag4Disabled() {
+        assertNotNull(Flags.flagName3());
+        assertFalse(Flags.flagName4());
+    }
+
+    @Test
+    @EnableFlags(Flags.FLAG_FLAG_NAME3)
+    @DisableFlags(Flags.FLAG_FLAG_NAME4)
+    public void runTestWithFlag3EnabledAndFlag4Disabled() {
+        assertTrue(Flags.flagName3());
+        assertFalse(Flags.flagName4());
+    }
+
+    @Test
+    @EnableFlags({Flags.FLAG_FLAG_NAME3, Flags.FLAG_FLAG_NAME4})
+    public void runTestWithTwoFlagsEnabled() {
+        assertTrue(Flags.flagName3());
+        assertTrue(Flags.flagName4());
+    }
+}
diff --git a/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/example/ExampleSetFlagsRuleTest.java b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/example/ExampleSetFlagsRuleTest.java
new file mode 100644
index 0000000..2f6e20b
--- /dev/null
+++ b/libraries/flag-helpers/junit/test/src/android/platform/test/flag/junit/example/ExampleSetFlagsRuleTest.java
@@ -0,0 +1,80 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.platform.test.flag.junit.example;
+
+import static android.platform.test.flag.junit.SetFlagsRule.DefaultInitValueType.DEVICE_DEFAULT;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import android.platform.test.annotations.DisableFlags;
+import android.platform.test.annotations.EnableFlags;
+import android.platform.test.flag.junit.Flags;
+import android.platform.test.flag.junit.SetFlagsRule;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Example for how to write a test using {@link SetFlagsRule} and the annotations {@link
+ * EnableFlags}, {@link DisableFlags}.
+ */
+@RunWith(JUnit4.class)
+public class ExampleSetFlagsRuleTest {
+
+    // NOTE: Flags.flagName3() defaults to false
+    // NOTE: Flags.flagName4() defaults to true
+
+    @Rule public final SetFlagsRule mSetFlagsRule = new SetFlagsRule(DEVICE_DEFAULT);
+
+    @Test
+    public void runTestWithDefaults() {
+        assertFalse(Flags.flagName3());
+        assertTrue(Flags.flagName4());
+    }
+
+    @Test
+    @EnableFlags(Flags.FLAG_FLAG_NAME3)
+    public void runTestWithFlag3Enabled() {
+        assertTrue(Flags.flagName3());
+        assertTrue(Flags.flagName4());
+    }
+
+    @Test
+    @DisableFlags(Flags.FLAG_FLAG_NAME4)
+    public void runTestWithFlag4Disabled() {
+        assertFalse(Flags.flagName3());
+        assertFalse(Flags.flagName4());
+    }
+
+    @Test
+    @EnableFlags(Flags.FLAG_FLAG_NAME3)
+    @DisableFlags(Flags.FLAG_FLAG_NAME4)
+    public void runTestWithFlag3EnabledAndFlag4Disabled() {
+        assertTrue(Flags.flagName3());
+        assertFalse(Flags.flagName4());
+    }
+
+    @Test
+    @EnableFlags({Flags.FLAG_FLAG_NAME3, Flags.FLAG_FLAG_NAME4})
+    public void runATestWithTwoFlagsEnabled() {
+        assertTrue(Flags.flagName3());
+        assertTrue(Flags.flagName4());
+    }
+}