| import sys |
| import os |
| import re |
| import inspect |
| import argparse |
| import unittest |
| import warnings |
| import random |
| import contextlib |
| from functools import wraps |
| from itertools import product |
| from copy import deepcopy |
| from numbers import Number |
| |
| import __main__ |
| import errno |
| |
| import torch |
| import torch.cuda |
| from torch._six import string_classes |
| import torch.backends.cudnn |
| import torch.backends.mkl |
| |
| |
| torch.set_default_tensor_type('torch.DoubleTensor') |
| torch.backends.cudnn.disable_global_flags() |
| |
| |
| parser = argparse.ArgumentParser(add_help=False) |
| parser.add_argument('--seed', type=int, default=1234) |
| parser.add_argument('--accept', action='store_true') |
| args, remaining = parser.parse_known_args() |
| SEED = args.seed |
| ACCEPT = args.accept |
| UNITTEST_ARGS = [sys.argv[0]] + remaining |
| torch.manual_seed(SEED) |
| |
| |
| def run_tests(): |
| unittest.main(argv=UNITTEST_ARGS) |
| |
| PY3 = sys.version_info > (3, 0) |
| PY34 = sys.version_info >= (3, 4) |
| |
| IS_WINDOWS = sys.platform == "win32" |
| |
| TEST_NUMPY = True |
| try: |
| import numpy |
| except ImportError: |
| TEST_NUMPY = False |
| |
| TEST_SCIPY = True |
| try: |
| import scipy |
| except ImportError: |
| TEST_SCIPY = False |
| |
| TEST_MKL = torch.backends.mkl.is_available() |
| |
| |
| def skipIfNoLapack(fn): |
| @wraps(fn) |
| def wrapper(*args, **kwargs): |
| try: |
| fn(*args, **kwargs) |
| except Exception as e: |
| if 'Lapack library not found' in e.args[0]: |
| raise unittest.SkipTest('Compiled without Lapack') |
| raise |
| return wrapper |
| |
| |
| def suppress_warnings(fn): |
| def wrapper(*args, **kwargs): |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| fn(*args, **kwargs) |
| return wrapper |
| |
| |
| def get_cpu_type(type_name): |
| module, name = type_name.rsplit('.', 1) |
| assert module == 'torch.cuda' |
| return getattr(torch, name) |
| |
| |
| def get_gpu_type(type_name): |
| if isinstance(type_name, type): |
| type_name = '{}.{}'.format(type_name.__module__, type_name.__name__) |
| module, name = type_name.rsplit('.', 1) |
| assert module == 'torch' |
| return getattr(torch.cuda, name) |
| |
| |
| def to_gpu(obj, type_map={}): |
| if isinstance(obj, torch.Tensor): |
| assert obj.is_leaf |
| t = type_map.get(obj.type(), get_gpu_type(obj.type())) |
| with torch.no_grad(): |
| res = obj.clone().type(t) |
| res.requires_grad = obj.requires_grad |
| return res |
| elif torch.is_storage(obj): |
| return obj.new().resize_(obj.size()).copy_(obj) |
| elif isinstance(obj, list): |
| return [to_gpu(o, type_map) for o in obj] |
| elif isinstance(obj, tuple): |
| return tuple(to_gpu(o, type_map) for o in obj) |
| else: |
| return deepcopy(obj) |
| |
| |
| def get_function_arglist(func): |
| return inspect.getargspec(func).args |
| |
| |
| def set_rng_seed(seed): |
| torch.manual_seed(seed) |
| random.seed(seed) |
| if TEST_NUMPY: |
| numpy.random.seed(seed) |
| |
| |
| @contextlib.contextmanager |
| def freeze_rng_state(): |
| rng_state = torch.get_rng_state() |
| if torch.cuda.is_available(): |
| cuda_rng_state = torch.cuda.get_rng_state() |
| yield |
| if torch.cuda.is_available(): |
| torch.cuda.set_rng_state(cuda_rng_state) |
| torch.set_rng_state(rng_state) |
| |
| |
| def iter_indices(tensor): |
| if tensor.dim() == 0: |
| return range(0) |
| if tensor.dim() == 1: |
| return range(tensor.size(0)) |
| return product(*(range(s) for s in tensor.size())) |
| |
| |
| def is_iterable(obj): |
| try: |
| iter(obj) |
| return True |
| except TypeError: |
| return False |
| |
| |
| class TestCase(unittest.TestCase): |
| precision = 1e-5 |
| maxDiff = None |
| |
| def setUp(self): |
| set_rng_seed(SEED) |
| |
| def assertTensorsSlowEqual(self, x, y, prec=None, message=''): |
| max_err = 0 |
| self.assertEqual(x.size(), y.size()) |
| for index in iter_indices(x): |
| max_err = max(max_err, abs(x[index] - y[index])) |
| self.assertLessEqual(max_err, prec, message) |
| |
| def safeToDense(self, t): |
| r = self.safeCoalesce(t) |
| return r.to_dense() |
| |
| def safeCoalesce(self, t): |
| tc = t.coalesce() |
| self.assertEqual(tc.to_dense(), t.to_dense()) |
| self.assertTrue(tc.is_coalesced()) |
| |
| # Our code below doesn't work when nnz is 0, because |
| # then it's a 0D tensor, not a 2D tensor. |
| if t._nnz() == 0: |
| self.assertEqual(t._indices(), tc._indices()) |
| self.assertEqual(t._values(), tc._values()) |
| return tc |
| |
| value_map = {} |
| for idx, val in zip(t._indices().t(), t._values()): |
| idx_tup = tuple(idx.tolist()) |
| if idx_tup in value_map: |
| value_map[idx_tup] += val |
| else: |
| value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val |
| |
| new_indices = sorted(list(value_map.keys())) |
| new_values = [value_map[idx] for idx in new_indices] |
| if t._values().ndimension() < 2: |
| new_values = t._values().new(new_values) |
| else: |
| new_values = torch.stack(new_values) |
| |
| new_indices = t._indices().new(new_indices).t() |
| tg = t.new(new_indices, new_values, t.size()) |
| |
| self.assertEqual(tc._indices(), tg._indices()) |
| self.assertEqual(tc._values(), tg._values()) |
| |
| if t.is_coalesced(): |
| self.assertEqual(tc._indices(), t._indices()) |
| self.assertEqual(tc._values(), t._values()) |
| |
| return tg |
| |
| def assertEqual(self, x, y, prec=None, message='', allow_inf=False): |
| if isinstance(prec, str) and message == '': |
| message = prec |
| prec = None |
| if prec is None: |
| prec = self.precision |
| |
| if isinstance(x, torch.Tensor) and isinstance(y, Number): |
| self.assertEqual(x.item(), y, prec, message, allow_inf) |
| elif isinstance(y, torch.Tensor) and isinstance(x, Number): |
| self.assertEqual(x, y.item(), prec, message, allow_inf) |
| elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): |
| def assertTensorsEqual(a, b): |
| super(TestCase, self).assertEqual(a.size(), b.size(), message) |
| if a.numel() > 0: |
| b = b.type_as(a) |
| b = b.cuda(device=a.get_device()) if a.is_cuda else b.cpu() |
| # check that NaNs are in the same locations |
| nan_mask = a != a |
| self.assertTrue(torch.equal(nan_mask, b != b), message) |
| diff = a - b |
| diff[nan_mask] = 0 |
| # TODO: implement abs on CharTensor |
| if diff.is_signed() and 'CharTensor' not in diff.type(): |
| diff = diff.abs() |
| max_err = diff.max() |
| self.assertLessEqual(max_err, prec, message) |
| super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message) |
| if x.is_sparse: |
| x = self.safeCoalesce(x) |
| y = self.safeCoalesce(y) |
| assertTensorsEqual(x._indices(), y._indices()) |
| assertTensorsEqual(x._values(), y._values()) |
| else: |
| assertTensorsEqual(x, y) |
| elif isinstance(x, string_classes) and isinstance(y, string_classes): |
| super(TestCase, self).assertEqual(x, y, message) |
| elif type(x) == set and type(y) == set: |
| super(TestCase, self).assertEqual(x, y, message) |
| elif is_iterable(x) and is_iterable(y): |
| super(TestCase, self).assertEqual(len(x), len(y), message) |
| for x_, y_ in zip(x, y): |
| self.assertEqual(x_, y_, prec, message) |
| elif isinstance(x, bool) and isinstance(y, bool): |
| super(TestCase, self).assertEqual(x, y, message) |
| elif isinstance(x, Number) and isinstance(y, Number): |
| if abs(x) == float('inf') or abs(y) == float('inf'): |
| if allow_inf: |
| super(TestCase, self).assertEqual(x, y, message) |
| else: |
| self.fail("Expected finite numeric values - x={}, y={}".format(x, y)) |
| return |
| super(TestCase, self).assertLessEqual(abs(x - y), prec, message) |
| else: |
| super(TestCase, self).assertEqual(x, y, message) |
| |
| def assertAlmostEqual(self, x, y, places=None, msg=None, delta=None, allow_inf=None): |
| prec = delta |
| if places: |
| prec = 10**(-places) |
| self.assertEqual(x, y, prec, msg, allow_inf) |
| |
| def assertNotEqual(self, x, y, prec=None, message=''): |
| if isinstance(prec, str) and message == '': |
| message = prec |
| prec = None |
| if prec is None: |
| prec = self.precision |
| |
| if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): |
| if x.size() != y.size(): |
| super(TestCase, self).assertNotEqual(x.size(), y.size()) |
| self.assertGreater(x.numel(), 0) |
| y = y.type_as(x) |
| y = y.cuda(device=x.get_device()) if x.is_cuda else y.cpu() |
| nan_mask = x != x |
| if torch.equal(nan_mask, y != y): |
| diff = x - y |
| if diff.is_signed(): |
| diff = diff.abs() |
| diff[nan_mask] = 0 |
| max_err = diff.max() |
| self.assertGreaterEqual(max_err, prec, message) |
| elif type(x) == str and type(y) == str: |
| super(TestCase, self).assertNotEqual(x, y) |
| elif is_iterable(x) and is_iterable(y): |
| super(TestCase, self).assertNotEqual(x, y) |
| else: |
| try: |
| self.assertGreaterEqual(abs(x - y), prec, message) |
| return |
| except (TypeError, AssertionError): |
| pass |
| super(TestCase, self).assertNotEqual(x, y, message) |
| |
| def assertObjectIn(self, obj, iterable): |
| for elem in iterable: |
| if id(obj) == id(elem): |
| return |
| raise AssertionError("object not found in iterable") |
| |
| # TODO: Support context manager interface |
| # NB: The kwargs forwarding to callable robs the 'subname' parameter. |
| # If you need it, manually apply your callable in a lambda instead. |
| def assertExpectedRaises(self, exc_type, callable, *args, **kwargs): |
| subname = None |
| if 'subname' in kwargs: |
| subname = kwargs['subname'] |
| del kwargs['subname'] |
| try: |
| callable(*args, **kwargs) |
| except exc_type as e: |
| self.assertExpected(str(e), subname) |
| return |
| # Don't put this in the try block; the AssertionError will catch it |
| self.fail(msg="Did not raise when expected to") |
| |
| def assertWarns(self, callable, msg=''): |
| r""" |
| Test if :attr:`callable` raises a warning. |
| """ |
| with warnings.catch_warnings(record=True) as ws: |
| warnings.simplefilter("always") # allow any warning to be raised |
| callable() |
| self.assertTrue(len(ws) > 0, msg) |
| |
| def assertWarnsRegex(self, callable, regex, msg=''): |
| r""" |
| Test if :attr:`callable` raises any warning with message that contains |
| the regex pattern :attr:`regex`. |
| """ |
| with warnings.catch_warnings(record=True) as ws: |
| warnings.simplefilter("always") # allow any warning to be raised |
| callable() |
| self.assertTrue(len(ws) > 0, msg) |
| found = any(re.search(regex, str(w.message)) is not None for w in ws) |
| self.assertTrue(found, msg) |
| |
| def assertExpected(self, s, subname=None): |
| r""" |
| Test that a string matches the recorded contents of a file |
| derived from the name of this test and subname. This file |
| is placed in the 'expect' directory in the same directory |
| as the test script. You can automatically update the recorded test |
| output using --accept. |
| |
| If you call this multiple times in a single function, you must |
| give a unique subname each time. |
| """ |
| if not (isinstance(s, str) or (sys.version_info[0] == 2 and isinstance(s, unicode))): |
| raise TypeError("assertExpected is strings only") |
| |
| def remove_prefix(text, prefix): |
| if text.startswith(prefix): |
| return text[len(prefix):] |
| return text |
| # NB: we take __file__ from the module that defined the test |
| # class, so we place the expect directory where the test script |
| # lives, NOT where test/common.py lives. This doesn't matter in |
| # PyTorch where all test scripts are in the same directory as |
| # test/common.py, but it matters in onnx-pytorch |
| module_id = self.__class__.__module__ |
| munged_id = remove_prefix(self.id(), module_id + ".") |
| test_file = os.path.realpath(sys.modules[module_id].__file__) |
| expected_file = os.path.join(os.path.dirname(test_file), |
| "expect", |
| munged_id) |
| if subname: |
| expected_file += "-" + subname |
| expected_file += ".expect" |
| expected = None |
| |
| def accept_output(update_type): |
| print("Accepting {} for {}:\n\n{}".format(update_type, munged_id, s)) |
| with open(expected_file, 'w') as f: |
| f.write(s) |
| |
| try: |
| with open(expected_file) as f: |
| expected = f.read() |
| except IOError as e: |
| if e.errno != errno.ENOENT: |
| raise |
| elif ACCEPT: |
| return accept_output("output") |
| else: |
| raise RuntimeError( |
| ("I got this output for {}:\n\n{}\n\n" |
| "No expect file exists; to accept the current output, run:\n" |
| "python {} {} --accept").format(munged_id, s, __main__.__file__, munged_id)) |
| |
| # a hack for JIT tests |
| if IS_WINDOWS: |
| expected = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', expected) |
| s = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', s) |
| |
| if ACCEPT: |
| if expected != s: |
| return accept_output("updated output") |
| else: |
| if hasattr(self, "assertMultiLineEqual"): |
| # Python 2.7 only |
| # NB: Python considers lhs "old" and rhs "new". |
| self.assertMultiLineEqual(expected, s) |
| else: |
| self.assertEqual(s, expected) |
| |
| if sys.version_info < (3, 2): |
| # assertRegexpMatches renamed to assertRegex in 3.2 |
| assertRegex = unittest.TestCase.assertRegexpMatches |
| # assertRaisesRegexp renamed to assertRaisesRegex in 3.2 |
| assertRaisesRegex = unittest.TestCase.assertRaisesRegexp |
| |
| |
| def download_file(url, binary=True): |
| if sys.version_info < (3,): |
| from urlparse import urlsplit |
| import urllib2 |
| request = urllib2 |
| error = urllib2 |
| else: |
| from urllib.parse import urlsplit |
| from urllib import request, error |
| |
| filename = os.path.basename(urlsplit(url)[2]) |
| data_dir = os.path.join(os.path.dirname(__file__), 'data') |
| path = os.path.join(data_dir, filename) |
| |
| if os.path.exists(path): |
| return path |
| try: |
| data = request.urlopen(url, timeout=15).read() |
| with open(path, 'wb' if binary else 'w') as f: |
| f.write(data) |
| return path |
| except error.URLError: |
| msg = "could not download test file '{}'".format(url) |
| warnings.warn(msg, RuntimeWarning) |
| raise unittest.SkipTest(msg) |