| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| import unittest |
| |
| import torch |
| |
| from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, skipIfRocm |
| |
| from test_jit import JitTestCase, RUN_CUDA |
| |
| |
| if GRAPH_EXECUTOR == ProfilingMode.PROFILING: |
| torch._C._jit_set_profiling_executor(True) |
| torch._C._jit_set_profiling_mode(True) |
| |
| class TestCudaFuser(JitTestCase): |
| |
| def setUp(self): |
| super(TestCudaFuser, self).setUp() |
| self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu() |
| self.old_gpu_fuse = torch._C._jit_can_fuse_on_gpu() |
| torch._C._jit_override_can_fuse_on_cpu(False) |
| torch._C._jit_override_can_fuse_on_gpu(False) |
| |
| if(RUN_CUDA): |
| torch._C._jit_register_cuda_fuser() |
| |
| def tearDown(self): |
| if(RUN_CUDA): |
| torch._C._jit_clear_cuda_fuser() |
| torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse) |
| torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse) |
| super(TestCudaFuser, self).tearDown() |
| |
| def _run_helper(self, jit_op, op, should_fuse, *args): |
| torch.cuda.manual_seed_all(123) |
| jit_o = jit_op(*args) |
| torch.cuda.manual_seed_all(123) |
| jit_o = jit_op(*args) |
| torch.cuda.manual_seed_all(123) |
| o = op(*args) |
| self.assertEqual(o, jit_o) |
| self.assertTrue(self._has_cuda_fusion_group(jit_op.graph_for(*args)) == should_fuse) |
| |
| def _has_cuda_fusion_group(self, graph): |
| has_cuda_fusion_group = False |
| for n in graph.nodes(): |
| if n.kind() == 'prim::CudaFusionGroup': |
| has_cuda_fusion_group = True |
| return has_cuda_fusion_group |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") |
| @skipIfRocm |
| def test_const(self): |
| def t(x, y): |
| o = x + y |
| o = o + 2.0 |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, dtype=torch.float, device="cuda") |
| y = torch.randn(4, 8, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y) |
| jit_o = t_jit(x, y) |
| o = t(x, y) |
| self.assertEqual(o, jit_o) |
| self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y))) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") |
| @skipIfRocm |
| def test_chunk(self): |
| def t(x, y, z, q): |
| o = x + q |
| x0, x1 = torch.chunk(o, 2) |
| o = x0 + x1 |
| o = o + y |
| o = o * z |
| o = torch.relu(o) |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, dtype=torch.float, device="cuda") |
| y = torch.randn(2, 8, dtype=torch.float, device="cuda") |
| z = torch.randn(2, 8, dtype=torch.float, device="cuda") |
| q = torch.randn(4, 8, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, z, q) |
| jit_o = t_jit(x, y, z, q) |
| o = t(x, y, z, q) |
| self.assertEqual(o, jit_o) |
| self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z, q))) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") |
| @skipIfRocm |
| def test_scalar_input(self): |
| def t(x : torch.Tensor, y : torch.Tensor, z : float): |
| o = x + y |
| o = o + z |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(4, 8, 1, 32, dtype=torch.float, device="cuda") |
| y = y.expand(4, 8, 32, 32) |
| jit_o = t_jit(x, y, 2.0) |
| jit_o = t_jit(x, y, 2.0) |
| o = t(x, y, 2.0) |
| self.assertEqual(o, jit_o) |
| self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0))) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") |
| @skipIfRocm |
| def test_broadcasting(self): |
| def t(x : torch.Tensor, y : torch.Tensor, z : float): |
| o = x + y |
| o = o + z |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(32, 32, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, 2.0) |
| jit_o = t_jit(x, y, 2.0) |
| o = t(x, y, 2.0) |
| self.assertEqual(o, jit_o) |
| self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0))) |
| |
| @unittest.skipIf(True, "real broadcast with different output not supported yet") |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") |
| @skipIfRocm |
| def test_broadcasting_multiple_output_shape(self): |
| def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor): |
| o = x + 12 |
| o1 = o + y |
| o2 = o + z |
| oo = o1.sum() + o2.sum() |
| return oo |
| t_jit = torch.jit.script(t) |
| x = torch.randn(32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(2, 32, 32, dtype=torch.float, device="cuda") |
| z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, z) |
| jit_o = t_jit(x, y, z) |
| o = t(x, y, z) |
| self.assertEqual(o, jit_o) |
| # Currently cannot fuse this |
| self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z))) |
| |
| @unittest.skipIf(True, "temporary disable for buggy codegen") |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") |
| @skipIfRocm |
| def test_broadcasting_multiple_output(self): |
| def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor): |
| o = x + 12 |
| o1 = o + y |
| o2 = o + z |
| oo = o1.sum() + o2.sum() |
| return oo |
| t_jit = torch.jit.script(t) |
| x = torch.randn(32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(4, 32, 32, dtype=torch.float, device="cuda") |
| z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, z) |
| jit_o = t_jit(x, y, z) |
| o = t(x, y, z) |
| self.assertEqual(o, jit_o) |
| # Currently cannot fuse this |
| self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z))) |
| |
| def _binary_test_helper(self, operation): |
| def t(x : torch.Tensor, y: torch.Tensor, z : float): |
| o = x + z |
| o = operation(o, y) |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, 2.0) |
| jit_o = t_jit(x, y, 2.0) |
| o = t(x, y, 2.0) |
| self.assertEqual(o, jit_o) |
| self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0))) |
| |
| def _unary_test_helper(self, operation): |
| def t(x : torch.Tensor, z : float): |
| o = x + z |
| o = operation(o) |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, 2.0) |
| jit_o = t_jit(x, 2.0) |
| o = t(x, 2.0) |
| self.assertEqual(o, jit_o) |
| self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, 2.0))) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") |
| @skipIfRocm |
| def test_unary_ops(self): |
| operations = [torch.neg, |
| torch.abs, |
| torch.log, |
| torch.log10, |
| torch.log1p, |
| torch.log2, |
| torch.lgamma, |
| torch.exp, |
| torch.expm1, |
| torch.erf, |
| torch.erfc, |
| torch.cos, |
| torch.acos, |
| torch.cosh, |
| torch.sin, |
| torch.asin, |
| torch.tan, |
| torch.atan, |
| torch.sqrt, |
| torch.rsqrt, |
| torch.ceil, |
| torch.floor, |
| torch.round, |
| torch.trunc, |
| torch.frac, |
| torch.reciprocal, |
| torch.relu, |
| torch.sigmoid, |
| torch.tanh, |
| torch.nn.functional.gelu] |
| for op in operations: |
| self._unary_test_helper(op) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") |
| @skipIfRocm |
| def test_binary_ops(self): |
| operations = [torch.div, |
| torch.mul, |
| torch.atan2, |
| torch.max, |
| torch.min, |
| torch.pow, |
| torch.remainder, |
| torch.fmod, |
| torch.eq, |
| torch.ne, |
| torch.ge, |
| torch.gt, |
| torch.le, |
| torch.lt] |
| for op in operations: |
| self._binary_test_helper(op) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") |
| @skipIfRocm |
| def test_ternary_ops(self): |
| x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| cond = torch.randint(0, 2, (4, 8, 32, 32)).to(dtype=torch.bool, device="cuda") |
| |
| def add(x : torch.Tensor, other : torch.Tensor, alpha : float): |
| o = torch.relu(x) |
| o = torch.add(o, other=other, alpha=alpha) |
| return o |
| add_jit = torch.jit.script(add) |
| self._run_helper(add_jit, add, True, x, y, 2.0) |
| |
| def clamp0(x : torch.Tensor, f : float): |
| o = torch.rand_like(x) |
| o = o * torch.clamp(x, min=f) |
| return o |
| clamp0_jit = torch.jit.script(clamp0) |
| self._run_helper(clamp0_jit, clamp0, True, x, 0.5) |
| |
| def clamp1(x : torch.Tensor, f : float, ff : float): |
| o = torch.rand_like(x) |
| o = o * torch.clamp(x, min=f, max=ff) |
| return o |
| clamp1_jit = torch.jit.script(clamp1) |
| self._run_helper(clamp1_jit, clamp1, True, x, -0.2, 0.7) |
| |
| def threshold(x : torch.Tensor, th : float, val : float): |
| o = torch.rand_like(x) |
| o = x * torch.threshold(o, th, val) |
| return o |
| threshold_jit = torch.jit.script(threshold) |
| self._run_helper(threshold_jit, threshold, True, x, 0.2, 0.9) |
| |
| def where(x : torch.Tensor, y : torch.Tensor, cond : torch.Tensor): |
| o = torch.rand_like(x) |
| o = o * torch.where(cond, x, y) |
| return o |
| where_jit = torch.jit.script(where) |
| self._run_helper(where_jit, where, True, x, y, cond) |
| |
| if __name__ == '__main__': |
| run_tests() |