| import unittest |
| import os |
| |
| import torch |
| |
| from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR |
| from torch.testing._internal.codegen.random_topo_test import runDefaultTestWithSeed |
| |
| from test_jit import JitTestCase, RUN_CUDA |
| |
| from jit.test_fuser_common import TestFuserCommon # noqa: F401 |
| |
| import itertools |
| import numpy as np |
| |
| os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1' |
| os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1' |
| os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0' |
| |
| if GRAPH_EXECUTOR == ProfilingMode.PROFILING: |
| torch._C._jit_set_texpr_fuser_enabled(False) |
| torch._C._jit_set_profiling_executor(True) |
| torch._C._jit_set_profiling_mode(True) |
| |
| FUSION_GROUP = 'prim::CudaFusionGroup' |
| FUSION_GUARD = 'prim::CudaFusionGuard' |
| |
| class TestCudaFuser(JitTestCase): |
| |
| def _getSubgraphInFusion(self, graph): |
| num_node = 0 |
| subgraph = None |
| |
| def count(block, ret): |
| for n in block.nodes(): |
| if n.kind() == FUSION_GROUP: |
| ret[0] = ret[0] + 1 |
| self.assertTrue(n.hasAttribute('Subgraph')) |
| ret[1] = n.g('Subgraph') |
| for block in n.blocks(): |
| count(block, ret) |
| ret = [num_node, subgraph] |
| count(graph, ret) |
| self.assertEqual(ret[0], 1) |
| return ret[1] |
| |
| def setUp(self): |
| super(TestCudaFuser, self).setUp() |
| self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu() |
| self.old_gpu_fuse = torch._C._jit_can_fuse_on_gpu() |
| torch._C._jit_override_can_fuse_on_cpu(False) |
| torch._C._jit_override_can_fuse_on_gpu(False) |
| self.old_guard = torch._C._jit_set_nvfuser_guard_mode(False) |
| |
| if(RUN_CUDA): |
| self.old_nvfuser = torch._C._jit_set_nvfuser_enabled(True) |
| |
| def tearDown(self): |
| if(RUN_CUDA): |
| torch._C._jit_set_nvfuser_enabled(self.old_nvfuser) |
| torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse) |
| torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse) |
| torch._C._jit_set_nvfuser_guard_mode(self.old_guard) |
| super(TestCudaFuser, self).tearDown() |
| |
| def _run_helper(self, jit_op, op, *args): |
| torch.cuda.manual_seed_all(123) |
| jit_o = jit_op(*args) |
| torch.cuda.manual_seed_all(123) |
| jit_o = jit_op(*args) |
| torch.cuda.manual_seed_all(123) |
| o = op(*args) |
| self.assertEqual(o, jit_o) |
| self.assertGraphContains(jit_op.graph_for(*args), FUSION_GUARD) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_half(self): |
| def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: float): |
| o_16 = torch.add(x, y) |
| o_32_a = torch.add(y, z, alpha=alpha) |
| o_32_b = torch.add(o_16, z) |
| return (o_16, o_32_a, o_32_b) |
| |
| t_jit = torch.jit.script(t) |
| alpha = 0.5 |
| # stick to integers, this avoid the numerical difference due to our |
| # promotion |
| x = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda") |
| y = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda") |
| z = torch.randint(0, 256, (4, 8)).to(dtype=torch.float16, device="cuda") |
| jit_o = t_jit(x, y, z, alpha) |
| jit_o = t_jit(x, y, z, alpha) |
| o = t(x, y, z, alpha) |
| for oo, jit_oo in zip(o, jit_o): |
| self.assertEqual(oo.dtype, jit_oo.dtype) |
| self.assertEqual(oo, jit_oo) |
| self.assertGraphContains(t_jit.graph_for(x, y, z, alpha), FUSION_GUARD) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_const(self): |
| def t(x, y): |
| o = x + y |
| o = o + 2.0 |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, dtype=torch.float, device="cuda") |
| y = torch.randn(4, 8, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y) |
| jit_o = t_jit(x, y) |
| o = t(x, y) |
| self.assertEqual(o, jit_o) |
| self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_chunk(self): |
| def t(x, y, z, q): |
| o = x + q |
| x0, x1 = torch.chunk(o, 2) |
| o = x0 + x1 |
| o = o + y |
| o = o * z |
| o = torch.relu(o) |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, dtype=torch.float, device="cuda") |
| y = torch.randn(2, 8, dtype=torch.float, device="cuda") |
| z = torch.randn(2, 8, dtype=torch.float, device="cuda") |
| q = torch.randn(4, 8, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, z, q) |
| jit_o = t_jit(x, y, z, q) |
| o = t(x, y, z, q) |
| self.assertEqual(o, jit_o) |
| self.assertGraphContains(t_jit.graph_for(x, y, z, q), FUSION_GUARD) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_scalar_input(self): |
| def t(x: torch.Tensor, y: torch.Tensor, z: float): |
| o = x + y |
| o = o + z |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(4, 8, 1, 32, dtype=torch.float, device="cuda") |
| y = y.expand(4, 8, 32, 32) |
| jit_o = t_jit(x, y, 2.0) |
| jit_o = t_jit(x, y, 2.0) |
| o = t(x, y, 2.0) |
| self.assertEqual(o, jit_o) |
| self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_broadcasting_0(self): |
| |
| def t(x: torch.Tensor, y: torch.Tensor, z: float): |
| o = x + y |
| o = o + z |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(32, 32, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, 2.0) |
| jit_o = t_jit(x, y, 2.0) |
| o = t(x, y, 2.0) |
| self.assertEqual(o, jit_o) |
| subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0)) |
| self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_broadcasting_1(self): |
| |
| def t(x: torch.Tensor, y: torch.Tensor, z: float): |
| o = x + y |
| o = o + z |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(1, 32, 32, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, 2.0) |
| jit_o = t_jit(x, y, 2.0) |
| o = t(x, y, 2.0) |
| self.assertEqual(o, jit_o) |
| subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0)) |
| self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_broadcasting_2(self): |
| |
| def t(x: torch.Tensor, y: torch.Tensor, z: float): |
| o = x + y |
| o = o + z |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 1, 32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(8, 32, 32, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, 2.0) |
| jit_o = t_jit(x, y, 2.0) |
| o = t(x, y, 2.0) |
| self.assertEqual(o, jit_o) |
| subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0)) |
| self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_broadcasting_3(self): |
| |
| def t(x: torch.Tensor, y: torch.Tensor, z: float): |
| o = x + y |
| o = o + z |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(8, 17, 8, dtype=torch.float, device="cuda") |
| y = torch.randn(8, 17, 1, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, 2.0) |
| jit_o = t_jit(x, y, 2.0) |
| o = t(x, y, 2.0) |
| self.assertEqual(o, jit_o) |
| subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0)) |
| self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False) |
| |
| # test_broadcasting_partition_logic_X |
| # Testing partition logic that is capable to avoid creating unsupported |
| # broadcasting semantics in CudaFusionGroup |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_broadcasting_partition_logic_0(self): |
| |
| def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): |
| x = x + 12.0 |
| o1 = x + y |
| o2 = x + z |
| o = o1 + o2 |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, 6, 8, dtype=torch.float32, device="cuda") |
| y = torch.randn(8, 6, 8, dtype=torch.float32, device="cuda") |
| z = torch.randn(6, 8, dtype=torch.float32, device="cuda") |
| jit_o = t_jit(x, y, z) |
| jit_o = t_jit(x, y, z) |
| o = t(x, y, z) |
| self.assertEqual(o, jit_o) |
| subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, z)) |
| self.assertGraphContainsExactly(subgraph, 'aten::add', 4, consider_subgraphs=False) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_broadcasting_partition_logic_1(self): |
| |
| def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): |
| x = x + 12.0 |
| o1 = x + y |
| o2 = x + z |
| o = o1 + o2 |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(8, 6, 8, dtype=torch.float32, device="cuda") |
| y = torch.randn(4, 8, 6, 8, dtype=torch.float32, device="cuda") |
| z = torch.randn(4, 1, 6, 8, dtype=torch.float32, device="cuda") |
| jit_o = t_jit(x, y, z) |
| jit_o = t_jit(x, y, z) |
| o = t(x, y, z) |
| self.assertEqual(o, jit_o) |
| subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, z)) |
| self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False) |
| |
| @unittest.skipIf(True, "Broadcast with different output not supported yet") |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_broadcasting_multiple_output_shape(self): |
| def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): |
| o = x + 12 |
| o1 = o + y |
| o2 = o + z |
| oo = o1.sum() + o2.sum() |
| return oo |
| t_jit = torch.jit.script(t) |
| x = torch.randn(32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(2, 32, 32, dtype=torch.float, device="cuda") |
| z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, z) |
| jit_o = t_jit(x, y, z) |
| o = t(x, y, z) |
| self.assertEqual(o, jit_o) |
| # Currently cannot fuse this |
| self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) |
| |
| @unittest.skipIf(True, "broadcast on branches can't be resolved yet") |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_broadcasting_multiple_output(self): |
| def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): |
| o = x + 12 |
| o1 = o + y |
| o2 = o + z |
| oo = o1.sum() + o2.sum() |
| return oo |
| t_jit = torch.jit.script(t) |
| x = torch.randn(32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(4, 32, 32, dtype=torch.float, device="cuda") |
| z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, z) |
| jit_o = t_jit(x, y, z) |
| o = t(x, y, z) |
| self.assertEqual(o, jit_o) |
| # Currently cannot fuse this |
| self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) |
| |
| def _binary_test_helper(self, operation): |
| def t(x: torch.Tensor, y: torch.Tensor, z: float): |
| o = x + z |
| o = operation(o, y) |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, 2.0) |
| jit_o = t_jit(x, y, 2.0) |
| o = t(x, y, 2.0) |
| self.assertEqual(o, jit_o) |
| self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD) |
| |
| def _unary_test_helper(self, operation): |
| def t(x: torch.Tensor, z: float): |
| o = x + z |
| o = operation(o) |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, 2.0) |
| jit_o = t_jit(x, 2.0) |
| o = t(x, 2.0) |
| self.assertEqual(o, jit_o) |
| self.assertGraphContains(t_jit.graph_for(x, 2.0), FUSION_GUARD) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_unary_ops(self): |
| operations = [torch.neg, |
| torch.abs, |
| torch.log, |
| torch.log10, |
| torch.log1p, |
| torch.log2, |
| torch.lgamma, |
| torch.exp, |
| torch.expm1, |
| torch.erf, |
| torch.erfc, |
| torch.cos, |
| torch.acos, |
| torch.cosh, |
| torch.sin, |
| torch.asin, |
| torch.tan, |
| torch.atan, |
| torch.sqrt, |
| torch.rsqrt, |
| torch.ceil, |
| torch.floor, |
| torch.round, |
| torch.trunc, |
| torch.frac, |
| torch.reciprocal, |
| torch.relu, |
| torch.sigmoid, |
| torch.tanh, |
| torch.nn.functional.gelu] |
| for op in operations: |
| self._unary_test_helper(op) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_binary_ops(self): |
| operations = [torch.div, |
| torch.mul, |
| torch.atan2, |
| torch.max, |
| torch.min, |
| torch.pow, |
| torch.remainder, |
| torch.fmod, |
| torch.eq, |
| torch.ne, |
| torch.ge, |
| torch.gt, |
| torch.le, |
| torch.lt] |
| for op in operations: |
| self._binary_test_helper(op) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| # legacy fuser does not work for rand_like, see issue #34361 |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") |
| def test_ternary_ops(self): |
| x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| cond = torch.randint(0, 2, (4, 8, 32, 32)).to(dtype=torch.bool, device="cuda") |
| |
| def add(x: torch.Tensor, other: torch.Tensor, alpha: float): |
| o = torch.relu(x) |
| o = torch.add(o, other=other, alpha=alpha) |
| return o |
| add_jit = torch.jit.script(add) |
| self._run_helper(add_jit, add, x, y, 2.0) |
| |
| def clamp0(x: torch.Tensor, f: float): |
| o = torch.rand_like(x) |
| o = o * torch.clamp(x, min=f) |
| return o |
| clamp0_jit = torch.jit.script(clamp0) |
| self._run_helper(clamp0_jit, clamp0, x, 0.5) |
| |
| def clamp1(x: torch.Tensor, f: float, ff: float): |
| o = torch.rand_like(x) |
| o = o * torch.clamp(x, min=f, max=ff) |
| return o |
| clamp1_jit = torch.jit.script(clamp1) |
| self._run_helper(clamp1_jit, clamp1, x, -0.2, 0.7) |
| |
| def threshold(x: torch.Tensor, th: float, val: float): |
| o = torch.rand_like(x) |
| o = x * torch.threshold(o, th, val) |
| return o |
| threshold_jit = torch.jit.script(threshold) |
| self._run_helper(threshold_jit, threshold, x, 0.2, 0.9) |
| |
| def where(x: torch.Tensor, y: torch.Tensor, cond: torch.Tensor): |
| o = torch.rand_like(x) |
| o = o * torch.where(cond, x, y) |
| return o |
| where_jit = torch.jit.script(where) |
| self._run_helper(where_jit, where, x, y, cond) |
| |
| def lerp(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): |
| o = torch.rand_like(x) |
| o = o * torch.lerp(x, y, z) |
| return o |
| lerp_jit = torch.jit.script(lerp) |
| self._run_helper(lerp_jit, lerp, x, y, z) |
| |
| def lerp_scale(x: torch.Tensor, y: torch.Tensor, z: float): |
| o = torch.rand_like(x) |
| o = o * torch.lerp(x, y, z) |
| return o |
| lerp_scale_jit = torch.jit.script(lerp_scale) |
| self._run_helper(lerp_scale_jit, lerp_scale, x, y, 0.5) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") |
| def test_addcmul_ops(self): |
| x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| |
| def addcmul(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, value: float): |
| o = torch.add(x, 0.5) |
| o = torch.addcmul(o, y, z, value=value) |
| return o |
| addcmul_jit = torch.jit.script(addcmul) |
| self._run_helper(addcmul_jit, addcmul, x, y, z, 2.0) |
| |
| def addcmul_no_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): |
| o = torch.add(x, 0.5) |
| o = torch.addcmul(o, y, z) |
| return o |
| addcmul_no_alpha_jit = torch.jit.script(addcmul_no_alpha) |
| self._run_helper(addcmul_no_alpha_jit, addcmul_no_alpha, x, y, z) |
| |
| def addcmul_const_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): |
| o = torch.add(x, 0.5) |
| o = torch.addcmul(o, y, z, value=0.75) |
| return o |
| addcmul_const_alpha_jit = torch.jit.script(addcmul_const_alpha) |
| self._run_helper(addcmul_const_alpha_jit, addcmul_const_alpha, x, y, z) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_dynamic_size(self): |
| old_guard = torch._C._jit_set_nvfuser_guard_mode(True) |
| torch._C._jit_set_bailout_depth(20) |
| |
| def t(x: torch.Tensor, y: torch.Tensor, z: float): |
| o = x + y |
| o = o + z |
| return o |
| t_jit = torch.jit.script(t) |
| x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") |
| y = torch.randn(32, 32, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, 2.0) |
| jit_o = t_jit(x, y, 2.0) |
| o = t(x, y, 2.0) |
| self.assertEqual(o, jit_o) |
| subgraph = self._getSubgraphInFusion(t_jit.graph_for(x, y, 2.0)) |
| self.assertGraphContainsExactly(subgraph, 'aten::add', 2, consider_subgraphs=False) |
| |
| # this test is not ideal, as we rely on the bailout to test it and we |
| # don't know a way to verify the bailout graph to validate the proper |
| # fusion. |
| x = torch.randn(8, 32, 16, 8, dtype=torch.float, device="cuda") |
| y = torch.randn(16, 8, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, 2.0) |
| jit_o = t_jit(x, y, 2.0) |
| o = t(x, y, 2.0) |
| self.assertEqual(o, jit_o) |
| self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD) |
| x = torch.randn(8, 17, 8, dtype=torch.float, device="cuda") |
| y = torch.randn(8, 17, 1, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x, y, 2.0) |
| jit_o = t_jit(x, y, 2.0) |
| o = t(x, y, 2.0) |
| self.assertEqual(o, jit_o) |
| self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GUARD) |
| torch._C._jit_set_nvfuser_guard_mode(old_guard) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| def test_random_topo(self): |
| os.environ["PYTORCH_CUDA_FUSER_DISABLE_FALLBACK"] = "1" |
| self.assertTrue(runDefaultTestWithSeed(28449)) |
| |
| def _compare(self, desc, inp1, inp2, error): |
| a = inp1.clone().detach().cpu().numpy() |
| b = inp2.clone().detach().cpu().numpy() |
| close = np.allclose(a, b, error, error) |
| if not close: |
| print(desc, close) |
| z = a - b |
| index = (np.abs(z) >= error + error * np.abs(b)).nonzero() |
| print("dif : ", z[index]) |
| print("inp1 : ", a[index]) |
| print("inp2 : ", b[index]) |
| return close |
| |
| # Permutation helper that applies binary operation between two tensors: |
| # 1. applies separate permutation `perm0` & `perm1` to two inputs |
| # 2. reduce dimension `broadcast_axis` of operand two to size 1 |
| # The purpose of this test is to ensure permutation works well in |
| # complicated cases with arbitrary stride order and broadcasting dimensions |
| def _permutation_helper(self, sizes, broadcast_axis, dtype, device, perm0, perm1): |
| def t(x: torch.Tensor, y: torch.Tensor): |
| o = torch.add(x, y) |
| o = torch.relu(o) |
| return o |
| |
| x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute([perm0.index(i) for i in range(len(sizes))]) |
| if broadcast_axis >= 0: |
| sizes[broadcast_axis] = 1 |
| y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute([perm1.index(i) for i in range(len(sizes))]) |
| t_jit = torch.jit.script(t) |
| jit_o = t_jit(x, y) |
| jit_o = t_jit(x, y) |
| o = t(x, y) |
| self.assertEqual(o.dtype, jit_o.dtype) |
| self.assertEqual(o, jit_o) |
| self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) |
| |
| # end-2-end test of permutation & contiguity handling in integration. |
| # we are testing inputs with all combination of permutation order, just to |
| # ensure that integration would be able to generate functionally correct |
| # kernels |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_binary_ops_permutation(self): |
| # note that num_dim is exclusive from len(x), so we are not reducing |
| # to single element (codegen limitation at this moment) |
| x = [7, 8, 12] |
| b_axes = range(-1, len(x)) |
| for b_axis in b_axes: |
| for perm0 in itertools.permutations(range(len(x))): |
| for perm1 in itertools.permutations(range(len(x))): |
| x = [7, 8, 12] |
| self._permutation_helper(x, b_axis, torch.float32, "cuda", perm0, perm1) |
| |
| def _reduction_helper(self, sizes, reduction_axis, dtype, device, perm0, perm1): |
| class MyReduction(torch.nn.Module): |
| __constants__ = ['reduction_axis'] |
| |
| def __init__(self): |
| super(MyReduction, self).__init__() |
| self.reduction_axis = reduction_axis |
| |
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| o = torch.add(x, y) |
| o = torch.sum(o, dim=self.reduction_axis) |
| return o |
| |
| t = MyReduction() |
| |
| x = torch.randn([sizes[i] for i in perm0], dtype=dtype, device=device).permute([perm0.index(i) for i in range(len(sizes))]) |
| y = torch.randn([sizes[i] for i in perm1], dtype=dtype, device=device).permute([perm1.index(i) for i in range(len(sizes))]) |
| t_jit = torch.jit.script(t) |
| jit_o = t_jit(x, y) |
| jit_o = t_jit(x, y) |
| o = t(x, y) |
| self.assertEqual(o.dtype, jit_o.dtype) |
| # numerical issues here due to our scheduling. |
| # can't use `self.assertEqual(o, jit_o)` |
| self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4)) |
| self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_reduction(self): |
| for x in ([7, 8, 12], [12, 8, 7, 9, 15], [128, 16, 8, 32]): |
| # note that num_dim is exclusive from len(x), so we are not reducing |
| # to single element (codegen limitation at this moment) |
| for num_reduce_dim in range(1, len(x)): |
| for axes in itertools.combinations(range(len(x)), num_reduce_dim): |
| perm0 = range(len(x)) |
| perm1 = range(len(x)) |
| self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_reduction_permutation(self): |
| x = [7, 8, 12] |
| # note that num_dim is exclusive from len(x), so we are not reducing |
| # to single element (codegen limitation at this moment) |
| for num_reduce_dim in range(1, len(x)): |
| for axes in itertools.combinations(range(len(x)), num_reduce_dim): |
| for perm0 in itertools.permutations(range(len(x))): |
| for perm1 in itertools.permutations(range(len(x))): |
| self._reduction_helper(x, axes, torch.float32, "cuda", perm0, perm1) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_reduction_multiple_output(self): |
| old_guard = torch._C._jit_set_nvfuser_guard_mode(True) |
| torch._C._jit_set_bailout_depth(20) |
| |
| def t(x: torch.Tensor, y: torch.Tensor, scale: float, z: torch.Tensor): |
| o = torch.mul(x, y) |
| o = torch.mul(o, scale) |
| out1 = torch.mul(o, z) |
| out2 = torch.sum(out1, dim=[2]) |
| return out1, out2 |
| |
| t_jit = torch.jit.script(t) |
| x = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda") |
| y = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda") |
| z = torch.randn(8, 4, 10, 16, dtype=torch.float, device="cuda") |
| scale = 0.5 |
| jit_o = t_jit(x, y, scale, z) |
| jit_o = t_jit(x, y, scale, z) |
| o = t(x, y, scale, z) |
| for oo, jit_oo in zip(o, jit_o): |
| self.assertEqual(oo.dtype, jit_oo.dtype) |
| self.assertEqual(oo, jit_oo) |
| self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD) |
| |
| x = x.to(memory_format=torch.channels_last) |
| y = y.to(memory_format=torch.channels_last) |
| z = z.to(memory_format=torch.channels_last) |
| jit_o = t_jit(x, y, scale, z) |
| jit_o = t_jit(x, y, scale, z) |
| o = t(x, y, scale, z) |
| for oo, jit_oo in zip(o, jit_o): |
| self.assertEqual(oo.dtype, jit_oo.dtype) |
| self.assertEqual(oo, jit_oo) |
| self.assertGraphContains(t_jit.graph_for(x, y, scale, z), FUSION_GUARD) |
| torch._C._jit_set_nvfuser_guard_mode(old_guard) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_reduction_dtype(self): |
| def t(x: torch.Tensor): |
| o = torch.mul(x, 1.0) |
| o = torch.sum(o, dim=[2], dtype=torch.float32) |
| return o |
| t_jit = torch.jit.script(t) |
| |
| x = torch.randn(8, 4, 16, dtype=torch.float, device="cuda") |
| jit_o = t_jit(x) |
| jit_o = t_jit(x) |
| o = t(x) |
| self.assertEqual(o.dtype, jit_o.dtype) |
| self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4)) |
| self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_reduction_half(self): |
| def t(x: torch.Tensor): |
| o = torch.mul(x, 1.0) |
| o = torch.sum(o, dim=[2]) |
| return o |
| |
| t_jit = torch.jit.script(t) |
| x = torch.randn(8, 4, 16, dtype=torch.float16, device="cuda") |
| jit_o = t_jit(x) |
| jit_o = t_jit(x) |
| o = t(x) |
| self.assertEqual(o.dtype, jit_o.dtype) |
| self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4)) |
| self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_pw_single_reduction_partition(self): |
| sizes = [8, 8, 8] |
| dtype = torch.float |
| device = "cuda" |
| x = torch.randn(sizes, dtype=dtype, device=device) |
| y = torch.randn(sizes, dtype=dtype, device=device) |
| z = torch.randn(sizes, dtype=dtype, device=device) |
| |
| def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): |
| o = torch.add(x, y) |
| o = torch.sum(o, dim=[0]) |
| o = torch.add(o, z) |
| return o |
| t_jit = torch.jit.script(t) |
| jit_o = t_jit(x, y, z) |
| jit_o = t_jit(x, y, z) |
| o = t(x, y, z) |
| self.assertEqual(o.dtype, jit_o.dtype) |
| self.assertEqual(o, jit_o) |
| self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_single_reduction_broadcast(self): |
| dtype = torch.float |
| device = "cuda" |
| x = torch.randn([7, 4, 8], dtype=dtype, device=device) |
| y = torch.randn([4, 8], dtype=dtype, device=device) |
| z = torch.randn([1, 4, 8], dtype=dtype, device=device) |
| |
| def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): |
| o = torch.add(x, y) |
| o = torch.add(o, z) |
| o = torch.sum(o, dim=[0]) |
| return o |
| t_jit = torch.jit.script(t) |
| jit_o = t_jit(x, y, z) |
| jit_o = t_jit(x, y, z) |
| o = t(x, y, z) |
| self.assertEqual(o.dtype, jit_o.dtype) |
| self.assertEqual(o, jit_o) |
| self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GUARD) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_profiling_node(self): |
| dtype = torch.float |
| device = "cuda" |
| x = torch.randn(4, 8, 8, 8, dtype=dtype, device=device) |
| |
| def repro(x: torch.Tensor, alpha: float): |
| o = torch.rand_like(x) |
| o = torch.add(o, alpha) |
| return o |
| repro_jit = torch.jit.script(repro) |
| self._run_helper(repro_jit, repro, x, 0.6) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_reduction_sizes_op(self): |
| dtype = torch.float |
| device = "cuda" |
| x = torch.randn(2, 3, 4, 5, dtype=dtype, device=device) |
| y = torch.randn(2, 3, 4, 5, dtype=dtype, device=device) |
| |
| def t(x: torch.Tensor, y: torch.Tensor): |
| o = x + y |
| o = torch.relu(o) |
| o = o.sum((1, 3)) |
| return o.size() |
| t_jit = torch.jit.script(t) |
| jit_o = t_jit(x, y) |
| jit_o = t_jit(x, y) |
| o = t(x, y) |
| self.assertEqual(o, jit_o) |
| # since the output value is not used at all, the fusion operator should |
| # have been optimized away |
| self.assertGraphContainsExactly(t_jit.graph_for(x, y), FUSION_GUARD, 0) |
| |
| class TestPassManagerCudaFuser(JitTestCase): |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, |
| "Requires fusion optimization pass to be effective") |
| def test_context_manager_test(self): |
| x = torch.randn(4, 8, dtype=torch.float, device="cuda") |
| y = torch.randn(4, 8, dtype=torch.float, device="cuda") |
| with torch.jit.fuser('fuser2'): |
| with torch.jit.fuser('fuser2'): |
| |
| def t1(x, y): |
| o = x + y |
| o = o + 2.0 |
| return o |
| t_jit = torch.jit.script(t1) |
| t_jit(x, y) |
| t_jit(x, y) |
| self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD) |
| |
| def t2(x, y): |
| o = x + y |
| o = o + 3.0 |
| return o |
| t_jit_2 = torch.jit.script(t2) |
| t_jit_2(x, y) |
| t_jit_2(x, y) |
| self.assertGraphContains(t_jit_2.graph_for(x, y), FUSION_GUARD) |
| |
| def t3(x, y): |
| o = x + y |
| o = o + 4.0 |
| return o |
| t_jit_3 = torch.jit.script(t3) |
| t_jit_3(x, y) |
| t_jit_3(x, y) |
| self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), FUSION_GUARD, 0) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| def test_register_fuser(self): |
| self.assertFalse(torch._C._jit_set_nvfuser_enabled(True)) |
| self.assertTrue(torch._C._jit_nvfuser_enabled()) |
| self.assertTrue(torch._C._jit_set_nvfuser_enabled(True)) |
| self.assertTrue(torch._C._jit_nvfuser_enabled()) |
| self.assertTrue(torch._C._jit_set_nvfuser_enabled(False)) |
| self.assertFalse(torch._C._jit_nvfuser_enabled()) |
| |
| |
| if __name__ == '__main__': |
| run_tests() |