| # Owner(s): ["oncall: jit"] |
| |
| import cmath |
| import os |
| import sys |
| from itertools import product |
| from textwrap import dedent |
| from typing import Dict, List |
| |
| import torch |
| from torch.testing._internal.common_utils import IS_MACOS |
| from torch.testing._internal.jit_utils import execWrapper, JitTestCase |
| |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| |
| |
| class TestComplex(JitTestCase): |
| def test_script(self): |
| def fn(a: complex): |
| return a |
| |
| self.checkScript(fn, (3 + 5j,)) |
| |
| def test_complexlist(self): |
| def fn(a: List[complex], idx: int): |
| return a[idx] |
| |
| input = [1j, 2, 3 + 4j, -5, -7j] |
| self.checkScript(fn, (input, 2)) |
| |
| def test_complexdict(self): |
| def fn(a: Dict[complex, complex], key: complex) -> complex: |
| return a[key] |
| |
| input = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j} |
| self.checkScript(fn, (input, -4.3 - 2j)) |
| |
| def test_pickle(self): |
| class ComplexModule(torch.jit.ScriptModule): |
| def __init__(self) -> None: |
| super().__init__() |
| self.a = 3 + 5j |
| self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j] |
| self.c = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j} |
| |
| @torch.jit.script_method |
| def forward(self, b: int): |
| return b + 2j |
| |
| loaded = self.getExportImportCopy(ComplexModule()) |
| self.assertEqual(loaded.a, 3 + 5j) |
| self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4]) |
| self.assertEqual(loaded.c, {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}) |
| self.assertEqual(loaded(2), 2 + 2j) |
| |
| def test_complex_parse(self): |
| def fn(a: int, b: torch.Tensor, dim: int): |
| # verifies `emitValueToTensor()` 's behavior |
| b[dim] = 2.4 + 0.5j |
| return (3 * 2j) + a + 5j - 7.4j - 4 |
| |
| t1 = torch.tensor(1) |
| t2 = torch.tensor([0.4, 1.4j, 2.35]) |
| |
| self.checkScript(fn, (t1, t2, 2)) |
| |
| def test_complex_constants_and_ops(self): |
| vals = ( |
| [0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2] |
| + [10.0**i for i in range(2)] |
| + [-(10.0**i) for i in range(2)] |
| ) |
| complex_vals = tuple(complex(x, y) for x, y in product(vals, vals)) |
| |
| funcs_template = dedent( |
| """ |
| def func(a: complex): |
| return cmath.{func_or_const}(a) |
| """ |
| ) |
| |
| def checkCmath(func_name, funcs_template=funcs_template): |
| funcs_str = funcs_template.format(func_or_const=func_name) |
| scope = {} |
| execWrapper(funcs_str, globals(), scope) |
| cu = torch.jit.CompilationUnit(funcs_str) |
| f_script = cu.func |
| f = scope["func"] |
| |
| if func_name in ["isinf", "isnan", "isfinite"]: |
| new_vals = vals + ([float("inf"), float("nan"), -1 * float("inf")]) |
| final_vals = tuple( |
| complex(x, y) for x, y in product(new_vals, new_vals) |
| ) |
| else: |
| final_vals = complex_vals |
| |
| for a in final_vals: |
| res_python = None |
| res_script = None |
| try: |
| res_python = f(a) |
| except Exception as e: |
| res_python = e |
| try: |
| res_script = f_script(a) |
| except Exception as e: |
| res_script = e |
| |
| if res_python != res_script: |
| if isinstance(res_python, Exception): |
| continue |
| |
| msg = f"Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}" |
| self.assertEqual(res_python, res_script, msg=msg) |
| |
| unary_ops = [ |
| "log", |
| "log10", |
| "sqrt", |
| "exp", |
| "sin", |
| "cos", |
| "asin", |
| "acos", |
| "atan", |
| "sinh", |
| "cosh", |
| "tanh", |
| "asinh", |
| "acosh", |
| "atanh", |
| "phase", |
| "isinf", |
| "isnan", |
| "isfinite", |
| ] |
| |
| # --- Unary ops --- |
| for op in unary_ops: |
| checkCmath(op) |
| |
| def fn(x: complex): |
| return abs(x) |
| |
| for val in complex_vals: |
| self.checkScript(fn, (val,)) |
| |
| def pow_complex_float(x: complex, y: float): |
| return pow(x, y) |
| |
| def pow_float_complex(x: float, y: complex): |
| return pow(x, y) |
| |
| self.checkScript(pow_float_complex, (2, 3j)) |
| self.checkScript(pow_complex_float, (3j, 2)) |
| |
| def pow_complex_complex(x: complex, y: complex): |
| return pow(x, y) |
| |
| for x, y in zip(complex_vals, complex_vals): |
| # Reference: https://github.com/pytorch/pytorch/issues/54622 |
| if x == 0: |
| continue |
| self.checkScript(pow_complex_complex, (x, y)) |
| |
| if not IS_MACOS: |
| # --- Binary op --- |
| def rect_fn(x: float, y: float): |
| return cmath.rect(x, y) |
| |
| for x, y in product(vals, vals): |
| self.checkScript( |
| rect_fn, |
| ( |
| x, |
| y, |
| ), |
| ) |
| |
| func_constants_template = dedent( |
| """ |
| def func(): |
| return cmath.{func_or_const} |
| """ |
| ) |
| float_consts = ["pi", "e", "tau", "inf", "nan"] |
| complex_consts = ["infj", "nanj"] |
| for x in float_consts + complex_consts: |
| checkCmath(x, funcs_template=func_constants_template) |
| |
| def test_infj_nanj_pickle(self): |
| class ComplexModule(torch.jit.ScriptModule): |
| def __init__(self) -> None: |
| super().__init__() |
| self.a = 3 + 5j |
| |
| @torch.jit.script_method |
| def forward(self, infj: int, nanj: int): |
| if infj == 2: |
| return infj + cmath.infj |
| else: |
| return nanj + cmath.nanj |
| |
| loaded = self.getExportImportCopy(ComplexModule()) |
| self.assertEqual(loaded(2, 3), 2 + cmath.infj) |
| self.assertEqual(loaded(3, 4), 4 + cmath.nanj) |
| |
| def test_complex_constructor(self): |
| # Test all scalar types |
| def fn_int(real: int, img: int): |
| return complex(real, img) |
| |
| self.checkScript( |
| fn_int, |
| ( |
| 0, |
| 0, |
| ), |
| ) |
| self.checkScript( |
| fn_int, |
| ( |
| -1234, |
| 0, |
| ), |
| ) |
| self.checkScript( |
| fn_int, |
| ( |
| 0, |
| -1256, |
| ), |
| ) |
| self.checkScript( |
| fn_int, |
| ( |
| -167, |
| -1256, |
| ), |
| ) |
| |
| def fn_float(real: float, img: float): |
| return complex(real, img) |
| |
| self.checkScript( |
| fn_float, |
| ( |
| 0.0, |
| 0.0, |
| ), |
| ) |
| self.checkScript( |
| fn_float, |
| ( |
| -1234.78, |
| 0, |
| ), |
| ) |
| self.checkScript( |
| fn_float, |
| ( |
| 0, |
| 56.18, |
| ), |
| ) |
| self.checkScript( |
| fn_float, |
| ( |
| -1.9, |
| -19.8, |
| ), |
| ) |
| |
| def fn_bool(real: bool, img: bool): |
| return complex(real, img) |
| |
| self.checkScript( |
| fn_bool, |
| ( |
| True, |
| True, |
| ), |
| ) |
| self.checkScript( |
| fn_bool, |
| ( |
| False, |
| False, |
| ), |
| ) |
| self.checkScript( |
| fn_bool, |
| ( |
| False, |
| True, |
| ), |
| ) |
| self.checkScript( |
| fn_bool, |
| ( |
| True, |
| False, |
| ), |
| ) |
| |
| def fn_bool_int(real: bool, img: int): |
| return complex(real, img) |
| |
| self.checkScript( |
| fn_bool_int, |
| ( |
| True, |
| 0, |
| ), |
| ) |
| self.checkScript( |
| fn_bool_int, |
| ( |
| False, |
| 0, |
| ), |
| ) |
| self.checkScript( |
| fn_bool_int, |
| ( |
| False, |
| -1, |
| ), |
| ) |
| self.checkScript( |
| fn_bool_int, |
| ( |
| True, |
| 3, |
| ), |
| ) |
| |
| def fn_int_bool(real: int, img: bool): |
| return complex(real, img) |
| |
| self.checkScript( |
| fn_int_bool, |
| ( |
| 0, |
| True, |
| ), |
| ) |
| self.checkScript( |
| fn_int_bool, |
| ( |
| 0, |
| False, |
| ), |
| ) |
| self.checkScript( |
| fn_int_bool, |
| ( |
| -3, |
| True, |
| ), |
| ) |
| self.checkScript( |
| fn_int_bool, |
| ( |
| 6, |
| False, |
| ), |
| ) |
| |
| def fn_bool_float(real: bool, img: float): |
| return complex(real, img) |
| |
| self.checkScript( |
| fn_bool_float, |
| ( |
| True, |
| 0.0, |
| ), |
| ) |
| self.checkScript( |
| fn_bool_float, |
| ( |
| False, |
| 0.0, |
| ), |
| ) |
| self.checkScript( |
| fn_bool_float, |
| ( |
| False, |
| -1.0, |
| ), |
| ) |
| self.checkScript( |
| fn_bool_float, |
| ( |
| True, |
| 3.0, |
| ), |
| ) |
| |
| def fn_float_bool(real: float, img: bool): |
| return complex(real, img) |
| |
| self.checkScript( |
| fn_float_bool, |
| ( |
| 0.0, |
| True, |
| ), |
| ) |
| self.checkScript( |
| fn_float_bool, |
| ( |
| 0.0, |
| False, |
| ), |
| ) |
| self.checkScript( |
| fn_float_bool, |
| ( |
| -3.0, |
| True, |
| ), |
| ) |
| self.checkScript( |
| fn_float_bool, |
| ( |
| 6.0, |
| False, |
| ), |
| ) |
| |
| def fn_float_int(real: float, img: int): |
| return complex(real, img) |
| |
| self.checkScript( |
| fn_float_int, |
| ( |
| 0.0, |
| 1, |
| ), |
| ) |
| self.checkScript( |
| fn_float_int, |
| ( |
| 0.0, |
| -1, |
| ), |
| ) |
| self.checkScript( |
| fn_float_int, |
| ( |
| 1.8, |
| -3, |
| ), |
| ) |
| self.checkScript( |
| fn_float_int, |
| ( |
| 2.7, |
| 8, |
| ), |
| ) |
| |
| def fn_int_float(real: int, img: float): |
| return complex(real, img) |
| |
| self.checkScript( |
| fn_int_float, |
| ( |
| 1, |
| 0.0, |
| ), |
| ) |
| self.checkScript( |
| fn_int_float, |
| ( |
| -1, |
| 1.7, |
| ), |
| ) |
| self.checkScript( |
| fn_int_float, |
| ( |
| -3, |
| 0.0, |
| ), |
| ) |
| self.checkScript( |
| fn_int_float, |
| ( |
| 2, |
| -8.9, |
| ), |
| ) |
| |
| def test_torch_complex_constructor_with_tensor(self): |
| tensors = [torch.rand(1), torch.randint(-5, 5, (1,)), torch.tensor([False])] |
| |
| def fn_tensor_float(real, img: float): |
| return complex(real, img) |
| |
| def fn_tensor_int(real, img: int): |
| return complex(real, img) |
| |
| def fn_tensor_bool(real, img: bool): |
| return complex(real, img) |
| |
| def fn_float_tensor(real: float, img): |
| return complex(real, img) |
| |
| def fn_int_tensor(real: int, img): |
| return complex(real, img) |
| |
| def fn_bool_tensor(real: bool, img): |
| return complex(real, img) |
| |
| for tensor in tensors: |
| self.checkScript(fn_tensor_float, (tensor, 1.2)) |
| self.checkScript(fn_tensor_int, (tensor, 3)) |
| self.checkScript(fn_tensor_bool, (tensor, True)) |
| |
| self.checkScript(fn_float_tensor, (1.2, tensor)) |
| self.checkScript(fn_int_tensor, (3, tensor)) |
| self.checkScript(fn_bool_tensor, (True, tensor)) |
| |
| def fn_tensor_tensor(real, img): |
| return complex(real, img) + complex(2) |
| |
| for x, y in product(tensors, tensors): |
| self.checkScript( |
| fn_tensor_tensor, |
| ( |
| x, |
| y, |
| ), |
| ) |
| |
| def test_comparison_ops(self): |
| def fn1(a: complex, b: complex): |
| return a == b |
| |
| def fn2(a: complex, b: complex): |
| return a != b |
| |
| def fn3(a: complex, b: float): |
| return a == b |
| |
| def fn4(a: complex, b: float): |
| return a != b |
| |
| x, y = 2 - 3j, 4j |
| self.checkScript(fn1, (x, x)) |
| self.checkScript(fn1, (x, y)) |
| self.checkScript(fn2, (x, x)) |
| self.checkScript(fn2, (x, y)) |
| |
| x1, y1 = 1 + 0j, 1.0 |
| self.checkScript(fn3, (x1, y1)) |
| self.checkScript(fn4, (x1, y1)) |
| |
| def test_div(self): |
| def fn1(a: complex, b: complex): |
| return a / b |
| |
| x, y = 2 - 3j, 4j |
| self.checkScript(fn1, (x, y)) |
| |
| def test_complex_list_sum(self): |
| def fn(x: List[complex]): |
| return sum(x) |
| |
| self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(),)) |
| |
| def test_tensor_attributes(self): |
| def tensor_real(x): |
| return x.real |
| |
| def tensor_imag(x): |
| return x.imag |
| |
| t = torch.randn(2, 3, dtype=torch.cdouble) |
| self.checkScript(tensor_real, (t,)) |
| self.checkScript(tensor_imag, (t,)) |
| |
| def test_binary_op_complex_tensor(self): |
| def mul(x: complex, y: torch.Tensor): |
| return x * y |
| |
| def add(x: complex, y: torch.Tensor): |
| return x + y |
| |
| def eq(x: complex, y: torch.Tensor): |
| return x == y |
| |
| def ne(x: complex, y: torch.Tensor): |
| return x != y |
| |
| def sub(x: complex, y: torch.Tensor): |
| return x - y |
| |
| def div(x: complex, y: torch.Tensor): |
| return x - y |
| |
| ops = [mul, add, eq, ne, sub, div] |
| |
| for shape in [(1,), (2, 2)]: |
| x = 0.71 + 0.71j |
| y = torch.randn(shape, dtype=torch.cfloat) |
| for op in ops: |
| eager_result = op(x, y) |
| scripted = torch.jit.script(op) |
| jit_result = scripted(x, y) |
| self.assertEqual(eager_result, jit_result) |