When sharding tests, use the test order provided by the TestLoader.

Test sharding works by sorting the list of all tests to execute, and having
shard i take every ith test. (This has the appealing property that each shard
can determine which tests to run independently).

However, if the TestLoader returns the test cases in a different order, we
want to *ignore* that order for selecting tests to execute, but *respect* that
order for the actual execution of the tests, eg in the case of randomized test
ordering.

This change tweaks the sharding implementation to respect the test case order
returned by the TestLoader, and adds a new test.

I also cleaned up the use of `assertEquals` in favor of `assertEqual` while I
was here.

PiperOrigin-RevId: 300132733
Change-Id: Iab86aeb42163c40c24a60af2e779499d0696497d
diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py
index 67c2f9a..cd80dc2 100644
--- a/absl/testing/absltest.py
+++ b/absl/testing/absltest.py
@@ -2242,11 +2242,15 @@
 
   def getShardedTestCaseNames(testCaseClass):
     filtered_names = []
-    for testcase in sorted(delegate_get_names(testCaseClass)):
+    # We need to sort the list of tests in order to determine which tests this
+    # shard is responsible for; however, it's important to preserve the order
+    # returned by the base loader, e.g. in the case of randomized test ordering.
+    ordered_names = delegate_get_names(testCaseClass)
+    for testcase in sorted(ordered_names):
       bucket = next(bucket_iterator)
       if bucket == shard_index:
         filtered_names.append(testcase)
-    return filtered_names
+    return [x for x in ordered_names if x in filtered_names]
 
   base_loader.getTestCaseNames = getShardedTestCaseNames
   return base_loader
diff --git a/absl/testing/tests/absltest_sharding_test.py b/absl/testing/tests/absltest_sharding_test.py
index f32dc11..0b10f0f 100755
--- a/absl/testing/tests/absltest_sharding_test.py
+++ b/absl/testing/tests/absltest_sharding_test.py
@@ -43,7 +43,7 @@
     if self._shard_file is not None and os.path.exists(self._shard_file):
       os.unlink(self._shard_file)
 
-  def _run_sharded(self, total_shards, shard_index, shard_file=None):
+  def _run_sharded(self, total_shards, shard_index, shard_file=None, env=None):
     """Runs the py_test binary in a subprocess.
 
     Args:
@@ -51,12 +51,17 @@
       shard_index: int, the shard index.
       shard_file: string, if not 'None', the path to the shard file.
         This method asserts it is properly created.
+      env: Environment variables to be set for the py_test binary.
 
     Returns:
       (stdout, exit_code) tuple of (string, int).
     """
-    env = {'TEST_TOTAL_SHARDS': str(total_shards),
-           'TEST_SHARD_INDEX': str(shard_index)}
+    if env is None:
+      env = {}
+    env.update({
+        'TEST_TOTAL_SHARDS': str(total_shards),
+        'TEST_SHARD_INDEX': str(shard_index)
+    })
     if 'SYSTEMROOT' in os.environ:
       # This is used by the random module on Windows to locate crypto
       # libraries.
@@ -102,14 +107,14 @@
       combined_outerr.extend(method_list)
       exit_code_by_shard.append(exit_code)
 
-    self.assertEquals(1, len([x for x in exit_code_by_shard if x != 0]),
-                      'Expected exactly one failure')
+    self.assertLen([x for x in exit_code_by_shard if x != 0], 1,
+                   'Expected exactly one failure')
 
     # Test completeness and partition properties.
-    self.assertEquals(NUM_TEST_METHODS, len(combined_outerr),
-                      'Partition requirement not met')
-    self.assertEquals(NUM_TEST_METHODS, len(set(combined_outerr)),
-                      'Completeness requirement not met')
+    self.assertLen(combined_outerr, NUM_TEST_METHODS,
+                   'Partition requirement not met')
+    self.assertLen(set(combined_outerr), NUM_TEST_METHODS,
+                   'Completeness requirement not met')
 
     # Test balance:
     for i in range(len(outerr_by_shard)):
@@ -123,7 +128,7 @@
 
   def test_zero_shards(self):
     out, exit_code = self._run_sharded(0, 0)
-    self.assertEquals(1, exit_code)
+    self.assertEqual(1, exit_code)
     self.assertGreaterEqual(out.find('Bad sharding values. index=0, total=0'),
                             0, 'Bad output: %s' % (out))
 
@@ -136,6 +141,20 @@
   def test_with_ten_shards(self):
     self._assert_sharding_correctness(10)
 
+  def test_sharding_with_randomization(self):
+    # If we're both sharding *and* randomizing, we need to confirm that we
+    # randomize within the shard; we use two seeds to confirm we're seeing the
+    # same tests (sharding is consistent) in a different order.
+    tests_seen = []
+    for seed in ('7', '17'):
+      out, exit_code = self._run_sharded(
+          2, 0, env={'TEST_RANDOMIZE_ORDERING_SEED': seed})
+      self.assertEqual(0, exit_code)
+      tests_seen.append([x for x in out.splitlines() if x.startswith('class')])
+    first_tests, second_tests = tests_seen  # pylint: disable=unbalanced-tuple-unpacking
+    self.assertEqual(set(first_tests), set(second_tests))
+    self.assertNotEqual(first_tests, second_tests)
+
 
 if __name__ == '__main__':
   absltest.main()