Add `classifier_threshold` flag to control the threshold for classifier values returned by the model.

Test: atest PhFlagsTest TopicsManagerTest
Fixes: 242628225
Change-Id: I41f369a611557c7f65451cd09c0db7bc42309ab3
diff --git a/adservices/service-core/java/com/android/adservices/service/Flags.java b/adservices/service-core/java/com/android/adservices/service/Flags.java
index 0d1a569..99025fb 100644
--- a/adservices/service-core/java/com/android/adservices/service/Flags.java
+++ b/adservices/service-core/java/com/android/adservices/service/Flags.java
@@ -123,6 +123,14 @@
         return CLASSIFIER_NUMBER_OF_TOP_LABELS;
     }
 
+    /** Threshold value for classification values. */
+    float CLASSIFIER_THRESHOLD = 0.0f;
+
+    /** Returns the threshold value for classification values. */
+    default float getClassifierThreshold() {
+        return CLASSIFIER_THRESHOLD;
+    }
+
     /* The default period for the Maintenance job. */
     long MAINTENANCE_JOB_PERIOD_MS = 86_400_000; // 1 day.
 
diff --git a/adservices/service-core/java/com/android/adservices/service/PhFlags.java b/adservices/service-core/java/com/android/adservices/service/PhFlags.java
index 48d38f0..a81cf3a 100644
--- a/adservices/service-core/java/com/android/adservices/service/PhFlags.java
+++ b/adservices/service-core/java/com/android/adservices/service/PhFlags.java
@@ -52,6 +52,7 @@
     // Topics classifier keys
     static final String KEY_CLASSIFIER_TYPE = "classifier_type";
     static final String KEY_CLASSIFIER_NUMBER_OF_TOP_LABELS = "classifier_number_of_top_labels";
+    static final String KEY_CLASSIFIER_THRESHOLD = "classifier_threshold";
 
     // Measurement keys
     static final String KEY_MEASUREMENT_EVENT_MAIN_REPORTING_JOB_PERIOD_MS =
@@ -320,6 +321,15 @@
     }
 
     @Override
+    public float getClassifierThreshold() {
+        // The priority of applying the flag values: PH (DeviceConfig) and then hard-coded value.
+        return DeviceConfig.getFloat(
+                DeviceConfig.NAMESPACE_ADSERVICES,
+                /* flagName */ KEY_CLASSIFIER_THRESHOLD,
+                /* defaultValue */ CLASSIFIER_THRESHOLD);
+    }
+
+    @Override
     public long getMaintenanceJobPeriodMs() {
         // The priority of applying the flag values: SystemProperties, PH (DeviceConfig) and then
         // hard-coded value.
diff --git a/adservices/service-core/java/com/android/adservices/service/topics/classifier/OnDeviceClassifier.java b/adservices/service-core/java/com/android/adservices/service/topics/classifier/OnDeviceClassifier.java
index 6fd0e3f..d69f384 100644
--- a/adservices/service-core/java/com/android/adservices/service/topics/classifier/OnDeviceClassifier.java
+++ b/adservices/service-core/java/com/android/adservices/service/topics/classifier/OnDeviceClassifier.java
@@ -155,14 +155,24 @@
         // Limit the number of entries to first MAX_LABELS_PER_APP.
         // TODO(b/235435229): Evaluate the strategy to use first x elements.
         int numberOfTopLabels = FlagsFactory.getFlags().getClassifierNumberOfTopLabels();
+        float classifierThresholdValue = FlagsFactory.getFlags().getClassifierThreshold();
+        LogUtil.i(
+                "numberOfTopLabels = %s\n classifierThresholdValue = %s",
+                numberOfTopLabels, classifierThresholdValue);
         return classifications.stream()
                 .sorted((c1, c2) -> Float.compare(c2.getScore(), c1.getScore())) // Reverse sorted.
+                .filter(category -> isAboveThreshold(category, classifierThresholdValue))
                 .map(OnDeviceClassifier::convertCategoryLabelToTopicId)
                 .map(this::createTopic)
                 .limit(numberOfTopLabels)
                 .collect(Collectors.toList());
     }
 
