blob: 795997d4e9d8844530d1a5bb8948ac3ae4c68a3d [file] [log] [blame]
#!/usr/bin/env python
from __future__ import print_function
import argparse
import os
import shlex
import shutil
import signal
import subprocess
import sys
import tempfile
import torch
from torch.utils import cpp_extension
TESTS = [
'autograd',
'cpp_extensions',
'cuda',
'dataloader',
'distributed',
'distributions',
'indexing',
'jit',
'legacy_nn',
'multiprocessing',
'nccl',
'nn',
'optim',
'sparse',
'torch',
'utils',
]
WINDOWS_BLACKLIST = [
'distributed',
]
DISTRIBUTED_TESTS_CONFIG = {
'tcp': {
'WORLD_SIZE': '3'
},
'gloo': {
'WORLD_SIZE': '2' if torch.cuda.device_count() == 2 else '3'
},
'nccl': {
'WORLD_SIZE': '2'
},
'mpi': {},
}
# https://stackoverflow.com/questions/2549939/get-signal-names-from-numbers-in-python
SIGNALS_TO_NAMES_DICT = dict((getattr(signal, n), n) for n in dir(signal)
if n.startswith('SIG') and '_' not in n)
def print_to_stderr(message):
print(message, file=sys.stderr)
def shell(command, cwd):
sys.stdout.flush()
sys.stderr.flush()
return subprocess.call(
shlex.split(command), universal_newlines=True, cwd=cwd)
def get_shell_output(command):
return subprocess.check_output(shlex.split(command)).decode().strip()
def run_test(python, test_module, test_directory, options):
verbose = '--verbose' if options.verbose else ''
return shell('{} {} {}'.format(python, test_module, verbose),
test_directory)
def test_cpp_extensions(python, test_module, test_directory, options):
try:
cpp_extension.verify_ninja_availability()
except RuntimeError:
print(
'Ninja is not available. Skipping C++ extensions test. '
"Install ninja with 'pip install ninja' or 'conda install ninja'.")
return 0
return_code = shell('{} setup.py install --root ./install'.format(python),
os.path.join(test_directory, 'cpp_extensions'))
if return_code != 0:
return return_code
python_path = os.environ.get('PYTHONPATH', '')
try:
cpp_extensions = os.path.join(test_directory, 'cpp_extensions')
if sys.platform == 'win32':
install_directory = os.path.join(cpp_extensions, 'install')
install_directories = get_shell_output(
'where -r "{}" *.pyd'.format(install_directory)).split('\r\n')
assert install_directories, 'install_directory must not be empty'
if len(install_directories) >= 1:
install_directory = install_directories[0]
install_directory = os.path.dirname(install_directory)
split_char = ';'
else:
install_directory = get_shell_output(
"find {}/install -name *-packages".format(cpp_extensions))
split_char = ':'
assert install_directory, 'install_directory must not be empty'
install_directory = os.path.join(test_directory, install_directory)
os.environ['PYTHONPATH'] = '{}{}{}'.format(install_directory,
split_char, python_path)
return run_test(python, test_module, test_directory, options)
finally:
os.environ['PYTHONPATH'] = python_path
def test_distributed(python, test_module, test_directory, options):
mpi_available = subprocess.call('command -v mpiexec', shell=True) == 0
if options.verbose and not mpi_available:
print_to_stderr(
'MPI not available -- MPI backend tests will be skipped')
for backend, env_vars in DISTRIBUTED_TESTS_CONFIG.items():
if backend == 'mpi' and not mpi_available:
continue
for with_init_file in {True, False}:
tmp_dir = tempfile.mkdtemp()
if options.verbose:
with_init = ' with file init_method' if with_init_file else ''
print_to_stderr(
'Running distributed tests for the {} backend{}'.format(
backend, with_init))
os.environ['TEMP_DIR'] = tmp_dir
os.environ['BACKEND'] = backend
os.environ['INIT_METHOD'] = 'env://'
os.environ.update(env_vars)
if with_init_file:
init_method = 'file://{}/shared_init_file'.format(tmp_dir)
os.environ['INIT_METHOD'] = init_method
try:
os.mkdir(os.path.join(tmp_dir, 'barrier'))
os.mkdir(os.path.join(tmp_dir, 'test_dir'))
if backend == 'mpi':
mpiexec = 'mpiexec -n 3 --noprefix {}'.format(python)
return_code = run_test(mpiexec, test_module,
test_directory, options)
else:
return_code = run_test(python, test_module, test_directory,
options)
if return_code != 0:
return return_code
finally:
shutil.rmtree(tmp_dir)
return 0
CUSTOM_HANDLERS = {
'cpp_extensions': test_cpp_extensions,
'distributed': test_distributed,
}
def parse_args():
parser = argparse.ArgumentParser(
description='Run the PyTorch unit test suite',
epilog='where TESTS is any of: {}'.format(', '.join(TESTS)))
parser.add_argument(
'-v',
'--verbose',
action='store_true',
help='print verbose information and test-by-test results')
parser.add_argument(
'-p', '--python', help='the python interpreter to execute tests with')
parser.add_argument(
'-c', '--coverage', action='store_true', help='enable coverage')
parser.add_argument(
'-i',
'--include',
nargs='+',
choices=TESTS,
default=TESTS,
metavar='TESTS',
help='select a set of tests to include (defaults to ALL tests)')
parser.add_argument(
'-x',
'--exclude',
nargs='+',
choices=TESTS,
metavar='TESTS',
default=[],
help='select a set of tests to exclude')
parser.add_argument(
'-f',
'--first',
choices=TESTS,
metavar='TESTS',
help='select the test to start from (excludes previous tests)')
parser.add_argument(
'-l',
'--last',
choices=TESTS,
metavar='TESTS',
help='select the last test to run (excludes following tests)')
parser.add_argument(
'--ignore-win-blacklist',
action='store_true',
help='always run blacklisted windows tests')
return parser.parse_args()
def get_python_command(options):
if options.coverage:
return 'coverage run --parallel-mode --source torch'
elif options.python:
return options.python
else:
return os.environ.get('PYCMD', 'python')
def get_selected_tests(options):
selected_tests = options.include
for test in options.exclude:
if test in selected_tests:
selected_tests.remove(test)
if options.first:
first_index = selected_tests.index(options.first)
selected_tests = selected_tests[first_index:]
if options.last:
last_index = selected_tests.index(options.last)
selected_tests = selected_tests[:last_index + 1]
if sys.platform == 'win32' and not options.ignore_win_blacklist:
ostype = os.environ.get('MSYSTEM')
target_arch = os.environ.get('VSCMD_ARG_TGT_ARCH')
if ostype != 'MINGW64' or target_arch != 'x64':
WINDOWS_BLACKLIST.append('cpp_extensions')
for test in WINDOWS_BLACKLIST:
if test in selected_tests:
print_to_stderr('Excluding {} on Windows'.format(test))
selected_tests.remove(test)
return selected_tests
def main():
options = parse_args()
python = get_python_command(options)
test_directory = os.path.dirname(os.path.abspath(__file__))
selected_tests = get_selected_tests(options)
if options.verbose:
print_to_stderr('Selected tests: {}'.format(', '.join(selected_tests)))
if options.coverage:
shell('coverage erase')
for test in selected_tests:
test_module = 'test_{}.py'.format(test)
print_to_stderr('Running {} ...'.format(test_module))
handler = CUSTOM_HANDLERS.get(test, run_test)
return_code = handler(python, test_module, test_directory, options)
assert isinstance(return_code, int) and not isinstance(
return_code, bool), 'Return code should be an integer'
if return_code != 0:
message = '{} failed!'.format(test_module)
if return_code < 0:
# subprocess.Popen returns the child process' exit signal as
# return code -N, where N is the signal number.
signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
message += ' Received signal: {}'.format(signal_name)
raise RuntimeError(message)
if options.coverage:
shell('coverage combine')
shell('coverage html')
if __name__ == '__main__':
main()