Add `classifier_threshold` flag to control the threshold for classifier values returned by the model.
Test: atest PhFlagsTest TopicsManagerTest
Fixes: 242628225
Change-Id: I41f369a611557c7f65451cd09c0db7bc42309ab3
diff --git a/adservices/service-core/java/com/android/adservices/service/Flags.java b/adservices/service-core/java/com/android/adservices/service/Flags.java
index 0d1a569..99025fb 100644
--- a/adservices/service-core/java/com/android/adservices/service/Flags.java
+++ b/adservices/service-core/java/com/android/adservices/service/Flags.java
@@ -123,6 +123,14 @@
return CLASSIFIER_NUMBER_OF_TOP_LABELS;
}
+ /** Threshold value for classification values. */
+ float CLASSIFIER_THRESHOLD = 0.0f;
+
+ /** Returns the threshold value for classification values. */
+ default float getClassifierThreshold() {
+ return CLASSIFIER_THRESHOLD;
+ }
+
/* The default period for the Maintenance job. */
long MAINTENANCE_JOB_PERIOD_MS = 86_400_000; // 1 day.
diff --git a/adservices/service-core/java/com/android/adservices/service/PhFlags.java b/adservices/service-core/java/com/android/adservices/service/PhFlags.java
index 48d38f0..a81cf3a 100644
--- a/adservices/service-core/java/com/android/adservices/service/PhFlags.java
+++ b/adservices/service-core/java/com/android/adservices/service/PhFlags.java
@@ -52,6 +52,7 @@
// Topics classifier keys
static final String KEY_CLASSIFIER_TYPE = "classifier_type";
static final String KEY_CLASSIFIER_NUMBER_OF_TOP_LABELS = "classifier_number_of_top_labels";
+ static final String KEY_CLASSIFIER_THRESHOLD = "classifier_threshold";
// Measurement keys
static final String KEY_MEASUREMENT_EVENT_MAIN_REPORTING_JOB_PERIOD_MS =
@@ -320,6 +321,15 @@
}
@Override
+ public float getClassifierThreshold() {
+ // The priority of applying the flag values: PH (DeviceConfig) and then hard-coded value.
+ return DeviceConfig.getFloat(
+ DeviceConfig.NAMESPACE_ADSERVICES,
+ /* flagName */ KEY_CLASSIFIER_THRESHOLD,
+ /* defaultValue */ CLASSIFIER_THRESHOLD);
+ }
+
+ @Override
public long getMaintenanceJobPeriodMs() {
// The priority of applying the flag values: SystemProperties, PH (DeviceConfig) and then
// hard-coded value.
diff --git a/adservices/service-core/java/com/android/adservices/service/topics/classifier/OnDeviceClassifier.java b/adservices/service-core/java/com/android/adservices/service/topics/classifier/OnDeviceClassifier.java
index 6fd0e3f..d69f384 100644
--- a/adservices/service-core/java/com/android/adservices/service/topics/classifier/OnDeviceClassifier.java
+++ b/adservices/service-core/java/com/android/adservices/service/topics/classifier/OnDeviceClassifier.java
@@ -155,14 +155,24 @@
// Limit the number of entries to first MAX_LABELS_PER_APP.
// TODO(b/235435229): Evaluate the strategy to use first x elements.
int numberOfTopLabels = FlagsFactory.getFlags().getClassifierNumberOfTopLabels();
+ float classifierThresholdValue = FlagsFactory.getFlags().getClassifierThreshold();
+ LogUtil.i(
+ "numberOfTopLabels = %s\n classifierThresholdValue = %s",
+ numberOfTopLabels, classifierThresholdValue);
return classifications.stream()
.sorted((c1, c2) -> Float.compare(c2.getScore(), c1.getScore())) // Reverse sorted.
+ .filter(category -> isAboveThreshold(category, classifierThresholdValue))
.map(OnDeviceClassifier::convertCategoryLabelToTopicId)
.map(this::createTopic)
.limit(numberOfTopLabels)
.collect(Collectors.toList());
}
+ // Filter category above the required threshold.
+ private static boolean isAboveThreshold(Category category, float classifierThresholdValue) {
+ return category.getScore() >= classifierThresholdValue;
+ }
+
// Converts Category Label to TopicId. Expects label to be labelId of the classified category.
// Returns -1 if conversion to int fails for the label.
private static int convertCategoryLabelToTopicId(Category category) {
diff --git a/adservices/tests/unittest/service-core/src/com/android/adservices/service/PhFlagsTest.java b/adservices/tests/unittest/service-core/src/com/android/adservices/service/PhFlagsTest.java
index a04e05c..02b0946 100644
--- a/adservices/tests/unittest/service-core/src/com/android/adservices/service/PhFlagsTest.java
+++ b/adservices/tests/unittest/service-core/src/com/android/adservices/service/PhFlagsTest.java
@@ -19,6 +19,7 @@
import static com.android.adservices.service.Flags.ADID_KILL_SWITCH;
import static com.android.adservices.service.Flags.APPSETID_KILL_SWITCH;
import static com.android.adservices.service.Flags.CLASSIFIER_NUMBER_OF_TOP_LABELS;
+import static com.android.adservices.service.Flags.CLASSIFIER_THRESHOLD;
import static com.android.adservices.service.Flags.DEFAULT_CLASSIFIER_TYPE;
import static com.android.adservices.service.Flags.DISABLE_FLEDGE_ENROLLMENT_CHECK;
import static com.android.adservices.service.Flags.DISABLE_TOPICS_ENROLLMENT_CHECK;
@@ -107,6 +108,7 @@
import static com.android.adservices.service.PhFlags.KEY_ADID_KILL_SWITCH;
import static com.android.adservices.service.PhFlags.KEY_APPSETID_KILL_SWITCH;
import static com.android.adservices.service.PhFlags.KEY_CLASSIFIER_NUMBER_OF_TOP_LABELS;
+import static com.android.adservices.service.PhFlags.KEY_CLASSIFIER_THRESHOLD;
import static com.android.adservices.service.PhFlags.KEY_CLASSIFIER_TYPE;
import static com.android.adservices.service.PhFlags.KEY_DISABLE_FLEDGE_ENROLLMENT_CHECK;
import static com.android.adservices.service.PhFlags.KEY_DISABLE_TOPICS_ENROLLMENT_CHECK;
@@ -350,6 +352,23 @@
}
@Test
+ public void testGetClassifierThreshold() {
+ // Without any overriding, the value is the hard coded constant.
+ assertThat(FlagsFactory.getFlags().getClassifierThreshold())
+ .isEqualTo(CLASSIFIER_THRESHOLD);
+
+ float phOverridingValue = 0.3f;
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_ADSERVICES,
+ KEY_CLASSIFIER_THRESHOLD,
+ Float.toString(phOverridingValue),
+ /* makeDefault */ false);
+
+ Flags phFlags = FlagsFactory.getFlags();
+ assertThat(phFlags.getClassifierThreshold()).isEqualTo(phOverridingValue);
+ }
+
+ @Test
public void testGetMaintenanceJobPeriodMs() {
// Without any overriding, the value is the hard coded constant.
assertThat(FlagsFactory.getFlags().getMaintenanceJobPeriodMs())
diff --git a/adservices/tests/unittest/service-core/src/com/android/adservices/service/topics/classifier/OnDeviceClassifierTest.java b/adservices/tests/unittest/service-core/src/com/android/adservices/service/topics/classifier/OnDeviceClassifierTest.java
index ec78896..9e8d17b 100644
--- a/adservices/tests/unittest/service-core/src/com/android/adservices/service/topics/classifier/OnDeviceClassifierTest.java
+++ b/adservices/tests/unittest/service-core/src/com/android/adservices/service/topics/classifier/OnDeviceClassifierTest.java
@@ -187,6 +187,29 @@
}
@Test
+ public void testClassify_successfulClassifications_overrideClassifierThreshold() {
+ // Check getClassification for sample descriptions.
+ String appPackage1 = "com.example.adservices.samples.topics.sampleapp1";
+ ImmutableMap<String, AppInfo> appInfoMap =
+ ImmutableMap.<String, AppInfo>builder()
+ .put(appPackage1, new AppInfo("appName1", "Sample app description."))
+ .build();
+ ImmutableSet<String> appPackages = ImmutableSet.of(appPackage1);
+ when(mPackageManagerUtil.getAppInformation(eq(appPackages))).thenReturn(appInfoMap);
+ // Override classifierThreshold.
+ float overrideThreshold = 0.1f;
+ setClassifierThreshold(overrideThreshold);
+
+ ImmutableMap<String, List<Topic>> classifications =
+ mOnDeviceClassifier.classify(appPackages);
+
+ verify(mPackageManagerUtil).getAppInformation(eq(appPackages));
+ assertThat(classifications).hasSize(1);
+ // Expecting 2 values greater than 0.1 threshold.
+ assertThat(classifications.get(appPackage1)).hasSize(2);
+ }
+
+ @Test
public void testClassify_successfulClassificationsForUpdatedAppDescription() {
// Check getClassification for sample descriptions.
String appPackage1 = "com.example.adservices.samples.topics.sampleapp1";
@@ -347,4 +370,12 @@
Integer.toString(overrideValue),
/* makeDefault */ false);
}
+
+ private void setClassifierThreshold(float overrideValue) {
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_ADSERVICES,
+ "classifier_threshold",
+ Float.toString(overrideValue),
+ /* makeDefault */ false);
+ }
}