move CI related functions out of run_test.py (#61124)

Summary:
run_test.py currently does lots of downloading and test file/suite/case parsing. It doesn't work well outside of the CI environment

Restructured the run_test.py and created tools/test/test_selections.py and move all test selection logic (reordering, categorizing slow test, creating shards)

Follow up PRs should:
- refactor those file read/write logic entangled inside test_selections.py into stats/ folder
- restructure and add network independent test logics to test_test_selections.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61124

Test Plan:
- tools/test
- CI

Related PR:
This follows the refactoring example in: https://github.com/pytorch/pytorch/issues/60373

Reviewed By: malfet

Differential Revision: D29558981

Pulled By: walterddr

fbshipit-source-id: 7f0fd9b4720a918d82918766c002295e8df04169
diff --git a/test/run_test.py b/test/run_test.py
index a54b029..ae8e1cf 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -2,9 +2,7 @@
 
 import argparse
 import copy
-import csv
 from datetime import datetime
-import json
 import modulefinder
 import os
 import shutil
@@ -17,20 +15,22 @@
 from torch.utils import cpp_extension
 from torch.testing._internal.common_utils import FILE_SCHEMA, IS_IN_CI, TEST_WITH_ROCM, shell, set_cwd
 import torch.distributed as dist
-from typing import Dict, Optional, Tuple, List, Any
-from typing_extensions import TypedDict
+from typing import Dict, Optional, List
 
 try:
+    # using tools/ to optimize test run.
     sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
-    from tools.stats.s3_stat_parser import (
-        get_previous_reports_for_branch,
-        get_previous_reports_for_pr,
-        Report,
-        HAVE_BOTO3)
-    from tools.testing.test_selections import calculate_shards
+    from tools.testing.test_selections import (
+        export_S3_test_times,
+        get_shard_based_on_S3,
+        get_slow_tests_based_on_S3,
+        get_specified_test_cases,
+        get_reordered_tests
+    )
+    HAVE_TEST_SELECTION_TOOLS = True
 except ImportError:
-    print("Unable to import s3_stat_parser from tools. Running without S3 stats...")
-    HAVE_BOTO3 = False
+    HAVE_TEST_SELECTION_TOOLS = False
+    print("Unable to import test_selections from tools/testing. Running without test selection stats...")
 
 
 TESTS = [
@@ -408,130 +408,6 @@
     print(message, file=sys.stderr)
 
 
-# Convert something like pytorch_windows_vs2019_py36_cuda10.1_build to pytorch_windows_vs2019_py36_cuda10.1
-def get_stripped_CI_job() -> str:
-    job = os.environ.get("JOB_BASE_NAME", "").rstrip('0123456789')
-    if job.endswith('_slow_test'):
-        job = job[:len(job) - len('_slow_test')]
-    elif job.endswith('_test') or job.endswith('-test'):
-        job = job[:len(job) - len('_test')]
-    elif job.endswith('_build') or job.endswith('-build'):
-        job = job[:len(job) - len('_build')]
-    return job
-
-
-def calculate_job_times(reports: List["Report"]) -> Dict[str, float]:
-    # an entry will be like ("test_file_name" -> (current_avg, # values))
-    jobs_to_times: Dict[str, Tuple[float, int]] = dict()
-    for report in reports:
-        assert report.get('format_version') == 2, "S3 format currently handled is version 2 only"
-        files: Dict[str, Any] = report['files']
-        for name, test_file in files.items():
-            if name not in jobs_to_times:
-                jobs_to_times[name] = (test_file['total_seconds'], 1)
-            else:
-                curr_avg, curr_count = jobs_to_times[name]
-                new_count = curr_count + 1
-                new_avg = (curr_avg * curr_count + test_file['total_seconds']) / new_count
-                jobs_to_times[name] = (new_avg, new_count)
-
-    # This is no longer needed after https://github.com/pytorch/pytorch/pull/60604,
-    # TODO remove this once viable/strict move pass the merged commit.
-    # if there's 'test_cpp_extensions_aot' entry in jobs_to_times, add 'test_cpp_extensions_aot_ninja'
-    # and 'test_cpp_extensions_aot_no_ninja' duplicate entries to ease future computation since
-    # test_cpp_extensions_aot_no_ninja and test_cpp_extensions_aot_ninja are Python test jobs that
-    # both use the test_cpp_extensions_aot.py file.
-    if 'test_cpp_extensions_aot' in jobs_to_times:
-        jobs_to_times['test_cpp_extensions_aot_ninja'] = jobs_to_times['test_cpp_extensions_aot']
-        jobs_to_times['test_cpp_extensions_aot_no_ninja'] = jobs_to_times['test_cpp_extensions_aot']
-    return {job: time for job, (time, _) in jobs_to_times.items()}
-
-
-def pull_job_times_from_S3() -> Dict[str, float]:
-    if HAVE_BOTO3:
-        ci_job_prefix = get_stripped_CI_job()
-        s3_reports: List["Report"] = get_previous_reports_for_branch('origin/viable/strict', ci_job_prefix)
-    else:
-        print('Uh oh, boto3 is not found. Either it is not installed or we failed to import s3_stat_parser.')
-        print('If not installed, please install boto3 for automatic sharding and test categorization.')
-        s3_reports = []
-
-    if len(s3_reports) == 0:
-        print('Gathered no reports from S3. Please proceed without them.')
-        return dict()
-
-    return calculate_job_times(s3_reports)
-
-
-def get_past_job_times() -> Dict[str, float]:
-    if os.path.exists(TEST_TIMES_FILE):
-        with open(TEST_TIMES_FILE) as file:
-            test_times_json: JobTimeJSON = json.load(file)
-
-        curr_commit = subprocess.check_output(['git', 'rev-parse', 'HEAD'], encoding="ascii").strip()
-        file_commit = test_times_json.get('commit', '')
-        curr_ci_job = get_stripped_CI_job()
-        file_ci_job = test_times_json.get('JOB_BASE_NAME', 'N/A')
-        if curr_commit != file_commit:
-            print(f'Current test times file is from different commit {file_commit}.')
-        elif curr_ci_job != file_ci_job:
-            print(f'Current test times file is for different CI job {file_ci_job}.')
-        else:
-            print(f'Found stats for current commit: {curr_commit} and job: {curr_ci_job}. Proceeding with those values.')
-            return test_times_json.get('job_times', {})
-
-        # Found file, but commit or CI job in JSON doesn't match
-        print(f'Overwriting current file with stats based on current commit: {curr_commit} and CI job: {curr_ci_job}')
-
-    job_times = pull_job_times_from_S3()
-    print(f'Exporting S3 test stats to {TEST_TIMES_FILE}.')
-    export_S3_test_times(TEST_TIMES_FILE, job_times)
-
-    return job_times
-
-
-class JobTimeJSON(TypedDict):
-    commit: str
-    job_times: Dict[str, float]
-
-
-def get_job_times_json(job_times: Dict[str, float]) -> JobTimeJSON:
-    return {
-        'commit': subprocess.check_output(['git', 'rev-parse', 'HEAD'], encoding="ascii").strip(),
-        'JOB_BASE_NAME': get_stripped_CI_job(),
-        'job_times': job_times,
-    }
-
-
-def get_shard(which_shard: int, num_shards: int, tests: List[str]) -> List[str]:
-    jobs_to_times = get_past_job_times()
-
-    # Got no stats from S3, returning early to save runtime
-    if len(jobs_to_times) == 0:
-        print('Gathered no stats from S3. Proceeding with default sharding plan.')
-        return tests[which_shard - 1 :: num_shards]
-
-    shards = calculate_shards(num_shards, tests, jobs_to_times)
-    _, tests_from_shard = shards[which_shard - 1]
-    return tests_from_shard
-
-
-def get_slow_tests_based_on_S3() -> List[str]:
-    jobs_to_times: Dict[str, float] = get_past_job_times()
-
-    # Got no stats from S3, returning early to save runtime
-    if len(jobs_to_times) == 0:
-        print('Gathered no stats from S3. No new slow tests calculated.')
-        return []
-
-    slow_tests: List[str] = []
-    for test in TESTS:
-        if test in jobs_to_times and test not in TARGET_DET_LIST:
-            if jobs_to_times[test] > SLOW_TEST_THRESHOLD:
-                slow_tests.append(test)
-    return slow_tests
-
-
 def get_test_case_args(test_module, using_pytest) -> List[str]:
     args = []
     # if test_module not specified or specified with '__all__' then run all tests
@@ -962,7 +838,9 @@
         which_shard, num_shards = options.shard
         assert which_shard <= num_shards, "Selected shard must be less than or equal to total number of shards"
         assert num_shards <= len(selected_tests), f"Number of shards must be less than {len(selected_tests)}"
-        selected_tests = get_shard(which_shard, num_shards, selected_tests)
+        # TODO: fix this to use test_times_filename, but currently this is not working
+        # because setting the export arg immeidately halts the test execution.
+        selected_tests = get_shard_based_on_S3(which_shard, num_shards, selected_tests, TEST_TIMES_FILE)
 
     return selected_tests
 
@@ -1118,127 +996,22 @@
         message += f' Received signal: {signal_name}'
     return message
 
-def export_S3_test_times(test_times_filename: str, test_times: Dict[str, float]) -> None:
-    if os.path.exists(test_times_filename):
-        print(f'Overwriting existent file: {test_times_filename}')
-    with open(test_times_filename, 'w+') as file:
-        job_times_json = get_job_times_json(test_times)
-        json.dump(job_times_json, file, indent='    ', separators=(',', ': '))
-        file.write('\n')
-
-
-def load_specified_test_cases(filename: str) -> None:
-    if not os.path.exists(filename):
-        print(f'Could not find specified tests file: {filename}. Proceeding with default behavior.')
-        return
-
-    # The below encoding is utf-8-sig because utf-8 doesn't properly handle the byte-order-mark character
-    with open(filename, mode='r', encoding="utf-8-sig") as csv_file:
-        csv_reader = csv.DictReader(csv_file)
-        line_count = 0
-        global SPECIFIED_TEST_CASES_DICT
-        for row in csv_reader:
-            line_count += 1
-            if line_count == 1:
-                if 'test_filename' not in row or 'test_case_name' not in row:
-                    print('Data is missing necessary columns for test specification. Proceeding with default behavior.')
-                    return
-            test_filename = row['test_filename']
-            test_case_name = row['test_case_name']
-            if test_filename not in TESTS:
-                print(f'Specified test_filename {test_filename} not found in TESTS. Skipping.')
-                continue
-            if test_filename not in SPECIFIED_TEST_CASES_DICT:
-                SPECIFIED_TEST_CASES_DICT[test_filename] = []
-            SPECIFIED_TEST_CASES_DICT[test_filename].append(test_case_name)
-        print(f'Processed {line_count} test cases.')
-
-
-def query_changed_test_files() -> List[str]:
-    cmd = ["git", "diff", "--name-only", "origin/master", "HEAD"]
-    proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
-
-    if proc.returncode != 0:
-        raise RuntimeError("Unable to get changed files")
-
-    lines = proc.stdout.decode().strip().split("\n")
-    lines = [line.strip() for line in lines]
-    return lines
-
-def query_failure_test_module(reports: List[Tuple["Report", str]]) -> List[str]:
-    test_modules = []
-    if len(reports) == 0 or len(reports[0]) == 0:
-        return test_modules
-    report = reports[0][0]
-    assert report.get('format_version') == 2, "S3 format currently handled is version 2 only"
-    files: Dict[str, Any] = report['files']
-    for fname, file in files.items():
-        contains_failure = any(
-            any(case['status'] == 'errored' or case['status'] == 'failed'
-                for _, case in suite['cases'].items())
-            for _, suite in file['suites'].items())
-        if contains_failure:
-            test_modules.append(fname)
-    return test_modules
-
-
-def reorder_tests(tests: List[str]) -> List[str]:
-    prioritized_tests = []
-    # Try using historic stats from PR.
-    if ENABLE_PR_HISTORY_REORDERING and HAVE_BOTO3:
-        pr_number = os.environ.get("CIRCLE_PR_NUMBER", "")
-        if len(pr_number):
-            ci_job_prefix = get_stripped_CI_job()
-            s3_reports: List[Tuple["Report", str]] = get_previous_reports_for_pr(
-                pr_number, ci_job_prefix)
-            prioritized_tests = query_failure_test_module(s3_reports)
-            print("Prioritized test from previous CI info.")
-
-    # Using file changes priority if no stats found from previous PR.
-    if len(prioritized_tests) == 0:
-        try:
-            changed_files = query_changed_test_files()
-        except Exception:
-            # If unable to get changed files from git, quit without doing any sorting
-            return tests
-
-        prefix = f"test{os.path.sep}"
-        prioritized_tests = [f for f in changed_files if f.startswith(prefix) and f.endswith(".py")]
-        prioritized_tests = [f[len(prefix):] for f in prioritized_tests]
-        prioritized_tests = [f[:-len(".py")] for f in prioritized_tests]
-        print("Prioritized test from test file changes.")
-
-    bring_to_front = []
-    the_rest = []
-
-    for test in tests:
-        if test in prioritized_tests:
-            bring_to_front.append(test)
-        else:
-            the_rest.append(test)
-    if len(tests) == len(bring_to_front) + len(the_rest):
-        print(f"reordering tests for PR:\n"
-              f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n")
-        return bring_to_front + the_rest
-    else:
-        print(f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
-              f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n")
-        return tests
-
 
 def main():
     options = parse_args()
 
+    # TODO: move this export & download function in tools/ folder
     test_times_filename = options.export_past_test_times
     if test_times_filename:
         print(f'Exporting past test times from S3 to {test_times_filename}, no tests will be run.')
-        export_S3_test_times(test_times_filename, pull_job_times_from_S3())
+        export_S3_test_times(test_times_filename)
         return
 
     specified_test_cases_filename = options.run_specified_test_cases
     if specified_test_cases_filename:
         print(f'Loading specified test cases to run from {specified_test_cases_filename}.')
-        load_specified_test_cases(specified_test_cases_filename)
+        global SPECIFIED_TEST_CASES_DICT
+        SPECIFIED_TEST_CASES_DICT = get_specified_test_cases(specified_test_cases_filename, TESTS)
 
     test_directory = os.path.dirname(os.path.abspath(__file__))
     selected_tests = get_selected_tests(options)
@@ -1253,7 +1026,7 @@
         selected_tests = filter(lambda test_name: "jit" in test_name, TESTS)
 
     if options.determine_from is not None and os.path.exists(options.determine_from):
-        slow_tests = get_slow_tests_based_on_S3()
+        slow_tests = get_slow_tests_based_on_S3(TESTS, TARGET_DET_LIST, SLOW_TEST_THRESHOLD)
         print('Added the following tests to target_det tests as calculated based on S3:')
         print(slow_tests)
         with open(options.determine_from, 'r') as fh:
@@ -1270,7 +1043,7 @@
         sys.path.remove('test')
 
     if IS_IN_CI:
-        selected_tests = reorder_tests(selected_tests)
+        selected_tests = get_reordered_tests(selected_tests, ENABLE_PR_HISTORY_REORDERING)
 
     has_failed = False
     failure_messages = []
diff --git a/tools/test/test_test_selection.py b/tools/test/test_test_selections.py
similarity index 100%
rename from tools/test/test_test_selection.py
rename to tools/test/test_test_selections.py
diff --git a/tools/testing/test_selections.py b/tools/testing/test_selections.py
index 7b6d089..c504f87 100644
--- a/tools/testing/test_selections.py
+++ b/tools/testing/test_selections.py
@@ -1,4 +1,64 @@
-from typing import Dict, Tuple, List
+import csv
+import json
+import os
+import subprocess
+
+from tools.stats.s3_stat_parser import (
+    get_previous_reports_for_branch,
+    get_previous_reports_for_pr,
+    Report, Version2Report,
+    HAVE_BOTO3)
+
+from typing import Any, Dict, List, Optional, Tuple, cast
+from typing_extensions import TypedDict
+
+class JobTimeJSON(TypedDict):
+    commit: str
+    JOB_BASE_NAME: str
+    job_times: Dict[str, float]
+
+
+def _get_stripped_CI_job() -> str:
+    """E.g. convert 'pytorch_windows_vs2019_py36_cuda10.1_build' to 'pytorch_windows_vs2019_py36_cuda10.1'.
+    """
+    job = os.environ.get("JOB_BASE_NAME", "").rstrip('0123456789')
+    if job.endswith('_slow_test'):
+        job = job[:len(job) - len('_slow_test')]
+    elif job.endswith('_test') or job.endswith('-test'):
+        job = job[:len(job) - len('_test')]
+    elif job.endswith('_build') or job.endswith('-build'):
+        job = job[:len(job) - len('_build')]
+    return job
+
+
+def _get_job_times_json(job_times: Dict[str, float]) -> JobTimeJSON:
+    return {
+        'commit': subprocess.check_output(['git', 'rev-parse', 'HEAD'], encoding="ascii").strip(),
+        'JOB_BASE_NAME': _get_stripped_CI_job(),
+        'job_times': job_times,
+    }
+
+
+def _calculate_job_times(reports: List["Report"]) -> Dict[str, float]:
+    """Compute test runtime by filename: ("test_file_name" -> (current_avg, # values))
+    """
+    jobs_to_times: Dict[str, Tuple[float, int]] = dict()
+    for report in reports:
+        v_report = cast(Version2Report, report)
+        assert 'format_version' in v_report.keys() and v_report.get('format_version') == 2, \
+            "S3 format currently handled is version 2 only"
+        files: Dict[str, Any] = v_report['files']
+        for name, test_file in files.items():
+            if name not in jobs_to_times:
+                jobs_to_times[name] = (test_file['total_seconds'], 1)
+            else:
+                curr_avg, curr_count = jobs_to_times[name]
+                new_count = curr_count + 1
+                new_avg = (curr_avg * curr_count + test_file['total_seconds']) / new_count
+                jobs_to_times[name] = (new_avg, new_count)
+
+    return {job: time for job, (time, _) in jobs_to_times.items()}
+
 
 def calculate_shards(num_shards: int, tests: List[str], job_times: Dict[str, float]) -> List[Tuple[float, List[str]]]:
     filtered_job_times: Dict[str, float] = dict()
@@ -25,3 +85,202 @@
         sharded_jobs[index][1].append(job)
         index = (index + 1) % num_shards
     return sharded_jobs
+
+
+def _pull_job_times_from_S3() -> Dict[str, float]:
+    if HAVE_BOTO3:
+        ci_job_prefix = _get_stripped_CI_job()
+        s3_reports: List["Report"] = get_previous_reports_for_branch('origin/viable/strict', ci_job_prefix)
+    else:
+        print('Uh oh, boto3 is not found. Either it is not installed or we failed to import s3_stat_parser.')
+        print('If not installed, please install boto3 for automatic sharding and test categorization.')
+        s3_reports = []
+
+    if len(s3_reports) == 0:
+        print('Gathered no reports from S3. Please proceed without them.')
+        return dict()
+
+    return _calculate_job_times(s3_reports)
+
+
+def _query_past_job_times(test_times_file: Optional[str] = None) -> Dict[str, float]:
+    """Read historic test job times from a file.
+
+    If the file doesn't exist or isn't matching current commit. It will download data from S3 and exported it.
+    """
+    if test_times_file and os.path.exists(test_times_file):
+        with open(test_times_file) as file:
+            test_times_json: JobTimeJSON = json.load(file)
+
+        curr_commit = subprocess.check_output(['git', 'rev-parse', 'HEAD'], encoding="ascii").strip()
+        file_commit = test_times_json.get('commit', '')
+        curr_ci_job = _get_stripped_CI_job()
+        file_ci_job = test_times_json.get('JOB_BASE_NAME', 'N/A')
+        if curr_commit != file_commit:
+            print(f'Current test times file is from different commit {file_commit}.')
+        elif curr_ci_job != file_ci_job:
+            print(f'Current test times file is for different CI job {file_ci_job}.')
+        else:
+            print(f'Found stats for current commit: {curr_commit} and job: {curr_ci_job}. Proceeding with those values.')
+            return test_times_json.get('job_times', {})
+
+        # Found file, but commit or CI job in JSON doesn't match
+        print(f'Overwriting current file with stats based on current commit: {curr_commit} and CI job: {curr_ci_job}')
+
+    job_times = export_S3_test_times(test_times_file)
+
+    return job_times
+
+
+def _query_failure_test_module(reports: List[Tuple["Report", str]]) -> List[str]:
+    test_modules: List[str] = []
+    if len(reports) == 0 or len(reports[0]) == 0:
+        return test_modules
+    report = reports[0][0]
+    v_report = cast(Version2Report, report)
+    assert 'format_version' in v_report.keys() and v_report.get('format_version') == 2, \
+        "S3 format currently handled is version 2 only"
+    files: Dict[str, Any] = v_report['files']
+    for fname, file in files.items():
+        contains_failure = any(
+            any(case['status'] == 'errored' or case['status'] == 'failed'
+                for _, case in suite['cases'].items())
+            for _, suite in file['suites'].items())
+        if contains_failure:
+            test_modules.append(fname)
+    return test_modules
+
+
+def _query_changed_test_files() -> List[str]:
+    cmd = ["git", "diff", "--name-only", "origin/master", "HEAD"]
+    proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+
+    if proc.returncode != 0:
+        raise RuntimeError("Unable to get changed files")
+
+    lines = proc.stdout.decode().strip().split("\n")
+    lines = [line.strip() for line in lines]
+    return lines
+
+
+def get_shard_based_on_S3(which_shard: int, num_shards: int, tests: List[str], test_times_file: str) -> List[str]:
+    """Get sharded test allocation based on historic S3 data.
+    """
+    jobs_to_times = _query_past_job_times(test_times_file)
+
+    # Got no stats from S3, returning early to save runtime
+    if len(jobs_to_times) == 0:
+        print('Gathered no stats from S3. Proceeding with default sharding plan.')
+        return tests[which_shard - 1 :: num_shards]
+
+    shards = calculate_shards(num_shards, tests, jobs_to_times)
+    _, tests_from_shard = shards[which_shard - 1]
+    return tests_from_shard
+
+
+def get_slow_tests_based_on_S3(test_list: List[str], td_list: List[str], slow_test_threshold: int) -> List[str]:
+    """Get list of slow tests based on historic S3 data.
+    """
+    jobs_to_times: Dict[str, float] = _query_past_job_times()
+
+    # Got no stats from S3, returning early to save runtime
+    if len(jobs_to_times) == 0:
+        print('Gathered no stats from S3. No new slow tests calculated.')
+        return []
+
+    slow_tests: List[str] = []
+    for test in test_list:
+        if test in jobs_to_times and test not in td_list:
+            if jobs_to_times[test] > slow_test_threshold:
+                slow_tests.append(test)
+    return slow_tests
+
+
+def get_specified_test_cases(filename: str, tests: List[str]) -> Dict[str, List[str]]:
+    """Get test cases from a specified test case file. Usually exported manually or through CI system.
+    """
+    if not os.path.exists(filename):
+        print(f'Could not find specified tests file: {filename}. Proceeding with default behavior.')
+        return dict()
+
+    # The below encoding is utf-8-sig because utf-8 doesn't properly handle the byte-order-mark character
+    with open(filename, mode='r', encoding="utf-8-sig") as csv_file:
+        csv_reader = csv.DictReader(csv_file)
+        line_count = 0
+        specified_test_case_dict: Dict[str, List[str]] = dict()
+        for row in csv_reader:
+            line_count += 1
+            if line_count == 1:
+                if 'test_filename' not in row or 'test_case_name' not in row:
+                    print('Data is missing necessary columns for test specification. Proceeding with default behavior.')
+                    return dict()
+            test_filename = row['test_filename']
+            test_case_name = row['test_case_name']
+            if test_filename not in tests:
+                print(f'Specified test_filename {test_filename} not found in TESTS. Skipping.')
+                continue
+            if test_filename not in specified_test_case_dict:
+                specified_test_case_dict[test_filename] = []
+            specified_test_case_dict[test_filename].append(test_case_name)
+        print(f'Processed {line_count} test cases.')
+        return specified_test_case_dict
+
+
+def get_reordered_tests(tests: List[str], is_reordering_by_pr: bool) -> List[str]:
+    """Get the reordered test filename list based on github PR history or git changed file.
+    """
+    prioritized_tests = []
+    # Try using historic stats from PR.
+    if is_reordering_by_pr and HAVE_BOTO3:
+        pr_number = os.environ.get("CIRCLE_PR_NUMBER", "")
+        if len(pr_number):
+            ci_job_prefix = _get_stripped_CI_job()
+            s3_reports: List[Tuple["Report", str]] = get_previous_reports_for_pr(
+                pr_number, ci_job_prefix)
+            prioritized_tests = _query_failure_test_module(s3_reports)
+            print("Prioritized test from previous CI info.")
+
+    # Using file changes priority if no stats found from previous PR.
+    if len(prioritized_tests) == 0:
+        try:
+            changed_files = _query_changed_test_files()
+        except Exception:
+            # If unable to get changed files from git, quit without doing any sorting
+            return tests
+
+        prefix = f"test{os.path.sep}"
+        prioritized_tests = [f for f in changed_files if f.startswith(prefix) and f.endswith(".py")]
+        prioritized_tests = [f[len(prefix):] for f in prioritized_tests]
+        prioritized_tests = [f[:-len(".py")] for f in prioritized_tests]
+        print("Prioritized test from test file changes.")
+
+    bring_to_front = []
+    the_rest = []
+
+    for test in tests:
+        if test in prioritized_tests:
+            bring_to_front.append(test)
+        else:
+            the_rest.append(test)
+    if len(tests) == len(bring_to_front) + len(the_rest):
+        print(f"reordering tests for PR:\n"
+              f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n")
+        return bring_to_front + the_rest
+    else:
+        print(f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n"
+              f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n")
+        return tests
+
+
+# TODO Refactor this and unify with tools.stats.export_slow_tests
+def export_S3_test_times(test_times_filename: Optional[str] = None) -> Dict[str, float]:
+    test_times: Dict[str, float] = _pull_job_times_from_S3()
+    if test_times_filename is not None:
+        print(f'Exporting S3 test stats to {test_times_filename}.')
+        if os.path.exists(test_times_filename):
+            print(f'Overwriting existent file: {test_times_filename}')
+        with open(test_times_filename, 'w+') as file:
+            job_times_json = _get_job_times_json(test_times)
+            json.dump(job_times_json, file, indent='    ', separators=(',', ': '))
+            file.write('\n')
+    return test_times