+    // Filter category above the required threshold.
+    private static boolean isAboveThreshold(Category category, float classifierThresholdValue) {
+        return category.getScore() >= classifierThresholdValue;
+    }
+
     // Converts Category Label to TopicId. Expects label to be labelId of the classified category.
     // Returns -1 if conversion to int fails for the label.
     private static int convertCategoryLabelToTopicId(Category category) {
diff --git a/adservices/tests/unittest/service-core/src/com/android/adservices/service/PhFlagsTest.java b/adservices/tests/unittest/service-core/src/com/android/adservices/service/PhFlagsTest.java
index a04e05c..02b0946 100644
--- a/adservices/tests/unittest/service-core/src/com/android/adservices/service/PhFlagsTest.java
+++ b/adservices/tests/unittest/service-core/src/com/android/adservices/service/PhFlagsTest.java
@@ -19,6 +19,7 @@
 import static com.android.adservices.service.Flags.ADID_KILL_SWITCH;
 import static com.android.adservices.service.Flags.APPSETID_KILL_SWITCH;
 import static com.android.adservices.service.Flags.CLASSIFIER_NUMBER_OF_TOP_LABELS;
+import static com.android.adservices.service.Flags.CLASSIFIER_THRESHOLD;
 import static com.android.adservices.service.Flags.DEFAULT_CLASSIFIER_TYPE;
 import static com.android.adservices.service.Flags.DISABLE_FLEDGE_ENROLLMENT_CHECK;
 import static com.android.adservices.service.Flags.DISABLE_TOPICS_ENROLLMENT_CHECK;
@@ -107,6 +108,7 @@
 import static com.android.adservices.service.PhFlags.KEY_ADID_KILL_SWITCH;
 import static com.android.adservices.service.PhFlags.KEY_APPSETID_KILL_SWITCH;
 import static com.android.adservices.service.PhFlags.KEY_CLASSIFIER_NUMBER_OF_TOP_LABELS;
+import static com.android.adservices.service.PhFlags.KEY_CLASSIFIER_THRESHOLD;
 import static com.android.adservices.service.PhFlags.KEY_CLASSIFIER_TYPE;
 import static com.android.adservices.service.PhFlags.KEY_DISABLE_FLEDGE_ENROLLMENT_CHECK;
 import static com.android.adservices.service.PhFlags.KEY_DISABLE_TOPICS_ENROLLMENT_CHECK;
@@ -350,6 +352,23 @@
     }
 
     @Test
+    public void testGetClassifierThreshold() {
+        // Without any overriding, the value is the hard coded constant.
+        assertThat(FlagsFactory.getFlags().getClassifierThreshold())
+                .isEqualTo(CLASSIFIER_THRESHOLD);
+
+        float phOverridingValue = 0.3f;
+        DeviceConfig.setProperty(
+                DeviceConfig.NAMESPACE_ADSERVICES,
+                KEY_CLASSIFIER_THRESHOLD,
+                Float.toString(phOverridingValue),
+                /* makeDefault */ false);
+
+        Flags phFlags = FlagsFactory.getFlags();
+        assertThat(phFlags.getClassifierThreshold()).isEqualTo(phOverridingValue);
+    }
+
+    @Test
     public void testGetMaintenanceJobPeriodMs() {
         // Without any overriding, the value is the hard coded constant.
         assertThat(FlagsFactory.getFlags().getMaintenanceJobPeriodMs())
diff --git a/adservices/tests/unittest/service-core/src/com/android/adservices/service/topics/classifier/OnDeviceClassifierTest.java b/adservices/tests/unittest/service-core/src/com/android/adservices/service/topics/classifier/OnDeviceClassifierTest.java
index ec78896..9e8d17b 100644
--- a/adservices/tests/unittest/service-core/src/com/android/adservices/service/topics/classifier/OnDeviceClassifierTest.java
+++ b/adservices/tests/unittest/service-core/src/com/android/adservices/service/topics/classifier/OnDeviceClassifierTest.java
@@ -187,6 +187,29 @@
     }
 
     @Test
+    public void testClassify_successfulClassifications_overrideClassifierThreshold() {
+        // Check getClassification for sample descriptions.
+        String appPackage1 = "com.example.adservices.samples.topics.sampleapp1";
+        ImmutableMap<String, AppInfo> appInfoMap =
+                ImmutableMap.<String, AppInfo>builder()
+                        .put(appPackage1, new AppInfo("appName1", "Sample app description."))
+                        .build();
+        ImmutableSet<String> appPackages = ImmutableSet.of(appPackage1);
+        when(mPackageManagerUtil.getAppInformation(eq(appPackages))).thenReturn(appInfoMap);
+        // Override classifierThreshold.
+        float overrideThreshold = 0.1f;
+        setClassifierThreshold(overrideThreshold);
+
+        ImmutableMap<String, List<Topic>> classifications =
+                mOnDeviceClassifier.classify(appPackages);
+
+        verify(mPackageManagerUtil).getAppInformation(eq(appPackages));
+        assertThat(classifications).hasSize(1);
+        // Expecting 2 values greater than 0.1 threshold.
+        assertThat(classifications.get(appPackage1)).hasSize(2);
+    }
+
+    @Test
     public void testClassify_successfulClassificationsForUpdatedAppDescription() {
         // Check getClassification for sample descriptions.
         String appPackage1 = "com.example.adservices.samples.topics.sampleapp1";
@@ -347,4 +370,12 @@
                 Integer.toString(overrideValue),
                 /* makeDefault */ false);
     }
+
+    private void setClassifierThreshold(float overrideValue) {
+        DeviceConfig.setProperty(
+                DeviceConfig.NAMESPACE_ADSERVICES,
+                "classifier_threshold",
+                Float.toString(overrideValue),
+                /* makeDefault */ false);
+    }
 }