blob: 38fcbb55f3a1144f034a99a9c7caf27485151d11 [file] [log] [blame]
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.platform.test.flag.junit;
import static org.junit.Assume.assumeFalse;
import android.platform.test.flag.util.FlagReadException;
import android.platform.test.flag.util.FlagSetException;
import com.google.common.base.CaseFormat;
import com.google.common.collect.Sets;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;
import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiPredicate;
import java.util.function.Predicate;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
/** A {@link TestRule} that helps to set flag values in unit test. */
public final class SetFlagsRule implements TestRule {
private static final String FAKE_FEATURE_FLAGS_IMPL_CLASS_NAME = "FakeFeatureFlagsImpl";
private static final String REAL_FEATURE_FLAGS_IMPL_CLASS_NAME = "FeatureFlagsImpl";
private static final String CUSTOM_FEATURE_FLAGS_CLASS_NAME = "CustomFeatureFlags";
private static final String FEATURE_FLAGS_CLASS_NAME = "FeatureFlags";
private static final String FEATURE_FLAGS_FIELD_NAME = "FEATURE_FLAGS";
private static final String FLAGS_CLASS_NAME = "Flags";
private static final String FLAG_CONSTANT_PREFIX = "FLAG_";
private static final String SET_FLAG_METHOD_NAME = "setFlag";
private static final String RESET_ALL_METHOD_NAME = "resetAll";
private static final String IS_FLAG_READ_ONLY_OPTIMIZED_METHOD_NAME = "isFlagReadOnlyOptimized";
// Store instances for entire life of a SetFlagsRule instance
private final Map<Class<?>, Object> mFlagsClassToFakeFlagsImpl = new HashMap<>();
private final Map<Class<?>, Object> mFlagsClassToRealFlagsImpl = new HashMap<>();
// Store classes that are currently mutated by this rule
private final Set<Class<?>> mMutatedFlagsClasses = new HashSet<>();
// Any flags added to this list cannot be set imperatively (i.e. with enableFlags/disableFlags)
private final Set<String> mLockedFlagNames = new HashSet<>();
// listener to be called before setting a flag
private final Listener mListener;
// TODO(322377082): remove repackage prefix list
private static final String[] REPACKAGE_PREFIX_LIST =
new String[] {
"", "com.android.internal.hidden_from_bootclasspath.",
};
private final Map<String, Set<String>> mPackageToRepackage = new HashMap<>();
private final boolean mIsInitWithDefault;
private FlagsParameterization mFlagsParameterization;
private boolean mIsRuleEvaluating = false;
public enum DefaultInitValueType {
/**
* Initialize flag value as null
*
* <p>Flag value need to be set before using
*/
NULL_DEFAULT,
/**
* Initialize flag value with the default value from the device
*
* <p>If flag value is not overridden by adb, then the default value is from the release
* configuration when the test is built.
*/
DEVICE_DEFAULT,
}
public SetFlagsRule() {
this(DefaultInitValueType.DEVICE_DEFAULT);
}
public SetFlagsRule(DefaultInitValueType defaultType) {
this(defaultType, null);
}
public SetFlagsRule(@Nullable FlagsParameterization flagsParameterization) {
this(DefaultInitValueType.DEVICE_DEFAULT, flagsParameterization);
}
public SetFlagsRule(
DefaultInitValueType defaultType,
@Nullable FlagsParameterization flagsParameterization) {
this(defaultType, flagsParameterization, null);
}
private SetFlagsRule(
DefaultInitValueType defaultType,
@Nullable FlagsParameterization flagsParameterization,
@Nullable Listener listener) {
mIsInitWithDefault = defaultType == DefaultInitValueType.DEVICE_DEFAULT;
mFlagsParameterization = flagsParameterization;
if (flagsParameterization != null) {
mLockedFlagNames.addAll(flagsParameterization.mOverrides.keySet());
}
mListener = listener;
}
/**
* Set the FlagsParameterization to be used during this test. This cannot be used to override a
* previous call, and cannot be called once the rule has been evaluated.
*/
public void setFlagsParameterization(@Nonnull FlagsParameterization flagsParameterization) {
Objects.requireNonNull(flagsParameterization, "FlagsParameterization cannot be cleared");
if (mFlagsParameterization != null) {
throw new AssertionError("FlagsParameterization cannot be overridden");
}
if (mIsRuleEvaluating) {
throw new AssertionError("Cannot set FlagsParameterization once the rule is running");
}
ensureFlagsAreUnset();
mFlagsParameterization = flagsParameterization;
mLockedFlagNames.addAll(flagsParameterization.mOverrides.keySet());
}
/**
* Enables the given flags.
*
* @param fullFlagNames The name of the flags in the flag class with the format
* {packageName}.{flagName}
*
* @deprecated Annotate your test or class with <code>@EnableFlags(String...)</code> instead
*/
@Deprecated
public void enableFlags(String... fullFlagNames) {
if (!mIsRuleEvaluating) {
throw new IllegalStateException("Not allowed to set flags outside test and setup code");
}
for (String fullFlagName : fullFlagNames) {
if (mLockedFlagNames.contains(fullFlagName)) {
throw new FlagSetException(fullFlagName, "Not allowed to change locked flags");
}
setFlagValue(fullFlagName, true);
}
}
/**
* Disables the given flags.
*
* @param fullFlagNames The name of the flags in the flag class with the format
* {packageName}.{flagName}
*
* @deprecated Annotate your test or class with <code>@DisableFlags(String...)</code> instead
*/
@Deprecated
public void disableFlags(String... fullFlagNames) {
if (!mIsRuleEvaluating) {
throw new IllegalStateException("Not allowed to set flags outside test and setup code");
}
for (String fullFlagName : fullFlagNames) {
if (mLockedFlagNames.contains(fullFlagName)) {
throw new FlagSetException(fullFlagName, "Not allowed to change locked flags");
}
setFlagValue(fullFlagName, false);
}
}
private void ensureFlagsAreUnset() {
if (!mFlagsClassToFakeFlagsImpl.isEmpty()) {
throw new IllegalStateException("Some flags were set before the rule was initialized");
}
}
@Override
public Statement apply(Statement base, Description description) {
return new Statement() {
@Override
public void evaluate() throws Throwable {
Throwable throwable = null;
try {
if (mListener != null) {
mListener.onStartedEvaluating();
}
AnnotationsRetriever.FlagAnnotations flagAnnotations =
AnnotationsRetriever.getFlagAnnotations(description);
assertAnnotationsMatchParameterization(flagAnnotations, mFlagsParameterization);
flagAnnotations.assumeAllSetFlagsMatchParameterization(mFlagsParameterization);
if (mFlagsParameterization != null) {
ensureFlagsAreUnset();
for (Map.Entry<String, Boolean> pair :
mFlagsParameterization.mOverrides.entrySet()) {
setFlagValue(pair.getKey(), pair.getValue());
}
}
for (Map.Entry<String, Boolean> pair :
flagAnnotations.mSetFlagValues.entrySet()) {
setFlagValue(pair.getKey(), pair.getValue());
}
mLockedFlagNames.addAll(flagAnnotations.mRequiredFlagValues.keySet());
mLockedFlagNames.addAll(flagAnnotations.mSetFlagValues.keySet());
mIsRuleEvaluating = true;
base.evaluate();
} catch (Throwable t) {
throwable = t;
} finally {
mIsRuleEvaluating = false;
try {
resetFlags();
} catch (Throwable t) {
if (throwable != null) {
t.addSuppressed(throwable);
}
throwable = t;
}
try {
if (mListener != null) {
mListener.onFinishedEvaluating();
}
} catch (Throwable t) {
if (throwable != null) {
t.addSuppressed(throwable);
}
throwable = t;
}
}
if (throwable != null) throw throwable;
}
};
}
private static void assertAnnotationsMatchParameterization(
AnnotationsRetriever.FlagAnnotations flagAnnotations,
FlagsParameterization parameterization) {
if (parameterization == null) return;
Set<String> parameterizedFlags = parameterization.mOverrides.keySet();
Set<String> requiredFlags = flagAnnotations.mRequiredFlagValues.keySet();
// Assert that NO Annotation-Required flag is in the parameterization
Set<String> parameterizedAndRequiredFlags =
Sets.intersection(parameterizedFlags, requiredFlags);
if (!parameterizedAndRequiredFlags.isEmpty()) {
throw new AssertionError(
"The following flags have required values (per @RequiresFlagsEnabled or"
+ " @RequiresFlagsDisabled) but they are part of the"
+ " FlagParameterization: "
+ parameterizedAndRequiredFlags);
}
}
private void setFlagValue(String fullFlagName, boolean value) {
if (!fullFlagName.contains(".")) {
throw new FlagSetException(
fullFlagName, "Flag name is not the expected format {packgeName}.{flagName}.");
}
// Get all packages containing Flags referencing the same fullFlagName.
Set<String> packageSet = getPackagesContainsFlag(fullFlagName);
for (String packageName : packageSet) {
setFlagValue(Flag.createFlag(fullFlagName, packageName), value);
}
}
private Set<String> getPackagesContainsFlag(String fullFlagName) {
return getAllPackagesForFlag(fullFlagName, mPackageToRepackage);
}
private static Set<String> getAllPackagesForFlag(
String fullFlagName, Map<String, Set<String>> packageToRepackage) {
String packageName = Flag.getFlagPackageName(fullFlagName);
Set<String> packageSet = packageToRepackage.getOrDefault(packageName, new HashSet<>());
if (!packageSet.isEmpty()) {
return packageSet;
}
for (String prefix : REPACKAGE_PREFIX_LIST) {
String repackagedName = String.format("%s%s", prefix, packageName);
String flagClassName = String.format("%s.%s", repackagedName, FLAGS_CLASS_NAME);
try {
Class.forName(flagClassName, false, SetFlagsRule.class.getClassLoader());
packageSet.add(repackagedName);
} catch (ClassNotFoundException e) {
// Skip if the class is not found
// An error will be thrown if no package containing flags referencing
// the passed in flag
}
}
packageToRepackage.put(packageName, packageSet);
if (packageSet.isEmpty()) {
throw new FlagSetException(
fullFlagName,
"Cannot find package containing Flags class referencing to this flag.");
}
return packageSet;
}
private void setFlagValue(Flag flag, boolean value) {
if (mListener != null) {
mListener.onBeforeSetFlag(flag, value);
}
Object fakeFlagsImplInstance = null;
Class<?> flagsClass = getFlagClassFromFlag(flag);
fakeFlagsImplInstance = getOrCreateFakeFlagsImp(flagsClass);
if (!mMutatedFlagsClasses.contains(flagsClass)) {
// Replace FeatureFlags in Flags class with FakeFeatureFlagsImpl
replaceFlagsImpl(flagsClass, fakeFlagsImplInstance);
mMutatedFlagsClasses.add(flagsClass);
}
// If the test is trying to set the flag value on a read_only flag in an optimized build
// skip this test, since it is not a valid testing case
// The reason for skipping instead of throwning error here is all read_write flag will be
// change to read_only in the final release configuration. Thus the test could be executed
// in other release configuration cases
// TODO(b/337449119): SetFlagsRule should still run tests that are consistent with the
// read-only values of flags. But be careful, if a ClassRule exists, the value returned by
// the original FeatureFlags instance may be overridden, and reading it may not be allowed.
boolean isOptimized = verifyFlagReadOnlyAndOptimized(fakeFlagsImplInstance, flag);
assumeFalse(
String.format(
"Flag %s is read_only, and the code is optimized. "
+ " The flag value should not be modified on this build"
+ " Skip this test.",
flag.fullFlagName()),
isOptimized);
// Set desired flag value in the FakeFeatureFlagsImpl
setFlagValueInFakeFeatureFlagsImpl(fakeFlagsImplInstance, flag, value);
}
private static Class<?> getFlagClassFromFlag(Flag flag) {
String className = flag.flagsClassName();
Class<?> flagsClass = null;
try {
flagsClass = Class.forName(className);
} catch (ClassNotFoundException e) {
throw new FlagSetException(
flag.fullFlagName(),
String.format(
"Can not load the Flags class %s to set its values. Please check the "
+ "flag name and ensure that the aconfig auto generated "
+ "library is in the dependency.",
className),
e);
}
return flagsClass;
}
private static Class<?> getFlagClassFromFlagsClassName(String className) {
if (!className.endsWith("." + FLAGS_CLASS_NAME)) {
throw new FlagSetException(
className,
"Can not watch this Flags class because it is not named 'Flags'. Please ensure"
+ " your @UsesFlags() annotations only reference the Flags classes.");
}
try {
return Class.forName(className);
} catch (ClassNotFoundException e) {
throw new FlagSetException(
className,
"Cannot load this Flags class to set its values. Please check the flag name and"
+ " ensure that the aconfig auto generated library is in the dependency.",
e);
}
}
private boolean getFlagValue(Object featureFlagsImpl, Flag flag) {
// Must be consistent with method name in aconfig auto generated code.
String methodName = getFlagMethodName(flag);
String fullFlagName = flag.fullFlagName();
try {
Object result =
featureFlagsImpl.getClass().getMethod(methodName).invoke(featureFlagsImpl);
if (result instanceof Boolean) {
return (Boolean) result;
}
throw new FlagReadException(
fullFlagName,
String.format(
"Flag type is %s, not boolean", result.getClass().getSimpleName()));
} catch (NoSuchMethodException e) {
throw new FlagReadException(
fullFlagName,
String.format(
"No method %s in the Flags class %s to read the flag value. Please"
+ " check the flag name.",
methodName, featureFlagsImpl.getClass().getName()),
e);
} catch (ReflectiveOperationException e) {
throw new FlagReadException(
fullFlagName,
String.format(
"Fail to get value of flag %s from instance %s",
fullFlagName, featureFlagsImpl.getClass().getName()),
e);
}
}
private String getFlagMethodName(Flag flag) {
return CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, flag.simpleFlagName());
}
private void setFlagValueInFakeFeatureFlagsImpl(
Object fakeFeatureFlagsImpl, Flag flag, boolean value) {
String fullFlagName = flag.fullFlagName();
try {
fakeFeatureFlagsImpl
.getClass()
.getMethod(SET_FLAG_METHOD_NAME, String.class, boolean.class)
.invoke(fakeFeatureFlagsImpl, fullFlagName, value);
} catch (NoSuchMethodException e) {
throw new FlagSetException(
fullFlagName,
String.format(
"Flag implementation %s is not fake implementation",
fakeFeatureFlagsImpl.getClass().getName()),
e);
} catch (ReflectiveOperationException e) {
throw new FlagSetException(fullFlagName, e);
}
}
private static boolean verifyFlagReadOnlyAndOptimized(Object fakeFeatureFlagsImpl, Flag flag) {
String fullFlagName = flag.fullFlagName();
try {
boolean result =
(Boolean)
fakeFeatureFlagsImpl
.getClass()
.getMethod(
IS_FLAG_READ_ONLY_OPTIMIZED_METHOD_NAME, String.class)
.invoke(fakeFeatureFlagsImpl, fullFlagName);
return result;
} catch (NoSuchMethodException e) {
// If the flag is generated under exported mode, then it doesn't have this method
String simpleClassName = fakeFeatureFlagsImpl.getClass().getSimpleName();
if (simpleClassName.equals(FAKE_FEATURE_FLAGS_IMPL_CLASS_NAME)) {
return false;
}
if (simpleClassName.equals(CUSTOM_FEATURE_FLAGS_CLASS_NAME)) {
return false;
}
throw new FlagSetException(
fullFlagName,
String.format(
"Cannot check whether flag is optimized. "
+ "Flag implementation %s is not fake implementation",
fakeFeatureFlagsImpl.getClass().getName()),
e);
} catch (ReflectiveOperationException e) {
throw new FlagSetException(fullFlagName, e);
}
}
@Nonnull
private Object getOrCreateFakeFlagsImp(Class<?> flagsClass) {
Object fakeFlagsImplInstance = mFlagsClassToFakeFlagsImpl.get(flagsClass);
if (fakeFlagsImplInstance != null) {
return fakeFlagsImplInstance;
}
String packageName = flagsClass.getPackageName();
String fakeClassName =
String.format("%s.%s", packageName, FAKE_FEATURE_FLAGS_IMPL_CLASS_NAME);
String interfaceName = String.format("%s.%s", packageName, FEATURE_FLAGS_CLASS_NAME);
Object realFlagsImplInstance = readFlagsImpl(flagsClass);
mFlagsClassToRealFlagsImpl.put(flagsClass, realFlagsImplInstance);
try {
Class<?> flagImplClass = Class.forName(fakeClassName);
Class<?> flagInterface = Class.forName(interfaceName);
fakeFlagsImplInstance =
flagImplClass
.getConstructor(flagInterface)
.newInstance(mIsInitWithDefault ? realFlagsImplInstance : null);
} catch (ReflectiveOperationException e) {
throw new UnsupportedOperationException(
String.format(
"Cannot create FakeFeatureFlagsImpl in Flags class %s.",
flagsClass.getName()),
e);
}
mFlagsClassToFakeFlagsImpl.put(flagsClass, fakeFlagsImplInstance);
return fakeFlagsImplInstance;
}
private static void replaceFlagsImpl(Class<?> flagsClass, Object flagsImplInstance) {
Field featureFlagsField = getFeatureFlagsField(flagsClass);
try {
featureFlagsField.set(null, flagsImplInstance);
} catch (IllegalAccessException e) {
throw new UnsupportedOperationException(
String.format(
"Cannot replace FeatureFlagsImpl to %s.",
flagsImplInstance.getClass().getName()),
e);
}
}
private static Object readFlagsImpl(Class<?> flagsClass) {
Field featureFlagsField = getFeatureFlagsField(flagsClass);
try {
return featureFlagsField.get(null);
} catch (IllegalAccessException e) {
throw new UnsupportedOperationException(
String.format(
"Cannot get FeatureFlags from Flags class %s.", flagsClass.getName()),
e);
}
}
private static Field getFeatureFlagsField(Class<?> flagsClass) {
Field featureFlagsField = null;
try {
featureFlagsField = flagsClass.getDeclaredField(FEATURE_FLAGS_FIELD_NAME);
} catch (ReflectiveOperationException e) {
throw new UnsupportedOperationException(
String.format(
"Cannot store FeatureFlagsImpl in Flag %s.", flagsClass.getName()),
e);
}
featureFlagsField.setAccessible(true);
return featureFlagsField;
}
private void resetFlags() {
String flagsClassName = null;
try {
for (Class<?> flagsClass : mMutatedFlagsClasses) {
flagsClassName = flagsClass.getName();
Object fakeFlagsImplInstance = mFlagsClassToFakeFlagsImpl.get(flagsClass);
Object flagsImplInstance = mFlagsClassToRealFlagsImpl.get(flagsClass);
// Replace FeatureFlags in Flags class with real FeatureFlagsImpl
replaceFlagsImpl(flagsClass, flagsImplInstance);
fakeFlagsImplInstance
.getClass()
.getMethod(RESET_ALL_METHOD_NAME)
.invoke(fakeFlagsImplInstance);
}
mMutatedFlagsClasses.clear();
} catch (Exception e) {
throw new FlagSetException(flagsClassName, e);
}
}
/** An interface that provides hooks to the ClassRule. */
private interface Listener {
/** Called before a flag is set. */
void onBeforeSetFlag(SetFlagsRule.Flag flag, boolean value);
/** Called after the rule has started evaluating for a test. */
void onStartedEvaluating();
/** Called after the rule has finished evaluating for a test. */
void onFinishedEvaluating();
}
/**
* A @ClassRule which adds extra consistency checks for SetFlagsRule.
* <li>Requires that tests monitor the Flags class of any flag that is set.
* <li>Fails a test if a flag that is set was read before the test started.
*/
public static class ClassRule implements TestRule {
/** The flags classes that are requested to be watched during construction. */
private final Set<Class<?>> mGlobalFlagsClassesToWatch = new HashSet<>();
/** The flags packages that are allowed to be set, for quick per-flag lookup */
private final Set<String> mSettableFlagsPackages = new HashSet<>();
/** The mapping from the Flags classes to the real implementations */
private final Map<Class<?>, Object> mFlagsClassToRealFlagsImpl = new HashMap<>();
/** The mapping from the Flags classes to the watcher implementations */
private final Map<Class<?>, Object> mFlagsClassToWatcherImpl = new HashMap<>();
/** The flags classes that have actually been mutated */
private final Set<Class<?>> mMutatedFlagsClasses = new HashSet<>();
/** The flag values set by class annotations */
private final Map<String, Boolean> mClassLevelSetFlagValues = new ConcurrentHashMap<>();
/**
* The individual flags which have been read from prior to tests starting, mapped to the
* stack trace of the first read.
*/
private final Map<String, FirstFlagRead> mFirstReadOutsideTestsByFlag =
new ConcurrentHashMap<>();
/**
* The individual flags which have been read from within a test, mapped to the stack trace
* of the first read.
*/
private final Map<String, FirstFlagRead> mFirstReadWithinTestByFlag =
new ConcurrentHashMap<>();
/** repackage cache */
private final Map<String, Set<String>> mPackageToRepackage = new HashMap<>();
/** The depth of the ClassRule evaluating on potentially nested suites */
private int mSuiteRunDepth = 0;
/** Whether the SetFlagsRule is evaluating for a test */
private boolean mIsTestRunning = false;
/** Typical constructor takes an initial list flags classes to watch */
public ClassRule(Class<?>... flagsClasses) {
for (Class<?> flagsClass : flagsClasses) {
mGlobalFlagsClassesToWatch.add(flagsClass);
}
}
/** Listener to be notified of events in any created SetFlagsRule */
private SetFlagsRule.Listener mListener =
new SetFlagsRule.Listener() {
@Override
public void onBeforeSetFlag(SetFlagsRule.Flag flag, boolean value) {
if (!mIsTestRunning) {
throw new IllegalStateException("Inner rule should be running!");
}
assertFlagCanBeSet(flag, value);
}
@Override
public void onStartedEvaluating() {
if (mSuiteRunDepth == 0) {
throw new IllegalStateException("Outer rule should be running!");
}
if (mIsTestRunning) {
throw new IllegalStateException("Inner rule is still running!");
}
mIsTestRunning = true;
}
@Override
public void onFinishedEvaluating() {
if (!mIsTestRunning) {
throw new IllegalStateException("Inner rule did not start!");
}
mIsTestRunning = false;
checkAllFlagsWatchersRestored();
mFirstReadWithinTestByFlag.clear();
}
};
/**
* Creates a SetFlagsRule which will work as normal, but additionally enforce the guarantees
* about not setting flags that were read within the ClassRule
*/
public SetFlagsRule createSetFlagsRule() {
return createSetFlagsRule(null);
}
/**
* Creates a SetFlagsRule with parameterization which will work as normal, but additionally
* enforce the guarantees about not setting flags that were read within the ClassRule
*/
public SetFlagsRule createSetFlagsRule(
@Nullable FlagsParameterization flagsParameterization) {
return new SetFlagsRule(
DefaultInitValueType.DEVICE_DEFAULT, flagsParameterization, mListener);
}
private boolean isFlagsClassMonitored(SetFlagsRule.Flag flag) {
return mSettableFlagsPackages.contains(flag.flagPackageName());
}
private void assertFlagCanBeSet(SetFlagsRule.Flag flag, boolean value) {
Exception firstReadWithinTest = mFirstReadWithinTestByFlag.get(flag.fullFlagName());
if (firstReadWithinTest != null) {
throw new FlagSetException(
flag.fullFlagName(),
"This flag was locked when it was read earlier in this test. To fix this"
+ " error, always use @EnableFlags() and @DisableFlags() to set"
+ " flags, which ensures flags are set before even any"
+ " @Before-annotated setup methods.",
firstReadWithinTest);
}
Exception firstReadOutsideTest = mFirstReadOutsideTestsByFlag.get(flag.fullFlagName());
if (firstReadOutsideTest != null) {
throw new FlagSetException(
flag.fullFlagName(),
"This flag was locked when it was read outside of the test code; likely"
+ " during initialization of the test class. To fix this error,"
+ " move test fixture initialization code into your"
+ " @Before-annotated setup method, and ensure you are using"
+ " @EnableFlags() and @DisableFlags() to set flags.",
firstReadOutsideTest);
}
if (!isFlagsClassMonitored(flag)) {
throw new FlagSetException(
flag.fullFlagName(),
"This flag's class is not monitored. Always use @EnableFlags() and"
+ " @DisableFlags() on the class or method instead of"
+ " .enableFlags() or .disableFlags() to prevent this error. When"
+ " using FlagsParameterization, add `@UsesFlags("
+ flag.flagPackageName()
+ ".Flags.class)` to the test class. As a last resort, pass the"
+ " Flags class to the constructor of your"
+ " SetFlagsRule.ClassRule.");
}
// Detect errors where the rule messed up and set the wrong flag value.
Boolean classLevelValue = mClassLevelSetFlagValues.get(flag.fullFlagName());
if (classLevelValue != null && classLevelValue != value) {
throw new FlagSetException(
flag.fullFlagName(),
"This flag's value was set at the class level to a different value.");
}
}
private void checkInstanceOfRealFlagsImpl(Object actual) {
if (!actual.getClass().getSimpleName().equals(REAL_FEATURE_FLAGS_IMPL_CLASS_NAME)) {
throw new IllegalStateException(
String.format(
"Wrong impl type during setup: '%s' is not a %s",
actual, REAL_FEATURE_FLAGS_IMPL_CLASS_NAME));
}
}
private void checkSameAs(Object expected, Object actual) {
if (expected != actual) {
throw new IllegalStateException(
String.format(
"Wrong impl instance found during teardown: expected %s but was %s",
expected, actual));
}
}
private Object getOrCreateFlagReadWatcher(Class<?> flagsClass) {
Object watcher = mFlagsClassToWatcherImpl.get(flagsClass);
if (watcher != null) {
return watcher;
}
Object flagsImplInstance = readFlagsImpl(flagsClass);
// strict mode: ensure that the current impl is the real impl
checkInstanceOfRealFlagsImpl(flagsImplInstance);
// save the real impl for restoration later
mFlagsClassToRealFlagsImpl.put(flagsClass, flagsImplInstance);
watcher = newFlagReadWatcher(flagsClass, flagsImplInstance);
mFlagsClassToWatcherImpl.put(flagsClass, watcher);
return watcher;
}
private void recordFlagRead(String flagName) {
if (mIsTestRunning) {
mFirstReadWithinTestByFlag.computeIfAbsent(flagName, FirstFlagRead::new);
} else {
mFirstReadOutsideTestsByFlag.computeIfAbsent(flagName, FirstFlagRead::new);
}
}
private Object newFlagReadWatcher(Class<?> flagsClass, Object flagsImplInstance) {
String packageName = flagsClass.getPackageName();
String customClassName =
String.format("%s.%s", packageName, CUSTOM_FEATURE_FLAGS_CLASS_NAME);
BiPredicate<String, Predicate<Object>> getValueImpl =
(flagName, predicate) -> {
// Flags set at the class level pose no consistency risk
Boolean value = mClassLevelSetFlagValues.get(flagName);
if (value != null) {
return value;
}
recordFlagRead(flagName);
return predicate.test(flagsImplInstance);
};
try {
Class<?> customFlagsClass = Class.forName(customClassName);
return customFlagsClass.getConstructor(BiPredicate.class).newInstance(getValueImpl);
} catch (ReflectiveOperationException e) {
throw new UnsupportedOperationException(
String.format(
"Cannot create CustomFeatureFlags in Flags class %s.",
flagsClass.getName()),
e);
}
}
/** Get the package name of the flags in this class. This is the non-repackaged name. */
private String getFlagPackageName(Class<?> flagsClass) {
String classPackageName = flagsClass.getPackageName();
String shortestPackageName = classPackageName;
for (String prefix : REPACKAGE_PREFIX_LIST) {
if (prefix.isEmpty()) continue;
if (classPackageName.startsWith(prefix)) {
String unprefixedPackage = classPackageName.substring(prefix.length());
if (unprefixedPackage.length() < shortestPackageName.length()) {
shortestPackageName = unprefixedPackage;
}
}
}
return shortestPackageName;
}
private void setupClassLevelFlagValues(Description description) {
mClassLevelSetFlagValues.putAll(
AnnotationsRetriever.getFlagAnnotations(description).mSetFlagValues);
}
private void setupFlagsWatchers(Description description) {
// Start with the static list of Flags classes to watch
Set<Class<?>> flagsClassesToWatch = new HashSet<>(mGlobalFlagsClassesToWatch);
// Collect the Flags classes from @UsedFlags annotations on the Descriptor
Set<String> usedFlagsClasses = AnnotationsRetriever.getAllUsedFlagsClasses(description);
for (String flagsClassName : usedFlagsClasses) {
flagsClassesToWatch.add(getFlagClassFromFlagsClassName(flagsClassName));
}
// Now setup watchers on the provided Flags classes
for (Class<?> flagsClass : flagsClassesToWatch) {
setupFlagsWatcher(flagsClass, getFlagPackageName(flagsClass));
}
// Get all annotated flags and then the distinct packages for each flag
Set<String> setFlags = AnnotationsRetriever.getAllAnnotationSetFlags(description);
Set<String> extraFlagPackages = new HashSet<>();
for (String setFlag : setFlags) {
extraFlagPackages.add(Flag.getFlagPackageName(setFlag));
}
// Do not bother with flags that are already monitored
extraFlagPackages.removeAll(mSettableFlagsPackages);
// Expand packages to all repackaged versions, stored as Flag objects
Set<Flag> extraWildcardFlags = new HashSet<>();
for (String extraFlagPackage : extraFlagPackages) {
String fullFlagName = extraFlagPackage + ".*";
Set<String> packages = getAllPackagesForFlag(fullFlagName, mPackageToRepackage);
for (String packageName : packages) {
Flag flag = Flag.createFlag(fullFlagName, packageName);
extraWildcardFlags.add(flag);
}
}
// Set up watchers for each wildcard flag
for (Flag flag : extraWildcardFlags) {
Class<?> flagsClass = getFlagClassFromFlag(flag);
setupFlagsWatcher(flagsClass, flag.flagPackageName());
}
}
private void setupFlagsWatcher(Class<?> flagsClass, String flagPackageName) {
if (mMutatedFlagsClasses.contains(flagsClass)) {
throw new IllegalStateException(
String.format("Flags class %s is already mutated", flagsClass.getName()));
}
Object watcher = getOrCreateFlagReadWatcher(flagsClass);
replaceFlagsImpl(flagsClass, watcher);
mMutatedFlagsClasses.add(flagsClass);
mSettableFlagsPackages.add(flagPackageName);
}
private void teardownFlagsWatchers() {
try {
for (Class<?> flagsClass : mMutatedFlagsClasses) {
Object flagsImplInstance = mFlagsClassToRealFlagsImpl.get(flagsClass);
// strict mode: ensure that the watcher is still in place
Object watcher = readFlagsImpl(flagsClass);
checkSameAs(mFlagsClassToWatcherImpl.get(flagsClass), watcher);
// Replace FeatureFlags in Flags class with real FeatureFlagsImpl
replaceFlagsImpl(flagsClass, flagsImplInstance);
}
mMutatedFlagsClasses.clear();
mSettableFlagsPackages.clear();
mFirstReadOutsideTestsByFlag.clear();
} catch (IllegalStateException e) {
throw e;
} catch (Exception e) {
throw new IllegalStateException("Failed to teardown Flags watchers", e);
}
if (mIsTestRunning) {
throw new IllegalStateException("An inner SetFlagsRule is still running");
}
if (!mFirstReadWithinTestByFlag.isEmpty()) {
throw new IllegalStateException("An inner SetFlagsRule did not fully clean up");
}
}
private void checkAllFlagsWatchersRestored() {
for (Class<?> flagsClass : mMutatedFlagsClasses) {
Object watcher = readFlagsImpl(flagsClass);
checkSameAs(mFlagsClassToWatcherImpl.get(flagsClass), watcher);
}
}
@Override
public Statement apply(Statement base, Description description) {
return new Statement() {
@Override
public void evaluate() throws Throwable {
Throwable throwable = null;
final int initialDepth = mSuiteRunDepth;
try {
mSuiteRunDepth++;
if (initialDepth == 0) {
setupFlagsWatchers(description);
setupClassLevelFlagValues(description);
}
base.evaluate();
} catch (Throwable t) {
throwable = t;
} finally {
mSuiteRunDepth--;
try {
if (initialDepth == 0) {
mClassLevelSetFlagValues.clear();
teardownFlagsWatchers();
}
if (mSuiteRunDepth != initialDepth) {
throw new IllegalStateException(
String.format(
"Evaluations were not correctly nested: initial"
+ " depth was %d but final depth was %d",
initialDepth, mSuiteRunDepth));
}
} catch (Throwable t) {
if (throwable != null) {
t.addSuppressed(throwable);
}
throwable = t;
}
}
if (throwable != null) throw throwable;
}
};
}
}
private static class FirstFlagRead extends Exception {
FirstFlagRead(String flagName) {
super(String.format("Flag '%s' was first read at this location:", flagName));
}
}
private static class Flag {
private static final String PACKAGE_NAME_SIMPLE_NAME_SEPARATOR = ".";
private final String mFullFlagName;
private final String mFlagPackageName;
private final String mClassPackageName;
private final String mSimpleFlagName;
public static String getFlagPackageName(String fullFlagName) {
int index = fullFlagName.lastIndexOf(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR);
return fullFlagName.substring(0, index);
}
public static Flag createFlag(String fullFlagName) {
int index = fullFlagName.lastIndexOf(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR);
String packageName = fullFlagName.substring(0, index);
return createFlag(fullFlagName, packageName);
}
public static Flag createFlag(String fullFlagName, String classPackageName) {
if (!fullFlagName.contains(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR)
|| !classPackageName.contains(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR)) {
throw new IllegalArgumentException(
String.format(
"Flag %s is invalid. The format should be {packageName}"
+ ".{simpleFlagName}",
fullFlagName));
}
int index = fullFlagName.lastIndexOf(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR);
String flagPackageName = fullFlagName.substring(0, index);
String simpleFlagName = fullFlagName.substring(index + 1);
return new Flag(fullFlagName, flagPackageName, classPackageName, simpleFlagName);
}
private Flag(
String fullFlagName,
String flagPackageName,
String classPackageName,
String simpleFlagName) {
this.mFullFlagName = fullFlagName;
this.mFlagPackageName = flagPackageName;
this.mClassPackageName = classPackageName;
this.mSimpleFlagName = simpleFlagName;
}
public String fullFlagName() {
return mFullFlagName;
}
public String flagPackageName() {
return mFlagPackageName;
}
public String classPackageName() {
return mClassPackageName;
}
public String simpleFlagName() {
return mSimpleFlagName;
}
public String flagsClassName() {
return String.format("%s.%s", classPackageName(), FLAGS_CLASS_NAME);
}
}
}