| import torch |
| import numpy as np |
| |
| import itertools |
| from itertools import product |
| import math |
| import random |
| import unittest |
| import warnings |
| import operator |
| |
| from torch._six import inf, nan |
| from torch.testing._internal.common_utils import ( |
| TestCase, iter_indices, TEST_WITH_ASAN, run_tests, |
| torch_to_numpy_dtype_dict, make_tensor) |
| from torch.testing._internal.common_device_type import ( |
| instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA, |
| dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyOnCPUAndCUDA, |
| skipCUDAIfRocm) |
| |
| # TODO: remove this |
| def _generate_input(shape, dtype, device, with_extremal): |
| if shape == (): |
| x = torch.tensor((), dtype=dtype, device=device) |
| else: |
| if dtype.is_floating_point or dtype.is_complex: |
| # work around torch.randn not being implemented for bfloat16 |
| if dtype == torch.bfloat16: |
| x = torch.randn(*shape, device=device) * random.randint(30, 100) |
| x = x.to(torch.bfloat16) |
| else: |
| x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) |
| x[torch.randn(*shape) > 0.5] = 0 |
| if with_extremal and dtype.is_floating_point: |
| # Use extremal values |
| x[torch.randn(*shape) > 0.5] = float('nan') |
| x[torch.randn(*shape) > 0.5] = float('inf') |
| x[torch.randn(*shape) > 0.5] = float('-inf') |
| elif with_extremal and dtype.is_complex: |
| x[torch.randn(*shape) > 0.5] = complex('nan') |
| x[torch.randn(*shape) > 0.5] = complex('inf') |
| x[torch.randn(*shape) > 0.5] = complex('-inf') |
| elif dtype == torch.bool: |
| x = torch.zeros(shape, dtype=dtype, device=device) |
| x[torch.randn(*shape) > 0.5] = True |
| else: |
| x = torch.randint(15, 100, shape, dtype=dtype, device=device) |
| |
| return x |
| |
| # TODO: refactor this out |
| # Converts half/bfloat16 dtype to float when device is cpu |
| def _convert_t(dtype, device): |
| if device == 'cpu' and dtype in {torch.half, torch.bfloat16}: |
| return torch.float |
| return dtype |
| |
| # TODO: revise the tests to use make_tensor in common_utils.py instead |
| # Returns a tensor of the requested shape, dtype, and device |
| # Requesting a half CPU tensor returns a float CPU tensor with |
| # values representable by a half. |
| # Initialization uses randint for non-float types and randn for float types. |
| def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor: |
| # Returns a tensor filled with ones |
| if fill_ones: |
| return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device) |
| |
| # Returns a tensor with random integer values |
| if not (dtype.is_floating_point or dtype.is_complex): |
| t = torch.randint(0, 10, shape, device=device) |
| if dtype != torch.uint8: |
| t = t - 5 # generate negative values also |
| return t.to(_convert_t(dtype, device)) |
| |
| # Populates the CPU tensor with floats representable as half/bfloat16 |
| if dtype == torch.half and device == 'cpu': |
| return torch.randn(*shape, dtype=torch.float, device=device).half().float() |
| if dtype == torch.bfloat16 and device == 'cpu': |
| return torch.randn(*shape, dtype=torch.float, device=device).bfloat16().float() |
| |
| # Default: returns a tensor with random float values |
| return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype) |
| |
| # TODO: update to use opinfos consistently |
| class TestBinaryUfuncs(TestCase): |
| |
| def test_add_broadcast_empty(self, device): |
| # empty + empty |
| self.assertRaises(RuntimeError, lambda: torch.randn(5, 0, device=device) + torch.randn(0, 5, device=device)) |
| self.assertEqual(torch.randn(5, 0, device=device), torch.randn(0, device=device) + torch.randn(5, 0, device=device)) |
| self.assertEqual(torch.randn(5, 0, 0, device=device), torch.randn(0, device=device) + torch.randn(5, 0, 1, device=device)) |
| |
| # scalar + empty |
| self.assertEqual(torch.randn(5, 0, 6, device=device), torch.randn((), device=device) + torch.randn(5, 0, 6, device=device)) |
| |
| # non-empty, empty |
| self.assertEqual(torch.randn(0, device=device), torch.randn(0, device=device) + torch.randn(1, device=device)) |
| self.assertEqual(torch.randn(0, 7, 0, 6, 5, 0, 7, device=device), |
| torch.randn(0, 7, 0, 6, 5, 0, 1, device=device) + torch.randn(1, 1, 5, 1, 7, device=device)) |
| self.assertRaises(RuntimeError, lambda: torch.randn(7, 0, device=device) + torch.randn(2, 1, device=device)) |
| |
| def test_addcmul_scalars_as_floats(self, device): |
| # zero-dim variables that don't require grad should bind to scalar arguments |
| x = torch.tensor(2.) |
| y = torch.tensor(3., device=device) |
| # 3 + (3 * 3) * 2 |
| self.assertEqual(y.addcmul(y, y, value=x), 21) |
| |
| x = torch.tensor(2., requires_grad=True) |
| self.assertRaises(Exception, lambda: y.addcmul(y, y, value=x)) |
| |
| # TODO: update to work on CUDA, too |
| @onlyCPU |
| def test_comparison_ops(self, device): |
| x = torch.randn(5, 5) |
| y = torch.randn(5, 5) |
| |
| eq = x == y |
| for idx in iter_indices(x): |
| self.assertEqual(x[idx] == y[idx], eq[idx] == 1) |
| |
| ne = x != y |
| for idx in iter_indices(x): |
| self.assertEqual(x[idx] != y[idx], ne[idx] == 1) |
| |
| lt = x < y |
| for idx in iter_indices(x): |
| self.assertEqual(x[idx] < y[idx], lt[idx] == 1) |
| |
| le = x <= y |
| for idx in iter_indices(x): |
| self.assertEqual(x[idx] <= y[idx], le[idx] == 1) |
| |
| gt = x > y |
| for idx in iter_indices(x): |
| self.assertEqual(x[idx] > y[idx], gt[idx] == 1) |
| |
| ge = x >= y |
| for idx in iter_indices(x): |
| self.assertEqual(x[idx] >= y[idx], ge[idx] == 1) |
| |
| # TODO: update to work on CUDA, too |
| @onlyCPU |
| def test_comparison_ops_must_take_bool_output(self, device): |
| for op in [torch.lt, torch.le, torch.gt, torch.ge, torch.eq, torch.ne, |
| torch.logical_and, torch.logical_or, torch.logical_xor]: |
| self.assertEqual(op(torch.tensor([True]), torch.tensor([False])).dtype, torch.bool) |
| |
| # TODO: update to work on CUDA, too |
| @onlyCPU |
| def test_inplace_comparison_ops_require_inputs_have_same_dtype(self, device): |
| with self.assertRaisesRegex(RuntimeError, 'Expected object of scalar type'): |
| for op in ['lt_', 'le_', 'gt_', 'ge_', 'eq_', 'ne_', 'logical_xor_', 'logical_and_', 'logical_or_']: |
| x = torch.tensor([1], dtype=torch.int) |
| y = torch.tensor([2], dtype=torch.long) |
| in_place_method = getattr(x, op) |
| in_place_method(y) |
| |
| # TODO: update to work on CUDA, too |
| @onlyCPU |
| def test_comparison_ops_check_for_scalar_overflow(self, device): |
| s = 1 << 20 |
| t = torch.tensor([1 << 5], dtype=torch.uint8) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(t < s) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(s < t) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(t <= s) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(s <= t) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(t > s) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(s > t) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(t >= s) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(s >= t) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(t == s) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(s == t) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(t != s) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(s != t) |
| |
| # TODO: update to work on CUDA, too |
| @onlyCPU |
| def test_comparison_ops_check_for_zerodim_tensor_overflow(self, device): |
| t1 = torch.tensor([1 << 5], dtype=torch.uint8) |
| t2 = torch.tensor([1 << 30], dtype=torch.int32) |
| ts1 = torch.tensor(1 << 20, dtype=torch.int32) |
| ts2 = torch.tensor(1 << 40, dtype=torch.int64) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(t1 < ts1) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(ts2 < t2) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(t1 <= ts1) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(ts2 <= t2) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(t1 > ts1) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(ts2 > t2) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(t1 >= ts1) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(ts2 >= t2) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(t1 == ts1) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(ts2 == t2) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(t1 != ts1) |
| with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): |
| self.assertTrue(ts2 != t2) |
| |
| # TODO: update to work on CUDA, too |
| @onlyCPU |
| def test_bitwise_ops(self, device): |
| x = torch.randn(5, 5).gt(0) |
| y = torch.randn(5, 5).gt(0) |
| |
| and_result = x & y |
| for idx in iter_indices(x): |
| if and_result[idx]: |
| self.assertTrue(x[idx] and y[idx]) |
| else: |
| self.assertFalse(x[idx] and y[idx]) |
| |
| or_result = x | y |
| for idx in iter_indices(x): |
| if or_result[idx]: |
| self.assertTrue(x[idx] or y[idx]) |
| else: |
| self.assertFalse(x[idx] or y[idx]) |
| |
| xor_result = x ^ y |
| for idx in iter_indices(x): |
| if xor_result[idx]: |
| self.assertTrue(x[idx] ^ y[idx]) |
| else: |
| self.assertFalse(x[idx] ^ y[idx]) |
| |
| x_clone = x.clone() |
| x_clone &= y |
| self.assertEqual(x_clone, and_result) |
| |
| x_clone = x.clone() |
| x_clone |= y |
| self.assertEqual(x_clone, or_result) |
| |
| x_clone = x.clone() |
| x_clone ^= y |
| self.assertEqual(x_clone, xor_result) |
| |
| def test_inplace_division(self, device): |
| t = torch.rand(5, 5, device=device) |
| id_before = id(t) |
| t /= 2 |
| id_after = id(t) |
| self.assertEqual(id_before, id_after) |
| |
| # TODO: update to run on CUDA -- what is this test even testing? |
| @onlyCPU |
| def test_cast_binary_op(self, device): |
| # Scalar |
| a = torch.tensor(2) |
| b = torch.tensor(3) |
| a_copy = a.clone() |
| b_copy = b.clone() |
| |
| self.assertEqual(torch.tensor(6, dtype=torch.float), a.float() * b) |
| |
| self.assertEqualTypeString(a, a_copy) |
| self.assertEqualTypeString(b, b_copy) |
| |
| # Tests that trying to add, inplace, a CUDA tensor to a CPU tensor |
| # throws the correct error message |
| @onlyCUDA |
| def test_cross_device_inplace_error_msg(self, device): |
| a = torch.tensor(2.) |
| b = torch.tensor(2., device=device) |
| with self.assertRaisesRegex(RuntimeError, |
| "Expected all tensors to be on the same device"): |
| a += b |
| |
| # TODO: refactor this test into a more generic one, it's parked here currently |
| @onlyOnCPUAndCUDA |
| def test_out_resize_warning(self, device): |
| a = torch.tensor((1, 2, 3), device=device, dtype=torch.float32) |
| b = torch.tensor((4, 5, 6), device=device, dtype=torch.float32) |
| |
| unary_inputs = (a,) |
| binary_inputs = (a, b) |
| unary_ops = (torch.ceil, torch.exp) |
| binary_ops = (torch.add, torch.sub) |
| for op in (unary_ops + binary_ops): |
| with warnings.catch_warnings(record=True) as w: |
| warnings.simplefilter("always") |
| inputs = unary_inputs if op in unary_ops else binary_inputs |
| |
| # No warnings |
| op(*inputs, out=torch.empty(3, device=device)) |
| op(*inputs, out=torch.empty(0, device=device)) |
| self.assertEqual(len(w), 0) |
| |
| # Cases that throw warnings |
| op(*inputs, out=torch.empty(2, device=device)) |
| self.assertEqual(len(w), 1) |
| |
| # Verifies that the inplace dunders (like idiv) actually are in place |
| @onlyOnCPUAndCUDA |
| def test_inplace_dunders(self, device): |
| t = torch.randn((1,), device=device) |
| expected = t.data_ptr() |
| t += 1 |
| t -= 1 |
| t *= 1 |
| t /= 1 |
| t //= 1 |
| self.assertEqual(expected, t.data_ptr()) |
| |
| def check_internal_mem_overlap(self, inplace_op, num_inputs, |
| dtype, device, |
| expected_failure=False): |
| if isinstance(inplace_op, str): |
| inplace_op = getattr(torch.Tensor, inplace_op) |
| input = torch.randn(1, dtype=dtype, device=device).expand(3, 3) |
| inputs = [input] + [torch.randn_like(input) |
| for i in range(num_inputs - 1)] |
| if not expected_failure: |
| with self.assertRaisesRegex(RuntimeError, 'single memory location'): |
| inplace_op(*inputs) |
| else: |
| with self.assertRaises(AssertionError): |
| with self.assertRaisesRegex(RuntimeError, 'single memory location'): |
| inplace_op(*inputs) |
| |
| def unary_check_input_output_mem_overlap(self, data, sz, op, |
| expected_failure=False): |
| |
| def _test(op, output, input): |
| output_exp = torch.empty_like(output) |
| op(input, out=output_exp) |
| self.assertEqual(op(input, out=output), output_exp, msg=op.__name__) |
| |
| # output is identical to input: |
| _test(op, output=data[0:sz], input=data[0:sz]) |
| # output and input are independent: |
| _test(op, output=data[0:sz], input=data[sz:2 * sz]) |
| # output partially overlaps with input: |
| if not expected_failure: |
| with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): |
| _test(op, data[0:sz], data[1:sz + 1]) |
| else: |
| with self.assertRaises(AssertionError): |
| with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): |
| _test(op, data[0:sz], data[1:sz + 1]) |
| |
| def binary_check_input_output_mem_overlap(self, op, device, |
| expected_failure=False): |
| sz = 3 |
| data = torch.randn(2 * sz, device=device) |
| other = torch.randn(sz, device=device) |
| |
| self.unary_check_input_output_mem_overlap( |
| data, sz, lambda input, out: op(other, input, out=out), |
| expected_failure=expected_failure) |
| |
| self.unary_check_input_output_mem_overlap( |
| data, sz, lambda input, out: op(input, other, out=out), |
| expected_failure=expected_failure) |
| |
| @dtypes(torch.double) |
| def test_binary_op_mem_overlap(self, device, dtype): |
| ops = [ |
| ("add", True, True, 'cpu'), |
| ("add", True, True, 'cuda'), |
| ("mul", True, True, 'cpu'), |
| ("mul", True, True, 'cuda'), |
| ("sub", True, True, 'cpu'), |
| ("sub", True, True, 'cuda'), |
| ("div", True, True, 'cpu'), |
| ("div", True, True, 'cuda'), |
| ("pow", True, True, 'cpu'), |
| ("pow", True, True, 'cuda'), |
| ("fmod", True, True, 'cpu'), |
| ("fmod", True, True, 'cuda'), |
| ("atan2", True, True, 'cpu'), |
| ("atan2", True, True, 'cuda'), |
| ("hypot", True, True, 'cpu'), |
| ("hypot", True, True, 'cuda'), |
| ("igamma", True, True, 'cpu'), |
| ("igamma", True, True, 'cuda'), |
| ("igammac", True, True, 'cpu'), |
| ("igammac", True, True, 'cuda'), |
| ("nextafter", True, True, 'cpu'), |
| ("nextafter", True, True, 'cuda'), |
| ("le", True, True, 'cpu'), |
| ("le", True, True, 'cuda'), |
| ("lt", True, True, 'cpu'), |
| ("lt", True, True, 'cuda'), |
| ("ge", True, True, 'cpu'), |
| ("ge", True, True, 'cuda'), |
| ("gt", True, True, 'cpu'), |
| ("gt", True, True, 'cuda'), |
| ("eq", True, True, 'cpu'), |
| ("eq", True, True, 'cuda'), |
| ("ne", True, True, 'cpu'), |
| ("ne", True, True, 'cuda'), |
| ("logical_and", True, True, 'cpu'), |
| ("logical_and", True, True, 'cuda'), |
| ("logical_or", True, True, 'cpu'), |
| ("logical_or", True, True, 'cuda'), |
| ("logical_xor", True, True, 'cpu'), |
| ("logical_xor", True, True, 'cuda'), |
| ] |
| |
| for (fn, has_input_output_mem_overlap_check, |
| has_internal_mem_overlap_check, dev) in ops: |
| if dev != device: |
| continue |
| out_op = getattr(torch, fn) |
| inplace_op = getattr(torch.Tensor, fn + '_') |
| self.check_internal_mem_overlap( |
| inplace_op, 2, dtype, device, |
| expected_failure=not has_internal_mem_overlap_check) |
| |
| self.binary_check_input_output_mem_overlap(out_op, device, |
| expected_failure=not has_input_output_mem_overlap_check) |
| |
| def _do_pow_for_exponents(self, m1, exponents, pow_fn, atol): |
| for num in exponents: |
| if isinstance(num, int) and num < 0 and not m1.is_floating_point() and not m1.is_complex(): |
| with self.assertRaisesRegex(RuntimeError, |
| r'Integers to negative integer powers are not allowed\.'): |
| torch.pow(m1[4], num) |
| else: |
| # base - tensor, exponent - number |
| # contiguous |
| res1 = torch.pow(m1[4], num) |
| res2 = res1.clone().zero_() |
| # `math.pow` has issues with complex exponentiation so we need to resort to normal `pow`. |
| for i in range(res2.size(0)): |
| res2[i] = pow_fn(m1[4][i], num) |
| rtol = 0 if atol is not None else None |
| self.assertEqual(res1, res2, atol=atol, rtol=rtol) |
| |
| # non-contiguous |
| res1 = torch.pow(m1[:, 4], num) |
| res2 = res1.clone().zero_() |
| for i in range(res2.size(0)): |
| res2[i] = pow_fn(m1[i, 4], num) |
| self.assertEqual(res1, res2, atol=atol, rtol=rtol) |
| |
| # scalar ** tensor to enforce correct handling of dtypes for __rpow__(). |
| expected_dtype = torch.result_type(num, m1) |
| res1 = num ** m1[4] |
| res2 = torch.tensor(num, dtype=expected_dtype, device=m1.device) ** m1[4] |
| self.assertEqual(res1, res2) |
| self.assertEqual(res1.dtype, expected_dtype) |
| |
| def test_pow(self, device): |
| # [res] torch.pow([res,] x) |
| |
| # pow has dedicated implementation for different exponents |
| for dtype in torch.testing.get_all_math_dtypes(device): |
| |
| # This test won't work on torch.half because math.pow will generate a much more accurate result. We skip it |
| # for now. |
| if dtype == torch.half: |
| continue |
| |
| # deferring to https://github.com/pytorch/pytorch/pull/36793 |
| if dtype.is_complex: |
| continue |
| |
| m1 = torch.empty(0, dtype=dtype, device=device) |
| if m1.is_floating_point() or m1.is_complex(): |
| m1 = torch.rand(100, 100, dtype=dtype, device=device) + 0.5 |
| else: |
| # math.pow will overflow and throw exceptions for large integers |
| range_high = 4 if dtype in (torch.int8, torch.uint8) else 10 |
| m1 = torch.randint(1, range_high, (100, 100), dtype=dtype, device=device) |
| |
| exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3] |
| complex_exponents = [-2.5j, -1.0j, 0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j] |
| if m1.is_complex(): |
| self._do_pow_for_exponents(m1, exponents + complex_exponents, pow, 10e-4) |
| else: |
| self._do_pow_for_exponents(m1, exponents, math.pow, None) |
| self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4) |
| |
| # base - number, exponent - tensor |
| # contiguous |
| res1 = torch.pow(3, m1[4]) |
| res2 = res1.clone().zero_() |
| for i in range(res2.size(0)): |
| res2[i] = math.pow(3, m1[4, i]) |
| self.assertEqual(res1, res2) |
| |
| # non-contiguous |
| res1 = torch.pow(3, m1[:, 4]) |
| res2 = res1.clone().zero_() |
| for i in range(res2.size(0)): |
| res2[i] = math.pow(3, m1[i][4]) |
| self.assertEqual(res1, res2) |
| |
| # resize behavior for exp == 1 |
| out = torch.zeros(1, dtype=dtype, device=device) |
| torch.pow(m1, 1, out=out) |
| self.assertEqual(out, m1) |
| |
| # TODO: refactor all these tests using opinfos properly |
| def _test_pow(self, base, exponent, np_exponent=None): |
| if np_exponent is None: |
| np_exponent = exponent |
| |
| def to_np(value): |
| if isinstance(value, torch.Tensor): |
| return value.cpu().numpy() |
| return value |
| |
| try: |
| np_res = np.power(to_np(base), to_np(np_exponent)) |
| expected = torch.from_numpy(np_res) if isinstance(np_res, np.ndarray) else torch.tensor(np_res, dtype=base.dtype) |
| except ValueError as e: |
| err_msg = "Integers to negative integer powers are not allowed." |
| self.assertEqual(str(e), err_msg) |
| out = torch.empty_like(base) |
| test_cases = [ |
| lambda: base.pow(exponent), |
| lambda: base.pow_(exponent), |
| lambda: torch.pow(base, exponent), |
| lambda: torch.pow(base, exponent, out=out) |
| ] |
| for test_case in test_cases: |
| self.assertRaisesRegex(RuntimeError, err_msg, test_case) |
| else: |
| if isinstance(base, torch.Tensor): |
| actual = base.pow(exponent) |
| self.assertEqual(actual, expected.to(actual)) |
| actual = base.clone() |
| if torch.can_cast(torch.result_type(base, exponent), base.dtype): |
| actual2 = actual.pow_(exponent) |
| self.assertEqual(actual, expected) |
| self.assertEqual(actual2, expected) |
| else: |
| self.assertRaisesRegex(RuntimeError, "can't be cast", lambda: actual.pow_(exponent)) |
| |
| actual = torch.pow(base, exponent) |
| self.assertEqual(actual, expected.to(actual)) |
| |
| actual2 = torch.pow(base, exponent, out=actual) |
| self.assertEqual(actual, expected.to(actual)) |
| self.assertEqual(actual2, expected.to(actual)) |
| |
| def test_int_pow(self, device): |
| |
| def _test_integral_pow(dt, range, dev): |
| tensor = torch.tensor((3, 3), dtype=dt, device=dev).random_(*range) |
| exps = [0, 1, 2, 4, |
| torch.tensor((3, 3), dtype=dt, device=dev).random_(0, 5)] |
| for exp in exps: |
| self._test_pow(tensor, exp) |
| |
| _test_integral_pow(torch.int8, (-3, 4), device) |
| _test_integral_pow(torch.uint8, (0, 4), device) |
| _test_integral_pow(torch.int16, (-5, 5), device) |
| _test_integral_pow(torch.int64, (-10, 10), device) |
| _test_integral_pow(torch.int32, (-10, 10), device) |
| |
| def test_int_tensor_pow_neg_ints(self, device): |
| ints = [torch.iinfo(torch.int32).min, |
| -3, -2, -1, 0, 1, 2, 3, |
| torch.iinfo(torch.int32).max] |
| neg_ints = [torch.iinfo(torch.int32).min, -3, -2, -1] |
| tensor = torch.tensor(ints, dtype=torch.int32, device=device) |
| for pow in neg_ints: |
| self._test_pow(tensor, pow) |
| |
| def test_long_tensor_pow_floats(self, device): |
| ints = [0, 1, 23, 4567] |
| floats = [0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0] |
| tensor = torch.tensor(ints, dtype=torch.int64, device=device) |
| for pow in floats: |
| self._test_pow(tensor, pow) |
| |
| def test_float_scalar_pow_float_tensor(self, device): |
| floats = [2.0, -3 / 2, -1.0, -1 / 2, -1 / 3, 0.0, |
| 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0] |
| tensor = torch.tensor(floats, dtype=torch.float32, device=device) |
| for base in floats: |
| self._test_pow(base, tensor) |
| |
| @onlyCUDA |
| def test_cuda_tensor_pow_scalar_tensor(self, device): |
| cuda_tensors = [torch.randn((3, 3), device=device), torch.tensor(3.0, device=device)] |
| scalar_tensors = [torch.tensor(5.0, device='cpu'), torch.tensor(-3), torch.tensor(1)] |
| for base, exp in product(cuda_tensors, scalar_tensors): |
| self._test_pow(base, exp) |
| |
| @onlyCUDA |
| def test_cpu_tensor_pow_cuda_scalar_tensor(self, device): |
| cpu_tensors = [torch.randn((3, 3), device='cpu'), torch.tensor(3.0, device='cpu')] |
| cuda_tensors = [torch.tensor(5.0, device='cuda'), torch.tensor(-3, device='cuda')] |
| for base, exp in product(cpu_tensors, cuda_tensors): |
| regex = 'Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!' |
| self.assertRaisesRegex(RuntimeError, regex, torch.pow, base, exp) |
| |
| @onlyOnCPUAndCUDA |
| @dtypes(*(torch.testing.get_all_dtypes(include_bool=False, include_bfloat16=False))) |
| def test_complex_scalar_pow_tensor(self, device, dtype): |
| complexes = [0.5j, 1. + 1.j, -1.5j, 2.2 - 1.6j, 1 + 0j] |
| exp = make_tensor((100,), device, dtype, low=-2, high=2) |
| exp[0] = exp[10] = exp[20] = 0 |
| for base in complexes: |
| self._test_pow(base, exp) |
| |
| def test_tensor_pow_tensor(self, dev): |
| def rotate(l, n): |
| return l[-n:] + l[:-n] |
| |
| def test_tensor_pow_tensor(values, torch_type, numpy_type): |
| vals_tensor = torch.tensor(values, dtype=torch_type, device=dev) |
| for i in range(len(values)): |
| pows = rotate(values, i) |
| pows_tensor = torch.tensor(pows, dtype=torch_type, device=dev) |
| self._test_pow(vals_tensor, pows_tensor) |
| |
| ints = [0, 1, 2, 3] |
| test_tensor_pow_tensor(ints, torch.int32, np.int32) |
| test_tensor_pow_tensor(ints, torch.int64, np.int64) |
| |
| floats = [-3.0, -2.0, -1.0, -1 / 2, -1 / 3, |
| 0.0, |
| 1 / 3, 1 / 2, 1.0, 2.0, 3.0] |
| test_tensor_pow_tensor(floats, torch.float32, np.float32) |
| test_tensor_pow_tensor(floats, torch.float64, np.float64) |
| |
| def test_logical_xor_with_nontrivial_alignment(self, device): |
| # test tensor that is not aligned to multiple of 16 bytes |
| size = 128 |
| a = (torch.randn(size, device=device) > 0) |
| b = (torch.randn(size, device=device) > 0) |
| c = (torch.randn(size, device=device) > 0) |
| non_trivial_alignment = [1, 2, 4, 8, 15] |
| for i in non_trivial_alignment: |
| for j in non_trivial_alignment: |
| for k in non_trivial_alignment: |
| a_ = a[i: 100 + i] |
| b_ = b[j: 100 + j] |
| c_ = c[k: 100 + k] |
| torch.logical_xor(a_, b_, out=c_) |
| for x, y, z in zip(a_.tolist(), b_.tolist(), c_.tolist()): |
| self.assertEqual(x ^ y, z) |
| |
| @dtypes(torch.float) |
| def test_add_with_tail(self, device, dtype): |
| # test tensor where there is a tail which is not a multiple |
| # of GPU warp size |
| for tail_size in [1, 63, 67, 130]: |
| size = 4096 + tail_size |
| a = torch.randn(size, device=device, dtype=dtype) |
| b = torch.randn(size, device=device, dtype=dtype) |
| c = a + b |
| for x, y, z in zip(a.tolist(), b.tolist(), c.tolist()): |
| self.assertEqual(x + y, z) |
| |
| # Tests that CUDA tensors on different devices cannot be used in the same |
| # binary operation, and that CUDA "scalars" cannot be used in the same |
| # binary operation as non-scalar CPU tensors. |
| @deviceCountAtLeast(2) |
| @onlyCUDA |
| def test_cross_device_binary_ops(self, devices): |
| vals = (1., (2.,)) |
| cpu_tensor = torch.randn(2, 2) |
| for op in (operator.add, torch.add, |
| operator.sub, torch.sub, |
| operator.mul, torch.mul, |
| operator.truediv, torch.true_divide, |
| operator.floordiv, torch.floor_divide): |
| for a, b in product(vals, vals): |
| a = torch.tensor(a, device=devices[0]) |
| b = torch.tensor(b, device=devices[1]) |
| |
| with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): |
| op(a, b) |
| with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): |
| op(b, a) |
| with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): |
| op(a, cpu_tensor) |
| with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): |
| op(cpu_tensor, a) |
| |
| # This test ensures that a scalar Tensor can be safely used |
| # in a binary operation in conjunction with a Tensor on all |
| # available CUDA devices |
| @deviceCountAtLeast(2) |
| @onlyCUDA |
| def test_binary_op_scalar_device_unspecified(self, devices): |
| scalar_val = torch.tensor(1.) |
| for default_device in devices: |
| with torch.cuda.device(default_device): |
| for device in devices: |
| device_obj = torch.device(device) |
| x = torch.rand(3, device=device) |
| y0 = x * scalar_val |
| self.assertEqual(y0.device, device_obj) |
| y1 = scalar_val * x |
| self.assertEqual(y1.device, device_obj) |
| self.assertEqual(y0, y1) |
| |
| def test_div_and_floordiv_vs_python(self, device): |
| # Tests torch division ops which can handle both arguments being |
| # scalars. |
| # NOTE: torch.floor_divide currently truncates instead of flooring. |
| # the quotient. See https://github.com/pytorch/pytorch/issues/43874. |
| def _scalar_helper(python_op, torch_op): |
| for a, b in product(range(-10, 10), range(-10, 10)): |
| for op in (lambda x: x * .5, lambda x: math.floor(x)): |
| a = op(a) |
| b = op(b) |
| |
| # Skips zero divisors |
| if b == 0: |
| continue |
| |
| expected = python_op(a, b) |
| |
| for op in (operator.truediv, torch.true_divide): |
| actual_scalar = torch_op(a, b) |
| |
| a_t = torch.tensor(a, device=device) |
| b_t = torch.tensor(b, device=device) |
| |
| actual_tensor = torch_op(a_t, b_t) |
| actual_first_tensor = torch_op(a_t, b) |
| actual_second_tensor = torch_op(a, b_t) |
| |
| self.assertEqual(actual_scalar, expected_div) |
| self.assertEqual(actual_tensor.item(), expected_div) |
| self.assertEqual(actual_first_tensor, actual_tensor) |
| self.assertEqual(actual_second_tensor, actual_tensor) |
| |
| _scalar_helper(operator.truediv, operator.truediv) |
| _scalar_helper(operator.truediv, torch.true_divide) |
| _scalar_helper(lambda a, b: math.trunc(a / b), operator.floordiv) |
| _scalar_helper(lambda a, b: math.trunc(a / b), torch.floor_divide) |
| |
| # NOTE: torch.floor_divide currently truncates instead of flooring. |
| # See https://github.com/pytorch/pytorch/issues/43874. |
| @onlyOnCPUAndCUDA |
| def test_div_and_floordiv_script_vs_python(self, device): |
| # Creates jitted functions of two tensors |
| def _wrapped_div(a, b): |
| return a / b |
| |
| def _wrapped_floordiv(a, b): |
| return a // b |
| |
| scripted_div = torch.jit.script(_wrapped_div) |
| scripted_floordiv = torch.jit.script(_wrapped_floordiv) |
| for a, b in product(range(-10, 10), range(-10, 10)): |
| for op in (lambda x: x * .5, lambda x: math.floor(x)): |
| a = op(a) |
| b = op(b) |
| |
| # Skips zero divisors |
| if b == 0: |
| continue |
| |
| expected_div = a / b |
| expected_truncdiv = math.trunc(a / b) |
| a_t = torch.tensor(a, device=device) |
| b_t = torch.tensor(b, device=device) |
| |
| self.assertEqual(scripted_div(a_t, b_t), expected_div) |
| self.assertEqual(scripted_floordiv(a_t, b_t), expected_truncdiv) |
| |
| # Creates jitted functions of one tensor |
| def _wrapped_div_scalar(a): |
| return a / 5 |
| |
| # NOTE: this will fail when given an integer input, since |
| # the JIT implements division as |
| # torch.reciprocal(a) * 5, and reciprocal is only |
| # implemented for float types. |
| def _wrapped_rdiv_scalar(a): |
| return 5 / a |
| |
| def _wrapped_floordiv_scalar(a): |
| return a // 5 |
| |
| # NOTE: this fails if the input is not an integer tensor |
| # See https://github.com/pytorch/pytorch/issues/45199 |
| def _wrapped_rfloordiv_scalar(a): |
| return 5 // a |
| |
| scripted_div_scalar = torch.jit.script(_wrapped_div_scalar) |
| scripted_rdiv_scalar = torch.jit.script(_wrapped_rdiv_scalar) |
| scripted_floordiv_scalar = torch.jit.script(_wrapped_floordiv_scalar) |
| scripted_rfloordiv_scalar = torch.jit.script(_wrapped_rfloordiv_scalar) |
| |
| for a in range(-10, 10): |
| for op in (lambda x: x * .5, lambda x: math.floor(x)): |
| a = op(a) |
| |
| a_t = torch.tensor(a, device=device) |
| |
| self.assertEqual(a / 5, scripted_div_scalar(a_t)) |
| self.assertEqual(math.trunc(a / 5), scripted_floordiv_scalar(a_t)) |
| |
| # Skips zero divisors |
| if a == 0: |
| continue |
| |
| if a_t.is_floating_point(): |
| self.assertEqual(5 / a, scripted_rdiv_scalar(a_t)) |
| else: |
| with self.assertRaises(RuntimeError): |
| scripted_rdiv_scalar(a_t) |
| |
| |
| # Handles Issue 45199 (see comment above) |
| if a_t.is_floating_point(): |
| with self.assertRaises(RuntimeError): |
| scripted_rfloordiv_scalar(a_t) |
| else: |
| self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t)) |
| |
| # NOTE: torch.floor_divide currently truncates instead of flooring |
| # the quotient. See https://github.com/pytorch/pytorch/issues/43874. |
| @onlyOnCPUAndCUDA |
| def test_idiv_and_ifloordiv_vs_python(self, device): |
| def _wrapped_idiv_tensor(a, b): |
| a /= b |
| return a |
| |
| def _wrapped_idiv_scalar(a): |
| a /= 5 |
| return a |
| |
| def _wrapped_true_divide__tensor(a, b): |
| a.true_divide_(b) |
| return a |
| |
| def _wrapped_true_divide__scalar(a): |
| a.true_divide_(5) |
| return a |
| |
| def _wrapped_floor_divide__tensor(a, b): |
| a.floor_divide_(b) |
| return a |
| |
| def _wrapped_floor_divide__scalar(a): |
| a.floor_divide_(5) |
| return a |
| |
| # The following functions are unsupported by the JIT |
| def _wrapped_ifloordiv_tensor(a, b): |
| a //= b |
| return a |
| |
| def _wrapped_ifloordiv_scalar(a): |
| a //= 5 |
| return a |
| |
| with self.assertRaises(torch.jit.frontend.NotSupportedError): |
| scripted_ifloordiv_tensor = torch.jit.script(_wrapped_ifloordiv_tensor) |
| |
| with self.assertRaises(torch.jit.frontend.NotSupportedError): |
| scripted_ifloordiv_scalar = torch.jit.script(_wrapped_ifloordiv_scalar) |
| |
| scripted_idiv_tensor = torch.jit.script(_wrapped_idiv_tensor) |
| scripted_idiv_scalar = torch.jit.script(_wrapped_idiv_scalar) |
| scripted_true_divide__tensor = torch.jit.script(_wrapped_true_divide__tensor) |
| scripted_true_divide__scalar = torch.jit.script(_wrapped_true_divide__scalar) |
| scripted_floor_divide__tensor = torch.jit.script(_wrapped_floor_divide__tensor) |
| scripted_floor_divide__scalar = torch.jit.script(_wrapped_floor_divide__scalar) |
| |
| for a, b in product(range(-10, 10), range(-10, 10)): |
| for op in (lambda x: x * .5, lambda x: math.floor(x)): |
| a = op(a) |
| b = op(b) |
| |
| # Skips zero divisors |
| if b == 0: |
| continue |
| |
| expected_idiv = a / b |
| expected_ifloordiv = a // b |
| expected_itruncdiv = math.trunc(a / b) |
| |
| a_t = torch.tensor(a, device=device) |
| b_t = torch.tensor(b, device=device) |
| |
| if a_t.is_floating_point(): |
| tmp0 = a_t.clone() |
| tmp0 /= b |
| |
| tmp1 = a_t.clone() |
| tmp1 /= b_t |
| |
| self.assertEqual(tmp0.item(), expected_idiv) |
| self.assertEqual(tmp1.item(), expected_idiv) |
| self.assertEqual(scripted_true_divide__tensor(a_t.clone(), b_t).item(), expected_idiv) |
| self.assertEqual(scripted_true_divide__scalar(a_t.clone()).item(), a / 5) |
| else: |
| tmp = a_t.clone() |
| with self.assertRaises(RuntimeError): |
| tmp /= b |
| with self.assertRaises(RuntimeError): |
| tmp /= b_t |
| with self.assertRaises(RuntimeError): |
| scripted_true_divide__tensor(tmp, b_t) |
| with self.assertRaises(RuntimeError): |
| scripted_true_divide__scalar(tmp) |
| |
| |
| if not a_t.is_floating_point() and b_t.is_floating_point(): |
| # Inplace modification fails because a float tensor is required |
| # if the divisor is a float tensor |
| with self.assertRaises(RuntimeError): |
| a_t.clone().floor_divide_(b_t) |
| with self.assertRaises(RuntimeError): |
| scripted_floor_divide_tensor(a_t.clone(), b_t) |
| tmp = a_t.clone() |
| with self.assertRaises(RuntimeError): |
| tmp //= b_t |
| else: |
| # Inplace modification is OK when both or neither tensor is |
| # a float tensor |
| self.assertEqual(a_t.clone().floor_divide_(b_t).item(), expected_itruncdiv) |
| self.assertEqual(scripted_floor_divide__tensor(a_t.clone(), b_t).item(), expected_itruncdiv) |
| tmp = a_t.clone() |
| tmp //= b_t |
| self.assertEqual(tmp.item(), expected_itruncdiv) |
| |
| self.assertEqual(scripted_floor_divide__scalar(a_t), math.trunc(a / 5)) |
| |
| # Tests binary op equivalence with Python builtin ops |
| # Also tests that reverse operations are equivalent to forward ops |
| # NOTE: division ops are tested separately above |
| def test_binary_ops_with_scalars(self, device): |
| for ops in ((operator.add, torch.add), |
| (operator.sub, torch.sub), |
| (operator.mul, torch.mul), |
| (operator.truediv, torch.div)): |
| python_op, torch_op = ops |
| |
| for a, b in product(range(-10, 10), range(-10, 10)): |
| for op in (lambda x: x * .5, lambda x: math.floor(x)): |
| a = op(a) |
| b = op(b) |
| |
| # Skips zero divisors |
| if b == 0 or a == 0: |
| continue |
| |
| a_tensor = torch.tensor(a, device=device) |
| b_tensor = torch.tensor(b, device=device) |
| a_tensor_cpu = a_tensor.cpu() |
| b_tensor_cpu = b_tensor.cpu() |
| vals = (a, b, a_tensor, b_tensor, a_tensor_cpu, b_tensor_cpu) |
| |
| for args in product(vals, vals): |
| first, second = args |
| |
| first_scalar = first if not isinstance(first, torch.Tensor) else first.item() |
| second_scalar = second if not isinstance(second, torch.Tensor) else second.item() |
| expected = python_op(first_scalar, second_scalar) |
| |
| self.assertEqual(expected, python_op(first, second)) |
| self.assertEqual(expected, torch_op(first, second)) |
| |
| @dtypes(*product(torch.testing.get_all_dtypes(include_complex=False), torch.testing.get_all_dtypes(include_complex=False))) |
| def test_maximum_minimum_type_promotion(self, device, dtypes): |
| a = torch.tensor((0, 1), device=device, dtype=dtypes[0]) |
| b = torch.tensor((1, 0), device=device, dtype=dtypes[1]) |
| for op in (torch.maximum, torch.max, torch.minimum, torch.min): |
| result = op(a, b) |
| self.assertEqual(result.dtype, torch.result_type(a, b)) |
| |
| @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool])) |
| def test_maximum_minimum_int_and_bool(self, device, dtype): |
| ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum)) |
| rng = np.random.default_rng() |
| a_np = np.array(rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]) |
| b_np = np.array(rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]) |
| |
| for torch_op, alias, numpy_op in ops: |
| a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) |
| b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) |
| tensor_result = torch_op(a_tensor, b_tensor) |
| alias_result = alias(a_tensor, b_tensor) |
| |
| out = torch.empty_like(a_tensor) |
| torch_op(a_tensor, b_tensor, out=out) |
| |
| numpy_result = numpy_op(a_np, b_np) |
| |
| self.assertEqual(alias_result, tensor_result) |
| self.assertEqual(tensor_result, numpy_result) |
| self.assertEqual(out, numpy_result) |
| |
| @precisionOverride({torch.bfloat16: 1e-2}) |
| @dtypes(*(torch.testing.get_all_fp_dtypes())) |
| def test_maximum_minimum_float(self, device, dtype): |
| ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum)) |
| |
| if dtype == torch.bfloat16: |
| a_np = np.random.randn(10).astype(np.float64) |
| b_np = np.random.randn(10).astype(np.float64) |
| else: |
| a_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype]) |
| b_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype]) |
| |
| for torch_op, alias, numpy_op in ops: |
| numpy_result = numpy_op(a_np, b_np) |
| |
| a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) |
| b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) |
| tensor_result = torch_op(a_tensor, b_tensor) |
| alias_result = alias(a_tensor, b_tensor) |
| out = torch.empty_like(a_tensor) |
| torch_op(a_tensor, b_tensor, out=out) |
| |
| self.assertEqual(alias_result, tensor_result) |
| self.assertEqual(tensor_result, numpy_result) |
| self.assertEqual(out, numpy_result) |
| |
| @dtypes(*(torch.testing.get_all_fp_dtypes())) |
| def test_maximum_minimum_float_nan_and_inf(self, device, dtype): |
| # np.maximum and np.minimum functions compare input arrays element-wisely. |
| # if one of the elements being compared is a NaN, then that element is returned. |
| ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum)) |
| a_vals = (float('inf'), -float('inf'), float('nan'), float('nan')) |
| b_vals = (-float('inf'), float('inf'), float('inf'), float('nan')) |
| if dtype == torch.bfloat16: |
| a_np = np.array(a_vals, dtype=np.float64) |
| b_np = np.array(b_vals, dtype=np.float64) |
| else: |
| a_np = np.array(a_vals, dtype=torch_to_numpy_dtype_dict[dtype]) |
| b_np = np.array(b_vals, dtype=torch_to_numpy_dtype_dict[dtype]) |
| |
| for torch_op, alias, numpy_op in ops: |
| numpy_result = numpy_op(a_np, b_np) |
| |
| a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) |
| b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) |
| tensor_result = torch_op(a_tensor, b_tensor) |
| alias_result = alias(a_tensor, b_tensor) |
| |
| out = torch.empty_like(a_tensor) |
| torch_op(a_tensor, b_tensor, out=out) |
| |
| self.assertEqual(alias_result, tensor_result) |
| if dtype == torch.bfloat16: |
| self.assertEqual(tensor_result, numpy_result, exact_dtype=False) |
| self.assertEqual(out, numpy_result, exact_dtype=False) |
| else: |
| self.assertEqual(tensor_result, numpy_result) |
| self.assertEqual(out, numpy_result) |
| |
| @dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes())) |
| def test_maximum_minimum_complex(self, device, dtypes): |
| for torch_op in (torch.maximum, torch.minimum, torch.max, torch.min): |
| with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'): |
| torch_op(torch.ones(1, device=device, dtype=dtypes[0]), |
| torch.ones(1, device=device, dtype=dtypes[1])) |
| |
| with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'): |
| torch_op(torch.ones(1, device=device, dtype=dtypes[1]), |
| torch.ones(1, device=device, dtype=dtypes[0])) |
| |
| @onlyCUDA |
| def test_maximum_minimum_cross_device(self, device): |
| a = torch.tensor((1, 2, -1)) |
| b = torch.tensor((3, 0, 4), device=device) |
| ops = (torch.maximum, torch.minimum) |
| |
| for torch_op in ops: |
| with self.assertRaisesRegex(RuntimeError, |
| "Expected all tensors to be on the same device"): |
| torch_op(a, b) |
| |
| with self.assertRaisesRegex(RuntimeError, |
| "Expected all tensors to be on the same device"): |
| torch_op(b, a) |
| |
| # test cuda tensor and cpu scalar |
| ops = ((torch.maximum, np.maximum), (torch.minimum, np.minimum)) |
| a_np = np.array(1) |
| b_np = np.array([3, 0, 4]) |
| |
| for torch_op, numpy_op in ops: |
| a_tensor = torch.from_numpy(a_np) |
| b_tensor = torch.from_numpy(b_np).to(device=device) |
| tensor_result_1 = torch_op(a_tensor, b_tensor) |
| numpy_result_1 = numpy_op(a_np, b_np) |
| tensor_result_2 = torch_op(b_tensor, a_tensor) |
| numpy_result_2 = numpy_op(b_np, a_np) |
| |
| self.assertEqual(tensor_result_1, numpy_result_1) |
| self.assertEqual(tensor_result_2, numpy_result_2) |
| |
| # TODO: tests like this should be generic |
| @dtypesIfCUDA(torch.half, torch.float, torch.double) |
| @dtypes(torch.float, torch.double) |
| def test_mul_intertype_scalar(self, device, dtype): |
| x = torch.tensor(1.5, dtype=dtype, device=device) |
| y = torch.tensor(3, dtype=torch.int32, device=device) |
| |
| self.assertEqual(x * y, 4.5) |
| self.assertEqual(y * x, 4.5) |
| |
| with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"): |
| y *= x |
| x *= y |
| self.assertEqual(x, 4.5) |
| |
| @onlyCPU |
| @dtypes(*torch.testing.get_all_dtypes()) |
| def test_sub(self, device, dtype): |
| m1 = torch.tensor([2.34, 4.44], dtype=dtype, device=device) |
| m2 = torch.tensor([1.23, 2.33], dtype=dtype, device=device) |
| |
| if dtype == torch.bool: |
| self.assertRaises(RuntimeError, lambda: m1 - m2) |
| elif (dtype == torch.bfloat16 or dtype == torch.half): |
| # bfloat16 has a lower precision so we have to have a separate check for it |
| self.assertEqual(m1 - m2, torch.tensor([1.11, 2.11], dtype=dtype), atol=0.01, rtol=0) |
| else: |
| self.assertEqual(m1 - m2, torch.tensor([1.11, 2.11], dtype=dtype)) |
| |
| # TODO: what is this test testing? |
| @onlyCPU |
| @dtypes(torch.float) |
| def test_csub(self, device, dtype): |
| # with a tensor |
| a = torch.randn(100, 90, dtype=dtype, device=device) |
| b = a.clone().normal_() |
| |
| res_add = torch.add(a, b, alpha=-1) |
| res_csub = a.clone() |
| res_csub.sub_(b) |
| self.assertEqual(res_add, res_csub) |
| |
| # with a scalar |
| a = torch.randn(100, 100, dtype=dtype, device=device) |
| |
| scalar = 123.5 |
| res_add = torch.add(a, -scalar) |
| res_csub = a.clone() |
| res_csub.sub_(scalar) |
| self.assertEqual(res_add, res_csub) |
| |
| # TODO: reconcile with minimum/maximum tests |
| @dtypesIfCUDA(torch.half, torch.float, torch.double) |
| @dtypes(torch.float, torch.double) |
| def test_min_max_binary_op_nan(self, device, dtype): |
| a = torch.rand(1000, dtype=dtype, device=device) |
| b = torch.rand(1000, dtype=dtype, device=device) |
| |
| # 0:250: a -- nan, b -- not nan |
| a[:250] = float('nan') |
| # 250:500: a -- not nan, b -- nan |
| b[250:500] = float('nan') |
| # 500:750: a and b both nan |
| a[500:750] = float('nan') |
| b[500:750] = float('nan') |
| # 750:1000: neither nan |
| |
| ma = torch.max(a, b) |
| mi = torch.min(a, b) |
| |
| for i in range(750): |
| self.assertTrue(torch.isnan(ma[i]), "max(a, b): {}, a: {}, b: {}".format(ma[i], a[i], b[i])) |
| self.assertTrue(torch.isnan(mi[i]), "min(a, b): {}, a: {}, b: {}".format(mi[i], a[i], b[i])) |
| |
| for i in range(750, 1000): |
| self.assertFalse(torch.isnan(ma[i]), "max(a, b): {}, a: {}, b: {}".format(ma[i], a[i], b[i])) |
| self.assertFalse(torch.isnan(mi[i]), "min(a, b): {}, a: {}, b: {}".format(mi[i], a[i], b[i])) |
| |
| @dtypes(*product(torch.testing.get_all_dtypes(include_complex=False), |
| torch.testing.get_all_dtypes(include_complex=False))) |
| def test_copysign(self, device, dtypes): |
| def _test_copysign_numpy(a, b): |
| torch_result = torch.copysign(a, b) |
| |
| if a.dtype == torch.bfloat16: |
| np_a = a.to(torch.float).cpu().numpy() |
| else: |
| np_a = a.cpu().numpy() |
| |
| if b.dtype == torch.bfloat16: |
| np_b = b.to(torch.float).cpu().numpy() |
| else: |
| np_b = b.cpu().numpy() |
| expected = torch.from_numpy(np.copysign(np_a, np_b)) |
| # To handle inconsistencies of type promotion between PyTorch and Numpy |
| # Applied for both arguments having integral precision and bfloat16 |
| types = [torch.bool, torch.bfloat16] + torch.testing.get_all_int_dtypes() |
| if a.dtype in types or b.dtype in types: |
| promoted_type = torch.promote_types(torch_result.dtype, expected.dtype) |
| torch_result = torch_result.to(promoted_type) |
| expected = expected.to(promoted_type) |
| |
| # Verify Value |
| self.assertEqual(torch_result, expected) |
| # Verify Sign |
| # Use double copysign to verify the correctnes of 0.0 and -0.0, since |
| # it always True for self.assertEqual(0.0 == -0.0). So, we use 1 as the |
| # magnitude to verify the sign between torch and numpy results, elementwise. |
| # Special case: NaN conversions between FP32 and FP16 is not bitwise |
| # equivalent to pass this assertion. |
| if a.dtype != torch.float16 and b.dtype != torch.float16: |
| self.assertEqual(torch.copysign(torch.tensor(1.0), torch_result), |
| torch.copysign(torch.tensor(1.0), expected)) |
| |
| # Compare Result with NumPy |
| # Type promotion |
| a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9) |
| b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9) |
| _test_copysign_numpy(a, b) |
| |
| # Broadcast |
| a = make_tensor((10, 1, 10), device=device, dtype=dtypes[0], low=-9, high=9) |
| b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9) |
| _test_copysign_numpy(a, b) |
| |
| a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9) |
| b = make_tensor((10, 1, 10), device=device, dtype=dtypes[1], low=-9, high=9) |
| _test_copysign_numpy(a, b) |
| |
| # 0.0/-0.0/inf/-inf/nan |
| cases = [0.0, -0.0, float('inf'), float('-inf'), float('nan')] |
| # torch.bfloat16 can not hold '-nan' |
| # torch.half can not hold '-nan' on CUDA |
| types = [torch.float32, torch.float64] |
| if device == 'cpu': |
| types.append(torch.float16) |
| if dtypes[0] in types: |
| b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9) |
| for case in cases: |
| _test_copysign_numpy(torch.tensor([case], device=device, dtype=dtypes[0]), b) |
| |
| if dtypes[1] in torch.testing.get_all_fp_dtypes(): |
| a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9) |
| for case in cases: |
| _test_copysign_numpy(a, torch.tensor([case], device=device, dtype=dtypes[1])) |
| |
| @dtypes(torch.bfloat16, torch.float) |
| def test_div(self, device, dtype): |
| for op, method, inplace in ((torch.div, torch.Tensor.div, torch.Tensor.div_), |
| (torch.true_divide, torch.Tensor.true_divide, |
| torch.Tensor.true_divide_)): |
| m1 = torch.randn(10, 10, dtype=torch.float, device=device).to(dtype=dtype) |
| res1 = m1.clone() |
| inplace(res1[:, 3], 2) |
| res2 = m1.clone() |
| for i in range(m1.size(0)): |
| res2[i, 3] = res2[i, 3] / 2 |
| self.assertEqual(res1, res2) |
| |
| if dtype == torch.bfloat16: |
| a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device) |
| a2 = torch.tensor([2., 2.], dtype=dtype, device=device) |
| self.assertEqual(op(a1, a2), |
| torch.tensor([2.1, 3.1], dtype=dtype, device=device), |
| atol=0.01, rtol=0) |
| self.assertEqual(method(a1, a2), op(a1, a2)) |
| |
| @dtypes(torch.bfloat16, torch.float) |
| def test_true_divide_out(self, device, dtype): |
| a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device) |
| a2 = torch.tensor([2., 2.], dtype=dtype, device=device) |
| res = torch.empty_like(a1) |
| self.assertEqual(torch.true_divide(a1, a2, out=res), |
| torch.tensor([2.1, 3.1], dtype=dtype, device=device), |
| atol=0.01, rtol=0) |
| |
| @onlyCUDA |
| @dtypes(torch.half) |
| def test_divmul_scalar(self, device, dtype): |
| x = torch.tensor(100., device=device, dtype=dtype) |
| x_ref = x.float() |
| scale = 1e5 |
| res = x.div(scale) |
| expected = x_ref.div(scale) |
| self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.) |
| x = torch.tensor(1e-5, device=device, dtype=dtype) |
| x_ref = x.float() |
| res = x.mul(scale) |
| expected = x_ref.mul(scale) |
| self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.) |
| res = scale * x |
| self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.) |
| |
| @dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) |
| @dtypes(*set(torch.testing.get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) |
| def test_floor_divide_tensor(self, device, dtype): |
| x = torch.randn(10, device=device).mul(30).to(dtype) |
| y = torch.arange(1, 11, dtype=dtype, device=device) |
| |
| z = x // y |
| z_alt = torch.trunc(x.double() / y.double()).to(dtype) |
| |
| self.assertEqual(z.dtype, x.dtype) |
| self.assertEqual(z, z_alt) |
| |
| @dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) |
| @dtypes(*set(torch.testing.get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) |
| def test_floor_divide_scalar(self, device, dtype): |
| x = torch.randn(100, device=device).mul(10).to(dtype) |
| |
| z = x // 3 |
| z_alt = torch.tensor([math.trunc(v.item() / 3.) for v in x], dtype=x.dtype, device=device) |
| |
| self.assertEqual(z.dtype, x.dtype) |
| self.assertEqual(z, z_alt) |
| |
| # Note: this tests fails on XLA |
| @onlyOnCPUAndCUDA |
| @dtypes(torch.float, torch.long) |
| def test_floor_divide_out(self, device, dtype): |
| x = torch.randn(10, device=device).mul(10).to(dtype) |
| y = torch.arange(1, 11, dtype=dtype, device=device) |
| o = torch.empty(10, dtype=dtype, device=device) |
| |
| torch.floor_divide(x, y, out=o) |
| self.assertEqual(o, x // y) |
| |
| # Tests scalar with out |
| torch.floor_divide(x, 2, out=o) |
| self.assertEqual(o, x // 2) |
| |
| if dtype == torch.int: |
| o = torch.empty(10, dtype=torch.float, device=device) |
| torch.floor_divide(x, y, out=o) |
| self.assertEqual(o, torch.floor_divide(x.float(), y.float())) |
| |
| @onlyCPU |
| @dtypes(*torch.testing.get_all_math_dtypes('cpu')) |
| def test_rdiv(self, device, dtype): |
| if dtype is torch.float16: |
| return |
| elif dtype.is_complex: |
| x = torch.rand(100, dtype=dtype, device=device).add(1).mul(4) |
| else: |
| x = torch.rand(100, device=device).add(1).mul(4).to(dtype) |
| y = 30 / x |
| z = torch.tensor([30 / v.item() for v in x], device=device) |
| self.assertEqual(y, z, exact_dtype=False) |
| |
| @dtypes(*torch.testing.get_all_fp_dtypes(include_bfloat16=False)) |
| def test_fmod_by_zero_float(self, device, dtype): |
| # check floating-point tensor fmod to zero is nan on both CPU and GPU |
| x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9) |
| zero = torch.zeros_like(x) |
| |
| self.assertTrue(torch.all(x.fmod(0.0).isnan())) |
| self.assertTrue(torch.all(x.fmod(zero).isnan())) |
| # out |
| out = torch.empty(0, device=device, dtype=dtype) |
| torch.fmod(x, zero, out=out) |
| self.assertEqual(out.size(), torch.Size([10, 10])) |
| self.assertTrue(torch.all(out.isnan())) |
| # in-place |
| x.fmod_(zero) |
| self.assertTrue(torch.all(x.isnan())) |
| |
| @onlyOnCPUAndCUDA # Check Issue https://github.com/pytorch/pytorch/issues/48130 |
| @skipCUDAIfRocm # Error happens on both ROCM and XLA |
| @dtypes(*torch.testing.get_all_int_dtypes()) |
| def test_fmod_by_zero_integral(self, device, dtype): |
| # check integral tensor fmod to zero |
| x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9) |
| zero = torch.zeros_like(x) |
| # out |
| out = torch.empty(0, device=device, dtype=dtype) |
| # In-place |
| x_ = x.clone() |
| # RuntimeError on CPU |
| if device == 'cpu': |
| with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"): |
| x.fmod(zero) |
| with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"): |
| torch.fmod(x, zero, out=out) |
| with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"): |
| x.fmod_(zero) |
| # Different value for different dtype on GPU |
| else: |
| if dtype == torch.int64: |
| self.assertEqual(x.fmod(zero) == 4294967295, x >= 0) |
| self.assertEqual(x.fmod(zero) == -1, x < 0) |
| # out |
| torch.fmod(x, zero, out=out) |
| self.assertEqual(out == 4294967295, x >= 0) |
| self.assertEqual(out == -1, x < 0) |
| self.assertEqual(out.size(), torch.Size([10, 10])) |
| # in-place |
| x_.fmod_(zero) |
| self.assertEqual(x_ == 4294967295, x >= 0) |
| self.assertEqual(x_ == -1, x < 0) |
| else: |
| value = 255 if dtype == torch.uint8 else -1 |
| self.assertTrue(torch.all(x.fmod(zero) == value)) |
| # out |
| torch.fmod(x, zero, out=out) |
| self.assertTrue(torch.all(out == value)) |
| self.assertEqual(out.size(), torch.Size([10, 10])) |
| # in-place |
| x_.fmod_(zero) |
| self.assertTrue(torch.all(x_ == value)) |
| |
| @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False)) |
| def test_fmod(self, device, dtype): |
| # Use numpy as reference |
| def _reference_implementation(x, mod): |
| np_x = x.cpu().numpy() |
| np_mod = 0 |
| # No type promotion |
| # Issue #47779: https://github.com/pytorch/pytorch/issues/47779 |
| if torch.is_tensor(mod): |
| np_mod = mod.cpu().numpy() |
| else: |
| np_mod = mod |
| # Non XLA platform needs to cast to int |
| if dtype in torch.testing.get_all_int_dtypes() and self.device_type in ['cpu', 'cuda']: |
| np_mod = int(np_mod) |
| exp = np.fmod(np_x, np_mod) |
| exp = torch.from_numpy(exp) |
| |
| res = torch.fmod(x, mod) |
| res = res.to(exp.dtype) |
| self.assertEqual(res, exp) |
| # out |
| out = torch.empty(0, device=device, dtype=dtype) |
| torch.fmod(x, mod, out=out) |
| out.to(exp.dtype) |
| self.assertEqual(out, exp) |
| self.assertEqual(out.size(), torch.Size([10, 10])) |
| # in-place |
| x.fmod_(mod) |
| x.to(exp.dtype) |
| self.assertEqual(out, exp) |
| |
| x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9) |
| # Exclude 0 |
| # mod with same dtype as x |
| mod = make_tensor((10, 10), device=device, dtype=dtype, low=1, high=9) |
| # mod with floating-point dtype |
| mod_float = make_tensor((10, 10), device=device, |
| dtype=torch.float if dtype in torch.testing.get_all_int_dtypes() else dtype, |
| low=1, high=9) |
| # non-contiguous |
| x_nc = x.t() |
| mod_nc = mod.t() |
| |
| # Mods: Integer, Float, Tensor, Non-contiguous Tensor |
| mods = [3, 2.3, mod, mod_nc] |
| for m in mods: |
| _reference_implementation(x, m) |
| _reference_implementation(x_nc, m) |
| |
| # Integral Tensor fmod to floating-point Tensor |
| # Can not cast floating-point result to original integral Tensor without type promotion |
| if dtype in torch.testing.get_all_int_dtypes(): |
| res = torch.fmod(x, mod_float) |
| exp = np.fmod(x.cpu().numpy(), mod_float.cpu().numpy()) |
| exp = torch.from_numpy(exp) |
| res = res.to(exp.dtype) |
| self.assertEqual(res, exp) |
| with self.assertRaisesRegex(RuntimeError, "result type (Half|Float|Double) " |
| "can't be cast to the desired " |
| "output type (Byte|Char|Short|Int|Long)"): |
| out = torch.empty(0, device=device, dtype=dtype) |
| torch.fmod(x, mod_float, out=out) |
| with self.assertRaisesRegex(RuntimeError, "result type (Half|Float|Double) " |
| "can't be cast to the desired " |
| "output type (Byte|Char|Short|Int|Long)"): |
| x.fmod_(mod_float) |
| else: |
| _reference_implementation(x, mod_float) |
| |
| @onlyCPU |
| @dtypes(torch.float, torch.long) |
| def test_remainder(self, device, dtype): |
| for use_item in [True, False]: |
| if dtype == torch.float: |
| m1 = torch.Tensor(10, 10).uniform_(-10., 10.).to(dtype=dtype, device=device) |
| res1 = m1.clone() |
| res2 = m1.clone() |
| qs = torch.arange(-5.1, 4.1, dtype=dtype, device=device) |
| # Check the case where the divisor is a simple float |
| for col_idx, q in enumerate(qs): |
| # Reference |
| for i in range(m1.size(0)): |
| res2[i, col_idx] = res2[i, col_idx] % q |
| # To test |
| res1[:, col_idx].remainder_(q if not use_item else q.item()) |
| self.assertEqual(res1, res2) |
| # Check the case where the divisor is a tensor |
| res1 = m1.clone() |
| res1.remainder_(qs.unsqueeze(0).expand_as(res1)) |
| self.assertEqual(res1, res2) |
| elif dtype == torch.long: |
| long_m1 = torch.LongTensor(10, 10).random_(-10, 10) |
| long_res1 = long_m1.clone() |
| long_res2 = long_m1.clone() |
| long_qs = torch.arange(-5, 5, dtype=dtype, device=device) |
| long_qs[5] = 5 # Can't handle the divisor=0 case |
| for col_idx, long_q in enumerate(long_qs): |
| # Reference |
| for i in range(long_m1.size(0)): |
| long_res2[i, col_idx] = long_res2[i, col_idx] % long_q |
| # To test |
| long_res1[:, col_idx].remainder_(long_q if not use_item else long_q.item()) |
| self.assertEqual(long_res1, long_res2) |
| # Divisor is a tensor case |
| long_res1 = long_m1.clone() |
| long_res1.remainder_(long_qs.unsqueeze(0).expand_as(long_res1)) |
| |
| @dtypes(torch.float, torch.double) |
| def test_remainder_fmod_large_dividend(self, device, dtype): |
| alarge = 1e9 |
| pi = 3.14159265358979 |
| for avalue in [alarge, -alarge]: |
| for bvalue in [pi, -pi]: |
| a = torch.tensor([avalue], dtype=dtype, device=device) |
| b = torch.tensor([bvalue], dtype=dtype, device=device) |
| c = torch.remainder(a, b) |
| d = torch.fmod(a, b) |
| self.assertTrue((b[0] > 0) == (c[0] > 0)) # remainder has same sign as divisor |
| self.assertTrue((a[0] > 0) == (d[0] > 0)) # fmod has same sign as dividend |
| self.assertTrue(abs(c[0]) < abs(b[0])) # remainder is within range of divisor |
| self.assertTrue(abs(d[0]) < abs(b[0])) # fmod is within range of divisor |
| if ((a[0] > 0) == (b[0] > 0)): |
| self.assertTrue(c[0] == d[0]) # remainder is same as fmod |
| else: |
| self.assertTrue(abs(c[0] - d[0]) == abs(b[0])) # differ by one divisor |
| |
| @dtypesIfCPU(torch.bfloat16, torch.float32, torch.float64) |
| @dtypes(torch.float32, torch.float64) |
| def test_hypot(self, device, dtype): |
| inputs = [ |
| (torch.randn(10, device=device).to(dtype), torch.randn(10, device=device).to(dtype)), |
| (torch.randn((3, 3, 3), device=device).to(dtype), torch.randn((3, 3, 3), device=device).to(dtype)), |
| (torch.randn((10, 1), device=device).to(dtype), torch.randn((10, 1), device=device).to(dtype).transpose(0, 1)), |
| (torch.randint(100, (10, ), device=device, dtype=torch.long), torch.randn(10, device=device).to(dtype)) |
| ] |
| for input in inputs: |
| actual = torch.hypot(input[0], input[1]) |
| if dtype == torch.bfloat16: |
| expected = torch.sqrt(input[0] * input[0] + input[1] * input[1]) |
| else: |
| expected = np.hypot(input[0].cpu().numpy(), input[1].cpu().numpy()) |
| self.assertEqual(actual, expected) |
| |
| @dtypes(torch.int64, torch.float64) |
| def test_remainder_edge_cases(self, device, dtype): |
| # Test variations of negative values used as input |
| a = torch.tensor([6, -6, -6, 6, 27, -27, -27, 27], dtype=dtype, device=device) |
| b = torch.tensor([-3, 3, -3, 3, -5, 5, -5, 5], dtype=dtype, device=device) |
| r = a.remainder(b) |
| r_expected = torch.tensor([0, 0, 0, 0, -3, 3, -2, 2], dtype=dtype, device=device) |
| self.assertEqual(r, r_expected) |
| |
| if dtype == torch.float64: |
| # Test cases where result should be nan |
| a = torch.tensor([-34, 0, 34], dtype=dtype, device=device) |
| b = torch.zeros(3, dtype=dtype, device=device) |
| self.assertTrue(torch.isnan(a.remainder(b)).all()) |
| |
| # Need to test a fairly large tensor with float cpu to run |
| # the Vec256 implementation |
| if device == 'cpu': |
| a = torch.tensor([6, -6, -6, 6, 27, -27, -27, 27] * 10000, dtype=dtype, device=device) |
| b = torch.tensor([-3, 3, -3, 3, -5, 5, -5, 5] * 10000, dtype=dtype, device=device) |
| r = a.remainder(b) |
| r_expected = torch.tensor([0, 0, 0, 0, -3, 3, -2, 2] * 10000, dtype=dtype, device=device) |
| self.assertEqual(r, r_expected) |
| # Test nan cases |
| |
| a = torch.tensor([-34, 0, 34] * 20000, dtype=dtype, device=device) |
| b = torch.zeros(3 * 20000, dtype=dtype, device=device) |
| self.assertTrue(torch.isnan(a.remainder(b)).all()) |
| |
| elif dtype == torch.int64: |
| if device == 'cpu': |
| # Test int divide by zero causes an exception |
| a = torch.ones(1000, dtype=dtype, device=device) |
| b = torch.ones(1000, dtype=dtype, device=device) |
| b[500] = 0 |
| self.assertRaises(RuntimeError, lambda: a.remainder(b)) |
| |
| # Check scalar type is promoted to match tensor |
| a = torch.ones(1, dtype=dtype, device=device) |
| b = 1.0 if dtype == torch.int64 else 1 |
| r = a.remainder(b) |
| self.assertEqual(r.dtype, a.dtype) |
| |
| @onlyOnCPUAndCUDA |
| @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) |
| def test_gcd(self, device, dtype): |
| # Tests gcd(0, 0), gcd(0, a) cases |
| t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device) |
| t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device) |
| actual = torch.gcd(t1, t2) |
| expected = np.gcd([0, 10, 0], [0, 0, 10]) |
| self.assertEqual(actual, expected) |
| |
| if dtype == torch.uint8: |
| # Test unsigned integers with potential sign issues (i.e., uint8 with value >= 128) |
| a = torch.tensor([190, 210], device=device, dtype=dtype) |
| b = torch.tensor([190, 220], device=device, dtype=dtype) |
| actual = torch.gcd(a, b) |
| expected = torch.tensor([190, 10], device=device, dtype=dtype) |
| else: |
| # Compares with NumPy |
| a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) |
| b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) |
| actual = torch.gcd(a, b) |
| expected = np.gcd(a.cpu().numpy(), b.cpu().numpy()) |
| self.assertEqual(actual, expected) |
| |
| @onlyOnCPUAndCUDA |
| @dtypes(torch.int16, torch.int32, torch.int64) |
| def test_lcm(self, device, dtype): |
| # Tests lcm(0, 0), lcm(0, a) cases |
| t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device) |
| t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device) |
| actual = torch.lcm(t1, t2) |
| expected = np.lcm([0, 10, 0], [0, 0, 10]) |
| self.assertEqual(actual, expected) |
| |
| # Compares with NumPy |
| a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) |
| b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) |
| actual = torch.lcm(a, b) |
| expected = np.lcm(a.cpu().numpy(), b.cpu().numpy()) |
| self.assertEqual(actual, expected) |
| |
| @onlyOnCPUAndCUDA |
| @dtypes(torch.float32, torch.float64) |
| def test_nextafter(self, device, dtype): |
| # Test special cases |
| t1 = torch.tensor([0, 0, 10], device=device, dtype=dtype) |
| t2 = torch.tensor([inf, -inf, 10], device=device, dtype=dtype) |
| actual = torch.nextafter(t1, t2) |
| expected = np.nextafter(t1.cpu().numpy(), t2.cpu().numpy()) |
| self.assertEqual(actual, expected, atol=0, rtol=0) |
| |
| actual = torch.nextafter(t2, t1) |
| expected = np.nextafter(t2.cpu().numpy(), t1.cpu().numpy()) |
| self.assertEqual(actual, expected, atol=0, rtol=0) |
| |
| t1 = torch.tensor([0, nan], device=device, dtype=dtype) |
| t2 = torch.tensor([nan, 0], device=device, dtype=dtype) |
| self.assertTrue(torch.nextafter(t1, t2).isnan().all()) |
| |
| a = torch.randn(100, device=device, dtype=dtype) |
| b = torch.randn(100, device=device, dtype=dtype) |
| actual = torch.nextafter(a, b) |
| expected = np.nextafter(a.cpu().numpy(), b.cpu().numpy()) |
| self.assertEqual(actual, expected, atol=0, rtol=0) |
| |
| def _test_cop(self, torchfn, mathfn, dtype, device): |
| def reference_implementation(res2): |
| for i, j in iter_indices(sm1): |
| idx1d = i * sm1.size(0) + j |
| res2[i, j] = mathfn(sm1[i, j], sm2[idx1d]) |
| return res2 |
| |
| # contiguous |
| m1 = torch.randn(10, 10, 10, dtype=dtype, device=device) |
| m2 = torch.randn(10, 10 * 10, dtype=dtype, device=device) |
| sm1 = m1[4] |
| sm2 = m2[4] |
| |
| res1 = torchfn(sm1, sm2.view(10, 10)) |
| res2 = reference_implementation(res1.clone()) |
| self.assertEqual(res1, res2) |
| |
| # non-contiguous |
| m1 = torch.randn(10, 10, 10, dtype=dtype, device=device) |
| m2 = torch.randn(10 * 10, 10 * 10, dtype=dtype, device=device) |
| sm1 = m1[:, 4] |
| sm2 = m2[:, 4] |
| # view as sm1.size() |
| sm2.set_(sm2.storage(), sm2.storage_offset(), sm1.size(), (sm2.stride()[0] * 10, sm2.stride()[0])) |
| res1 = torchfn(sm1, sm2) |
| # reference_implementation assumes 1-d sm2 |
| sm2.set_(sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride()) |
| res2 = reference_implementation(res1.clone()) |
| self.assertEqual(res1, res2) |
| |
| @onlyCPU |
| @dtypes(torch.float) |
| def test_cdiv(self, device, dtype): |
| self._test_cop(torch.div, lambda x, y: x / y, dtype, device) |
| |
| @onlyCPU |
| @dtypes(torch.float) |
| def test_cremainder(self, device, dtype): |
| self._test_cop(torch.remainder, lambda x, y: x % y, dtype, device) |
| |
| @onlyCPU |
| @dtypes(torch.float) |
| def test_cmul(self, device, dtype): |
| self._test_cop(torch.mul, lambda x, y: x * y, dtype, device) |
| |
| @onlyCPU |
| @dtypes(torch.float) |
| def test_cpow(self, device, dtype): |
| self._test_cop(torch.pow, lambda x, y: nan if x < 0 else math.pow(x, y), dtype, device) |
| |
| @onlyCPU |
| @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) |
| def test_floor_divide_zero(self, device, dtype): |
| a = torch.tensor([0, 1], dtype=dtype, device=device) |
| b = torch.tensor([0, 1], dtype=dtype, device=device) |
| with self.assertRaisesRegex(RuntimeError, 'ZeroDivisionError'): |
| a // b |
| |
| @unittest.skipIf(TEST_WITH_ASAN, "Integer overflows are not allowed under ASAN") |
| @dtypes(*torch.testing.get_all_dtypes()) |
| def test_muldiv_scalar(self, device, dtype): |
| x = make_tensor((10, 3), device, dtype, low=None, high=None) |
| s = make_tensor((1,), 'cpu', dtype, low=None, high=None).item() |
| y = torch.full_like(x, s) |
| self.assertEqual(x * s, x * y) |
| self.assertEqual(s * x, y * x) |
| self.assertEqual(x / s, x / y) |
| self.assertEqual(s / x, y / x) |
| |
| @dtypes(*tuple(itertools.combinations_with_replacement(torch.testing.get_all_dtypes(), 2))) |
| def test_comparison_ops_type_promotion_and_broadcasting(self, device, dtypes): |
| # issue #42660 |
| # testing all combinations of broadcasting and type promotion |
| # with a range of dtypes and input shapes, and with extremal values |
| def compare_with_numpy_bin_op(torch_fn, np_fn, x, y, out=None): |
| # working around the fact that numpy doesn't support bfloat16 |
| # by letting numpy treat them as float32's |
| x_np = x if x.dtype != torch.bfloat16 else x.to(torch.float32) |
| y_np = y.cpu().numpy() if y.dtype != torch.bfloat16 else y.to(torch.float32).cpu().numpy() |
| self.compare_with_numpy(lambda inp: torch_fn(inp, y, out=out) if out else torch_fn(inp, y), |
| lambda inp: np_fn(inp, y_np, out=out) if out else np_fn(inp, y_np), |
| x_np) |
| |
| complex_op_denylist = [torch.lt, torch.le, torch.gt, torch.ge] # complex not supported |
| input_sizes = [ |
| (1,), |
| (10,), |
| (10, 1), |
| (1, 10), |
| (4, 10), |
| (64, 10), |
| (12, 3)] |
| op_pairs = [(torch.lt, np.less), |
| (torch.le, np.less_equal), |
| (torch.gt, np.greater), |
| (torch.ge, np.greater_equal), |
| (torch.eq, np.equal), |
| (torch.ne, np.not_equal), |
| (torch.logical_and, np.logical_and), |
| (torch.logical_or, np.logical_or), |
| (torch.logical_xor, np.logical_xor)] |
| |
| for size1 in input_sizes: |
| size2 = (2,) + size1 # perform broadcasting |
| for with_extremal in [False, True]: |
| a = _generate_input(size1, dtypes[0], device, with_extremal) |
| b = _generate_input(size2, dtypes[1], device, with_extremal) |
| for torch_op, numpy_op in op_pairs: |
| if (dtypes[0].is_complex or dtypes[1].is_complex) and torch_op in complex_op_denylist: |
| continue |
| # functional version of op |
| compare_with_numpy_bin_op(torch_op, numpy_op, a, b) |
| |
| # functional comparison ops always return bool tensors |
| self.assertEqual(torch_op(a, b).dtype, torch.bool) |
| |
| # out version of op |
| out = torch.zeros(1, dtype=torch.complex128) # all casts to complex128 are safe |
| compare_with_numpy_bin_op(torch_op, numpy_op, a, b, out=out) |
| |
| @onlyOnCPUAndCUDA |
| @dtypes(torch.int8, torch.int16, torch.int32, torch.int64) |
| def test_signed_shift(self, device, dtype): |
| "Ensure that signed integer bit shifting works as expected." |
| a = torch.tensor([-10, 10], device=device, dtype=dtype) # [11...1110110, 1010] |
| expected_l = torch.tensor([-40, 40], device=device, dtype=dtype) # [11...11011000, 101000] |
| self.assertEqual(a << 2, expected_l) |
| self.compare_with_numpy(lambda x: x << 2, lambda x: np.left_shift(x, 2), a) |
| expected_r = torch.tensor([-5, 5], device=device, dtype=dtype) # [1111...111011, 101] |
| self.assertEqual(a >> 1, expected_r) |
| self.compare_with_numpy(lambda x: x >> 1, lambda x: np.right_shift(x, 1), a) |
| |
| def test_bitwise_and(self, device): |
| for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): |
| a = torch.tensor([1, -2, 3], dtype=dtype, device=device) |
| b = torch.tensor([2, 1, 3], dtype=dtype, device=device) |
| expected_res = torch.tensor([0, 0, 3], dtype=dtype, device=device) |
| b_scalar = 2 |
| expected_res_scalar = torch.tensor([0, 2, 2], dtype=dtype, device=device) |
| |
| # standard version |
| self.assertEqual(torch.bitwise_and(a, b), expected_res) |
| self.assertEqual(torch.bitwise_and(a, b_scalar), expected_res_scalar) |
| |
| # out |
| c = torch.empty(0, dtype=dtype, device=device) |
| torch.bitwise_and(a, b, out=c) |
| self.assertEqual(c, expected_res) |
| torch.bitwise_and(a, b_scalar, out=c) |
| self.assertEqual(c, expected_res_scalar) |
| |
| # in-place |
| a1 = a.clone() |
| a1.bitwise_and_(b) |
| self.assertEqual(a1, expected_res) |
| a.bitwise_and_(b_scalar) |
| self.assertEqual(a, expected_res_scalar) |
| |
| self.assertEqual(torch.tensor([False, True, False], device=device), |
| torch.bitwise_and(torch.tensor([True, True, False], device=device), |
| torch.tensor([False, True, False], device=device))) |
| |
| def test_bitwise_or(self, device): |
| for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): |
| a = torch.tensor([1, -2, 3], dtype=dtype, device=device) |
| b = torch.tensor([2, 1, 3], dtype=dtype, device=device) |
| expected_res = torch.tensor([3, -1, 3], dtype=dtype, device=device) |
| b_scalar = 2 |
| expected_res_scalar = torch.tensor([3, -2, 3], dtype=dtype, device=device) |
| |
| # standard version |
| self.assertEqual(torch.bitwise_or(a, b), expected_res) |
| self.assertEqual(torch.bitwise_or(a, b_scalar), expected_res_scalar) |
| |
| # out |
| c = torch.empty(0, dtype=dtype, device=device) |
| torch.bitwise_or(a, b, out=c) |
| self.assertEqual(c, expected_res) |
| torch.bitwise_or(a, b_scalar, out=c) |
| self.assertEqual(c, expected_res_scalar) |
| |
| # in-place |
| a1 = a.clone() |
| a1.bitwise_or_(b) |
| self.assertEqual(a1, expected_res) |
| a.bitwise_or_(b_scalar) |
| self.assertEqual(a, expected_res_scalar) |
| |
| self.assertEqual(torch.tensor([True, True, False], device=device), |
| torch.bitwise_or(torch.tensor([True, True, False], device=device), |
| torch.tensor([False, True, False], device=device))) |
| |
| def test_bitwise_xor(self, device): |
| for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): |
| a = torch.tensor([1, -2, 3], dtype=dtype, device=device) |
| b = torch.tensor([2, 1, 3], dtype=dtype, device=device) |
| expected_res = torch.tensor([3, -1, 0], dtype=dtype, device=device) |
| b_scalar = 2 |
| expected_res_scalar = torch.tensor([3, -4, 1], dtype=dtype, device=device) |
| |
| # standard version |
| self.assertEqual(torch.bitwise_xor(a, b), expected_res) |
| self.assertEqual(torch.bitwise_xor(a, b_scalar), expected_res_scalar) |
| |
| # out |
| c = torch.empty(0, dtype=dtype, device=device) |
| torch.bitwise_xor(a, b, out=c) |
| self.assertEqual(c, expected_res) |
| torch.bitwise_xor(a, b_scalar, out=c) |
| self.assertEqual(c, expected_res_scalar) |
| |
| # in-place |
| a1 = a.clone() |
| a1.bitwise_xor_(b) |
| self.assertEqual(a1, expected_res) |
| a.bitwise_xor_(b_scalar) |
| self.assertEqual(a, expected_res_scalar) |
| |
| self.assertEqual(torch.tensor([True, False, False], device=device), |
| torch.bitwise_xor(torch.tensor([True, True, False], device=device), |
| torch.tensor([False, True, False], device=device))) |
| |
| @onlyOnCPUAndCUDA |
| @dtypes(*list(product(torch.testing.get_all_dtypes(include_complex=False), |
| torch.testing.get_all_dtypes(include_complex=False)))) |
| def test_heaviside(self, device, dtypes): |
| input_dtype = dtypes[0] |
| values_dtype = dtypes[1] |
| |
| rng = np.random.default_rng() |
| input = np.array(rng.integers(-10, 10, size=10), |
| dtype=torch_to_numpy_dtype_dict[input_dtype if (input_dtype != torch.bfloat16) else torch.float64]) |
| input[0] = input[3] = input[7] = 0 |
| values = np.array(rng.integers(-10, 10, size=10), |
| dtype=torch_to_numpy_dtype_dict[values_dtype if (values_dtype != torch.bfloat16) else torch.float64]) |
| np_result = torch.from_numpy(np.heaviside(input, values)).to(device=device, dtype=input_dtype) |
| |
| input = torch.from_numpy(input).to(device=device, dtype=input_dtype) |
| values = torch.from_numpy(values).to(device=device, dtype=values_dtype) |
| out = torch.empty_like(input) |
| |
| if input_dtype == values_dtype: |
| torch_result = torch.heaviside(input, values) |
| self.assertEqual(np_result, torch_result) |
| |
| torch_result = input.heaviside(values) |
| self.assertEqual(np_result, torch_result) |
| |
| torch.heaviside(input, values, out=out) |
| self.assertEqual(np_result, out) |
| |
| input.heaviside_(values) |
| self.assertEqual(np_result, input) |
| else: |
| with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): |
| torch.heaviside(input, values) |
| with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): |
| input.heaviside(values) |
| with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): |
| torch.heaviside(input, values, out=out) |
| with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): |
| input.heaviside_(values) |
| |
| @onlyCUDA |
| def test_heaviside_cross_device(self, device): |
| x = torch.tensor([-9, 5, 0, 6, -2, 2], device='cuda') |
| y = torch.tensor(0) |
| result = torch.heaviside(x, y) |
| expect = torch.tensor([0, 1, 0, 1, 0, 1], device='cuda') |
| self.assertEqual(result, expect) |
| |
| result = torch.heaviside(y, x) |
| expect = torch.tensor([-9, 5, 0, 6, -2, 2], device='cuda') |
| self.assertEqual(result, expect) |
| |
| x = torch.tensor([-9, 5, 0, 6, -2, 2]) |
| y = torch.tensor(0, device='cuda') |
| with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'): |
| torch.heaviside(x, y) |
| |
| with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'): |
| torch.heaviside(y, x) |
| |
| @dtypes(*list(product(torch.testing.get_all_complex_dtypes(), |
| torch.testing.get_all_complex_dtypes()))) |
| def test_heaviside_complex(self, device, dtypes): |
| input_dtype = dtypes[0] |
| values_dtype = dtypes[1] |
| |
| data = (complex(0, -6), complex(-1, 3), complex(1, 1)) |
| input = torch.tensor(data, device=device, dtype=input_dtype) |
| values = torch.tensor(data, device=device, dtype=values_dtype) |
| out = torch.empty_like(input) |
| real = input.real |
| |
| with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): |
| torch.heaviside(input, real) |
| with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): |
| real.heaviside(values) |
| with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): |
| input.heaviside_(values) |
| with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): |
| torch.heaviside(real, real, out=out) |
| |
| def _test_logical(self, device, dtypes, op, a_, b_, expected_res_): |
| expected_res = torch.tensor(expected_res_, dtype=dtypes[0], device=device) |
| a = torch.tensor(a_, dtype=dtypes[0], device=device) |
| b = torch.tensor(b_, dtype=dtypes[1], device=device) |
| |
| # new tensor |
| self.assertEqual(expected_res.bool(), getattr(a, op)(b)) |
| # out |
| c = torch.empty(0, dtype=torch.bool, device=device) |
| getattr(torch, op)(a, b, out=c) |
| self.assertEqual(expected_res.bool(), c) |
| |
| # in-place |
| # TODO: remove when different dtypes as operands are supported |
| if dtypes[0] != dtypes[1]: |
| with self.assertRaises(RuntimeError): |
| getattr(a, op + '_')(b) |
| return |
| |
| getattr(a, op + '_')(b) |
| self.assertEqual(expected_res, a) |
| |
| @dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) |
| def test_logical_xor(self, device, dtypes): |
| self._test_logical(device, dtypes, 'logical_xor', [10, 0, 1, 0], [1, 0, 0, 10], [0, 0, 1, 1]) |
| |
| @dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) |
| def test_logical_and(self, device, dtypes): |
| self._test_logical(device, dtypes, 'logical_and', [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 0, 0]) |
| |
| @dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) |
| def test_logical_or(self, device, dtypes): |
| self._test_logical(device, dtypes, 'logical_or', [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 1, 1]) |
| |
| def test_remainder_overflow(self, device): |
| # Check Integer Overflows |
| x = torch.tensor(23500, dtype=torch.int64, device=device) |
| q = 392486996410368 |
| self.assertEqual(x % q, x) |
| self.assertEqual(-x % q, q - x) |
| self.assertEqual(x % -q, x - q) |
| self.assertEqual(-x % -q, -x) |
| |
| def test_rpow(self, device): |
| m = torch.randn(10, 10, device=device) |
| self.assertEqual(torch.pow(2, m), 2**m) |
| |
| # test with scalar |
| m = torch.randn(1, device=device).squeeze() |
| assert m.dim() == 0, "m is intentionally a scalar" |
| self.assertEqual(torch.pow(2, m), 2**m) |
| |
| @onlyCPU |
| def test_ldexp(self, device): |
| # random values |
| mantissas = torch.randn(64, device=device) |
| exponents = torch.randint(-31, 31, (64,), device=device, dtype=torch.int32) |
| |
| # basic test |
| np_outcome = np.ldexp(mantissas.numpy(), exponents.numpy()) |
| pt_outcome_1 = torch.ldexp(mantissas, exponents) |
| pt_outcome_2 = mantissas.ldexp(exponents) |
| self.assertEqual(np_outcome, pt_outcome_1) |
| self.assertEqual(np_outcome, pt_outcome_2) |
| mantissas.ldexp_(exponents) |
| self.assertEqual(np_outcome, mantissas) |
| |
| # test bounds |
| mantissas = torch.tensor([float('inf'), float('-inf'), float('inf'), float('nan')], device=device) |
| exponents = torch.randint(0, 31, (4,), device=device, dtype=torch.int32) |
| np_outcome = np.ldexp(mantissas.numpy(), exponents.numpy()) |
| pt_outcome = torch.ldexp(mantissas, exponents) |
| self.assertEqual(np_outcome, pt_outcome) |
| |
| def test_lerp(self, device): |
| start_end_shapes = [(), (5,), (5, 5), (5, 5, 5)] |
| for shapes in product(start_end_shapes, start_end_shapes): |
| start = torch.randn(shapes[0], device=device) |
| end = torch.randn(shapes[1], device=device) |
| |
| # Tensor weights |
| for weight in [torch.randn(shapes[0], device=device), random.random()]: |
| actual = torch.lerp(start, end, weight) |
| actual_method = start.lerp(end, weight) |
| self.assertEqual(actual, actual_method) |
| actual_out = torch.Tensor().to(device) |
| torch.lerp(start, end, weight, out=actual_out) |
| self.assertEqual(actual, actual_out) |
| expected = start + weight * (end - start) |
| self.assertEqual(expected, actual) |
| |
| def _test_logaddexp(self, device, dtype, base2): |
| if base2: |
| ref_func = np.logaddexp2 |
| our_func = torch.logaddexp2 |
| else: |
| ref_func = np.logaddexp |
| our_func = torch.logaddexp |
| |
| def _test_helper(a, b): |
| ref = ref_func(a.cpu().numpy(), b.cpu().numpy()) |
| v = our_func(a, b) |
| self.assertEqual(ref, v) |
| |
| # simple test |
| a = torch.randn(64, 2, dtype=dtype, device=device) - 0.5 |
| b = torch.randn(64, 2, dtype=dtype, device=device) - 0.5 |
| _test_helper(a, b) |
| _test_helper(a[:3], b[:3]) |
| |
| # large value test for numerical stability |
| a *= 10000 |
| b *= 10000 |
| _test_helper(a, b) |
| _test_helper(a[:3], b[:3]) |
| |
| a = torch.tensor([float('inf'), float('-inf'), float('inf'), float("nan")], dtype=dtype, device=device) |
| b = torch.tensor([float('inf'), float('-inf'), float('-inf'), float("nan")], dtype=dtype, device=device) |
| _test_helper(a, b) |
| |
| @dtypes(torch.float32, torch.float64) |
| def test_logaddexp(self, device, dtype): |
| self._test_logaddexp(device, dtype, base2=False) |
| |
| @dtypes(torch.float32, torch.float64) |
| def test_logaddexp2(self, device, dtype): |
| self._test_logaddexp(device, dtype, base2=True) |
| |
| def test_add(self, device): |
| dtypes = [torch.float, torch.double] + torch.testing.get_all_complex_dtypes() |
| for dtype in dtypes: |
| # [res] torch.add([res,] tensor1, tensor2) |
| m1 = torch.randn(100, 100, dtype=dtype, device=device) |
| v1 = torch.randn(100, dtype=dtype, device=device) |
| |
| # contiguous |
| res1 = torch.add(m1[4], v1) |
| res2 = res1.clone().zero_() |
| for i in range(m1.size(1)): |
| res2[i] = m1[4, i] + v1[i] |
| self.assertEqual(res1, res2) |
| |
| m1 = torch.randn(100, 100, device=device) |
| v1 = torch.randn(100, device=device) |
| |
| # non-contiguous |
| res1 = torch.add(m1[:, 4], v1) |
| res2 = res1.clone().zero_() |
| for i in range(m1.size(0)): |
| res2[i] = m1[i, 4] + v1[i] |
| self.assertEqual(res1, res2) |
| |
| # [res] torch.add([res,] tensor, value) |
| m1 = torch.randn(10, 10, device=device) |
| |
| # contiguous |
| res1 = m1.clone() |
| res1[3].add_(2) |
| res2 = m1.clone() |
| for i in range(m1.size(1)): |
| res2[3, i] = res2[3, i] + 2 |
| self.assertEqual(res1, res2) |
| |
| # non-contiguous |
| m1 = torch.randn(10, 10, device=device) |
| res1 = m1.clone() |
| res1[:, 3].add_(2) |
| res2 = m1.clone() |
| for i in range(m1.size(0)): |
| res2[i, 3] = res2[i, 3] + 2 |
| self.assertEqual(res1, res2) |
| |
| # inter-type |
| m1 = torch.randn(10, 10, dtype=dtype, device=device) |
| self.assertEqual(m1 + 3, m1 + torch.tensor(3)) |
| self.assertEqual(3 + m1, torch.tensor(3) + m1) |
| |
| # contiguous + non-contiguous |
| m1 = torch.randn(10, 10, dtype=dtype, device=device) |
| m2 = torch.randn(10, 10, dtype=dtype, device=device).t() |
| res = m1 + m2 |
| self.assertTrue(res.is_contiguous()) |
| self.assertEqual(res, m1 + m2.contiguous()) |
| |
| # 1d + empty |
| m1 = torch.tensor([1.0], dtype=dtype, device=device) |
| m2 = torch.tensor([], dtype=dtype, device=device) |
| self.assertEqual(m1 + m2, []) |
| |
| # inter-type unint8 |
| one = torch.tensor(1, dtype=torch.uint8, device=device) |
| self.assertEqual(torch.add(one, 1), 2) |
| self.assertEqual(torch.add(one, 1).dtype, torch.uint8) |
| |
| # bool |
| m1 = torch.tensor([True, False, False, True, False, False], dtype=torch.bool, device=device) |
| m2 = torch.tensor([True, True, False, False, False, True], dtype=torch.bool, device=device) |
| expected = torch.tensor([True, True, False, True, False, True], dtype=torch.bool, device=device) |
| self.assertEqual(m1 + m2, expected) |
| |
| # fused multiply add |
| a = torch.zeros(2, 3, dtype=torch.bool, device=device) |
| res = torch.add(a, a, alpha=0) |
| expected = torch.zeros(2, 3, device=device).bool() |
| self.assertEqual(res, expected) |
| |
| # bfloat16 |
| m1 = torch.tensor([1., 2.], dtype=torch.bfloat16) |
| m2 = torch.tensor([3., 4.], dtype=torch.bfloat16) |
| self.assertEqual(m1 + m2, torch.tensor([4., 6.], dtype=torch.bfloat16)) |
| |
| # different alpha types |
| m1 = torch.tensor([2 + 3j, 4 + 5j], dtype=torch.complex64, device=device) |
| m2 = torch.tensor([4 + 5j, 2 + 3j], dtype=torch.complex64, device=device) |
| # add complex numbers with float alpha |
| res = torch.add(m1, m2, alpha=0.1) |
| expected = torch.tensor([2.4000 + 3.5000j, 4.2000 + 5.3000j], dtype=torch.complex64, device=device) |
| self.assertEqual(res, expected) |
| |
| # add complex numbers with complex alpha |
| res = torch.add(m1, m2, alpha=complex(0.1, 0.2)) |
| expected = torch.tensor([1.4000 + 4.3000j, 3.6000 + 5.7000j], dtype=torch.complex64, device=device) |
| self.assertEqual(res, expected) |
| |
| # add complex numbers with integer alpha |
| res = torch.add(m1, m2, alpha=2) |
| expected = torch.tensor([10. + 13.j, 8. + 11.j], dtype=torch.complex64, device=device) |
| self.assertEqual(res, expected) |
| |
| # mismatched alpha |
| m1 = torch.tensor([1], dtype=torch.int8, device=device) |
| m2 = torch.tensor([2], dtype=torch.int8, device=device) |
| self.assertRaisesRegex(RuntimeError, |
| r"Boolean alpha only supported for Boolean results\.", |
| lambda: torch.add(m1, m2, alpha=True)) |
| self.assertRaisesRegex(RuntimeError, |
| r"For integral input tensors, argument alpha must not be a floating point number\.", |
| lambda: torch.add(m1, m2, alpha=1.0)) |
| |
| # mismatched alpha, float / double tensor and complex alpha |
| m1 = torch.tensor([3., 4.], device=device) |
| m2 = torch.tensor([4., 3.], device=device) |
| self.assertRaises(RuntimeError, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2))) |
| |
| m1 = torch.tensor([3., 4.], dtype=torch.double, device=device) |
| m2 = torch.tensor([4., 3.], dtype=torch.double, device=device) |
| self.assertRaises(RuntimeError, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2))) |
| |
| # complex |
| m1 = torch.tensor((4.0000 + 4.0000j), dtype=torch.complex64) |
| m2 = torch.tensor(4., dtype=torch.float64) |
| self.assertRaisesRegex(RuntimeError, r"result type ComplexFloat can't be cast to the desired output type Double", |
| lambda: torch.add(m1, m1, out=m2)) |
| |
| |
| def test_sub_typing(self, device): |
| m1 = torch.tensor([True, False, False, True, False, False], dtype=torch.bool, device=device) |
| m2 = torch.tensor([True, True, False, False, False, True], dtype=torch.bool, device=device) |
| self.assertRaisesRegex(RuntimeError, |
| r"Subtraction, the `\-` operator, with two bool tensors is not supported. " |
| r"Use the `\^` or `logical_xor\(\)` operator instead.", |
| lambda: m1 - m2) |
| self.assertRaisesRegex(RuntimeError, |
| r"Subtraction, the `\-` operator, with a bool tensor is not supported. " |
| r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", |
| lambda: 1 - m1) |
| self.assertRaisesRegex(RuntimeError, |
| r"Subtraction, the `\-` operator, with a bool tensor is not supported. " |
| r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", |
| lambda: m2 - 1) |
| |
| # mismatched alpha |
| m1 = torch.tensor([1], dtype=torch.int8, device=device) |
| m2 = torch.tensor([2], dtype=torch.int8, device=device) |
| self.assertRaisesRegex(RuntimeError, |
| r"Boolean alpha only supported for Boolean results\.", |
| lambda: torch.sub(m1, m2, alpha=True)) |
| self.assertRaisesRegex(RuntimeError, |
| r"For integral input tensors, argument alpha must not be a floating point number\.", |
| lambda: torch.sub(m1, m2, alpha=1.0)) |
| |
| def test_mul(self, device): |
| m1 = torch.randn(10, 10, device=device) |
| res1 = m1.clone() |
| res1[:, 3].mul_(2) |
| res2 = m1.clone() |
| for i in range(res1.size(0)): |
| res2[i, 3] = res2[i, 3] * 2 |
| self.assertEqual(res1, res2) |
| |
| a1 = torch.tensor([True, False, False, True], dtype=torch.bool, device=device) |
| a2 = torch.tensor([True, False, True, False], dtype=torch.bool, device=device) |
| self.assertEqual(a1 * a2, torch.tensor([True, False, False, False], dtype=torch.bool, device=device)) |
| |
| if device == 'cpu': |
| a1 = torch.tensor([0.1, 0.1], dtype=torch.bfloat16, device=device) |
| a2 = torch.tensor([1.1, 0.1], dtype=torch.bfloat16, device=device) |
| self.assertEqual(a1 * a2, torch.tensor([0.11, 0.01], dtype=torch.bfloat16, device=device), atol=0.01, rtol=0) |
| self.assertEqual(a1.mul(a2), a1 * a2) |
| |
| def test_bool_tensor_comparison_ops(self, device): |
| a = torch.tensor([True, False, True, False, True, False], dtype=torch.bool, device=device) |
| b = torch.tensor([True, False, True, True, True, True], dtype=torch.bool, device=device) |
| self.assertEqual(a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)) |
| self.assertEqual(a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)) |
| self.assertEqual(a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)) |
| self.assertEqual(a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.bool, device=device)) |
| self.assertEqual(a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)) |
| self.assertEqual(a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.bool, device=device)) |
| self.assertEqual(a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)) |
| self.assertEqual(a == torch.tensor(True, dtype=torch.bool, device=device), |
| torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)) |
| self.assertEqual(a == torch.tensor(0, dtype=torch.bool, device=device), |
| torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool, device=device)) |
| self.assertFalse(a.equal(b)) |
| |
| @dtypes(*torch.testing.get_all_dtypes(include_complex=False)) |
| def test_logical(self, device, dtype): |
| if dtype != torch.bool: |
| x = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype) |
| b = torch.tensor([2], device=device, dtype=dtype) |
| self.assertEqual(x.lt(2), torch.tensor([True, False, False, False])) |
| self.assertEqual(x.le(2), torch.tensor([True, True, False, False])) |
| self.assertEqual(x.ge(2), torch.tensor([False, True, True, True])) |
| self.assertEqual(x.gt(2), torch.tensor([False, False, True, True])) |
| self.assertEqual(x.eq(2), torch.tensor([False, True, False, False])) |
| self.assertEqual(x.ne(2), torch.tensor([True, False, True, True])) |
| |
| self.assertEqual(x.lt(b), torch.tensor([True, False, False, False])) |
| self.assertEqual(x.le(b), torch.tensor([True, True, False, False])) |
| self.assertEqual(x.ge(b), torch.tensor([False, True, True, True])) |
| self.assertEqual(x.gt(b), torch.tensor([False, False, True, True])) |
| self.assertEqual(x.eq(b), torch.tensor([False, True, False, False])) |
| self.assertEqual(x.ne(b), torch.tensor([True, False, True, True])) |
| else: |
| x = torch.tensor([True, False, True, False], device=device) |
| self.assertEqual(x.lt(True), torch.tensor([False, True, False, True])) |
| self.assertEqual(x.le(True), torch.tensor([True, True, True, True])) |
| self.assertEqual(x.ge(True), torch.tensor([True, False, True, False])) |
| self.assertEqual(x.gt(True), torch.tensor([False, False, False, False])) |
| self.assertEqual(x.eq(True), torch.tensor([True, False, True, False])) |
| self.assertEqual(x.ne(True), torch.tensor([False, True, False, True])) |
| |
| def test_atan2(self, device): |
| def _test_atan2_with_size(size, device): |
| a = torch.rand(size=size, device=device, dtype=torch.double) |
| b = torch.rand(size=size, device=device, dtype=torch.double) |
| actual = a.atan2(b) |
| x = a.view(-1) |
| y = b.view(-1) |
| expected = torch.tensor([math.atan2(x[i].item(), y[i].item()) for i in range(x.numel())], |
| device=device, dtype=torch.double) |
| self.assertEqual(expected, actual.view(-1), rtol=0, atol=0.02) |
| |
| _test_atan2_with_size((2, 2), device) |
| _test_atan2_with_size((3, 3), device) |
| _test_atan2_with_size((5, 5), device) |
| |
| def test_atan2_edgecases(self, device): |
| def _test_atan2(x, y, expected, device, dtype): |
| expected_tensor = torch.tensor([expected], dtype=dtype, device=device) |
| x_tensor = torch.tensor([x], dtype=dtype, device=device) |
| y_tensor = torch.tensor([y], dtype=dtype, device=device) |
| actual = torch.atan2(y_tensor, x_tensor) |
| self.assertEqual(expected_tensor, actual, rtol=0, atol=0.02) |
| |
| for dtype in [torch.float, torch.double]: |
| _test_atan2(0, 0, 0, device, dtype) |
| _test_atan2(0, 1, math.pi / 2, device, dtype) |
| _test_atan2(0, -1, math.pi / -2, device, dtype) |
| _test_atan2(-1, 0, math.pi, device, dtype) |
| _test_atan2(1, 0, 0, device, dtype) |
| _test_atan2(-1, -1, math.pi * -3 / 4 , device, dtype) |
| _test_atan2(1, 1, math.pi / 4 , device, dtype) |
| _test_atan2(1, -1, math.pi / -4 , device, dtype) |
| _test_atan2(-1, 1, math.pi * 3 / 4 , device, dtype) |
| |
| def test_trapz(self, device): |
| def test_dx(sizes, dim, dx, device): |
| t = torch.randn(sizes, device=device) |
| actual = torch.trapz(t, dx=dx, dim=dim) |
| expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim) |
| self.assertEqual(expected.shape, actual.shape) |
| self.assertEqual(expected, actual) |
| |
| def test_x(sizes, dim, x, device): |
| t = torch.randn(sizes, device=device) |
| actual = torch.trapz(t, x=torch.tensor(x, device=device), dim=dim) |
| expected = np.trapz(t.cpu().numpy(), x=x, axis=dim) |
| self.assertEqual(expected.shape, actual.shape) |
| self.assertEqual(expected, actual.cpu()) |
| |
| test_dx((2, 3, 4), 1, 1, device) |
| test_dx((10, 2), 0, 0.1, device) |
| test_dx((1, 10), 0, 2.3, device) |
| test_dx((0, 2), 0, 1.0, device) |
| test_dx((0, 2), 1, 1.0, device) |
| test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) |
| test_x((10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device) |
| test_x((1, 10), 0, [1.0], device) |
| test_x((0, 2), 0, [], device) |
| test_x((0, 2), 1, [1.0, 2.0], device) |
| with self.assertRaisesRegex( |
| IndexError, |
| 'Dimension out of range'): |
| test_x((2, 3), 2, [], device) |
| test_dx((2, 3), 2, 1.0, device) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| 'There must be one `x` value for each sample point'): |
| test_x((2, 3), 1, [1.0, 2.0], device) |
| test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device) |
| |
| @dtypes(torch.double) |
| def test_pow_scalar_overloads_mem_overlap(self, device, dtype): |
| sz = 3 |
| doubles = torch.randn(2 * sz, dtype=dtype, device=device) |
| self.check_internal_mem_overlap( |
| lambda t: t.pow_(42), 1, dtype, device) |
| self.unary_check_input_output_mem_overlap( |
| doubles, sz, lambda input, out: torch.pow(input, 42, out=out)) |
| self.unary_check_input_output_mem_overlap( |
| doubles, sz, lambda input, out: torch.pow(42, input, out=out)) |
| |
| @dtypes(*list(product(torch.testing.get_all_dtypes(include_bool=False), |
| torch.testing.get_all_dtypes(include_bool=False)))) |
| def test_float_power(self, device, dtypes): |
| def to_np(value): |
| if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16: |
| return value.to(torch.float).cpu().numpy() |
| return value.cpu().numpy() if isinstance(value, torch.Tensor) else value |
| |
| base_dtype = dtypes[0] |
| exp_dtype = dtypes[1] |
| out_dtype = torch.complex128 if base_dtype.is_complex or exp_dtype.is_complex else torch.float64 |
| |
| base = make_tensor((30,), device, base_dtype, low=1, high=100) |
| # Complex and real results do not agree between PyTorch and NumPy when computing negative and zero power of 0 |
| # Related: https://github.com/pytorch/pytorch/issues/48000 |
| # base[0] = base[3] = base[7] = 0 |
| exp = make_tensor((30,), device, exp_dtype, low=-2, high=2) |
| exp[0] = exp[4] = exp[6] = 0 |
| |
| expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp))) |
| |
| exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2] |
| complex_exponents = exponents + [-2.5j, -1.0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j] |
| |
| for op in (torch.float_power, torch.Tensor.float_power, torch.Tensor.float_power_): |
| |
| # Case of Tensor x Tensor |
| if op is torch.Tensor.float_power_ and base_dtype != out_dtype: |
| with self.assertRaisesRegex(RuntimeError, "is not the desired type"): |
| op(base.clone(), exp) |
| else: |
| result = op(base.clone(), exp) |
| self.assertEqual(expected, result) |
| |
| if op is torch.float_power: |
| out = torch.empty_like(base).to(device=device, dtype=out_dtype) |
| op(base, exp, out=out) |
| self.assertEqual(expected, out) |
| |
| # Case of Tensor x Scalar |
| for i in complex_exponents if exp_dtype.is_complex else exponents: |
| out_dtype_scalar_exp = torch.complex128 if base_dtype.is_complex or type(i) == complex else torch.float64 |
| expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i)) |
| |
| if op is torch.Tensor.float_power_ and base_dtype != out_dtype_scalar_exp: |
| with self.assertRaisesRegex(RuntimeError, "is not the desired type"): |
| op(base.clone(), i) |
| else: |
| result = op(base.clone(), i) |
| self.assertEqual(expected_scalar_exp, result) |
| |
| if op is torch.float_power: |
| out = torch.empty_like(base).to(device=device, dtype=out_dtype_scalar_exp) |
| op(base, i, out=out) |
| self.assertEqual(expected_scalar_exp, out) |
| |
| # Case of Scalar x Tensor |
| for i in complex_exponents if base_dtype.is_complex else exponents: |
| out_dtype_scalar_base = torch.complex128 if exp_dtype.is_complex or type(i) == complex else torch.float64 |
| expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp))) |
| |
| result = torch.float_power(i, exp) |
| self.assertEqual(expected_scalar_base, result) |
| |
| out = torch.empty_like(exp).to(device=device, dtype=out_dtype_scalar_base) |
| torch.float_power(i, exp, out=out) |
| self.assertEqual(expected_scalar_base, out) |
| |
| def test_float_power_exceptions(self, device): |
| def _promo_helper(x, y): |
| for i in (x, y): |
| if type(i) == complex: |
| return torch.complex128 |
| elif type(i) == torch.Tensor and i.is_complex(): |
| return torch.complex128 |
| return torch.double |
| |
| test_cases = ((torch.tensor([-2, -1, 0, 1, 2], device=device), -.25), |
| (torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device), 2.)) |
| for base, exp in test_cases: |
| for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble): |
| out = torch.empty(1, device=device, dtype=out_dtype) |
| required_dtype = _promo_helper(base, exp) |
| |
| if out.dtype == required_dtype: |
| torch.float_power(base, exp, out=out) |
| else: |
| with self.assertRaisesRegex(RuntimeError, "is not the desired output type"): |
| torch.float_power(base, exp, out=out) |
| |
| if base.dtype == required_dtype: |
| torch.Tensor.float_power_(base.clone(), exp) |
| else: |
| with self.assertRaisesRegex(RuntimeError, "is not the desired type"): |
| torch.Tensor.float_power_(base.clone(), exp) |
| |
| |
| tensor_binary_ops = [ |
| '__lt__', '__le__', |
| '__gt__', '__ge__', |
| '__eq__', '__ne__', |
| |
| '__add__', '__radd__', '__iadd__', |
| '__sub__', '__rsub__', '__isub__', |
| '__mul__', '__rmul__', '__imul__', |
| '__matmul__', '__rmatmul__', '__imatmul__', |
| '__truediv__', '__rtruediv__', '__itruediv__', |
| '__floordiv__', '__rfloordiv__', '__ifloordiv__', |
| '__mod__', '__rmod__', '__imod__', |
| '__divmod__', '__rdivmod__', '__idivmod__', |
| '__pow__', '__rpow__', '__ipow__', |
| '__lshift__', '__rlshift__', '__ilshift__', |
| '__rshift__', '__rrshift__', '__irshift__', |
| '__and__', '__rand__', '__iand__', |
| '__xor__', '__rxor__', '__ixor__', |
| '__or__', '__ror__', '__ior__', |
| ] |
| |
| # Test that binary math operations return NotImplemented for unknown types. |
| def generate_not_implemented_tests(cls): |
| class UnknownType: |
| pass |
| |
| # TODO: refactor to inline these |
| _types = [ |
| torch.half, torch.float, torch.double, |
| torch.int8, torch.short, torch.int, torch.long, |
| torch.uint8 |
| ] |
| |
| # TODO: refactor to use make_tensor |
| def _small_2d(dtype, device, has_zeros=True, fill_ones=False, oneish=False): |
| t = _make_tensor((5, 5), dtype, device, fill_ones=fill_ones) |
| if oneish: |
| return t.clamp(min=_number(.99, 1, dtype), max=1.01) |
| if not has_zeros: |
| return t.clamp(min=(_number(_div_min, 1, dtype))) |
| return t |
| |
| for op in tensor_binary_ops: |
| @dtypes(*_types) |
| def test(self, device, dtype): |
| # Generate the inputs |
| tensor = _small_2d(dtype, device) |
| |
| # Runs the tensor op on the device |
| result = getattr(tensor, op)(UnknownType()) |
| self.assertEqual(result, NotImplemented) |
| |
| test_name = "test_{}_not_implemented".format(op) |
| assert not hasattr(cls, test_name), "{0} already in {1}".format( |
| test_name, cls.__name__) |
| |
| setattr(cls, test_name, test) |
| |
| |
| generate_not_implemented_tests(TestBinaryUfuncs) |
| instantiate_device_type_tests(TestBinaryUfuncs, globals()) |
| |
| if __name__ == '__main__': |
| run_tests() |