Add a flag for smart test selection.

Also under smart test selection mode, aggresively seperate each test into
one invocation.

Bug: 396787299
Test: atest-dev atest_unittests
Change-Id: I7e6b42ae2e6e259b5d582c084d13db6846c3451f
diff --git a/atest/arg_parser.py b/atest/arg_parser.py
index 3250525..68c56c1 100644
--- a/atest/arg_parser.py
+++ b/atest/arg_parser.py
@@ -256,6 +256,15 @@
       ),
   )
   parser.add_argument(
+      '--smart-test-selection',
+      default=False,
+      action='store_true',
+      help=(
+          'Automatically select test classes based on correlation with code'
+          ' change, and run them.'
+      ),
+  )
+  parser.add_argument(
       '--use-modules-in',
       help=(
           'Force include MODULES-IN-* as build targets. Hint: This may solve'
diff --git a/atest/test_runners/atest_tf_test_runner.py b/atest/test_runners/atest_tf_test_runner.py
index faf565f..2db8d92 100644
--- a/atest/test_runners/atest_tf_test_runner.py
+++ b/atest/test_runners/atest_tf_test_runner.py
@@ -94,6 +94,9 @@
 # The environment variable for TF preparer incremental setup.
 _INCREMENTAL_SETUP_KEY = 'TF_PREPARER_INCREMENTAL_SETUP'
 
+# Smart test selection.
+_SMART_TEST_SELECTION = 'smart_test_selection'
+
 
 class Error(Exception):
   """Module-level error."""
@@ -196,6 +199,7 @@
     metrics.LocalDetectEvent(
         detect_type=DetectType.IS_MINIMAL_BUILD, result=int(self._minimal_build)
     )
+    self._smart_test_selection = extra_args.get(_SMART_TEST_SELECTION, False)
 
   def requires_device_update(
       self, test_infos: List[test_info.TestInfo]
@@ -225,42 +229,47 @@
         A list of TestRunnerInvocation instances.
     """
     invocations = []
-    device_test_infos, deviceless_test_infos = self._partition_tests(test_infos)
-    if deviceless_test_infos:
+    device_test_info_lists, deviceless_test_info_lists = self._partition_tests(
+        test_infos
+    )
+    if deviceless_test_info_lists:
       extra_args_for_deviceless_test = extra_args.copy()
       extra_args_for_deviceless_test.update({constants.HOST: True})
-      invocations.append(
-          TestRunnerInvocation(
-              test_runner=self,
-              extra_args=extra_args_for_deviceless_test,
-              test_infos=deviceless_test_infos,
-          )
-      )
-    if device_test_infos:
+      for temp_test_infos in deviceless_test_info_lists:
+        invocations.append(
+            TestRunnerInvocation(
+                test_runner=self,
+                extra_args=extra_args_for_deviceless_test,
+                test_infos=temp_test_infos,
+            )
+        )
+    if device_test_info_lists:
       extra_args_for_device_test = extra_args.copy()
       if rollout_control.tf_preparer_incremental_setup.is_enabled():
         extra_args_for_device_test.update({_INCREMENTAL_SETUP_KEY: True})
-      invocations.append(
-          TestRunnerInvocation(
-              test_runner=self,
-              extra_args=extra_args_for_device_test,
-              test_infos=device_test_infos,
-          )
-      )
+      for temp_test_infos in device_test_info_lists:
+        invocations.append(
+            TestRunnerInvocation(
+                test_runner=self,
+                extra_args=extra_args_for_device_test,
+                test_infos=temp_test_infos,
+            )
+        )
 
     return invocations
 
   def _partition_tests(
       self,
       test_infos: List[test_info.TestInfo],
-  ) -> (List[test_info.TestInfo], List[test_info.TestInfo]):
+  ) -> (List[List[test_info.TestInfo]], List[List[test_info.TestInfo]]):
     """Partition input tests into two lists based on whether it requires device.
 
     Args:
         test_infos: A list of TestInfos.
 
     Returns:
-        Two lists one contains device tests the other contains deviceless tests.
+        Two lists one contains device test info lists the other contains
+        deviceless test info lists.
     """
     device_test_infos = []
     deviceless_test_infos = []
@@ -272,7 +281,15 @@
       else:
         deviceless_test_infos.append(info)
 
-    return device_test_infos, deviceless_test_infos
+    return [
+        [info] for info in device_test_infos
+    ] if self._smart_test_selection or not device_test_infos else [
+        device_test_infos
+    ], [
+        [info] for info in deviceless_test_infos
+    ] if self._smart_test_selection or not deviceless_test_infos else [
+        deviceless_test_infos
+    ]
 
   def _try_set_gts_authentication_key(self):
     """Set GTS authentication key if it is available or exists.
diff --git a/atest/test_runners/atest_tf_test_runner_unittest.py b/atest/test_runners/atest_tf_test_runner_unittest.py
index 3437888..95191ee 100755
--- a/atest/test_runners/atest_tf_test_runner_unittest.py
+++ b/atest/test_runners/atest_tf_test_runner_unittest.py
@@ -1290,6 +1290,57 @@
     )
     self.assertFalse(constants.HOST in invocations[0]._extra_args)
 
+  def test_create_invocations_with_smart_test_selection_returns_multiple_invocations(
+      self,
+  ):
+    tr = atf_tr.AtestTradefedTestRunner(
+        results_dir=uc.TEST_INFO_DIR,
+        extra_args={
+            constants.HOST: False,
+            atf_tr._SMART_TEST_SELECTION: True,
+        },
+    )
+    tr.module_info = module_info.ModuleInfo(
+        name_to_module_info={
+            'device_test_1': (
+                module_info_unittest_base.device_driven_test_module(
+                    name='device_test_1'
+                )
+            ),
+            'device_test_2': (
+                module_info_unittest_base.host_driven_device_test_module(
+                    name='device_test_2'
+                )
+            ),
+            'deviceless_test_1': (
+                module_info_unittest_base.robolectric_test_module(
+                    name='deviceless_test_1'
+                )
+            ),
+            'deviceless_test_2': (
+                module_info_unittest_base.robolectric_test_module(
+                    name='deviceless_test_2'
+                )
+            ),
+        }
+    )
+    test_info_device_1 = test_info_of('device_test_1')
+    test_info_device_2 = test_info_of('device_test_2')
+    test_info_deviceless_1 = test_info_of('deviceless_test_1')
+    test_info_deviceless_2 = test_info_of('deviceless_test_2')
+
+    invocations = tr.create_invocations(
+        {},
+        [
+            test_info_device_1,
+            test_info_device_2,
+            test_info_deviceless_1,
+            test_info_deviceless_2,
+        ],
+    )
+
+    self.assertEqual(len(invocations), 4)
+
   def test_create_invocations_returns_invocation_only_for_deviceless_tests(
       self,
   ):