| import numpy as np | 
 | import torch | 
 | import torch.nn.functional as F | 
 | from torch import nn | 
 | import unittest | 
 |  | 
 | from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests | 
 |  | 
 | from torch.testing._internal.jit_utils import JitTestCase | 
 |  | 
 |  | 
 | class BaseTestClass(JitTestCase): | 
 |     def setUp(self): | 
 |         self.old_profiling_executor = torch._C._jit_set_profiling_executor(True) | 
 |         self.old_profiling_mode = torch._C._jit_set_profiling_mode(True) | 
 |  | 
 |         self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() | 
 |         self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu() | 
 |         torch._C._jit_override_can_fuse_on_cpu(True) | 
 |         torch._C._jit_override_can_fuse_on_gpu(True) | 
 |         self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() | 
 |         torch._C._jit_set_texpr_fuser_enabled(True) | 
 |         self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() | 
 |         torch._C._debug_set_fusion_group_inlining(False) | 
 |         self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu() | 
 |         torch._C._jit_set_te_must_use_llvm_cpu(False) | 
 |  | 
 |         self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] | 
 |  | 
 |     def tearDown(self): | 
 |         torch._C._jit_set_profiling_executor(self.old_profiling_executor) | 
 |         torch._C._jit_set_profiling_mode(self.old_profiling_mode) | 
 |  | 
 |         torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) | 
 |         torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state) | 
 |         torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state) | 
 |         torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) | 
 |         torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu) | 
 |  | 
 |     def assertLastGraphAllFused(self): | 
 |         self.assertAllFused(torch.jit.last_executed_optimized_graph()) | 
 |  | 
 |  | 
 | def warmup_and_run_forward(f, *args): | 
 |     for _ in range(torch._C._jit_get_num_profiled_runs() + 1): | 
 |         results = f(*args) | 
 |     return results | 
 |  | 
 |  | 
 | class TestTensorExprFuser(BaseTestClass): | 
 |     def test_easy(self): | 
 |         def easy(x, y): | 
 |             aaa = torch.add(x, y) | 
 |             return aaa | 
 |  | 
 |         traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024))) | 
 |  | 
 |         a = torch.rand(1024) | 
 |         b = torch.rand(1024) | 
 |         x = warmup_and_run_forward(traced, a, b) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy()) | 
 |  | 
 |     def test_three_arg(self): | 
 |         def easy(x, y, z): | 
 |             aaa = torch.add(x, y) | 
 |             bbb = torch.add(aaa, z) | 
 |             return bbb | 
 |  | 
 |         traced = torch.jit.trace( | 
 |             easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) | 
 |         ) | 
 |  | 
 |         a = torch.rand(1024) | 
 |         b = torch.rand(1024) | 
 |         c = torch.rand(1024) | 
 |         x = warmup_and_run_forward(traced, a, b, c) | 
 |         self.assertLastGraphAllFused() | 
 |         npr = a.numpy() + b.numpy() + c.numpy() | 
 |         np.testing.assert_allclose(npr, x.numpy()) | 
 |  | 
 |     def test_four_arg(self): | 
 |         def run_addcmul(x, y, z, w): | 
 |             c = torch.addcmul(torch.add(x, y), z, w) | 
 |             return c | 
 |  | 
 |         for dev in self.devices: | 
 |             rand_a = torch.rand(1024, dtype=torch.float, device=dev) | 
 |             rand_b = torch.rand(1024, dtype=torch.float, device=dev) | 
 |             rand_c = torch.rand(1024, dtype=torch.float, device=dev) | 
 |             rand_d = torch.rand(1024, dtype=torch.float, device=dev) | 
 |  | 
 |             traced = torch.jit.trace( | 
 |                 run_addcmul, | 
 |                 ( | 
 |                     torch.zeros(1024, dtype=torch.float, device=dev), | 
 |                     torch.zeros(1024, dtype=torch.float, device=dev), | 
 |                     torch.zeros(1024, dtype=torch.float, device=dev), | 
 |                     torch.zeros(1024, dtype=torch.float, device=dev), | 
 |                 ), | 
 |             ) | 
 |  | 
 |             x = warmup_and_run_forward(traced, rand_a, rand_b, rand_c, rand_d) | 
 |             self.assertLastGraphAllFused() | 
 |             y = run_addcmul(rand_a, rand_b, rand_c, rand_d) | 
 |             np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6) | 
 |  | 
 |     def test_three_arg2(self): | 
 |         for device in self.devices: | 
 |             def test(x, y, z): | 
 |                 aaa = torch.add(x, y) | 
 |                 bbb = torch.add(aaa, z) | 
 |                 return bbb | 
 |  | 
 |             M = 32 | 
 |             N = 32 | 
 |             traced = torch.jit.trace( | 
 |                 test, | 
 |                 ( | 
 |                     torch.rand(M, N, device=device), | 
 |                     torch.rand(M, N, device=device), | 
 |                     torch.rand(M, N, device=device), | 
 |                 ), | 
 |             ) | 
 |  | 
 |             a = torch.rand(M, N, device=device) | 
 |             b = torch.rand(M, N, device=device) | 
 |             c = torch.rand(M, N, device=device) | 
 |             x = traced(a, b, c) | 
 |             x = warmup_and_run_forward(traced, a, b, c) | 
 |             self.assertLastGraphAllFused() | 
 |             npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() | 
 |             np.testing.assert_allclose(npr, x.cpu().numpy()) | 
 |  | 
 |     def test_broadcast3(self): | 
 |         for device in self.devices: | 
 |             def test_body(M, N, L, K): | 
 |                 def test(x, y, z): | 
 |                     v1 = torch.add(x, y) | 
 |                     v2 = torch.add(v1, z) | 
 |                     return v2 | 
 |  | 
 |                 a_shape = [M, N] | 
 |                 b_shape = [L, M, 1] | 
 |                 c_shape = [K, L, 1, 1] | 
 |                 traced = torch.jit.trace( | 
 |                     test, | 
 |                     ( | 
 |                         torch.rand(*a_shape, device=device), | 
 |                         torch.rand(*b_shape, device=device), | 
 |                         torch.rand(*c_shape, device=device), | 
 |                     ), | 
 |                 ) | 
 |  | 
 |                 a = torch.rand(*a_shape, device=device) | 
 |                 b = torch.rand(*b_shape, device=device) | 
 |                 c = torch.rand(*c_shape, device=device) | 
 |                 x = warmup_and_run_forward(traced, a, b, c) | 
 |                 self.assertLastGraphAllFused() | 
 |                 npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() | 
 |                 np.testing.assert_allclose(npr, x.cpu().numpy()) | 
 |  | 
 |             test_configs = [[5, 2, 7, 3], [8, 8, 8, 8]] | 
 |             for test_config in test_configs: | 
 |                 test_body(*test_config) | 
 |  | 
 |     def test_all_combos(self): | 
 |         def easy(x, y, z): | 
 |             a = torch.add(x, y) | 
 |             b = torch.add(a, z) | 
 |             c = torch.add(x, b) | 
 |             d = torch.add(c, a) | 
 |             return d | 
 |  | 
 |         def np_easy(x, y, z): | 
 |             a = x + y | 
 |             b = a + z | 
 |             c = x + b | 
 |             d = c + a | 
 |             return d | 
 |  | 
 |         traced = torch.jit.trace( | 
 |             easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) | 
 |         ) | 
 |  | 
 |         a = torch.rand(1024) | 
 |         b = torch.rand(1024) | 
 |         c = torch.rand(1024) | 
 |         x = warmup_and_run_forward(traced, a, b, c) | 
 |         self.assertLastGraphAllFused() | 
 |         npr = np_easy(a.numpy(), b.numpy(), c.numpy()) | 
 |         np.testing.assert_allclose(npr, x.numpy()) | 
 |  | 
 |     def test_rank_two(self): | 
 |         def easy(x, y, z): | 
 |             a = torch.add(x, y) | 
 |             b = torch.add(a, z) | 
 |             c = torch.add(x, b) | 
 |             d = torch.add(c, a) | 
 |             return d | 
 |  | 
 |         def np_easy(x, y, z): | 
 |             a = x + y | 
 |             b = a + z | 
 |             c = x + b | 
 |             d = c + a | 
 |             return d | 
 |  | 
 |         shape = 32, 32 | 
 |         traced = torch.jit.trace( | 
 |             easy, (torch.rand(shape), torch.rand(shape), torch.rand(shape)) | 
 |         ) | 
 |  | 
 |         a = torch.rand(shape) | 
 |         b = torch.rand(shape) | 
 |         c = torch.rand(shape) | 
 |         x = warmup_and_run_forward(traced, a, b, c) | 
 |         self.assertLastGraphAllFused() | 
 |         npr = np_easy(a.numpy(), b.numpy(), c.numpy()) | 
 |         np.testing.assert_allclose(npr, x.numpy()) | 
 |  | 
 |     def test_broadcast(self): | 
 |         def easy(x, y, z): | 
 |             a = torch.add(x, y) | 
 |             b = torch.add(a, z) | 
 |             return b | 
 |  | 
 |         def np_easy(x, y, z): | 
 |             a = x + y | 
 |             b = a + z | 
 |             return b | 
 |  | 
 |         N = 32 | 
 |         traced = torch.jit.trace(easy, (torch.rand(N, N), torch.rand(N), torch.rand(N, N))) | 
 |  | 
 |         a = torch.rand(N, N) | 
 |         b = torch.rand(N) | 
 |         c = torch.rand(N, N) | 
 |         x = warmup_and_run_forward(traced, a, b, c) | 
 |         self.assertLastGraphAllFused() | 
 |         npr = np_easy(a.numpy(), b.numpy(), c.numpy()) | 
 |         np.testing.assert_allclose(npr, x.numpy()) | 
 |  | 
 |     def test_broadcast_2(self): | 
 |         zero = torch.tensor([0.0], dtype=torch.float) | 
 |  | 
 |         def foo(x, y, z): | 
 |             aaa = torch.add(x, y) | 
 |             bbb = torch.add(zero, aaa) | 
 |             return torch.add(bbb, z) | 
 |  | 
 |         def foo_np(x, y, z): | 
 |             a = x + y | 
 |             b = zero.numpy() + a | 
 |             return b + z | 
 |  | 
 |         x = torch.rand(3, 4) | 
 |         y = torch.ones(3, 1) | 
 |         z = torch.rand(4) | 
 |         traced = torch.jit.trace(foo, (x, y, z)) | 
 |  | 
 |         r = warmup_and_run_forward(traced, x, y, z) | 
 |         self.assertLastGraphAllFused() | 
 |  | 
 |         rnp = foo_np(x.numpy(), y.numpy(), z.numpy()) | 
 |         np.testing.assert_allclose(r, rnp) | 
 |  | 
 |     def test_broadcast_big2(self): | 
 |         zero = torch.tensor([0.0], dtype=torch.float) | 
 |  | 
 |         def foo(x, y, z): | 
 |             aaa = torch.add(x, y) | 
 |             bbb = torch.add(zero, aaa) | 
 |             return torch.add(bbb, z) | 
 |  | 
 |         def foo_np(x, y, z): | 
 |             a = x + y | 
 |             b = zero.numpy() + a | 
 |             return b + z | 
 |  | 
 |         x = torch.rand(32, 1024) | 
 |         y = torch.ones(32, 1) | 
 |         z = torch.rand(1024) | 
 |         traced = torch.jit.trace(foo, (x, y, z)) | 
 |  | 
 |         r = warmup_and_run_forward(traced, x, y, z) | 
 |         self.assertLastGraphAllFused() | 
 |         rnp = foo_np(x.numpy(), y.numpy(), z.numpy()) | 
 |         np.testing.assert_allclose(r, rnp) | 
 |  | 
 |     def test_alpha(self): | 
 |         def alpha(x): | 
 |             aaa = torch.add(x, x, alpha=2.0) | 
 |             return aaa | 
 |  | 
 |         traced = torch.jit.trace(alpha, (torch.tensor([1.0]))) | 
 |  | 
 |         a = torch.tensor([1.0]) | 
 |         x = traced(a) | 
 |         np.testing.assert_allclose(a.numpy() + 2.0 * a.numpy(), x.numpy()) | 
 |  | 
 |     @suppress_warnings | 
 |     def test_constant(self): | 
 |         def constant(x): | 
 |             bbb = torch.tensor([1.0]) | 
 |             aaa = torch.add(x, bbb) | 
 |             return aaa | 
 |  | 
 |         traced = torch.jit.trace(constant, (torch.tensor([1.0]))) | 
 |  | 
 |         a = torch.tensor([1.0]) | 
 |         x = warmup_and_run_forward(traced, a) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose(a.numpy() + 1.0, x.numpy()) | 
 |  | 
 |     def test_add_sub(self): | 
 |         def easy(x, y, z): | 
 |             aaa = torch.add(x, y) | 
 |             bbb = torch.sub(aaa, z) | 
 |             return bbb | 
 |  | 
 |         traced = torch.jit.trace( | 
 |             easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) | 
 |         ) | 
 |  | 
 |         a = torch.rand(1024) | 
 |         b = torch.rand(1024) | 
 |         c = torch.rand(1024) | 
 |         x = warmup_and_run_forward(traced, a, b, c) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose(a.numpy() + b.numpy() - c.numpy(), x.numpy()) | 
 |  | 
 |     def test_promotion(self): | 
 |         def easy(x, y): | 
 |             aaa = torch.add(x, y) | 
 |             return aaa | 
 |  | 
 |         traced = torch.jit.trace( | 
 |             easy, | 
 |             (torch.zeros(1024, dtype=torch.int32), torch.rand(1024, dtype=torch.float32)), | 
 |         ) | 
 |  | 
 |         a = torch.zeros(1024, dtype=torch.int32) | 
 |         b = torch.rand(1024, dtype=torch.float32) | 
 |         x = warmup_and_run_forward(traced, a, b) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy()) | 
 |  | 
 |     def test_double(self): | 
 |         TENSOR_LEN = 8 | 
 |  | 
 |         def easy(x, y): | 
 |             aaa = torch.add(x, y) | 
 |             bbb = torch.mul(aaa, y) | 
 |             return bbb | 
 |  | 
 |         traced = torch.jit.trace( | 
 |             easy, | 
 |             (torch.rand(TENSOR_LEN, dtype=torch.float64), torch.full((TENSOR_LEN,), 0.5, dtype=torch.float64)), | 
 |         ) | 
 |  | 
 |         a = torch.rand(TENSOR_LEN, dtype=torch.double) | 
 |         b = torch.full((TENSOR_LEN,), 0.5, dtype=torch.double) | 
 |         x = warmup_and_run_forward(traced, a, b) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) | 
 |  | 
 |     def test_short(self): | 
 |         TENSOR_LEN = 8 | 
 |  | 
 |         def easy(x, y): | 
 |             aaa = torch.add(x, y) | 
 |             bbb = torch.mul(aaa, y) | 
 |             return bbb | 
 |  | 
 |         traced = torch.jit.trace( | 
 |             easy, | 
 |             (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16), | 
 |              torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)), | 
 |         ) | 
 |  | 
 |         a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16) | 
 |         b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16) | 
 |         x = warmup_and_run_forward(traced, a, b) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) | 
 |  | 
 |     def test_char(self): | 
 |         TENSOR_LEN = 8 | 
 |  | 
 |         def easy(x, y): | 
 |             aaa = torch.add(x, y) | 
 |             bbb = torch.mul(aaa, y) | 
 |             return bbb | 
 |  | 
 |         traced = torch.jit.trace( | 
 |             easy, | 
 |             (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8), | 
 |              torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)), | 
 |         ) | 
 |  | 
 |         a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) | 
 |         b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) | 
 |         x = warmup_and_run_forward(traced, a, b) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) | 
 |  | 
 |     def test_int64_promotion(self): | 
 |         TENSOR_LEN = 8 | 
 |  | 
 |         def easy(x, y): | 
 |             aaa = torch.add(x, y) | 
 |             bbb = torch.mul(aaa, y) | 
 |             return bbb | 
 |  | 
 |         traced = torch.jit.trace( | 
 |             easy, | 
 |             (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8), | 
 |              torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64)), | 
 |         ) | 
 |  | 
 |         a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) | 
 |         b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64) | 
 |         x = warmup_and_run_forward(traced, a, b) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) | 
 |  | 
 |     def test_eq(self): | 
 |         def easy(x, y): | 
 |             c = torch.eq(x, y) | 
 |             return c | 
 |  | 
 |         traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) | 
 |         a = torch.zeros(1024, dtype=torch.int32) | 
 |         b = torch.zeros(1024, dtype=torch.int32) | 
 |         x = warmup_and_run_forward(traced, a, b) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose(np.ones(1024), x.numpy()) | 
 |  | 
 |     def test_ne(self): | 
 |         def easy(x, y): | 
 |             c = torch.ne(x, y) | 
 |             return c | 
 |  | 
 |         traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) | 
 |         a = torch.zeros(1024, dtype=torch.int32) | 
 |         b = torch.ones(1024, dtype=torch.int32) | 
 |         x = warmup_and_run_forward(traced, a, b) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose(np.ones(1024), x.numpy()) | 
 |  | 
 |     def test_ge(self): | 
 |         def easy(x, y): | 
 |             c = torch.ge(x, y) | 
 |             return c | 
 |  | 
 |         traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) | 
 |         aa = np.empty([1024], dtype=np.int32) | 
 |         aa.fill(5) | 
 |         a = torch.from_numpy(aa) | 
 |         b = torch.zeros(1024, dtype=torch.int32) | 
 |         x = warmup_and_run_forward(traced, a, b) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose(np.ones(1024), x.numpy()) | 
 |  | 
 |     def test_gt(self): | 
 |         def easy(x, y): | 
 |             c = torch.gt(x, y) | 
 |             return c | 
 |  | 
 |         traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) | 
 |         a = torch.ones(1024, dtype=torch.int32) | 
 |         b = torch.zeros(1024, dtype=torch.int32) | 
 |         x = warmup_and_run_forward(traced, a, b) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose(np.ones(1024), x.numpy()) | 
 |  | 
 |     def test_le(self): | 
 |         def easy(x, y): | 
 |             c = torch.le(x, y) | 
 |             return c | 
 |  | 
 |         traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) | 
 |         aa = np.empty([1024], dtype=np.int32) | 
 |         aa.fill(5) | 
 |         a = torch.from_numpy(aa) | 
 |         b = torch.zeros(1024, dtype=torch.int32) | 
 |         x = warmup_and_run_forward(traced, a, b) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose(np.zeros(1024), x.numpy()) | 
 |  | 
 |     def test_lt(self): | 
 |         def easy(x, y): | 
 |             c = torch.lt(x, y) | 
 |             return c | 
 |  | 
 |         for dev in self.devices: | 
 |             traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev))) | 
 |             a = torch.ones(1024, dtype=torch.int32, device=dev) | 
 |             b = torch.zeros(1024, dtype=torch.int32, device=dev) | 
 |             x = warmup_and_run_forward(traced, a, b) | 
 |             self.assertLastGraphAllFused() | 
 |             np.testing.assert_allclose(np.zeros(1024), x.cpu().numpy()) | 
 |  | 
 |     @suppress_warnings | 
 |     def test_min_max(self): | 
 |         def test(x, y): | 
 |             return torch.max(torch.min(x, y), torch.tensor([4.0])) | 
 |  | 
 |         traced = torch.jit.trace(test, (torch.zeros(1024), torch.zeros(1024))) | 
 |         a = 8.0 * torch.rand(1024) | 
 |         b = 8.0 * torch.rand(1024) | 
 |         np.testing.assert_allclose( | 
 |             warmup_and_run_forward(traced, a, b), np.maximum(np.minimum(a.numpy(), b.numpy()), [4.0]) | 
 |         ) | 
 |         self.assertLastGraphAllFused() | 
 |  | 
 |     def test_min_max_reduction(self): | 
 |         def test(x): | 
 |             return torch.min(x) + torch.max(x) | 
 |  | 
 |         traced = torch.jit.trace(test, (torch.zeros(1024))) | 
 |         a = 8.0 * torch.rand(1024) | 
 |         np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy())) | 
 |         self.assertLastGraphAllFused() | 
 |  | 
 |     def test_min_max_reduction2(self): | 
 |         def test(x): | 
 |             return x.min() + x.max() | 
 |  | 
 |         traced = torch.jit.trace(test, (torch.zeros(1024))) | 
 |         a = 8.0 * torch.rand(1024) | 
 |         np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy())) | 
 |         self.assertLastGraphAllFused() | 
 |  | 
 |     def test_min_max_reduction_dim1(self): | 
 |         def test(x): | 
 |             return torch.min(x, 1)[0] + torch.max(x, 1)[0] | 
 |  | 
 |         traced = torch.jit.trace(test, (torch.zeros(16, 16))) | 
 |         a = 8.0 * torch.rand(16, 16) | 
 |         np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin( | 
 |             a.numpy(), axis=1) + np.amax(a.numpy(), axis=1)) | 
 |         self.assertLastGraphAllFused() | 
 |  | 
 |     def test_min_max_reduction_dim1_2(self): | 
 |         def test(x): | 
 |             return torch.min(x * x, 1) | 
 |  | 
 |         traced = torch.jit.trace(test, (torch.zeros(16, 16))) | 
 |         a = 8.0 * torch.rand(16, 16) | 
 |         np.testing.assert_allclose(warmup_and_run_forward(traced, a)[0], np.amin((a * a).numpy(), axis=1)) | 
 |         self.assertLastGraphAllFused() | 
 |  | 
 |     def test_clamp(self): | 
 |         def test(x): | 
 |             return torch.clamp(x + 3.0, 0.0, 6.0) | 
 |  | 
 |         for dev in self.devices: | 
 |             traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) | 
 |             a = 20.0 * torch.rand(1024, device=dev) - 10.0 | 
 |             an = a.cpu().numpy() | 
 |             np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip(an + 3.0, 0.0, 6.0)) | 
 |             self.assertLastGraphAllFused() | 
 |  | 
 |     def test_relu(self): | 
 |         def test(x): | 
 |             return torch.clamp(F.relu(x), 0, 0.5) | 
 |  | 
 |         for dev in self.devices: | 
 |             traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) | 
 |             a = 20.0 * torch.rand(1024, device=dev) - 10.0 | 
 |             an = a.cpu().numpy() | 
 |             np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5)) | 
 |             self.assertLastGraphAllFused() | 
 |  | 
 |     def test_reps(self): | 
 |         def easy(x, y): | 
 |             c = torch.add(x, y) | 
 |             return c | 
 |  | 
 |         traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024))) | 
 |  | 
 |         for _ in range(32): | 
 |             a = torch.ones(1024) | 
 |             b = torch.zeros(1024) | 
 |             x = warmup_and_run_forward(traced, a, b) | 
 |             np.testing.assert_allclose(np.ones(1024), x.numpy()) | 
 |  | 
 |     def test_add_const_rhs(self): | 
 |         def test(x): | 
 |             return x + 3.0 | 
 |  | 
 |         traced = torch.jit.trace(test, torch.rand(4)) | 
 |         x = torch.rand(4) | 
 |         y = warmup_and_run_forward(traced, x) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose(x.numpy() + 3.0, y.numpy()) | 
 |  | 
 |     def test_int_output(self): | 
 |         def test(x, y, z): | 
 |             return x * y * z | 
 |  | 
 |         xs = [(torch.rand(4) * 3 + 1).to(torch.int32) for i in range(3)] | 
 |         x, y, z = xs | 
 |         xn, yn, zn = [t.numpy() for t in xs] | 
 |         traced = torch.jit.trace(test, (x, y, z)) | 
 |         res = warmup_and_run_forward(traced, x, y, z) | 
 |         self.assertLastGraphAllFused() | 
 |         np.testing.assert_allclose(xn * yn * zn, res.numpy()) | 
 |  | 
 |     def test_binary_ops(self): | 
 |         def test_atan2(x, y): | 
 |             c = torch.atan2(torch.add(x, y), y) | 
 |             return c | 
 |  | 
 |         def test_gt(x, y): | 
 |             c = torch.gt(torch.add(x, y), y) | 
 |             return c | 
 |  | 
 |         def test_ge(x, y): | 
 |             c = torch.ge(torch.add(x, y), y) | 
 |             return c | 
 |  | 
 |         def test_lt(x, y): | 
 |             c = torch.lt(torch.add(x, y), y) | 
 |             return c | 
 |  | 
 |         def test_le(x, y): | 
 |             c = torch.le(torch.add(x, y), y) | 
 |             return c | 
 |  | 
 |         def test_lerp(x, y): | 
 |             c = torch.lerp(torch.add(x, 1), x, 2.0) | 
 |             return c | 
 |  | 
 |         def test_mul(x, y): | 
 |             c = torch.mul(torch.add(x, y), y) | 
 |             return c | 
 |  | 
 |         def test_ne(x, y): | 
 |             c = torch.ne(torch.add(x, y), y) | 
 |             return c | 
 |  | 
 |         def test_div(x, y): | 
 |             c = torch.div(torch.add(x, y), 2) | 
 |             return c | 
 |  | 
 |         def test_eq(x, y): | 
 |             c = torch.eq(torch.add(x, y), y) | 
 |             return c | 
 |  | 
 |         def test_fmod(x, y): | 
 |             c = torch.fmod(torch.add(x, y), 2) | 
 |             return c | 
 |  | 
 |         def test_sub(x, y): | 
 |             c = torch.sub(torch.add(x, y), x) | 
 |             return c | 
 |  | 
 |         def test_remainder(x, y): | 
 |             c = torch.remainder(torch.add(x, y), 3.0) | 
 |             return c | 
 |  | 
 |         def test_pow(x, y): | 
 |             c = torch.pow(torch.add(x, y), 2.0) | 
 |             return c | 
 |  | 
 |         def test_type_as(x, y): | 
 |             return x.type_as(torch.add(x, y)) | 
 |  | 
 |         fns = { | 
 |             test_atan2, | 
 |             test_gt, | 
 |             test_ge, | 
 |             test_lt, | 
 |             test_le, | 
 |             test_lerp, | 
 |             test_mul, | 
 |             test_ne, | 
 |             test_div, | 
 |             test_eq, | 
 |             test_fmod, | 
 |             test_sub, | 
 |             test_remainder, | 
 |             test_pow, | 
 |             test_type_as, | 
 |         } | 
 |         for torch_fn in fns: | 
 |             for dev in self.devices: | 
 |                 rand_a = torch.rand(1024, device=dev) | 
 |                 rand_b = torch.rand(1024, device=dev) | 
 |                 in1 = 20 * torch.rand(1024, device=dev) | 
 |                 in2 = 20 * torch.rand(1024, device=dev) | 
 |                 traced = torch.jit.trace(torch_fn, (in1, in2)) | 
 |                 x = warmup_and_run_forward(traced, rand_a, rand_b) | 
 |                 self.assertLastGraphAllFused() | 
 |                 y = torch_fn(rand_a, rand_b) | 
 |                 np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=2e-3) | 
 |  | 
 |     def test_unary_ops(self): | 
 |         def test_cast_float(x, y): | 
 |             c = torch.ops.aten._cast_Float(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_round(x, y): | 
 |             c = torch.round(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_sin(x, y): | 
 |             c = torch.sin(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_asin(x, y): | 
 |             c = torch.asin(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_sinh(x, y): | 
 |             c = torch.sinh(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_cos(x, y): | 
 |             c = torch.cos(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_acos(x, y): | 
 |             c = torch.acos(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_cosh(x, y): | 
 |             c = torch.cosh(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_tan(x, y): | 
 |             c = torch.tan(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_atan(x, y): | 
 |             c = torch.atan(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_tanh(x, y): | 
 |             c = torch.tanh(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_sqrt(x, y): | 
 |             c = torch.sqrt(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_rsqrt(x, y): | 
 |             c = torch.rsqrt(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_floor(x, y): | 
 |             c = torch.floor(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_ceil(x, y): | 
 |             c = torch.ceil(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_trunc(x, y): | 
 |             c = torch.trunc(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_abs(x, y): | 
 |             c = torch.abs(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_log(x, y): | 
 |             c = torch.log(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_log2(x, y): | 
 |             c = torch.log2(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_log10(x, y): | 
 |             c = torch.log10(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_log1p(x, y): | 
 |             c = torch.log1p(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_rqrt(x, y): | 
 |             c = torch.rsqrt(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_erf(x, y): | 
 |             c = torch.erf(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_exp(x, y): | 
 |             c = torch.exp(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_expm1(x, y): | 
 |             c = torch.expm1(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_erfc(x, y): | 
 |             c = torch.erfc(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_frac(x, y): | 
 |             c = torch.frac(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_lgamma(x, y): | 
 |             c = torch.lgamma(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_sigmoid(x, y): | 
 |             c = torch.sigmoid(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_reciprocal(x, y): | 
 |             c = torch.reciprocal(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_neg(x, y): | 
 |             c = torch.neg(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_relu(x, y): | 
 |             c = torch.relu(torch.add(x, y)) | 
 |             return c | 
 |  | 
 |         def test_hardtanh(x, y): | 
 |             c = F.hardtanh(torch.add(x, y), -1.0, 1.0) | 
 |             return c | 
 |  | 
 |         def test_threshold(x, y): | 
 |             c = F.threshold(torch.add(x, y), 0.5, 10) | 
 |             return c | 
 |  | 
 |         fns = { | 
 |             test_round, | 
 |             test_sin, | 
 |             test_asin, | 
 |             test_sinh, | 
 |             test_cos, | 
 |             test_acos, | 
 |             test_cosh, | 
 |             test_tan, | 
 |             test_atan, | 
 |             test_sqrt, | 
 |             test_floor, | 
 |             test_ceil, | 
 |             test_trunc, | 
 |             test_abs, | 
 |             test_log, | 
 |             test_log2, | 
 |             test_log10, | 
 |             test_log1p, | 
 |             test_rsqrt, | 
 |             test_exp, | 
 |             test_expm1, | 
 |             test_erf, | 
 |             test_erfc, | 
 |             test_frac, | 
 |             test_lgamma, | 
 |             test_reciprocal, | 
 |             test_neg, | 
 |             test_threshold, | 
 |             test_relu, | 
 |             test_tanh, | 
 |             test_hardtanh, | 
 |             test_sigmoid, | 
 |         } | 
 |  | 
 |         for torch_fn in fns: | 
 |             for dev in self.devices: | 
 |                 # print(torch_fn, dev) | 
 |                 rand_a = torch.rand(1024, device=dev) | 
 |                 rand_b = torch.rand(1024, device=dev) | 
 |                 ins = 20 * torch.rand(1024, device=dev) | 
 |                 cc = np.empty([1024], dtype=np.float32) | 
 |                 cc.fill(np.nan) | 
 |                 nans = torch.from_numpy(cc).to(dev) | 
 |                 traced = torch.jit.trace(torch_fn, (ins, ins)) | 
 |                 x = warmup_and_run_forward(traced, rand_a, rand_b) | 
 |                 self.assertLastGraphAllFused() | 
 |                 y = torch_fn(rand_a, rand_b) | 
 |                 np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=2e-3) | 
 |                 # nans | 
 |                 # TODO: reenable. Currently all of the tests fail | 
 |                 # traced = torch.jit.trace(torch_fn, (ins, ins)) | 
 |                 # x = warmup_and_run_forward(traced, rand_a, rand_b) | 
 |                 # y = torch_fn(nans, rand_b) | 
 |                 # try: | 
 |                 #     np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) | 
 |                 #     print("Succeeded on dev=", dev, "function=", torch_fn) | 
 |                 # except AssertionError: | 
 |                 #     # Print extra info before exiting: | 
 |                 #     print("Failed on dev=", dev, "function=", torch_fn) | 
 |                 #     # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) | 
 |  | 
 |     def test_rand_like(self): | 
 |         N = 1 << 16 | 
 |  | 
 |         def run_rand_like(x, y): | 
 |             return torch.rand_like(torch.add(x, y)) | 
 |  | 
 |         for device in self.devices: | 
 |             x = torch.rand(N, device=device) | 
 |             traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False) | 
 |             x_v = warmup_and_run_forward(traced, x, x) | 
 |             self.assertLastGraphAllFused() | 
 |             x_np = x.cpu().numpy() | 
 |             x1_mean = np.mean(x_np) | 
 |             x2_mean = np.mean(x_np ** 2) | 
 |             x3_mean = np.mean(x_np ** 3) | 
 |             np.testing.assert_allclose(x1_mean, 1. / 2, rtol=2e-2) | 
 |             np.testing.assert_allclose(x2_mean, 1. / 3, rtol=2e-2) | 
 |             np.testing.assert_allclose(x3_mean, 1. / 4, rtol=2e-2) | 
 |  | 
 |     def test_nans(self): | 
 |         def test_max(x, y): | 
 |             return torch.max(2 * x, 2 * y) | 
 |  | 
 |         def test_min(x, y): | 
 |             return torch.min(2 * x, 2 * y) | 
 |  | 
 |         tmax = torch.jit.trace(test_max, (torch.rand(1), torch.rand(1))) | 
 |         tmin = torch.jit.trace(test_min, (torch.rand(1), torch.rand(1))) | 
 |  | 
 |         x = torch.tensor([np.nan]) | 
 |         y = torch.tensor([1.0]) | 
 |  | 
 |         assert np.isnan(warmup_and_run_forward(tmin, x, y).item()) | 
 |         assert np.isnan(warmup_and_run_forward(tmin, y, x).item()) | 
 |         self.assertLastGraphAllFused() | 
 |         assert np.isnan(warmup_and_run_forward(tmax, x, y).item()) | 
 |         assert np.isnan(warmup_and_run_forward(tmax, y, x).item()) | 
 |         self.assertLastGraphAllFused() | 
 |  | 
 |     def test_double_intrinsics(self): | 
 |         def do_pow(x): | 
 |             return torch.pow(x, 7) | 
 |  | 
 |         for device in self.devices: | 
 |             x = torch.rand(10, dtype=torch.double, device=device) | 
 |             traced = torch.jit.trace(do_pow, (x)) | 
 |             x = warmup_and_run_forward(traced, x) | 
 |             self.assertLastGraphAllFused() | 
 |  | 
 |     def test_remainder(self): | 
 |         def run_remainder(x, y): | 
 |             c = torch.remainder(torch.add(x, y), x) | 
 |             return c | 
 |  | 
 |         a = torch.rand(1024, dtype=float) | 
 |         b = torch.rand(1024, dtype=float) | 
 |         zeros = torch.zeros(1024, dtype=float) | 
 |         cc = np.array(1024, dtype=float) | 
 |         cc.fill(np.nan) | 
 |         nans = torch.from_numpy(cc) | 
 |  | 
 |         # random floats | 
 |         traced = torch.jit.trace(run_remainder, (torch.zeros(1024), torch.zeros(1024))) | 
 |         x = warmup_and_run_forward(traced, a, b) | 
 |         self.assertLastGraphAllFused() | 
 |         y = run_remainder(a, b) | 
 |         np.testing.assert_allclose(x.numpy(), y.numpy()) | 
 |  | 
 |         # div by 0 | 
 |         traced = torch.jit.trace(run_remainder, (torch.zeros(1024), torch.zeros(1024))) | 
 |         x = warmup_and_run_forward(traced, zeros, a) | 
 |         self.assertLastGraphAllFused() | 
 |         y = run_remainder(zeros, a) | 
 |         np.testing.assert_allclose(x.numpy(), y.numpy()) | 
 |  | 
 |         # numerators and denominatos are nan | 
 |         traced = torch.jit.trace(run_remainder, (torch.zeros(1024), torch.zeros(1024))) | 
 |         x = warmup_and_run_forward(traced, nans, a) | 
 |         self.assertLastGraphAllFused() | 
 |         y = run_remainder(nans, a) | 
 |         np.testing.assert_allclose(x.numpy(), y.numpy()) | 
 |  | 
 |     def test_multioutput(self): | 
 |         def easy(x): | 
 |             b = x + 1 | 
 |             c = b + b | 
 |             return (b, c) | 
 |  | 
 |         traced = torch.jit.trace(easy, (torch.zeros(1024))) | 
 |  | 
 |         a = torch.zeros(1024) | 
 |         b, c = warmup_and_run_forward(traced, a) | 
 |         self.assertLastGraphAllFused() | 
 |         bp = a.numpy() + 1 | 
 |         cp = bp + bp | 
 |         np.testing.assert_allclose(b.numpy(), bp) | 
 |         np.testing.assert_allclose(c.numpy(), cp) | 
 |  | 
 |     def test_chunk(self): | 
 |         def easy(x): | 
 |             y = x + 1 | 
 |             aaa, bbb = torch.chunk(y, 2) | 
 |             return aaa + bbb | 
 |  | 
 |         traced = torch.jit.trace(easy, (torch.zeros(1024, 1024))) | 
 |  | 
 |         a = torch.zeros(32, 32) | 
 |         x = warmup_and_run_forward(traced, a) | 
 |         self.assertLastGraphAllFused() | 
 |         npr = a.numpy() | 
 |         npr2 = npr + 1 | 
 |         npr_a, npr_b = np.array_split(npr2, 2) | 
 |         np.testing.assert_allclose(npr_a + npr_b, x.numpy()) | 
 |  | 
 |     def test_cat(self): | 
 |         for device in self.devices: | 
 |             def foo(*args): | 
 |                 args_2 = [v + i for i, v in enumerate(args)] | 
 |                 v = torch.cat(args_2, dim=1) | 
 |                 return v * v | 
 |  | 
 |             M = 16 | 
 |             Ns = [128, 16, 1] | 
 |             values = [torch.zeros(M, N, device=device) for N in Ns] | 
 |             traced = torch.jit.trace(foo, values) | 
 |  | 
 |             x = warmup_and_run_forward(traced, *values) | 
 |             self.assertLastGraphAllFused() | 
 |             ref = foo(*values) | 
 |             np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) | 
 |  | 
 |     # This test checks that we correctly handle fusion group with just aten::cat in it. | 
 |     # Note that the test only makes sense with min_fusion_group=1, otherwise no | 
 |     # fusion groups would be formed at all. | 
 |     # TODO: Fix and re-enable the test. | 
 |     @unittest.skip("cat is broken with fusion group inlining disabled") | 
 |     def test_cat_only(self): | 
 |         for device in self.devices: | 
 |             def foo(*args): | 
 |                 args_2 = [v + i for i, v in enumerate(args)] | 
 |                 v = torch.cat(args_2, dim=1) | 
 |                 return v | 
 |  | 
 |             M = 16 | 
 |             Ns = [128, 16, 1] | 
 |             values = [torch.zeros(M, N, device=device) for N in Ns] | 
 |             traced = torch.jit.trace(foo, values) | 
 |  | 
 |             x = warmup_and_run_forward(traced, *values) | 
 |             self.assertLastGraphAllFused() | 
 |             ref = foo(*values) | 
 |             np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) | 
 |  | 
 |     def test_cat_negative_dim(self): | 
 |         for device in self.devices: | 
 |             def foo(*args): | 
 |                 v = torch.cat(args, dim=-1) | 
 |                 return v * v | 
 |  | 
 |             M = 16 | 
 |             Ns = [128, 16, 1] | 
 |             values = [torch.randn(M, N, device=device) for N in Ns] | 
 |             traced = torch.jit.trace(foo, values) | 
 |  | 
 |             x = warmup_and_run_forward(traced, *values) | 
 |             self.assertLastGraphAllFused() | 
 |             ref = foo(*values) | 
 |             np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) | 
 |  | 
 |     def test_cat_promote_inputs(self): | 
 |         for device in self.devices: | 
 |             def foo(*args): | 
 |                 v = torch.cat(args, dim=1) | 
 |                 return v * v | 
 |  | 
 |             M = 16 | 
 |             Ns = [128, 16, 1] | 
 |             dtypes = [torch.half, torch.float32, torch.double] | 
 |             values = [torch.randn(M, N, device=device, dtype=dt) for N, dt in zip(Ns, dtypes)] | 
 |             traced = torch.jit.trace(foo, values) | 
 |  | 
 |             x = warmup_and_run_forward(traced, *values) | 
 |             self.assertLastGraphAllFused() | 
 |             ref = foo(*values) | 
 |             np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) | 
 |  | 
 |     def test_cat_empty_tensors(self): | 
 |         for device in self.devices: | 
 |             def foo(*args): | 
 |                 v = torch.cat(args, dim=1) | 
 |                 return v * v | 
 |  | 
 |             M = 16 | 
 |             Ns = [128, 16, 1] | 
 |             empty = torch.tensor([], device=device, dtype=torch.double) | 
 |             values = [empty] + [torch.randn(M, N, device=device) for N in Ns] | 
 |             traced = torch.jit.trace(foo, values) | 
 |  | 
 |             x = warmup_and_run_forward(traced, *values) | 
 |             self.assertLastGraphAllFused() | 
 |             ref = foo(*values) | 
 |             np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) | 
 |  | 
 |             # now test with only empty tensors | 
 |             values = [empty for i in range(3)] | 
 |             traced = torch.jit.trace(foo, values) | 
 |             x = warmup_and_run_forward(traced, *values) | 
 |             self.assertLastGraphAllFused() | 
 |             ref = foo(*values) | 
 |             np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) | 
 |  | 
 |     def test_cat_with_constant_dim(self): | 
 |         for device in self.devices: | 
 |             def foo(*args): | 
 |                 v1 = torch.cat(args, dim=1) | 
 |                 v2 = torch.cat([v1], dim=1) | 
 |                 return v2 * v2 | 
 |  | 
 |             empty = torch.tensor([], device=device, dtype=torch.float32) | 
 |             inputs = [empty] + [torch.randn(1, 64, device=device), torch.randn(1, 64, device=device)] | 
 |             traced = torch.jit.trace(foo, inputs) | 
 |  | 
 |             x = warmup_and_run_forward(traced, *inputs) | 
 |             self.assertLastGraphAllFused() | 
 |             ref = foo(*inputs) | 
 |             np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) | 
 |  | 
 |     def test_scalar(self): | 
 |         @torch.jit.script | 
 |         def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor: | 
 |             return torch.add(torch.add(x, y, alpha=a), z, alpha=b) | 
 |  | 
 |         @torch.jit.script | 
 |         def test_int(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: int, b: int) -> torch.Tensor: | 
 |             return torch.add(torch.add(x, y, alpha=a), z, alpha=b) | 
 |  | 
 |         for test in (test_float, test_int): | 
 |             x, y, z = [torch.rand(4) for i in range(3)] | 
 |             a, b = 1, 2 | 
 |             test(x, y, z, a, b) | 
 |             r = test(x, y, z, a, b) | 
 |             xn, yn, zn = [t.numpy() for t in (x, y, z)] | 
 |             np.testing.assert_allclose(r.numpy(), xn + yn * a + zn * b) | 
 |  | 
 |     def test_loop(self): | 
 |         @torch.jit.script | 
 |         def test(x: torch.Tensor, y: torch.Tensor, z: int) -> torch.Tensor: | 
 |             b = y | 
 |             for i in range(0, z): | 
 |                 a = x + y | 
 |                 b = b + y | 
 |             return b | 
 |  | 
 |         x, y, z = (torch.zeros(32, 32), torch.ones(32, 32), 4) | 
 |         test(x, y, z) | 
 |         r = test(x, y, z) | 
 |  | 
 |     def test_slice(self): | 
 |         def easy(x, y): | 
 |             a = x[0:512:2] | 
 |             b = y[0:512:2] | 
 |             return a + b | 
 |  | 
 |         traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024))) | 
 |  | 
 |         a = torch.ones(1024, 1024) | 
 |         x = traced(a, a) | 
 |         npr = a[0:512:2] | 
 |         npr = npr + npr | 
 |         np.testing.assert_allclose(npr.numpy(), x.numpy()) | 
 |  | 
 |     def test_unsqueeze(self, N=256): | 
 |         def easy(x, y): | 
 |             a = torch.unsqueeze(x, 0) | 
 |             b = torch.unsqueeze(y, 0) | 
 |             return a + b | 
 |  | 
 |         traced = torch.jit.trace(easy, (torch.ones(N, N), torch.zeros(N, N))) | 
 |  | 
 |         a = torch.rand(N, N) | 
 |         x = traced(a, a) | 
 |         npr = np.expand_dims(a, 0) | 
 |         npr = npr + npr | 
 |         np.testing.assert_allclose(npr, x.numpy()) | 
 |  | 
 |     def _test_softmax(self, device): | 
 |         def test_softmax(x, y): | 
 |             a = F.softmax(x, dim=0, dtype=torch.float32) | 
 |             b = F.softmax(y, dim=0, dtype=torch.float32) | 
 |             c = F.softmax(x, dim=1, dtype=torch.float32) | 
 |             d = F.softmax(y, dim=1, dtype=torch.float32) | 
 |             return a + b + c + d | 
 |  | 
 |         def test_softmax_neg_index(x, y): | 
 |             a = F.softmax(x, dim=-2, dtype=torch.float32) | 
 |             b = F.softmax(y, dim=-2, dtype=torch.float32) | 
 |             c = F.softmax(x, dim=-1, dtype=torch.float32) | 
 |             d = F.softmax(y, dim=-1, dtype=torch.float32) | 
 |             return a + b + c + d | 
 |  | 
 |         def test_log_softmax(x, y): | 
 |             a = F.log_softmax(x, dim=0, dtype=torch.float32) | 
 |             b = F.log_softmax(y, dim=0, dtype=torch.float32) | 
 |             c = F.log_softmax(x, dim=1, dtype=torch.float32) | 
 |             d = F.log_softmax(y, dim=1, dtype=torch.float32) | 
 |             return a + b + c + d | 
 |  | 
 |         for test in (test_softmax, test_log_softmax, test_softmax_neg_index): | 
 |             old = torch._C._jit_set_texpr_reductions_enabled(True) | 
 |             traced = torch.jit.trace(test, (torch.randn(2, 3, device=device), torch.randn(2, 3, device=device))) | 
 |             inp = torch.randn(2, 3, device=device) | 
 |             res = traced(inp, inp) | 
 |             # Use eager mode as reference. | 
 |             ref = test(inp, inp) | 
 |             np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06) | 
 |             torch._C._jit_set_texpr_reductions_enabled(old) | 
 |  | 
 |     def test_softmax_cpu(self): | 
 |         self._test_softmax('cpu') | 
 |  | 
 |     @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") | 
 |     @unittest.skip("global allocs are not supported yet.") | 
 |     def test_softmax_cuda(self): | 
 |         self._test_softmax('cuda') | 
 |  | 
 |     def test_half_gelu(self): | 
 |         devices = ["cuda"] if torch.cuda.is_available() else [] | 
 |  | 
 |         @torch.jit.script | 
 |         def bias_gelu(bias, y): | 
 |             x = bias + y | 
 |             return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) | 
 |  | 
 |         for device in devices: | 
 |             a = torch.rand(1024, dtype=torch.half, device=device) | 
 |             b = torch.rand(1024, dtype=torch.half, device=device) | 
 |             traced = torch.jit.trace(bias_gelu, (a, b)) | 
 |             x = warmup_and_run_forward(traced, a, b) | 
 |             self.assertLastGraphAllFused() | 
 |  | 
 |     def test_half_bn_relu(self): | 
 |         devices = ["cuda"] if torch.cuda.is_available() else [] | 
 |  | 
 |         def foo(a, b, c): | 
 |             y = torch.nn.functional.batch_norm(a, b, c) | 
 |             z = y.relu() | 
 |             return z | 
 |  | 
 |         for device in devices: | 
 |             a = torch.rand(16, 16, dtype=torch.half, device=device) | 
 |             b = torch.rand(16, dtype=torch.half, device=device) | 
 |             c = torch.rand(16, dtype=torch.half, device=device) | 
 |             traced = torch.jit.trace(foo, (a, b, c)) | 
 |             print(traced.graph) | 
 |             x = warmup_and_run_forward(traced, a, b, c) | 
 |             self.assertLastGraphAllFused() | 
 |  | 
 |     def test_exp_pow(self): | 
 |         @torch.jit.script | 
 |         def do_exp(x, y, z): | 
 |             return ((x * y) * 2) * torch.pow(z, 2) | 
 |  | 
 |         for device in self.devices: | 
 |             x = torch.rand(10, dtype=torch.double, device=device) | 
 |             y = torch.rand(10, dtype=torch.double, device=device) | 
 |             z = torch.rand(10, dtype=torch.double, device=device) | 
 |             traced = torch.jit.trace(do_exp, (x, y, z)) | 
 |             x = warmup_and_run_forward(traced, x, y, z) | 
 |             self.assertLastGraphAllFused() | 
 |  | 
 |     def test_transpose(self): | 
 |         @torch.jit.script | 
 |         def test(x, y, z): | 
 |             return x.transpose(0, 1) + y + z | 
 |         x = torch.rand(4, 5, 2, 3) | 
 |         y = torch.rand(5, 4, 2, 3) | 
 |         z = torch.rand(5, 4, 2, 3) | 
 |         ref = test(x, y, z) | 
 |         res = test(x, y, z) | 
 |         np.testing.assert_allclose(ref.numpy(), res.numpy()) | 
 |  | 
 |     def test_sliced_stride(self): | 
 |         @torch.jit.script | 
 |         def test(x, y, z): | 
 |             return x + y + z | 
 |         x = torch.rand(16, 4, 2, 3)[::2] | 
 |         y = torch.rand(8, 4, 2, 3) | 
 |         z = torch.rand(8, 4, 2, 3) | 
 |         ref = test(x, y, z) | 
 |         res = test(x, y, z) | 
 |         np.testing.assert_allclose(ref.numpy(), res.numpy()) | 
 |  | 
 |     @unittest.skip("dynamic shapes are not quite there yet") | 
 |     @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") | 
 |     def test_dynamic_shape(self): | 
 |         with num_profiled_runs(2): | 
 |             @torch.jit.script | 
 |             def test(x, y, z): | 
 |                 return x * y * z | 
 |             x, y, z = [torch.rand(4, 8).cuda() for _ in range(3)] | 
 |             ref = test(x, y, z) | 
 |             _ = test(*[torch.rand(6, 8).cuda() for _ in range(3)]) | 
 |             res = test(x, y, z) | 
 |             np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy()) | 
 |  | 
 |             # A wild broadcast appears. | 
 |             x = torch.rand(4, 8).cuda() | 
 |             y = torch.rand(1, 8).cuda() | 
 |             z = torch.rand(4, 1).cuda() | 
 |             res = test(x, y, z) | 
 |             xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)] | 
 |             np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn) | 
 |  | 
 |             # Mismatched shapes shouldn't reach codegen. | 
 |             x = torch.rand(4, 8).cuda() | 
 |             y = torch.rand(4, 8).cuda() | 
 |             z = torch.rand(5, 8).cuda() | 
 |             try: | 
 |                 res = test(x, y, z) | 
 |             except RuntimeError as e: | 
 |                 assert "The size of tensor a (4) must match" in e.args[0] | 
 |  | 
 |             # Changing a static dimension fails guards. | 
 |             # x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)] | 
 |             # xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)] | 
 |             # res = test(x, y, z) | 
 |             # print(test.graph_for(x, y, z)) | 
 |             # np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn) | 
 |  | 
 |     @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") | 
 |     def test_guard_fails(self): | 
 |         @torch.jit.script | 
 |         def test(x, y, z): | 
 |             return x * y * z | 
 |         r1 = test(*[torch.rand(4).cuda() for _ in range(3)]) | 
 |         r2 = test(*[torch.rand(4).cuda() for _ in range(3)]) | 
 |         r3 = test(*[torch.rand(4).cuda() for _ in range(3)]) | 
 |         r4 = test(*[torch.rand(7).cuda() for _ in range(3)]) | 
 |  | 
 |     def test_bitwise_ops(self): | 
 |         def run_and(x, y): | 
 |             return x & (x & y) | 
 |  | 
 |         def run_or(x, y): | 
 |             return x & (x | y) | 
 |  | 
 |         def run_xor(x, y): | 
 |             return x ^ (x ^ y) | 
 |  | 
 |         def run_lshift(x, y): | 
 |             return x & (x << y) | 
 |  | 
 |         def run_rshift(x, y): | 
 |             return x & (x >> y) | 
 |  | 
 |         fns = {run_and, run_or, run_xor, run_lshift, run_rshift} | 
 |  | 
 |         for device in self.devices: | 
 |             for fn in fns: | 
 |                 a = torch.ones(128, dtype=torch.int32, device=device) | 
 |                 b = torch.zeros(128, dtype=torch.int32, device=device) | 
 |                 inp = torch.ones(128, dtype=torch.int32, device=device) | 
 |                 traced = torch.jit.trace(fn, (inp, inp)) | 
 |                 x = warmup_and_run_forward(traced, a, b) | 
 |                 self.assertLastGraphAllFused() | 
 |                 y = fn(a, b) | 
 |                 np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) | 
 |  | 
 |     def test_where(self): | 
 |         def run_where(x, y): | 
 |             return torch.where(torch.gt(x, y), x, y) | 
 |  | 
 |         a = torch.rand(1024, dtype=float) | 
 |         b = torch.rand(1024, dtype=float) | 
 |         traced = torch.jit.trace(run_where, (torch.zeros(1024), torch.zeros(1024))) | 
 |         x = warmup_and_run_forward(traced, a, b) | 
 |         self.assertLastGraphAllFused() | 
 |         y = run_where(a, b) | 
 |         np.testing.assert_allclose(x.numpy(), y.numpy()) | 
 |  | 
 |     def test_multi_rand(self): | 
 |         for device in self.devices: | 
 |             def test(x): | 
 |                 y = torch.rand_like(x) | 
 |                 return (x + y) - (y - x) | 
 |             a = torch.rand(4, device=device) | 
 |             scripted = torch.jit.script(test) | 
 |             out = warmup_and_run_forward(scripted, a) | 
 |             self.assertLastGraphAllFused() | 
 |             assert torch.allclose(out, 2 * a) | 
 |  | 
 |     def test_mask(self): | 
 |         def test(x): | 
 |             return x.unsqueeze(1) == 0 | 
 |  | 
 |         for d in self.devices: | 
 |             x = torch.rand(4, device=d) > 0.5 | 
 |             scripted = torch.jit.script(test) | 
 |             out = warmup_and_run_forward(scripted, x) | 
 |             self.assertLastGraphAllFused() | 
 |             assert torch.equal(out, test(x)) | 
 |  | 
 |     def test_simple_add(self): | 
 |         val = torch._C._jit_get_te_generate_block_code() | 
 |         torch._C._jit_set_te_generate_block_code(True) | 
 |         fall_bk = torch._C._jit_texpr_fallback_allowed() | 
 |         torch._C._jit_texpr_set_fallback_allowed(True) | 
 |  | 
 |         def simple(a, b): | 
 |             return torch.add(a, b) | 
 |  | 
 |         a = torch.ones(256, 256) | 
 |         b = torch.ones(256, 256) | 
 |         traced = torch.jit.trace(simple, | 
 |                                  (torch.ones(256, 256), torch.ones(256, 256))) | 
 |         f = traced(a, b) | 
 |         f_test = np.full((256, 256), 2, dtype=float) | 
 |         np.testing.assert_allclose(f.numpy(), f_test) | 
 |         torch._C._jit_set_te_generate_block_code(val) | 
 |         torch._C._jit_texpr_set_fallback_allowed(fall_bk) | 
 |  | 
 |     def test_strided_output_preserved(self): | 
 |         def foo(a, b): | 
 |             return a + b - a | 
 |  | 
 |         # smaller, easier to debug example | 
 |         x = torch.arange(6) | 
 |         x = torch.as_strided(x, (2, 3), (1, 2)) | 
 |         total = 0 | 
 |         for i in range(2): | 
 |             for j in range(3): | 
 |                 x[i, j] = total | 
 |                 total += 1 | 
 |         foo_script = torch.jit.script(foo) | 
 |         foo_script(x, x) | 
 |         foo_script(x, x) | 
 |         out_s = foo_script(x, x) | 
 |         out_eager = foo(x, x) | 
 |         self.assertEqual(out_s, out_eager) | 
 |         self.assertEqual(out_s.stride(), out_eager.stride()) | 
 |         self.assertLastGraphAllFused() | 
 |  | 
 |         # more dims | 
 |         N, C, H, W, = 2, 3, 4, 5 | 
 |         x = torch.rand(N, C, H, W).to(memory_format=torch.channels_last) | 
 |         foo_script = torch.jit.script(foo) | 
 |         foo_script(x, x) | 
 |         foo_script(x, x) | 
 |         out_s = foo_script(x, x) | 
 |         out_eager = foo(x, x) | 
 |         self.assertEqual(out_s, out_eager) | 
 |         self.assertEqual(out_s.stride(), out_eager.stride()) | 
 |         self.assertLastGraphAllFused() | 
 |  | 
 |     def test_alias_analysis_module(self): | 
 |         class AliasModule(nn.Module): | 
 |             def __init__(self): | 
 |                 super(AliasModule, self).__init__() | 
 |                 torch.manual_seed(1337) | 
 |                 self.a = torch.randn(128, 128) | 
 |                 self.b = torch.randn(128, 128) | 
 |                 self.c = torch.randn(128, 128) | 
 |  | 
 |             def forward(self, x, y, z): | 
 |                 z = z + self.a | 
 |                 self.b.add_(y) | 
 |                 w = z + self.a | 
 |                 z = w + x | 
 |                 return z | 
 |         x = torch.randn(128, 128) | 
 |  | 
 |         def getModule(script): | 
 |             am = AliasModule() | 
 |             if script: | 
 |                 return torch.jit.script(am) | 
 |             return am | 
 |  | 
 |         am = getModule(False) | 
 |         am_s = getModule(True) | 
 |         ref = am(x, x, x) | 
 |         test = am_s(x, x, x) | 
 |         torch.testing.assert_close(ref, test) | 
 |  | 
 |         # Now do the aliasing | 
 |         am.a = am.b | 
 |         ref = am(x, x, x) | 
 |  | 
 |         am_s.a = am_s.b | 
 |         test = am_s(x, x, x) | 
 |  | 
 |         torch.testing.assert_close(ref, test) | 
 |  | 
 |     def test_alias_analysis_inputs(self): | 
 |         class AliasModule(nn.Module): | 
 |             def __init__(self): | 
 |                 super(AliasModule, self).__init__() | 
 |                 torch.manual_seed(1337) | 
 |                 self.a = torch.randn(128, 128) | 
 |                 self.b = torch.randn(128, 128) | 
 |                 self.c = torch.randn(128, 128) | 
 |  | 
 |             def forward(self, x, y, z): | 
 |                 x.add_(y) | 
 |                 w = z + self.a | 
 |                 z = w + x | 
 |                 return z | 
 |  | 
 |         def getModule(script): | 
 |             am = AliasModule() | 
 |             if script: | 
 |                 return torch.jit.script(am) | 
 |             return am | 
 |         am = getModule(False) | 
 |         am_s = getModule(True) | 
 |  | 
 |         torch.manual_seed(1337) | 
 |         x = torch.randn(128, 128) | 
 |         ref = am(x, x, x) | 
 |  | 
 |         torch.manual_seed(1337) | 
 |         x = torch.randn(128, 128) | 
 |         test = am_s(x, x, x) | 
 |  | 
 |         torch.testing.assert_close(ref, test) | 
 |  | 
 |     def test_alias_analysis_input_and_module(self): | 
 |         class AliasModule(nn.Module): | 
 |             def __init__(self): | 
 |                 super(AliasModule, self).__init__() | 
 |                 torch.manual_seed(1337) | 
 |                 self.a = torch.randn(128, 128) | 
 |                 self.b = torch.randn(128, 128) | 
 |                 self.c = torch.randn(128, 128) | 
 |  | 
 |             def forward(self, x, y, z): | 
 |                 x.add_(y) | 
 |                 w = z + self.b | 
 |                 z = w + x | 
 |                 return z | 
 |  | 
 |         def getModule(script): | 
 |             am = AliasModule() | 
 |             if script: | 
 |                 return torch.jit.script(am) | 
 |             return am | 
 |         am = getModule(False) | 
 |         am_s = getModule(True) | 
 |  | 
 |         torch.manual_seed(1337) | 
 |         x = torch.randn(128, 128) | 
 |         am.b = x | 
 |         ref = am(x, x, x) | 
 |  | 
 |         torch.manual_seed(1337) | 
 |         x = torch.randn(128, 128) | 
 |         am_s.b = x | 
 |         test = am_s(x, x, x) | 
 |  | 
 |         torch.testing.assert_close(ref, test) | 
 |  | 
 |     def test_multiple_outputs(self): | 
 |         for device in self.devices: | 
 |             # A bug reported internally similar to the one reported in #48533 | 
 |             def foo(a, b, c): | 
 |                 t_next = c + 1 | 
 |                 t5 = t_next * b | 
 |                 t6 = torch.unsqueeze(t_next, 1) | 
 |                 t7 = a * t6 | 
 |                 return (t7, t5, t_next) | 
 |  | 
 |             a = torch.rand(20, 20, dtype=torch.float32, device=device) | 
 |             b = torch.rand(20 * 29, dtype=torch.float32, device=device).as_strided([20], [29]) | 
 |             c = torch.ones(20, dtype=torch.int64, device=device) | 
 |             traced = torch.jit.trace(foo, (a, b, c)) | 
 |             ref = foo(a, b, c) | 
 |             exp = traced(a, b, c) | 
 |             exp = traced(a, b, c) | 
 |             self.assertEqual(ref, exp) | 
 |  | 
 | if __name__ == '__main__': | 
 |     run_tests() |