|  | """ | 
|  | The testing package contains testing-specific utilities. | 
|  | """ | 
|  |  | 
|  | import torch | 
|  | import random | 
|  |  | 
|  | __all__ = [ | 
|  | 'assert_allclose', 'make_non_contiguous', 'rand_like', 'randn_like' | 
|  | ] | 
|  |  | 
|  | rand_like = torch.rand_like | 
|  | randn_like = torch.randn_like | 
|  |  | 
|  |  | 
|  | def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True): | 
|  | if not isinstance(actual, torch.Tensor): | 
|  | actual = torch.tensor(actual) | 
|  | if not isinstance(expected, torch.Tensor): | 
|  | expected = torch.tensor(expected, dtype=actual.dtype) | 
|  | if expected.shape != actual.shape: | 
|  | expected = expected.expand_as(actual) | 
|  | if rtol is None or atol is None: | 
|  | if rtol is not None or atol is not None: | 
|  | raise ValueError("rtol and atol must both be specified or both be unspecified") | 
|  | rtol, atol = _get_default_tolerance(actual, expected) | 
|  |  | 
|  | close = torch.isclose(actual, expected, rtol, atol, equal_nan) | 
|  | if close.all(): | 
|  | return | 
|  |  | 
|  | # Find the worst offender | 
|  | error = (expected - actual).abs() | 
|  | expected_error = atol + rtol * expected.abs() | 
|  | delta = error - expected_error | 
|  | delta[close] = 0  # mask out NaN/inf | 
|  | _, index = delta.reshape(-1).max(0) | 
|  |  | 
|  | # TODO: consider adding torch.unravel_index | 
|  | def _unravel_index(index, shape): | 
|  | res = [] | 
|  | for size in shape[::-1]: | 
|  | res.append(int(index % size)) | 
|  | index = int(index // size) | 
|  | return tuple(res[::-1]) | 
|  |  | 
|  | index = _unravel_index(index.item(), actual.shape) | 
|  |  | 
|  | # Count number of offenders | 
|  | count = (~close).long().sum() | 
|  |  | 
|  | msg = ('Not within tolerance rtol={} atol={} at input{} ({} vs. {}) and {}' | 
|  | ' other locations ({:2.2f}%)') | 
|  |  | 
|  | raise AssertionError(msg.format( | 
|  | rtol, atol, list(index), actual[index].item(), expected[index].item(), | 
|  | count - 1, 100 * count / actual.numel())) | 
|  |  | 
|  |  | 
|  | def make_non_contiguous(tensor): | 
|  | if tensor.numel() <= 1:  # can't make non-contiguous | 
|  | return tensor.clone() | 
|  | osize = list(tensor.size()) | 
|  |  | 
|  | # randomly inflate a few dimensions in osize | 
|  | for _ in range(2): | 
|  | dim = random.randint(0, len(osize) - 1) | 
|  | add = random.randint(4, 15) | 
|  | osize[dim] = osize[dim] + add | 
|  |  | 
|  | # narrow doesn't make a non-contiguous tensor if we only narrow the 0-th dimension, | 
|  | # (which will always happen with a 1-dimensional tensor), so let's make a new | 
|  | # right-most dimension and cut it off | 
|  |  | 
|  | input = tensor.new(torch.Size(osize + [random.randint(2, 3)])) | 
|  | input = input.select(len(input.size()) - 1, random.randint(0, 1)) | 
|  | # now extract the input of correct size from 'input' | 
|  | for i in range(len(osize)): | 
|  | if input.size(i) != tensor.size(i): | 
|  | bounds = random.randint(1, input.size(i) - tensor.size(i)) | 
|  | input = input.narrow(i, bounds, tensor.size(i)) | 
|  |  | 
|  | input.copy_(tensor) | 
|  | return input | 
|  |  | 
|  |  | 
|  | def get_all_dtypes(): | 
|  | return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, | 
|  | torch.float16, torch.float32, torch.float64] | 
|  |  | 
|  |  | 
|  | # 'dtype': (rtol, atol) | 
|  | _default_tolerances = { | 
|  | 'float64': (1e-5, 1e-8),  # NumPy default | 
|  | 'float32': (1e-4, 1e-5),  # This may need to be changed | 
|  | 'float16': (1e-3, 1e-3),  # This may need to be changed | 
|  | } | 
|  |  | 
|  |  | 
|  | def _get_default_tolerance(a, b=None): | 
|  | if b is None: | 
|  | dtype = str(a.dtype).split('.')[-1]  # e.g. "float32" | 
|  | return _default_tolerances.get(dtype, (0, 0)) | 
|  | a_tol = _get_default_tolerance(a) | 
|  | b_tol = _get_default_tolerance(b) | 
|  | return (max(a_tol[0], b_tol[0]), max(a_tol[1], b_tol[1])) |