| # Owner(s): ["NNC"] |
| |
| import contextlib |
| import math |
| import operator |
| import os |
| import unittest |
| import warnings |
| from typing import List |
| |
| import torch |
| import torch.nn.functional as F |
| from torch.testing import FileCheck |
| |
| # these needs to be set before `common_utils` |
| # infers `GRAPH_EXECUTOR`. |
| # this file **requires** these settings |
| # and setting them after `GRAPH_EXECUTOR` is |
| # inferred erroneously runs or skips |
| # some tests |
| torch._C._jit_set_profiling_executor(True) |
| torch._C._get_graph_executor_optimize(True) |
| |
| from itertools import combinations, permutations, product |
| |
| from textwrap import dedent |
| |
| from jit.test_fuser_common import TestFuserCommon # noqa: F401 |
| |
| from test_jit import ( |
| backward_graph, |
| get_lstm_inputs, |
| get_milstm_inputs, |
| LSTMCellC, |
| LSTMCellF, |
| LSTMCellS, |
| MiLSTMCell, |
| ) |
| |
| from torch.testing._internal.common_device_type import ( |
| instantiate_device_type_tests, |
| onlyCPU, |
| OpDTypes, |
| ops, |
| ) |
| from torch.testing._internal.common_jit import JitCommonTestCase |
| |
| from torch.testing._internal.common_methods_invocations import op_db |
| from torch.testing._internal.common_utils import ( |
| enable_profiling_mode_for_profiling_tests, |
| GRAPH_EXECUTOR, |
| IS_FBCODE, |
| ProfilingMode, |
| run_tests, |
| skipIfTorchDynamo, |
| slowTest, |
| TEST_WITH_ASAN, |
| TEST_WITH_ROCM, |
| ) |
| from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn |
| from torch.testing._internal.jit_utils import ( |
| clone_inputs, |
| get_traced_sample_variant_pairs, |
| JitTestCase, |
| NoTracerWarnContextManager, |
| RUN_CUDA, |
| RUN_CUDA_HALF, |
| RUN_CUDA_MULTI_GPU, |
| set_fusion_group_inlining, |
| TensorExprTestOptions, |
| warmup_backward, |
| ) |
| |
| FUSION_GROUP = "prim::TensorExprGroup" |
| LLVM_ENABLED = torch._C._llvm_enabled() |
| |
| autograd_check_set = { |
| "aten::__is__", |
| "prim::AutogradAllNonZero", |
| "prim::AutogradAllZero", |
| "prim::ListConstruct", |
| } |
| |
| |
| def strip_profiling_nodes(nodes): |
| profiling_opcodes = {"prim::BailoutTemplate", "prim::BailOut"} |
| return [n for n in nodes if n.kind() not in profiling_opcodes] |
| |
| |
| def warmup_forward(f, *args, profiling_count=2): |
| for i in range(profiling_count): |
| results = f(*args) |
| |
| return results |
| |
| |
| @contextlib.contextmanager |
| def texpr_reductions_enabled(): |
| old = torch._C._jit_set_texpr_reductions_enabled(True) |
| try: |
| yield |
| finally: |
| torch._C._jit_set_texpr_reductions_enabled(old) |
| |
| |
| @contextlib.contextmanager |
| def texpr_enable_strategy(strategy): |
| old = torch._C._jit_set_fusion_strategy(strategy) |
| try: |
| yield |
| finally: |
| torch._C._jit_set_fusion_strategy(old) |
| |
| |
| @contextlib.contextmanager |
| def inline_fusion_groups(): |
| old_inlining = torch._C._debug_get_fusion_group_inlining() |
| torch._C._debug_set_fusion_group_inlining(True) |
| try: |
| yield |
| finally: |
| torch._C._debug_set_fusion_group_inlining(old_inlining) |
| |
| |
| class TestTEFuser(JitTestCase): |
| def setUp(self): |
| super().setUp() |
| self.tensorexpr_options = TensorExprTestOptions() |
| |
| # note: `self.dynamic_shapes` instatiated in specialization of class |
| # defined below |
| |
| fusion_strategy = [("DYNAMIC", 20)] if self.dynamic_shapes else [("STATIC", 20)] |
| self.old_fusion_strategy = torch._C._jit_set_fusion_strategy(fusion_strategy) |
| |
| self.devices = ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"] |
| self.int_dtypes = [ |
| torch.int8, |
| torch.int16, |
| torch.int32, |
| torch.int64, |
| torch.bool, |
| ] |
| self.fp_dtypes = [ |
| torch.float16, |
| torch.float32, |
| torch.float64, |
| torch.bfloat16, |
| ] |
| self.dtypes = self.int_dtypes + self.fp_dtypes |
| |
| def tearDown(self): |
| self.tensorexpr_options.restore() |
| torch._C._jit_set_fusion_strategy(self.old_fusion_strategy) |
| super().tearDown() |
| |
| def assertAllFused(self, graph, except_for=None): |
| except_for = except_for if except_for is not None else set() |
| # TODO - upstream |
| guards = ( |
| "prim::TypeCheck", |
| "prim::RequiresGradCheck", |
| "prim::TensorExprDynamicGuard", |
| ) |
| guard_found = False |
| |
| def autodiff_guard(node): |
| if node.kind() != "aten::all": |
| return False |
| inps = list(node.inputs()) |
| if len(inps) != 1 or inps[0].node().kind() != "prim::ListConstruct": |
| return False |
| li_inps = list(inps[0].node().inputs()) |
| for li_inp in li_inps: |
| if li_inp.node().kind() in ( |
| "prim::AutogradAllNonZero", |
| "prim::AutogradAllZero", |
| ): |
| return True |
| return False |
| |
| def is_guard(node): |
| return node.kind() in guards or autodiff_guard(node) |
| |
| for node in graph.block().nodes(): |
| if node.kind() == "prim::Constant": |
| continue |
| if is_guard(node): |
| self.assertFalse(guard_found) |
| guard_found = True |
| continue |
| if node.kind() in except_for: |
| continue |
| if node.kind() == "prim::If": |
| self.assertTrue(is_guard(node.prev())) |
| continue |
| self.assertTrue(False, "Found unexpected node:" + node.kind()) |
| |
| self.assertTrue(guard_found) |
| |
| def assertLastGraphAllFused(self): |
| self.assertAllFused(torch.jit.last_executed_optimized_graph()) |
| |
| def findFusionGroups(self, graph): |
| result = [] |
| for n in graph.nodes(): |
| if n.kind() == FUSION_GROUP: |
| result.append(n.g("Subgraph")) |
| continue |
| for block in n.blocks(): |
| result += self.findFusionGroups(block) |
| return result |
| |
| def test_typecheck(self): |
| a = torch.ones(1) |
| |
| def fused_kernel(a, b): |
| return (a + b) * 2.0 |
| |
| scripted = self.checkScript(fused_kernel, (a, a)) |
| graph = scripted.graph_for(a, a) |
| # double check we fused |
| fusion_groups = self.findFusionGroups(graph) |
| self.assertEqual(len(fusion_groups), 1) |
| # we use a bigger tensor now (size 2) |
| # if we won't trigger a recompilation |
| # we will still create a tensor up to (size 1) |
| # if the type check fails |
| a = torch.ones(2) |
| # shape changed if we don't trigger recompilation |
| # we would compute the wrong result silently |
| self.assertEqual(scripted(a, a), fused_kernel(a, a)) |
| |
| def test_sum_simple(self): |
| def func(x): |
| x2 = x * x |
| return x2.sum() |
| |
| with texpr_reductions_enabled(): |
| a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") |
| a = a.reshape(5, 3) |
| scripted = self.checkScript(func, (a,)) |
| self.assertLastGraphAllFused() |
| |
| def test_nop(self): |
| pass |
| |
| def test_sum_dim(self): |
| def func(x): |
| return x.sum((0,)) * 2 |
| |
| def func_neg(x): |
| return x.sum((-2,)) * 2 |
| |
| with texpr_reductions_enabled(): |
| a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") |
| a = a.reshape(5, 3) |
| scripted = self.checkScript(func, (a,)) |
| self.assertLastGraphAllFused() |
| scripted = self.checkScript(func_neg, (a,)) |
| self.assertLastGraphAllFused() |
| |
| def test_sum_keepdim_cast(self): |
| def func(x): |
| return x.sum((0,), keepdim=True, dtype=torch.double) * 2 |
| |
| with texpr_reductions_enabled(): |
| a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu") |
| a = a.reshape(5, 3) |
| |
| self.checkScript(func, (a,)) |
| self.assertLastGraphAllFused() |
| |
| def test_abs(self): |
| for device in self.devices: |
| |
| def func(x): |
| return x.abs() * 2 |
| |
| a = torch.randn(5, device=device) |
| scripted = self.checkScript(func, (a,)) |
| self.assertLastGraphAllFused() |
| |
| def test_unsqueeze_size_calculation(self): |
| for device in self.devices: |
| |
| def foo(b, d): |
| x = d.unsqueeze(1) |
| y = x * 42.0 |
| z = b + y |
| r = z / 42.0 |
| return r |
| |
| inputs = ( |
| torch.rand(20, 28, device=device, requires_grad=True), |
| torch.rand(20, device=device), |
| ) |
| scripted = self.checkScript(foo, inputs) |
| self.assertAllFused(scripted.graph_for(*inputs)) |
| |
| def test_zero_element_tensors(self): |
| for device in self.devices: |
| |
| def decode(sin_t, cos_t): |
| theta = torch.atan2(sin_t.float(), cos_t.float()) |
| return theta |
| |
| sin = torch.zeros(0, device=device) |
| cos = torch.zeros(0, device=device) |
| inputs = [sin, cos] |
| ge = self.checkScript(decode, inputs) |
| |
| def test_arg_configurations_smoke(self): |
| if self.dynamic_shapes: |
| self.skipTest("TODO: chunk dynamic shapes") |
| |
| # A smoke test to make sure we won't use the same kernel for contiguous |
| # and non-contiguous arguments. |
| # TODO: add optionally enabled debug counters to the fuser to verify |
| # that we really can tell the difference between configurations |
| for device in self.devices: |
| |
| def f(x, y): |
| z1, z2 = (x + y).chunk(2, dim=1) |
| return z1 * z2 |
| |
| x = torch.randn(4, 4, dtype=torch.float, device=device) |
| y = torch.randn(4, 4, dtype=torch.float, device=device) |
| traced_f = torch.jit.trace( |
| f, |
| ( |
| x, |
| y, |
| ), |
| ) |
| self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y)) |
| |
| def test_broadcast(self): |
| for device in self.devices: |
| |
| def scaleshift(x, scale, shift): |
| return x * scale + shift |
| |
| inputs = [ |
| torch.randn(4, 4, dtype=torch.float, device=device), |
| torch.randn(4, dtype=torch.float, device=device), |
| torch.randn(4, dtype=torch.float, device=device), |
| ] |
| self.checkScript(scaleshift, inputs) |
| |
| @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") |
| @unittest.skipIf(not RUN_CUDA_HALF, "no half support") |
| @unittest.skipIf( |
| GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on" |
| ) |
| def test_cuda_half(self): |
| x = torch.randn(4, 4, dtype=torch.half, device="cuda") |
| y = torch.randn(4, 4, dtype=torch.half, device="cuda") |
| |
| funcs = [self.fn_test_comparison_gt_lt, self.fn_test_relu, self.fn_test_exp] |
| |
| # Note: Non fused inputs must be float to prevent loss of precision |
| inputs = (x.float(), y.float()) |
| fusion_inputs = (x, y) |
| for fn in funcs: |
| local_inputs = [t.clone().requires_grad_() for t in inputs] |
| local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs] |
| |
| # Verifies outputs |
| fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False) |
| outputs = fn(*local_inputs) |
| fusion_outputs = fusion(*local_fusion_inputs) |
| outputs_half = [t.half() for t in outputs] |
| self.assertEqual(outputs_half, fusion_outputs) |
| |
| # Verifies gradients |
| for output, fusion_output in zip(outputs_half, fusion_outputs): |
| grads = torch.autograd.grad( |
| output.float().sum(), |
| local_inputs, |
| allow_unused=True, |
| retain_graph=True, |
| ) |
| fusion_grads = torch.autograd.grad( |
| fusion_output.sum(), |
| local_fusion_inputs, |
| allow_unused=True, |
| retain_graph=True, |
| ) |
| grads_half = [t.half() for t in grads] |
| self.assertEqual(grads_half, fusion_grads) |
| |
| def test_checks_cat_inputs(self): |
| # single fusion node causes error |
| with set_fusion_group_inlining(True): |
| for device in self.devices: |
| # We shouldn't treat cat nodes as broadcasting. All their inputs |
| # need to be checked for having the same map size, before we can |
| # run the kernel. |
| def f(x, y): |
| return torch.cat([x + 2 * x + x**2, y + 4 * y + y**3], dim=0) |
| |
| # NOTE: y is broadcastable to x, but output of f(x, y) should have |
| # shape 3x4, and not 4x4. |
| x = torch.randn(2, 4, dtype=torch.float, device=device) |
| y = torch.randn(1, 4, dtype=torch.float, device=device) |
| |
| scripted = self.checkScript(f, (x, y)) |
| self.assertEqual(scripted(x, y).shape, (3, 4)) |
| self.assertAllFused(scripted.graph_for(x, y)) |
| |
| def test_chunk(self): |
| if self.dynamic_shapes: |
| self.skipTest("TODO: chunk dynamic shapes") |
| |
| for device in self.devices: |
| |
| def fn(x): |
| a, b, c = x.chunk(3, 1) |
| return a * b + c |
| |
| inputs = [torch.randn(10, 6, dtype=torch.float, device=device)] |
| |
| self.checkScript(fn, inputs) |
| self.assertLastGraphAllFused() |
| |
| def test_chunk_correctness(self): |
| if self.dynamic_shapes: |
| self.skipTest("TODO: chunk dynamic shapes") |
| |
| for device in self.devices: |
| |
| def chunk_4_0(x): |
| x0, x1, x2, x3 = x.chunk(4, 0) |
| return x0 + x1 + x2 + x3 |
| |
| def chunk_4_1(x): |
| x0, x1, x2, x3 = x.chunk(4, 1) |
| return x0 + x1 + x2 + x3 |
| |
| def chunk_4_last(x): |
| x0, x1, x2, x3 = x.chunk(4, 2) |
| return x0 + x1 + x2 + x3 |
| |
| fns = [chunk_4_0, chunk_4_1, chunk_4_last] |
| tensors = [ |
| # splitSize = 1 |
| torch.randn(4, 4, 4, dtype=torch.float, device=device), |
| # contiguous case |
| torch.randn(12, 8, 16, dtype=torch.float, device=device), |
| # non-contiguous case |
| torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose( |
| 1, 2 |
| ), |
| ] |
| |
| for tensor in tensors: |
| for fn in fns: |
| self.checkScript(fn, [tensor]) |
| self.assertLastGraphAllFused() |
| |
| def test_chunk_distributes(self): |
| if self.dynamic_shapes: |
| self.skipTest("TODO: chunk dynamic shapes") |
| |
| if self.dynamic_shapes: |
| self.skipTest("TODO: chunk dynamic shapes") |
| |
| for device in self.devices: |
| |
| def f(x, y): |
| z1, z2 = (x + y).chunk(2, dim=1) |
| return z1 * z2 |
| |
| x = torch.randn(4, 4, dtype=torch.float, device=device) |
| y = torch.randn(4, 4, dtype=torch.float, device=device) |
| |
| ge = self.checkTrace(f, (x, y)) |
| graph = ge.graph_for(x, y) |
| # XXX: The old fuser does broadcast_tensors but the new fuser doesn't. |
| # FileCheck().check("broadcast_tensors").check('with ' + FUSION_GROUP + '_') \ |
| # .check_count('ConstantChunk', 2, exactly=True).run(str(graph)) |
| FileCheck().check("with " + FUSION_GROUP + "_").check_count( |
| "ConstantChunk", 1, exactly=True |
| ).run(str(graph)) |
| |
| def test_chunk_motion_deduplicates_inputs(self): |
| if self.dynamic_shapes: |
| self.skipTest("TODO: chunk dynamic shapes") |
| |
| for device in self.devices: |
| |
| def func1(x): |
| z = x * x |
| z0, z1 = z.chunk(2) |
| return z0 * z1 |
| |
| def func2(x): |
| z = x * x * x |
| z0, z1 = z.chunk(2) |
| return z0 * z1 |
| |
| inputs = [ |
| torch.tensor([1.1, 1.2], device=device, dtype=torch.float), |
| ] |
| for func in [func1, func2]: |
| self.checkScript(func, inputs) |
| self.assertLastGraphAllFused() |
| |
| def test_chunk_multiple(self): |
| if self.dynamic_shapes: |
| self.skipTest("TODO: chunk dynamic shapes") |
| |
| for device in self.devices: |
| # The arguments are intentionally used out of order as a test to see |
| # if the fusion compiler adds extra args in the correct order |
| def fn(s, x, y, z): |
| z1, z2 = z.chunk(2, 2) |
| x1, x2, x3 = x.chunk(3, 1) |
| y1, y2 = y.chunk(2, 0) |
| return s + x1 + x2 + x3 + y1 + y2 + z1 + z2 |
| |
| inputs = [ |
| torch.randn(5, 2, 3, dtype=torch.float, device=device), |
| torch.randn(5, 6, 3, dtype=torch.float, device=device), |
| torch.randn(10, 2, 3, dtype=torch.float, device=device), |
| torch.randn(5, 2, 6, dtype=torch.float, device=device), |
| ] |
| |
| ge = self.checkScript(fn, inputs) |
| self.assertAllFused(ge.graph_for(*inputs)) |
| |
| def test_minmax(self): |
| for device in self.devices: |
| |
| def tmax(a, b): |
| return torch.max(2 * a, b) |
| |
| def tmin(a, b): |
| return torch.min(2 * a, b) |
| |
| a = torch.randn(4, 4, dtype=torch.float) |
| b = torch.randn(4, 4, dtype=torch.float) |
| nan = torch.tensor(float("nan"), dtype=torch.float) |
| |
| for f, inputs, device in product( |
| (tmax, tmin), ([a, b], [a, nan], [b, nan]), self.devices |
| ): |
| inputs = [t.to(device) for t in inputs] |
| s = self.checkScript(f, inputs) |
| self.assertAllFused(s.graph_for(*inputs)) |
| |
| def test_clamp(self): |
| for device in self.devices: |
| |
| def func2(a, b): |
| return torch.clamp(a + b, min=0, max=2) |
| |
| def funcInf(a, b): |
| return torch.clamp(a + b, min=0, max=float("inf")) |
| |
| def funcNegInf(a, b): |
| return torch.clamp(a + b, min=float("-inf"), max=0) |
| |
| def funcOptMin(a, b): |
| return torch.clamp(a + b, max=2) |
| |
| def funcOptMax(a, b): |
| return torch.clamp(a + b, min=0) |
| |
| a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True) |
| b = torch.randn(4, 4, dtype=torch.float, device=device) |
| nan = torch.tensor(float("nan"), dtype=torch.float, device=device) |
| |
| funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax) |
| for f, inputs in product(funcs, [[a, b], [a, nan]]): |
| inp1, inp2 = inputs |
| s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) |
| self.assertAllFused( |
| s.graph_for(inp1, inp2), |
| except_for={"aten::size", "aten::_size_if_not_equal"}, |
| ) |
| c = s(inp1, inp2) |
| with enable_profiling_mode_for_profiling_tests(): |
| warmup_backward(c.sum()) |
| graph = backward_graph(s) |
| self.assertAllFused( |
| graph, |
| except_for={"aten::Float", "aten::_grad_sum_to_size"}.union( |
| autograd_check_set |
| ), |
| ) |
| |
| def test_clamp_double(self): |
| for device in self.devices: |
| |
| def clamp_double(x, eta: float): |
| return 1 - x.clamp(eta, 1 - eta) |
| |
| x = torch.tensor([1.0, 1.0], dtype=torch.double, device=device) |
| eta = 1e-9 |
| s = self.checkScript( |
| clamp_double, |
| (x, eta), |
| profiling=ProfilingMode.PROFILING, |
| atol=1e-10, |
| rtol=1e-5, |
| ) |
| self.assertAllFused(s.graph_for(x, eta), except_for={"aten::sub"}) |
| |
| def test_clamp_int(self): |
| for device in self.devices: |
| |
| def clamp_int(x, eta: int): |
| return x.clamp(0, eta) |
| |
| x = torch.tensor([1, 1], device=device) |
| eta = 1 << 32 |
| s = self.checkScript(clamp_int, (x, eta), profiling=ProfilingMode.PROFILING) |
| self.assertAllFused(s.graph_for(x, eta)) |
| |
| def test_add_bool(self): |
| sizes = [(1,), (2,), (4, 4)] |
| for device, size in product(self.devices, sizes): |
| |
| def f(x, y, z): |
| return x + y + z |
| |
| x = torch.randint(0, 2, size, dtype=torch.bool, device=device) |
| y = torch.randint(0, 2, size, dtype=torch.bool, device=device) |
| z = torch.randint(0, 2, size, dtype=torch.bool, device=device) |
| ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) |
| self.assertAllFused(ge.graph_for(x, y, z)) |
| |
| def test_mul_bool(self): |
| for device in self.devices: |
| |
| def f(x, y, z): |
| return x * y * z |
| |
| x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) |
| y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) |
| z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) |
| |
| ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) |
| self.assertAllFused(ge.graph_for(x, y, z)) |
| |
| def test_div_bool(self): |
| for device in self.devices: |
| |
| def f(x, y, z): |
| return (x + y) / z |
| |
| x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) |
| y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device) |
| z = torch.ones_like(x, dtype=torch.bool, device=device) |
| |
| ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False) |
| self.assertAllFused(ge.graph_for(x, y, z)) |
| |
| def test_bitwise_ops(self): |
| def apply(fn): |
| return lambda x, y, z: fn(fn(x, y), z) |
| |
| binary_ops = [ |
| operator.__and__, |
| operator.__or__, |
| operator.__xor__, |
| operator.__lshift__, |
| operator.__rshift__, |
| ] |
| devices = self.devices |
| for dtype, op, device in product(self.int_dtypes, binary_ops, devices): |
| try: |
| x = self.data_for(dtype, device) |
| y = self.data_for(dtype, device) |
| z = self.data_for(dtype, device) |
| fn = apply(op) |
| ref = fn(x, y, z) |
| except Exception: |
| # If eager mode doesn't support a dtype/op/device combo, |
| # neither does the fuser. Catch everything to avoid needing to |
| # guess what errors might be thrown by eager. |
| continue |
| try: |
| t = torch.jit.trace(fn, (x, y, z)) |
| self.assertEqual(ref, t(x, y, z)) |
| self.assertAllFused(t.graph_for(x, y, z)) |
| except Exception as e: |
| raise RuntimeError( |
| " ".join(["Failed:", str(dtype), op.__name__, device]) |
| ) from e |
| |
| def test_minmax_int_ops(self): |
| def apply(fn): |
| return lambda x, y, z: fn(fn(x, y), z) |
| |
| binary_ops = [torch.min, torch.max] |
| devices = self.devices |
| for dtype, op, device in product(self.int_dtypes, binary_ops, devices): |
| try: |
| x = self.data_for(dtype, device) |
| y = self.data_for(dtype, device) |
| z = self.data_for(dtype, device) |
| fn = apply(op) |
| ref = fn(x, y, z) |
| except Exception: |
| # If eager mode doesn't support a dtype/op/device combo, |
| # neither does the fuser. Catch everything to avoid needing to |
| # guess what errors might be thrown by eager. |
| continue |
| try: |
| t = torch.jit.trace(fn, (x, y, z)) |
| self.assertEqual(ref, t(x, y, z)) |
| self.assertAllFused(t.graph_for(x, y, z)) |
| except Exception as e: |
| raise RuntimeError( |
| " ".join(["Failed:", str(dtype), op.__name__, device]) |
| ) from e |
| |
| def test_comparison_eq_ne(self): |
| for device in self.devices: |
| |
| def f(x, y): |
| mask = (x == 0).type_as(x) |
| z = x * mask + y |
| mask = (x != 0).type_as(x) |
| z = z * mask + y |
| return z |
| |
| x = torch.randn(4, 4, dtype=torch.float, device=device) |
| y = torch.randn(4, 4, dtype=torch.float, device=device) |
| |
| ge = self.checkTrace(f, (x, y)) |
| self.assertAllFused(ge.graph_for(x, y)) |
| |
| @staticmethod |
| def fn_test_comparison_gt_lt(x, y): |
| mask = (x > 0).type_as(x) |
| z = x * mask + y |
| mask = (x < 0).type_as(x) |
| z = z * mask + y |
| return z |
| |
| def test_comparison_gt_lt(self): |
| for device in self.devices: |
| x = torch.randn(4, 4, dtype=torch.float, device=device) |
| y = torch.randn(4, 4, dtype=torch.float, device=device) |
| |
| ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y)) |
| self.assertAllFused(ge.graph_for(x, y)) |
| |
| def test_comparison_ge_le(self): |
| for device in self.devices: |
| |
| def f(x, y): |
| mask = (x >= 0).type_as(x) |
| z = x * mask + y |
| mask = (x <= 0).type_as(x) |
| z = z * mask + y |
| return z |
| |
| x = torch.randn(4, 4, dtype=torch.float, device=device) |
| y = torch.randn(4, 4, dtype=torch.float, device=device) |
| |
| ge = self.checkTrace(f, (x, y)) |
| self.assertAllFused(ge.graph_for(x, y)) |
| x.requires_grad_(True) |
| y.requires_grad_(True) |
| self.assertAllFused( |
| ge.graph_for(x, y), |
| except_for=( |
| "aten::size", |
| "prim::BroadcastSizes", |
| "aten::_size_if_not_equal", |
| ), |
| ) |
| |
| def test_addcmul(self): |
| for device in self.devices: |
| t = torch.randn(1, 4, dtype=torch.float, device=device) |
| t1 = torch.randn(4, 1, dtype=torch.float, device=device) |
| t2 = torch.randn(1, 4, dtype=torch.float, device=device) |
| |
| def foo(t, t1, t2): |
| return t.addcmul(t + 1, t2, value=0.1) |
| |
| ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True) |
| graph = ge.graph_for(t, t1, t2) |
| fusion_groups = self.findFusionGroups(graph) |
| self.assertEqual(len(fusion_groups), 1) |
| FileCheck().check("aten::add(").check("aten::addcmul(").run( |
| str(fusion_groups[0]) |
| ) |
| |
| # TODO: We leak CUDA memory here because the traced graph holds onto a |
| # constant-ified tensor. Since the Python-global CompilationUnit is alive |
| # until the end of the process, the memory is effectively leaked. |
| # Removed `_cuda` suffix from this test which disables leak-checking. |
| # If this is a real problem, we'll need to revisit Torchscript Function |
| # lifetimes in Python. |
| def test_lerp(self): |
| for device in self.devices: |
| start = torch.randn(4, 1, dtype=torch.float, device=device) |
| end = torch.randn(1, 4, dtype=torch.float, device=device) |
| weight = torch.tensor(0.5, dtype=torch.float, device=device) |
| |
| # scalar weight overload |
| def foo_weight_scalar(start, end): |
| return torch.lerp(start + 1, end, 0.5) |
| |
| # tensor weight overload |
| def foo_weight_tensor(start, end): |
| return torch.lerp(start + 1, end, weight) |
| |
| ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end)) |
| graph = ge_weight_scalar.graph_for(start, end) |
| self.assertAllFused(graph) |
| |
| # TODO: uncomment when TE enables support for scalar tensors |
| # ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end)) |
| # graph = ge_weight_tensor.graph_for(start, end) |
| # self.assertAllFused(graph) |
| |
| def test_concat(self): |
| # disabling concat causes error with single concat node |
| with set_fusion_group_inlining(True): |
| for device in self.devices: |
| hx = torch.randn(3, 20, dtype=torch.float, device=device) |
| cx = torch.randn(3, 20, dtype=torch.float, device=device) |
| |
| def foo(hx, cx): |
| return torch.cat((hx + cx, hx * cx)) |
| |
| ge = self.checkTrace(foo, (hx, cx)) |
| graph = ge.graph_for(hx, cx) |
| self.assertAllFused(graph) |
| # XXX: TE fuser can handle concats in a fusion group. |
| # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) |
| |
| def test_remove_output_used_only_in_size(self): |
| for device in self.devices: |
| |
| def test_fuse(a, b): |
| c = a + b |
| d = c + b |
| return d |
| |
| scripted_f = torch.jit.script(test_fuse) |
| x = torch.ones(1, requires_grad=True, device=device) |
| y = torch.ones(1, requires_grad=True, device=device) |
| warmup_forward(scripted_f, x, y, profiling_count=3) |
| g = scripted_f.graph_for(x, y) |
| diff_nodes = g.findAllNodes("prim::DifferentiableGraph") |
| self.assertEqual(len(diff_nodes), 1) |
| g = diff_nodes[0].g("Subgraph") |
| if_nodes = [n for n in g.nodes() if n.kind() == "prim::If"] |
| self.assertEqual(len(if_nodes), 1) |
| |
| # the if node and the fusion group inside it should only have one output |
| self.assertEqual(len(list(if_nodes[0].outputs())), 1) |
| |
| def test_concat_invariant(self): |
| for device in self.devices: |
| # Invariant: the output of prim::FusedConcat may |
| # not be an input to any node inside the FusionGroup. |
| def fn(x, y, z): |
| x1 = x + y |
| y1 = x - y |
| w = torch.cat([x1, y1]) |
| return w + z |
| |
| x = torch.randn(2, 2, dtype=torch.float, device=device) |
| y = torch.randn(2, 2, dtype=torch.float, device=device) |
| z = torch.randn(4, 2, dtype=torch.float, device=device) |
| ge = self.checkTrace(fn, (x, y, z)) |
| graph = ge.graph_for(x, y, z) |
| self.assertAllFused(graph, except_for={"aten::add"}) |
| # XXX: TE fuser can handle concats inside a fusion group. |
| # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) |
| |
| @staticmethod |
| def fn_test_exp(x, y): |
| return (x + 0.5 * y).exp() |
| |
| def test_exp(self): |
| for device in self.devices: |
| x = torch.randn(4, 4, dtype=torch.float, device=device) |
| y = torch.randn(4, 4, dtype=torch.float, device=device) |
| |
| ge = self.checkTrace(self.fn_test_exp, (x, y)) |
| self.assertAllFused(ge.graph_for(x, y)) |
| |
| def test_threshold(self): |
| for device in self.devices: |
| |
| def f(x): |
| return torch.threshold(x, 0, -10) + x + x + x |
| |
| x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device=device) |
| scripted = self.checkScript(f, (x,)) |
| self.assertAllFused(scripted.graph_for(x)) |
| |
| def test_scalar_arg(self): |
| for device in self.devices: |
| |
| def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor: |
| return p * (x * x + x) |
| |
| x = torch.randn(4, 4, dtype=torch.float, device=device) |
| p = 3 |
| scripted = self.checkScript(fn_test_scalar_arg, (x, p)) |
| self.assertAllFused(scripted.graph_for(x, p)) |
| |
| x.requires_grad_(True) |
| |
| # use another function otherwise we will bailout |
| # and won't be able to do fused checks |
| def fn_test_scalar_arg_requires_grad( |
| x: torch.Tensor, p: float |
| ) -> torch.Tensor: |
| return p * (x * x + x) |
| |
| scripted = torch.jit.script(fn_test_scalar_arg_requires_grad) |
| out = scripted(x, p) |
| out = scripted(x, p) |
| out = scripted(x, p) |
| self.assertAllFused( |
| scripted.graph_for(x, p), |
| except_for=( |
| "aten::size", |
| "prim::BroadcastSizes", |
| "aten::_size_if_not_equal", |
| ), |
| ) |
| |
| @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") |
| @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") |
| def test_fusion_reuse_multi_gpu(self): |
| def fn(x, y): |
| return x * y * x * y |
| |
| inputs_cpu = [ |
| torch.randn(4, 4, dtype=torch.float), |
| torch.randn(4, 4, dtype=torch.float), |
| ] |
| inputs_cuda0 = [x.cuda(0) for x in inputs_cpu] |
| inputs_cuda1 = [y.cuda(1) for y in inputs_cpu] |
| |
| # Should not crash; these should compile different kernels. |
| ge = self.checkScript(fn, inputs_cpu) |
| self.assertAllFused(ge.graph_for(*inputs_cpu)) |
| ge(*inputs_cuda0) |
| ge(*inputs_cuda1) |
| |
| # TODO: we're currently not checking 'device' in the type info when pulling |
| # nodes into a fusion group. We should fix that and re-enable this test. |
| @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") |
| @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") |
| def test_kernel_cache_multi_gpu(self): |
| def not_fusible(x): |
| return x |
| |
| def fn(x, y, z): |
| x_out = x * x * x * x * x # fusion: lambda x. x * x * x * x * x |
| y_out = y * y * y * y * y |
| z_out = z * z * z * z * z |
| return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out) |
| |
| inputs = [ |
| torch.randn(4, 4, dtype=torch.float), |
| torch.randn(4, 4, dtype=torch.float, device="cuda:0"), |
| torch.randn(4, 4, dtype=torch.float, device="cuda:1"), |
| ] |
| |
| prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() |
| |
| # There are 3 FusionGroups. Because they have the same graph, they |
| # should reuse the same KernelSpec in the KernelSpec cache. |
| ge = self.checkScript(fn, inputs) |
| self.assertGraphContainsExactly(ge.graph_for(*inputs), FUSION_GROUP, 3, True) |
| new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() |
| # XXX: This assumes that the same kernel isn't already used by another test |
| # FIXME: Use the TE fuser's way of querying the cache. |
| # self.assertEqual(new_cache_size - prev_cache_size, 1) |
| |
| @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") |
| def test_nonzero_device_cuda(self): |
| device = "cuda:" + str(1) |
| x = torch.tensor([0.4], dtype=torch.float, device=device) |
| y = torch.tensor([0.7], dtype=torch.float, device=device) |
| |
| def doit(x, y): |
| return torch.sigmoid(torch.tanh(x * (x + y) + x)) |
| |
| ge = self.checkTrace(doit, (x, y)) |
| self.assertAllFused(ge.graph_for(x, y)) |
| |
| def test_lstm(self): |
| for device in self.devices: |
| inputs = get_lstm_inputs(device, training=True) |
| module = self.checkScript(LSTMCellS, inputs) |
| self.assertAllFused( |
| module.graph_for(inputs), except_for={"prim::TupleConstruct"} |
| ) |
| |
| def test_lstm_concat(self): |
| # single fusion node causes error |
| with set_fusion_group_inlining(True): |
| for device in self.devices: |
| inputs = get_lstm_inputs(device) |
| ge = self.checkTrace(LSTMCellC, inputs) |
| graph = ge.graph_for(*inputs) |
| except_nodes = {"prim::TupleConstruct", "aten::linear"} |
| # TODO... Chunk |
| if self.dynamic_shapes: |
| except_nodes = except_nodes.union( |
| {"aten::add", "prim::ConstantChunk"} |
| ) |
| self.assertAllFused(ge.graph_for(*inputs), except_for=except_nodes) |
| # XXX: TE fuser can handle concats inside a fusion group. |
| # FileCheck().check("FusedConcat").check_next("return").run(str(graph)) |
| |
| def test_lstm_gates_permutations(self): |
| for device in self.devices: |
| # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh. |
| # Test that any permutation of this will still result in one FusionGroup. |
| choices = ["x.mm(w_ih.t())", "hx.mm(w_hh.t())", "b_ih", "b_hh"] |
| template = dedent( |
| """ |
| def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): |
| gates = {} + {} + {} + {} |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
| return ingate * forgetgate * cellgate * outgate |
| """ |
| ) |
| for permutation in permutations(choices, len(choices)): |
| code = template.format(*permutation) |
| scope = {} |
| exec(code, globals(), scope) |
| cu = torch.jit.CompilationUnit(code) |
| fusion_group_len = 2 if self.dynamic_shapes else 1 |
| inputs = get_lstm_inputs(device, training=False) |
| self.assertEqual(cu.cell(*inputs), scope["cell"](*inputs)) |
| forward_graph = cu.cell.graph_for(*inputs) |
| self.assertGraphContainsExactly( |
| forward_graph, FUSION_GROUP, fusion_group_len |
| ) |
| |
| # TODO: Fuser doesn't work at all when inputs require grad. Fix that |
| def test_lstm_traced(self): |
| for device in self.devices: |
| inputs = get_lstm_inputs(device) |
| ge = self.checkTrace(LSTMCellF, inputs) |
| graph = ge.graph_for(*inputs) |
| fusion_groups = self.findFusionGroups(graph) |
| # TODO: chunk |
| fusion_group_len = 2 if self.dynamic_shapes else 1 |
| self.assertEqual(len(fusion_groups), fusion_group_len) |
| f = FileCheck() |
| if not self.dynamic_shapes: |
| f.check("Chunk") |
| f.check("aten::sigmoid").check("aten::tanh").run( |
| str(fusion_groups[0 if not self.dynamic_shapes else 1]) |
| ) |
| |
| def test_milstm(self): |
| if self.dynamic_shapes: |
| self.skipTest("don't run conv with dynamic shapes") |
| |
| for device in self.devices: |
| inputs = get_milstm_inputs(device, training=True) |
| module = self.checkScript(MiLSTMCell, inputs) |
| forward_graph = module.graph_for(*inputs) |
| # TODO: chunk |
| fusion_group_len = 2 if self.dynamic_shapes else 1 |
| self.assertGraphContainsExactly( |
| forward_graph, FUSION_GROUP, fusion_group_len, consider_subgraphs=True |
| ) |
| FileCheck().check("DifferentiableGraph").check("TupleConstruct").check_next( |
| "return" |
| ).check(FUSION_GROUP).run(str(forward_graph)) |
| hy, cy = module(*inputs) |
| warmup_backward((hy + cy).sum()) |
| |
| @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") |
| @unittest.skip("rand_like is not supported yet") |
| def test_rand_cuda(self): |
| class M(torch.jit.ScriptModule): |
| __constants__ = ["d"] |
| |
| def __init__(self): |
| super().__init__() |
| self.d = torch.device("cuda") |
| |
| @torch.jit.script_method |
| def create(self, x): |
| return x * x + x + torch.rand_like(x) |
| |
| x = torch.zeros([3, 4, 5], dtype=torch.float, device="cuda") |
| m = M() |
| out1 = m.create(x) |
| out2 = m.create(x) |
| self.assertNotEqual(out1, out2) |
| self.assertTrue(torch.all(out1 >= 0)) |
| self.assertTrue(torch.all(out1 < 1)) |
| self.assertTrue(torch.all(out2 >= 0)) |
| self.assertTrue(torch.all(out2 < 1)) |
| self.assertAllFused(m.create.graph_for(x)) |
| |
| @staticmethod |
| def fn_test_relu(x, y): |
| return F.relu(x + 0.5 * y) |
| |
| def test_relu(self): |
| for device in self.devices: |
| x = torch.randn(4, 4, dtype=torch.float, device=device) |
| y = torch.randn(4, 4, dtype=torch.float, device=device) |
| |
| ge = self.checkTrace(self.fn_test_relu, (x, y)) |
| self.assertAllFused(ge.graph_for(x, y)) |
| |
| def test_erf(self): |
| for device in self.devices: |
| # only enabled on gpu |
| if device == "cpu": |
| continue |
| |
| def fn_test_erf(x): |
| return F.relu(torch.erf(x) - torch.erfc(x)) |
| |
| x = torch.randn(4, 4, dtype=torch.float, device=device) |
| ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING) |
| self.assertAllFused(ge.graph_for(x)) |
| x.requires_grad_(True) |
| ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING) |
| self.assertAllFused( |
| ge.graph_for(x), |
| except_for=( |
| "aten::size", |
| "prim::BroadcastSizes", |
| "aten::_size_if_not_equal", |
| ), |
| ) |
| |
| @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") |
| @unittest.skip("rand_like is not supported yet") |
| def test_rand_broadcast_cuda(self): |
| def fn_test_rand(x, y): |
| r = torch.rand_like(y) |
| return r * x + x |
| |
| # If using profiling, a different function is needed to test different |
| # shapes, or we'll use a cached script. |
| def fn_test_rand2(x, y): |
| r = torch.rand_like(y) |
| return r * x * x |
| |
| x = torch.randn(4, 4, dtype=torch.float, device="cuda") |
| y = torch.randn(4, 4, dtype=torch.float, device="cuda") |
| script_f = torch.jit.script(fn_test_rand) |
| warmup_forward(script_f, x, y) |
| out = script_f(x, y) |
| self.assertAllFused(script_f.graph_for(x, y)) |
| x.requires_grad_(True) |
| out = script_f(x, y) |
| self.assertAllFused( |
| script_f.graph_for(x, y), |
| except_for=( |
| "aten::size", |
| "prim::BroadcastSizes", |
| "aten::_size_if_not_equal", |
| ), |
| ) |
| |
| # test that broadcasting random produces correct results |
| x = torch.ones(4, 4, dtype=torch.float, device="cuda") |
| y = torch.ones(4, dtype=torch.float, device="cuda") |
| script_f = torch.jit.script(fn_test_rand2) |
| warmup_forward(script_f, x, y) |
| out = script_f(x, y) |
| self.assertEqual(out[0, :] + torch.zeros(4, 4, device="cuda"), out) |
| |
| @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") |
| @unittest.skip("rand_like is not supported yet") |
| def test_rand_diamond(self): |
| def fn_test_diamond(x, y): |
| r = torch.rand_like(y) |
| a = x + r |
| b = y - r |
| return a + b |
| |
| x = torch.randn(4, 4, dtype=torch.float, device="cuda") |
| y = torch.randn(4, 4, dtype=torch.float, device="cuda") |
| script_f = torch.jit.script(fn_test_diamond) |
| warmup_forward(script_f, x, y) |
| out = script_f(x, y) |
| self.assertEqual(out, x + y) |
| |
| def test_scalar(self): |
| def fn(x, y): |
| return 2 * x + y |
| |
| x = torch.tensor(0.1, dtype=torch.float, device="cpu") |
| y = torch.tensor(1, dtype=torch.float, device="cpu") |
| ge = self.checkScript(fn, (x, y)) |
| self.assertAllFused(ge.graph_for(x, y)) |
| |
| def test_inlined_optimized_graph(self): |
| @torch.jit.script |
| def foo(x): |
| return torch.relu(x + x) |
| |
| for _ in range(3): |
| foo(torch.rand([4, 4])) |
| |
| for _ in range(3): |
| foo(torch.rand([10])) |
| |
| for _ in range(3): |
| foo(torch.rand([2, 2, 2])) |
| |
| g = torch.jit.last_executed_optimized_graph() |
| |
| FileCheck().check_count("prim::If", 1, exactly=True).check( |
| "prim::TensorExpr" |
| ).run(g) |
| torch._C._jit_pass_inline(g) |
| f = FileCheck() |
| for _ in range(3): |
| f.check("prim::If").check("prim::TensorExpr") |
| f.run(g) |
| |
| def test_small_constant(self): |
| for device in self.devices: |
| |
| def fn_test_small_constant(x, y): |
| return (1e-8 * x + 5e-9 * y) * 1e8 |
| |
| x = torch.randn(4, 4, dtype=torch.float, device=device) |
| y = torch.randn(4, 4, dtype=torch.float, device=device) |
| |
| ge = self.checkTrace(fn_test_small_constant, (x, y)) |
| self.assertAllFused(ge.graph_for(x, y)) |
| |
| # Currently we don't pull constants into fusion groups, because in some |
| # cases it could remove the constant from the original graph and now our |
| # fusion group needs to return that constant for its other users. |
| # Instead of never pulling constants into the fusion group, we should just |
| # be more careful at how we rewrite its users. |
| # TODO: fix that and reenable the test. |
| def test_tensor_scalar_ops(self): |
| for device in self.devices: |
| |
| def should_fuse(x): |
| z = 3.0 |
| y = x + z |
| return x * y |
| |
| def should_fuse_scalar(x, z): |
| y = x + int(z) |
| return x * y |
| |
| inputs = [torch.randn(2, 2, dtype=torch.float, device=device)] |
| ge = self.checkScript(should_fuse, inputs) |
| graph = ge.graph_for(*inputs) |
| fusion_groups = self.findFusionGroups(graph) |
| self.assertEqual(len(fusion_groups), 1) |
| FileCheck().check("aten::add").check("aten::mul").run(str(fusion_groups[0])) |
| |
| inputs = [ |
| torch.randn(2, 2, dtype=torch.float, device=device), |
| torch.tensor(3.0, dtype=torch.float, device=device), |
| ] |
| ge = self.checkScript(should_fuse_scalar, inputs) |
| # Check that the fused graph computes correct results when the scalar |
| # input changes. |
| inputs = [ |
| torch.randn(2, 2, dtype=torch.float, device=device), |
| torch.tensor(7.0, dtype=torch.float, device=device), |
| ] |
| self.assertEqual(ge(*inputs), should_fuse_scalar(*inputs)) |
| # The TE fuser supports fusion of non-constant scalars |
| self.assertGraphContainsExactly( |
| ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True |
| ) |
| |
| def test_where_and_typing(self): |
| for device in self.devices: |
| |
| def f(x, y): |
| mask = x > y |
| res = torch.where(mask, x, y) |
| return mask, res |
| |
| x = torch.randn(4, 4, dtype=torch.double, device=device) |
| y = torch.randn(4, 4, dtype=torch.double, device=device) |
| |
| script_f = self.checkScript(f, (x, y)) |
| self.assertAllFused( |
| script_f.graph_for(x, y), except_for={"prim::TupleConstruct"} |
| ) |
| |
| def test_disabled(self): |
| old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() |
| torch._C._jit_override_can_fuse_on_cpu(False) |
| |
| def fn(a): |
| return a**2 + a |
| |
| x = torch.randn(4, dtype=torch.float, device="cpu") |
| s = self.checkScript(fn, (x,)) |
| g = s.graph_for(x) |
| self.assertEqual(len(self.findFusionGroups(g)), 0) |
| |
| torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state) |
| |
| def data_for(self, dtype, device="cuda", size=None): |
| if size is None: |
| v = torch.arange(1, 3, dtype=torch.float, device=device) |
| else: |
| v = torch.rand(*size, device=device) |
| if dtype == torch.bool: |
| return v > 2 |
| elif dtype in [torch.qint8, torch.quint8, torch.qint32]: |
| return torch.quantize_per_tensor(v, 0.1, 1, dtype=dtype) |
| else: |
| return v.to(dtype) |
| |
| def test_torch_to(self): |
| # test no op |
| @torch.jit.script |
| def foo(x): |
| return x.to(torch.float) |
| |
| foo(torch.tensor([3.0], dtype=torch.float)) |
| foo(torch.tensor([3.0], dtype=torch.float)) |
| FileCheck().check_not("TensorExpr").run( |
| torch.jit.last_executed_optimized_graph() |
| ) |
| |
| # test not fusing non-const inputs |
| @torch.jit.script |
| def foo(x, dtype: int): |
| return x.to(dtype) |
| |
| foo(torch.tensor([3.0], dtype=torch.float), torch.int) |
| foo(torch.tensor([3.0], dtype=torch.float), torch.int) |
| FileCheck().check_not("TensorExpr").run( |
| torch.jit.last_executed_optimized_graph() |
| ) |
| |
| # test not fusing to_pinned inputs |
| @torch.jit.script |
| def foo(x, dtype: int): |
| return x.to(pin_memory=True) |
| |
| foo(torch.tensor([3.0], dtype=torch.float), torch.int) |
| foo(torch.tensor([3.0], dtype=torch.float), torch.int) |
| FileCheck().check_not("TensorExpr").run( |
| torch.jit.last_executed_optimized_graph() |
| ) |
| |
| # test across-device not supported |
| if torch.cuda.is_available(): |
| |
| @torch.jit.script |
| def foo(x): |
| return x.to(device="cuda") |
| |
| foo(torch.tensor([3.0], dtype=torch.float)) |
| foo(torch.tensor([3.0], dtype=torch.float)) |
| FileCheck().check_not("TensorExpr").run( |
| torch.jit.last_executed_optimized_graph() |
| ) |
| |
| sizes = [(1, 4), (4, 4)] |
| # reuses cast impl, smaller dtype set for faster test |
| dtypes = [ |
| torch.bool, |
| torch.int, |
| torch.float16, |
| torch.float32, |
| torch.float64, |
| ] |
| |
| class MyMod(torch.nn.Module): |
| def __init__(self, dtype): |
| super().__init__() |
| self.dtype = dtype |
| |
| def forward(self, x): |
| return x.to(self.dtype) |
| |
| bad_dtypes = [] |
| for dtype, output_dtype, device, size in product( |
| dtypes, dtypes, self.devices, sizes |
| ): |
| # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed |
| if dtype in [torch.float16, torch.bfloat16] and device == "cpu": |
| continue |
| if dtype == output_dtype: |
| continue |
| |
| x = self.data_for(dtype, device, size=size) |
| mod = MyMod(output_dtype) |
| ref = mod.forward(x) |
| # use freezing to make non-Tensor args to `to` constant |
| mod = torch.jit.freeze(torch.jit.script(mod.eval())) |
| warmup_forward(mod.forward, x) |
| self.assertEqual(ref, mod.forward(x)) |
| self.assertLastGraphAllFused() |
| |
| @unittest.skip("Temporarily disabled") |
| def test_masked_fill(self): |
| dtypes = [ |
| torch.int8, |
| torch.int16, |
| torch.int32, |
| torch.int64, |
| # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed |
| # torch.float16, |
| torch.float32, |
| torch.float64, |
| torch.bool, |
| ] |
| sizes = [(2,), (4, 4)] |
| for self_dtype, device, scalar_val, size in product( |
| dtypes, self.devices, [0.4, 3], sizes |
| ): |
| input_v = self.data_for(self_dtype, device, size=size) |
| mask = self.data_for(torch.bool, device, size=size) |
| |
| def fn(input_v, mask): |
| return torch.masked_fill(input_v, mask, scalar_val) |
| |
| ref = fn(input_v, mask) |
| try: |
| t = torch.jit.trace(fn, (input_v, mask)) |
| torch.testing.assert_close(ref, t(input_v, mask)) |
| self.assertLastGraphAllFused() |
| except Exception as e: |
| raise RuntimeError( |
| " ".join( |
| [ |
| "Failed:", |
| str(self_dtype), |
| op.__name__, # noqa: F821 |
| device, |
| str(size), |
| ] |
| ) |
| ) from e |
| |
| def test_isnan(self): |
| x = torch.rand([4]) |
| x[0] = float("nan") |
| inputs = [x, torch.tensor([float("nan"), 0.5])] |
| dtypes = [ |
| torch.int8, |
| torch.int16, |
| torch.int32, |
| torch.int64, |
| torch.float16, |
| torch.float32, |
| torch.float64, |
| torch.bool, |
| ] |
| |
| for inp, device, dtype in product(inputs, self.devices, dtypes): |
| # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed |
| if dtype in [torch.float16, torch.bfloat16] and device == "cpu": |
| continue |
| inp = inp.to(device=device, dtype=dtype) |
| try: |
| f = torch.jit.trace(lambda x: x.isnan(), (inp,)) |
| warmup_forward(f, inp) |
| self.assertEqual(f(inp), inp.isnan()) |
| self.assertLastGraphAllFused() |
| except Exception as e: |
| raise RuntimeError( |
| " ".join(["Failed:", str(dtype), "isnan", device]) |
| ) from e |
| |
| def test_gelu(self): |
| def apply(fn): |
| return lambda x, approximate: fn(x, approximate) |
| |
| unary_ops = [ |
| F.gelu, |
| ] |
| sizes = [(1,), (2,), (4, 4)] |
| for dtype, op, device, size in product( |
| self.dtypes, unary_ops, self.devices, sizes |
| ): |
| # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed |
| if dtype in [torch.float16, torch.bfloat16] and device == "cpu": |
| continue |
| try: |
| x = self.data_for(dtype, device, size=size) |
| cond = self.data_for(torch.bool, device) |
| fn = apply(op) |
| ref = fn(x, cond) |
| except Exception: |
| # If eager mode doesn't support a dtype/op/device combo, |
| # neither does the fuser. Catch everything to avoid needing to |
| # guess what errors might be thrown by eager. |
| continue |
| try: |
| t = torch.jit.trace(fn, (x, cond)) |
| torch.testing.assert_close(ref, t(x, cond)) |
| self.assertAllFused(t.graph_for(x, cond)) |
| except Exception as e: |
| raise RuntimeError( |
| " ".join(["Failed:", str(dtype), op.__name__, device, str(size)]) |
| ) from e |
| |
| def test_unary_ops(self): |
| with torch._jit_internal._disable_emit_hooks(): |
| |
| def apply(fn): |
| return lambda x: fn(x) |
| |
| unary_ops = [ |
| torch.lgamma, |
| torch.sigmoid, |
| torch.reciprocal, |
| torch.neg, |
| torch.relu, |
| F.relu6, |
| torch.log, |
| torch.log10, |
| torch.log1p, |
| torch.log2, |
| torch.exp, |
| torch.expm1, |
| torch.erf, |
| torch.erfc, |
| torch.cos, |
| torch.sin, |
| torch.tan, |
| torch.acos, |
| torch.asin, |
| torch.cosh, |
| torch.sinh, |
| torch.atan, |
| torch.tanh, |
| F.hardtanh, |
| F.hardsigmoid, |
| F.hardswish, |
| F.softplus, |
| F.silu, |
| F.mish, |
| F.elu, |
| torch.sqrt, |
| torch.rsqrt, |
| torch.abs, |
| # TODO broken on int8 since |
| # https://github.com/pytorch/pytorch/pull/85144 |
| # RuntimeError: Invalid integral op_type: 23 |
| # torch.ceil, |
| # torch.floor, |
| # torch.round, |
| # torch.trunc, |
| torch.frac, |
| # TODO: broken on ROCm? |
| # F.hardshrink, |
| F.leaky_relu, |
| lambda x: torch.threshold(x, 0, -10), |
| # TODO: broken since type promotion was added |
| # lambda x: torch.clamp(x, -10, 10), |
| ] |
| gpu_only = {torch.erf, torch.erfc} |
| sizes = [(1,), (2,), (4, 4)] |
| for dtype, op, device, size in product( |
| self.dtypes, unary_ops, self.devices, sizes |
| ): |
| # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed |
| if dtype in [torch.float16, torch.bfloat16] and device == "cpu": |
| continue |
| # todo - re-enable. fails with .500 |
| if dtype == torch.bfloat16 and op == torch.round: |
| continue |
| if op in gpu_only and device == "cpu": |
| continue |
| try: |
| x = self.data_for(dtype, device, size=size) |
| fn = apply(op) |
| ref = fn(x) |
| except Exception: |
| # If eager mode doesn't support a dtype/op/device combo, |
| # neither does the fuser. Catch everything to avoid needing to |
| # guess what errors might be thrown by eager. |
| continue |
| try: |
| t = torch.jit.trace(fn, (x,)) |
| torch.testing.assert_close(ref, t(x)) |
| self.assertAllFused(t.graph_for(x)) |
| except Exception as e: |
| raise RuntimeError( |
| " ".join( |
| ["Failed:", str(dtype), op.__name__, device, str(size)] |
| ) |
| ) from e |
| |
| def test_binary_ops(self): |
| def apply(fn): |
| return lambda x, y: fn(x, y) |
| |
| binary_ops = [ |
| operator.__and__, |
| operator.__or__, |
| operator.__xor__, |
| torch.add, |
| torch.sub, |
| torch.mul, |
| torch.min, |
| torch.max, |
| lambda x, y: torch.lerp(x, y, 0.5), |
| torch.atan2, |
| torch.div, |
| torch.eq, |
| torch.ne, |
| torch.ge, |
| torch.gt, |
| torch.lt, |
| torch.fmod, |
| torch.remainder, |
| lambda x, y: y.type_as(x), |
| ] |
| fp_only = [ |
| torch.fmod, |
| torch.remainder, |
| ] |
| devices = self.devices |
| for dtype, op, device in product(self.dtypes, binary_ops, devices): |
| if dtype in [torch.float16, torch.bfloat16] and device == "cpu": |
| continue |
| try: |
| x = self.data_for(dtype, device) |
| y = self.data_for(dtype, device) |
| fn = apply(op) |
| ref = fn(x, y) |
| except Exception: |
| # If eager mode doesn't support a dtype/op/device combo, |
| # neither does the fuser. Catch everything to avoid needing to |
| # guess what errors might be thrown by eager. |
| continue |
| try: |
| t = torch.jit.trace(fn, (x, y)) |
| self.assertEqual(ref, t(x, y)) |
| if op not in fp_only or dtype.is_floating_point: |
| self.assertAllFused(t.graph_for(x, y)) |
| except Exception as e: |
| raise RuntimeError( |
| " ".join(["Failed:", str(dtype), op.__name__, device]) |
| ) from e |
| |
| def test_binary_scalar_ops(self): |
| def apply(fn): |
| return lambda x, y: fn(x, y) |
| |
| ir_template = """ |
| graph(%x : {dtype_x}, %y : {dtype_y}): |
| %z = {op}(%x, %y) |
| return (%z)""" |
| |
| binary_ops = [ |
| "aten::mul", |
| "aten::add", |
| "aten::sub", |
| "aten::div", |
| "aten::lt", |
| "aten::le", |
| "aten::eq", |
| "aten::ne", |
| "aten::gt", |
| "aten::ge", |
| "aten::__or__", |
| "aten::__xor__", |
| "aten::__and__", |
| "aten::__lshift__", |
| "aten::__rshift__", |
| ] |
| dtypes = ["int", "float", "bool"] |
| values = {"int": [10, 3], "float": [12.34, 2.78], "bool": [True, False]} |
| devices = self.devices |
| for dtype_x, dtype_y, op, device in product( |
| dtypes, dtypes, binary_ops, devices |
| ): |
| code = ir_template.format(**locals()) |
| |
| # Interpret the graph |
| try: |
| graph = torch._C.parse_ir(code) |
| for x, y in product(values[dtype_x], values[dtype_y]): |
| ref = torch._C._jit_interpret_graph(graph, (x, y)) |
| except Exception: |
| # If we can't interpret this IR, don't bother checking NNC. |
| continue |
| |
| # Compile the graph |
| try: |
| k = torch._C._te.TensorExprKernel(graph) |
| except Exception as e: |
| raise RuntimeError( |
| " ".join(["Compilation failed:", device, str(code)]) |
| ) from e |
| |
| # Run the graph |
| for x, y in product(values[dtype_x], values[dtype_y]): |
| ref = torch._C._jit_interpret_graph(graph, (x, y)) |
| try: |
| res = k.run((x, y)) |
| self.assertEqual(ref, res) |
| except Exception as e: |
| raise RuntimeError( |
| " ".join( |
| ["Failed at runtime:", device, str(x), str(y), str(code)] |
| ) |
| ) from e |
| |
| def test_matmul(self): |
| if self.dynamic_shapes: |
| self.skipTest("don't run conv with dynamic shapes") |
| |
| def fn(x, y): |
| return torch.matmul(x, y) |
| |
| devices = ["cpu"] # No cuda support for ext calls yet |
| sizes = [ |
| [[128, 128], [128, 128]], |
| [[10, 10], [10, 10]], |
| [[1, 16], [16, 128]], |
| [[128], [128]], |
| [[128], [128, 128]], |
| [[3], [3]], |
| [[3, 4], [4]], |
| [[10, 3, 4], [4]], |
| [[10, 3, 4], [10, 4, 5]], |
| [[10, 3, 4], [4, 5]], |
| ] |
| |
| # Only 2D x 2D matrix multiply is supported. For non-supported sizes we |
| # still want to run results verification to test that we didn't |
| # accidentally fuse it, but we skip the 'is-fused' check. |
| # TODO: add support for other shape combinations and make this set empty: |
| skip_is_fused_check_sizes = [ |
| "[[128], [128]]", |
| "[[128], [128, 128]]", |
| "[[3], [3]]", |
| "[[3, 4], [4]]", |
| "[[10, 3, 4], [4]]", |
| "[[10, 3, 4], [10, 4, 5]]", |
| "[[10, 3, 4], [4, 5]]", |
| ] |
| for dtype, size, device in product(self.dtypes, sizes, devices): |
| if dtype in [torch.float16, torch.bfloat16] and device == "cpu": |
| continue |
| try: |
| size_x, size_y = size |
| x = self.data_for(dtype, device, size=size_x) |
| y = self.data_for(dtype, device, size=size_y) |
| ref = fn(x, y) |
| except Exception as e: |
| # If eager mode doesn't support a dtype/op/device combo, |
| # neither does the fuser. Catch everything to avoid needing to |
| # guess what errors might be thrown by eager. |
| continue |
| try: |
| t = torch.jit.trace(fn, (x, y)) |
| t(x, y) |
| self.assertEqual(ref, t(x, y)) |
| if str(size) not in skip_is_fused_check_sizes: |
| self.assertAllFused(t.graph_for(x, y)) |
| except Exception as e: |
| raise RuntimeError(" ".join(["Failed:", str(dtype), device])) from e |
| |
| def test_binary_tensor_scalar_ops(self): |
| with torch._jit_internal._disable_emit_hooks(): |
| |
| def apply_with_scalar(fn, scalar): |
| return lambda x: fn(x, scalar) |
| |
| # FIXME: Fails in IR Eval: torch.int64 and_ cpu |
| binary_ops = [ |
| operator.__and__, |
| operator.__or__, |
| operator.__xor__, |
| torch.add, |
| torch.sub, |
| torch.mul, |
| torch.eq, |
| torch.ne, |
| torch.ge, |
| torch.lt, |
| torch.gt, |
| ] |
| devices = self.devices |
| # Maybe we should split this into separate tests to speed it up by |
| # only using scalar values relevant to particular ops |
| scalars = [1.5, 3, 0, -2.0, -1] |
| for dtype, op, device, scalar in product( |
| self.dtypes, binary_ops, devices, scalars |
| ): |
| if dtype in [torch.float16, torch.bfloat16] and device == "cpu": |
| continue |
| try: |
| x = self.data_for(dtype, device) |
| fn = apply_with_scalar(op, scalar) |
| ref = fn(x) |
| except Exception: |
| # If eager mode doesn't support a dtype/op/device combo, |
| # neither does the fuser. Catch everything to avoid needing to |
| # guess what errors might be thrown by eager. |
| continue |
| try: |
| t = torch.jit.trace(fn, (x)) |
| self.assertEqual(ref, t(x)) |
| self.assertAllFused(t.graph_for(x)) |
| except Exception as e: |
| raise RuntimeError( |
| " ".join(["Failed:", str(dtype), op.__name__, device]) |
| ) from e |
| |
| def test_binary_div_ops(self): |
| def apply_with_scalar(fn, scalar): |
| return lambda x: fn(x, scalar) |
| |
| binary_ops = [ |
| torch.div, |
| torch.remainder, |
| torch.fmod, |
| ] |
| devices = self.devices |
| # Maybe we should split this into separate tests to speed it up by |
| # only using scalar values relevant to particular ops |
| scalars = [1.5, 3, -2.0, -1] # skip 0 |
| for dtype, op, device, scalar in product( |
| self.dtypes, binary_ops, devices, scalars |
| ): |
| if dtype in [torch.float16, torch.bfloat16] and device == "cpu": |
| continue |
| try: |
| x = self.data_for(dtype, device) |
| fn = apply_with_scalar(op, scalar) |
| ref = fn(x) |
| except Exception: |
| # If eager mode doesn't support a dtype/op/device combo, |
| # neither does the fuser. Catch everything to avoid needing to |
| # guess what errors might be thrown by eager. |
| continue |
| try: |
| t = torch.jit.trace(fn, (x)) |
| self.assertEqual(ref, t(x)) |
| except Exception as e: |
| raise RuntimeError( |
| f"Failed: {dtype} {op.__name__} {device} {scalar}" |
| ) from e |
| |
| def test_binary_pow(self): |
| def apply_with_scalar(fn, scalar): |
| return lambda x: fn(x, scalar) |
| |
| dtypes = [ |
| # FIXME: 'pow' fails with dtype=torch.float16/device=cuda/scalar=0 |
| # torch.float16, |
| torch.float32, |
| torch.float64, |
| # torch.bool intentionally not included |
| ] |
| binary_ops = [ |
| torch.pow, |
| ] |
| # Maybe we should split this into separate tests to speed it up by |
| # only using scalar values relevant to particular ops |
| scalars = [1.5, 3, 0, -2.0, -1] |
| for dtype, op, device, scalar in product( |
| dtypes, binary_ops, self.devices, scalars |
| ): |
| if dtype in [torch.float16, torch.bfloat16] and device == "cpu": |
| continue |
| try: |
| x = self.data_for(dtype, device) |
| fn = apply_with_scalar(op, scalar) |
| ref = fn(x) |
| except Exception: |
| # If eager mode doesn't support a dtype/op/device combo, |
| # neither does the fuser. Catch everything to avoid needing to |
| # guess what errors might be thrown by eager. |
| continue |
| try: |
| t = torch.jit.trace(fn, (x)) |
| self.assertEqual(ref, t(x)) |
| self.assertAllFused(t.graph_for(x)) |
| except Exception as e: |
| raise RuntimeError( |
| " ".join(["Failed:", str(dtype), op.__name__, device]) |
| ) from e |
| |
| def test_ternary_ops(self): |
| def apply(fn): |
| return lambda x, y, z: fn(x, y, z) |
| |
| ternary_ops = [ |
| torch.lerp, |
| torch.addcmul, |
| ] |
| devices = self.devices |
| for dtype, op, device in product(self.dtypes, ternary_ops, devices): |
| if dtype in [torch.float16, torch.bfloat16] and device == "cpu": |
| continue |
| try: |
| x = self.data_for(dtype, device) |
| y = self.data_for(dtype, device) |
| z = self.data_for(dtype, device) |
| fn = apply(op) |
| ref = fn(x, y, z) |
| except Exception: |
| # If eager mode doesn't support a dtype/op/device combo, |
| # neither does the fuser. Catch everything to avoid needing to |
| # guess what errors might be thrown by eager. |
| continue |
| try: |
| t = torch.jit.trace(fn, (x, y, z)) |
| self.assertEqual(ref, t(x, y, z)) |
| self.assertAllFused(t.graph_for(x, y, z)) |
| except Exception as e: |
| raise RuntimeError( |
| " ".join(["Failed:", str(dtype), op.__name__, device]) |
| ) from e |
| |
| def test_ternary_norm_ops(self): |
| def apply(fn): |
| return lambda x, y, z: fn(x, y, z) |
| |
| ternary_ops = [ |
| F.batch_norm, |
| ] |
| devices = self.devices |
| for dtype, op, device in product(self.dtypes, ternary_ops, devices): |
| if dtype in [torch.float16, torch.bfloat16] and device == "cpu": |
| continue |
| try: |
| x = self.data_for(dtype, device, size=[5, 3, 128, 128]) |
| y = self.data_for(dtype, device, size=[3]) |
| z = self.data_for(dtype, device, size=[3]) |
| fn = apply(op) |
| ref = fn(x, y, z) |
| except Exception: |
| # If eager mode doesn't support a dtype/op/device combo, |
| # neither does the fuser. Catch everything to avoid needing to |
| # guess what errors might be thrown by eager. |
| continue |
| try: |
| t = torch.jit.trace(fn, (x, y, z)) |
| self.assertEqual(ref, t(x, y, z)) |
| self.assertAllFused(t.graph_for(x, y, z)) |
| except Exception as e: |
| raise RuntimeError( |
| " ".join(["Failed:", str(dtype), op.__name__, device]) |
| ) from e |
| |
| @unittest.skip( |
| "FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure" |
| ) |
| def test_list_ops(self): |
| def apply(fn): |
| return lambda x, y, z: fn([x * x, y * y, z * z]) |
| |
| devices = self.devices |
| list_ops = [ |
| torch.cat, |
| ] |
| for dtype, op, device in product(self.dtypes, list_ops, devices): |
| if dtype in [torch.float16, torch.bfloat16] and device == "cpu": |
| continue |
| try: |
| x = self.data_for(dtype, device, size=[5, 4, 1, 7]) |
| y = self.data_for(dtype, device, size=[5, 4, 1, 7]) |
| z = self.data_for(dtype, device, size=[5, 4, 1, 7]) |
| fn = apply(op) |
| ref = fn(x, y, z) |
| except Exception: |
| # If eager mode doesn't support a dtype/op/device combo, |
| # neither does the fuser. Catch everything to avoid needing to |
| # guess what errors might be thrown by eager. |
| continue |
| try: |
| t = torch.jit.trace(fn, (x, y, z)) |
| self.assertEqual(ref, t(x, y, z)) |
| self.assertAllFused(t.graph_for(x, y, z)) |
| except Exception as e: |
| raise RuntimeError( |
| " ".join(["Failed:", str(dtype), op.__name__, device]) |
| ) from e |
| |
| def test_where_ops(self): |
| def apply(fn): |
| return lambda cond, x, y: fn(cond, x, y) |
| |
| ops = [ |
| torch.where, |
| lambda cond, x, y: torch.where(cond, x, 3.1415), |
| lambda cond, x, y: torch.where(cond, 42, y), |
| ] |
| devices = self.devices |
| for dtype, op, device in product(self.dtypes, ops, devices): |
| if dtype in [torch.float16, torch.bfloat16] and device == "cpu": |
| continue |
| try: |
| cond = self.data_for(torch.bool, device) |
| x = self.data_for(dtype, device) |
| y = self.data_for(dtype, device) |
| fn = apply(op) |
| ref = fn(cond, x, y) |
| except Exception: |
| # If eager mode doesn't support a dtype/op/device combo, |
| # neither does the fuser. Catch everything to avoid needing to |
| # guess what errors might be thrown by eager. |
| continue |
| try: |
| t = torch.jit.trace(fn, (cond, x, y)) |
| self.assertEqual(ref, t(cond, x, y)) |
| self.assertAllFused(t.graph_for(cond, x, y)) |
| except Exception as e: |
| raise RuntimeError( |
| " ".join(["Failed:", str(dtype), op.__name__, device]) |
| ) from e |
| |
| def test_unsupported_dtypes(self): |
| for device in self.devices: |
| |
| def fn(x): |
| return x * x + x |
| |
| unsupported_dtypes = [ |
| torch.uint8, |
| torch.complex32, |
| torch.complex64, |
| torch.complex128, |
| torch.qint8, |
| torch.quint8, |
| torch.qint32, |
| ] |
| for dtype in unsupported_dtypes: |
| try: |
| x = self.data_for(dtype, device) |
| ref = fn(x) |
| except Exception: |
| # If eager mode doesn't support a dtype/op/device combo, |
| # neither does the fuser. Catch everything to avoid needing to |
| # guess what errors might be thrown by eager. |
| continue |
| t = torch.jit.trace(fn, (x,)) |
| self.assertEqual(ref, t(x)) |
| self.assertEqual(len(self.findFusionGroups(t.graph_for(x))), 0) |
| |
| def test_superslomo(self): |
| devices = self.devices.copy() |
| if not LLVM_ENABLED: |
| devices.remove("cpu") |
| for device in devices: |
| # Test extracted from Super-SloMo: https://github.com/avinashpaliwal/Super-SloMo |
| # A few interesting things happen here: strided inputs of mixed size, |
| # plus outputs of mixed shapes. The latter characteristic happened to |
| # expose a memory corruption bug due to not properly guarding the |
| # outputs. |
| def eager(t0, t1, t2, t3, t4): |
| t5 = torch.mul(t0, t4) |
| t6 = torch.mul(t2, t3) |
| t7 = torch.mul(t6, t1) |
| t9 = torch.add(t5, t7) |
| t11 = torch.add(t0, t6) |
| ft_p = torch.div(t9, t11) |
| return (ft_p, t11, t9, t6) |
| |
| t0 = torch.rand(1, 6, 352, 352, device=device).transpose(0, 1) |
| t1 = torch.rand(6, 3, 352, 352, device=device) |
| t2 = torch.rand(6, device=device)[None, None, None, :].permute(3, 0, 1, 2) |
| t3 = torch.rand(6, 1, 352, 352, device=device) |
| t4 = torch.rand(6, 3, 352, 352, device=device) |
| inputs = [t0, t1, t2, t3, t4] |
| |
| script = torch.jit.script(eager) |
| for _ in range(4): |
| for pair in zip(script(*inputs), eager(*inputs)): |
| test, ref = pair |
| torch.testing.assert_close(test, ref) |
| self.assertAllFused( |
| script.graph_for(*inputs), except_for={"prim::TupleConstruct"} |
| ) |
| |
| def test_sub_gt_and(self): |
| for device in self.devices: |
| |
| def eager(t1, t2, t3, t4, t: float): |
| w = t1 - t2 |
| h = t3 - t4 |
| k = (w > t) & (h > t) |
| assert k.dtype == torch.bool |
| if t > 0.5: |
| # Putting a use of k in a never-executed conditional prevents |
| # profiling its type, which leaves it as "Tensor". If we |
| # propagate Tensor back to the definition of k, we have to be |
| # careful not to create a fusion group containing it. |
| return k + 1 |
| return w |
| |
| t = torch.rand(8, dtype=torch.float, device=device) |
| scripted = self.checkScript(eager, (t, t, t, t, 0.1)) |
| |
| @skipIfTorchDynamo("too slow") |
| def test_chunk_mul_one(self): |
| if self.dynamic_shapes: |
| self.skipTest("TODO: chunk dynamic shapes") |
| |
| for device in self.devices: |
| |
| def eager(x): |
| z, y, w = torch.chunk(x, 3, -1) |
| return z * 3, y, w |
| |
| x = torch.rand(64, 1, 3072, dtype=torch.float, device=device) |
| z, y, w = eager(x) |
| script = self.checkScript(eager, (x,)) |
| |
| def test_eq_unsqueeze_type_as(self): |
| for device in self.devices: |
| |
| def eager(a, b): |
| mask = b == 1 |
| mask = torch.unsqueeze(mask, -1) |
| x = mask.type_as(a) |
| return x, mask |
| |
| a = torch.rand(1, 64, 1024, device=device, dtype=torch.float) |
| b = torch.randint(-2, 2, (1, 64), device=device, dtype=torch.long) |
| script = self.checkScript(eager, (a, b)) |
| |
| def test_neg_pow(self): |
| def eager_tt(a: torch.Tensor, b: torch.Tensor): |
| return torch.neg(torch.pow(a, b)) |
| |
| def eager_ts(a: torch.Tensor, b: float): |
| return torch.neg(torch.pow(a, b)) |
| |
| def eager_st(a: float, b: torch.Tensor): |
| return torch.neg(torch.pow(a, b)) |
| |
| a = torch.rand(1, dtype=torch.float) |
| b = torch.rand(1, dtype=torch.float) |
| s = b.item() |
| script = self.checkScript(eager_tt, (a, b)) |
| # TODO: re-enable fusion, which doesn't work right now. just test correctness for now |
| # self.assertAllFused(script.graph_for(a, b)) |
| script = self.checkScript(eager_ts, (a, s)) |
| # self.assertAllFused(script.graph_for(a, s)) |
| script = self.checkScript(eager_st, (s, b)) |
| # self.assertAllFused(script.graph_for(s, b)) |
| |
| @unittest.skipIf(not LLVM_ENABLED, "Too slow to run with the TE interpreter") |
| def test_conv2d_depthwise(self): |
| if self.dynamic_shapes: |
| self.skipTest("don't run conv with dynamic shapes") |
| |
| def eager(input, weight, bias): |
| return torch.conv2d(input, weight, bias, stride=1, padding=1, groups=72) |
| |
| input = torch.rand((1, 72, 56, 56), dtype=torch.float) |
| weight = torch.rand((72, 1, 3, 3), dtype=torch.float) |
| bias = torch.rand((72), dtype=torch.float) |
| |
| script = self.checkScript(eager, (input, weight, bias)) |
| self.assertAllFused(script.graph_for(input, weight, bias)) |
| |
| def test_conv2d(self): |
| if self.dynamic_shapes: |
| self.skipTest("don't run conv with dynamic shapes") |
| |
| def eager(input, weight, bias): |
| return torch.conv2d(input, weight, bias, stride=1, padding=1, groups=1) |
| |
| input = torch.rand((1, 64, 56, 56), dtype=torch.float) |
| weight = torch.rand((64, 64, 3, 3), dtype=torch.float) |
| bias = torch.rand((64), dtype=torch.float) |
| |
| script = self.checkScript(eager, (input, weight, bias)) |
| FileCheck().check_not("TensorExpr").run( |
| torch.jit.last_executed_optimized_graph() |
| ) |
| |
| def test_type_as_cat(self): |
| with inline_fusion_groups(): |
| |
| def eager(x, y): |
| return torch.cat((x, y.type_as(x)), dim=1) |
| |
| dtypes = self.dtypes.copy() |
| # CPU fuser doesn't support float16. |
| dtypes.remove(torch.float16) |
| dtypes.remove(torch.bfloat16) |
| for dtype1, dtype2 in product(dtypes, dtypes): |
| x = torch.randint( |
| 2, |
| ( |
| 1, |
| 13, |
| ), |
| ).to(dtype1) |
| zero = torch.tensor([[0]]).to(dtype2) |
| one = torch.tensor([[1]]).to(dtype2) |
| script = torch.jit.trace(eager, (x, zero)) |
| for _ in range(3): |
| torch.testing.assert_close(script(x, zero), eager(x, zero)) |
| torch.testing.assert_close(script(x, one), eager(x, one)) |
| self.assertAllFused(script.graph_for(x, one)) |
| |
| def test_to_device(self): |
| def eager(x): |
| return x.to(device="cpu").relu() |
| |
| x = torch.rand(8) |
| script = self.checkScript(eager, (x,)) |
| self.assertAllFused(script.graph_for(x)) |
| |
| def test_dims(self): |
| def eager(x, y): |
| return x / (y + 0.0001) |
| |
| x = torch.linspace(-1, 1, 768, dtype=torch.float32).as_strided( |
| (1, 1, 768), (768, 1, 1) |
| ) |
| y = torch.tensor([[[2.0]]], dtype=torch.float32) |
| script = self.checkScript(eager, (x, y)) |
| self.assertAllFused(script.graph_for(x, y)) |
| |
| @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") |
| def test_channels_last_dims_dynamic(self): |
| def eager(x, y): |
| return x + (y + 0.0001) |
| |
| indices = [0, 1, 2, 3] |
| sets = [] |
| for i in range(0, len(indices) + 1): |
| for subset in combinations(indices, i): |
| sets.append(subset) # noqa: PERF402 |
| |
| for set in sets: |
| size = [2, 3, 4, 5] |
| for index in set: |
| size[index] = 1 |
| inp = torch.rand(size).to(memory_format=torch.channels_last).cuda() |
| with texpr_enable_strategy([("DYNAMIC", 20)]): |
| foo_s = torch.jit.trace(eager, (inp, inp)) |
| for _ in range(3): |
| out = foo_s(inp, inp) |
| out_eager = eager(inp, inp) |
| self.assertEqual(out_eager, out) |
| self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) |
| g = torch.jit.last_executed_optimized_graph() |
| FileCheck().check("TensorExpr").run(g) |
| |
| def test_exhaust_specializations(self): |
| with texpr_enable_strategy([("STATIC", 1)]): |
| |
| @torch.jit.script |
| def foo(x): |
| return x + x + x |
| |
| for _ in range(3): |
| foo(torch.rand([2, 2])) |
| |
| for _ in range(3): |
| foo(torch.rand([4, 4, 4])) |
| |
| g = torch.jit.last_executed_optimized_graph() |
| torch._C._jit_pass_inline(g) |
| |
| FileCheck().check_count("TensorExpr", 2, exactly=True).run(g) |
| |
| def test_unsqueeze_var_dim(self): |
| def eager(x, y, z: int): |
| return x * torch.unsqueeze(y, dim=z) |
| |
| x = torch.rand(4, 4, 64).permute(1, 0, 2) |
| y = torch.rand(4, 4) |
| z = 2 |
| script = self.checkScript(eager, (x, y, z)) |
| |
| def _test_fwd_bwd(self, fn): |
| x = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True) |
| xs = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True) |
| script = torch.jit.script(fn) |
| for i in range(11): |
| y = fn(x) |
| g0 = torch.rand_like(y) |
| y.backward(g0) |
| |
| ys = script(xs) |
| ys.backward(g0) |
| |
| with torch.no_grad(): |
| x -= 0.1 * x.grad |
| xs -= 0.1 * xs.grad |
| x.grad = None |
| xs.grad = None |
| torch.testing.assert_close(y, ys) |
| |
| def test_relu_fwd_bwd(self): |
| def eager(x): |
| return torch.relu(x * 1.01) |
| |
| self._test_fwd_bwd(eager) |
| |
| def test_hardswish_fwd_bwd(self): |
| def eager(x): |
| return F.hardswish(x) * 1.01 |
| |
| self._test_fwd_bwd(eager) |
| |
| def test_hardsigmoid_fwd_bwd(self): |
| def eager(x): |
| return F.hardsigmoid(x) * 1.01 |
| |
| self._test_fwd_bwd(eager) |
| |
| def test_cat_graph_opt(self): |
| def foo(x, y, z): |
| return torch.log(torch.cat([x, y, z])) |
| |
| self.checkScript( |
| foo, (torch.rand([5, 5]), torch.rand([2, 5]), torch.rand([1, 5])) |
| ) |
| # TODO: not sure why not updated graph isn't reflected in last_optimized_graph |
| self.assertLastGraphAllFused() |
| |
| def test_dynamic_cat(self): |
| with inline_fusion_groups(): |
| |
| @torch.jit.script |
| def repro( |
| xs: List[torch.Tensor], ys: List[torch.Tensor], zs: List[torch.Tensor] |
| ): |
| return [ |
| torch.cat([x, torch.cat([y, z], dim=-1)], dim=-1) |
| for x, y, z in zip(xs, ys, zs) |
| ] |
| |
| for _ in range(3): |
| N = 3 |
| xs = [torch.ones(21) for _ in range(N)] |
| # Note: concat of ys and zs will have the same size for each |
| # pair, even though the individual ys and zs do not. |
| ys = [torch.ones(N - i) for i in range(N)] |
| zs = [torch.ones(i) for i in range(N)] |
| repro(xs, ys, zs) |
| |
| def test_scalar_only_inputs(self): |
| def eager(b: float): |
| a = torch.ones(1) |
| return a * b |
| |
| script = self.checkScript(eager, (1.0,)) |
| |
| def test_cat_2k_args(self): |
| with inline_fusion_groups(): |
| |
| def eager(x): |
| return torch.relu(torch.cat([x for _ in range(2000)])) |
| |
| x = torch.randn(1) |
| trace = self.checkTrace(eager, (x,)) |
| fusion_groups = self.findFusionGroups(trace.graph_for(x)) |
| self.assertEqual(len(fusion_groups), 0) |
| |
| def test_adaptive_avg_pool2d(self): |
| # TODO: once the adaptive_avg_pool2d is available in OpInfo DB, this |
| # test should be moved there |
| with inline_fusion_groups(): |
| |
| def foo1(x): |
| return torch.nn.functional.adaptive_avg_pool2d(x, (2, 2)) |
| |
| def foo2(x): |
| return torch.nn.functional.adaptive_avg_pool2d(x, (2)) |
| |
| x = torch.randn(4, 4, 4) |
| for foo in [foo1, foo2]: |
| f = torch.jit.trace(foo, (x,)) |
| kernel = torch._C._te.TensorExprKernel(f.graph) |
| correct_val = f(x) |
| self.assertEqual(kernel.run((x,)), correct_val) |
| |
| def test_unrolled_cat(self): |
| with inline_fusion_groups(): |
| |
| def eager(x): |
| ret = torch.empty(0) |
| for i in range(x.shape[0]): |
| ret = torch.cat([ret, x[i].relu()]) |
| return ret |
| |
| script = torch.jit.script(eager) |
| |
| # Warm up with size=1 tensor; since the loop iterates once the |
| # profile data will be "burned in" assuming size=1, and then |
| # unrolled. |
| x = torch.ones(1, 1) |
| for _ in range(3): |
| script(x) |
| |
| torch.testing.assert_close(eager(x), script(x)) |
| |
| # Now when an input hits the unrolled path, it will produce an |
| # incorrectly-sized tensor, since size=1 has been burned in. |
| x = torch.ones((8, 1)) |
| torch.testing.assert_close(eager(x), script(x)) |
| |
| @skipIfTorchDynamo("too slow") |
| @unittest.skipIf(TEST_WITH_ASAN, "takes 10+ minutes on asan") |
| @unittest.skipIf(TEST_WITH_ROCM, "Tensor-likes are not close for nans") |
| def test_batch_norm(self): |
| def test(fn, args): |
| trace = torch.jit.trace(fn, args) |
| self.assertAllFused(trace.graph_for(*args)) |
| # TODO: Are `NaN`'s actually ok here or did this pass silently before, because `equal_nan=True` was the |
| # default? |
| torch.testing.assert_close(fn(*args), trace(*args), equal_nan=True) |
| |
| def bn(i, x): |
| return torch.batch_norm(i, x, x, x, x, False, 0.1, 1e-4, False).relu() |
| |
| def bn_no_weight(i, x): |
| return torch.batch_norm(i, None, x, x, x, False, 0.1, 1e-4, False).relu() |
| |
| def bn_no_bias(i, x): |
| return torch.batch_norm(i, x, None, x, x, False, 0.1, 1e-4, False).relu() |
| |
| def bn_neither(i, x): |
| return torch.batch_norm(i, None, None, x, x, False, 0.1, 1e-4, False).relu() |
| |
| for device in self.devices: |
| i = torch.randn(4, 16, 32, 40, device=device) |
| x = torch.randn(16, device=device) |
| for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]: |
| test(fn, (i, x)) |
| |
| def test_profiler(self): |
| @torch.jit.script |
| def test(x, y, z): |
| return x * y + z |
| |
| args = [torch.randn(4) for _ in range(3)] |
| with torch.autograd.profiler.profile() as prof: |
| for _ in range(3): |
| test(*args) |
| self.assertIn("fused_mul_add", prof.table()) |
| |
| def test_skip_grad_in_check(self): |
| @torch.jit.script |
| def foo(x): |
| return (x + 2) / 2 |
| |
| inp = torch.rand([4, 4]) |
| for _ in range(3): |
| foo(inp) |
| |
| inp.requires_grad_(True) |
| with torch.inference_mode(): |
| for _ in range(3): |
| foo(inp) |
| g = torch.jit.last_executed_optimized_graph() |
| torch._C._jit_pass_inline(g) |
| torch._C._jit_pass_inline(g) |
| FileCheck().check_count("prim::If", 1, exactly=True).run(g) |
| |
| def test_dynamic_shapes(self): |
| from functools import partial |
| |
| n = 10 |
| |
| gen_tensor = ( |
| lambda n: R(1, n), |
| lambda n: R(n, n), |
| lambda n: R(n, n).transpose(0, 1), |
| lambda n: R(n + 1, n + 1, 2)[:n, n, 0], |
| lambda n: R(n, n, 2)[:, :, 0], |
| lambda n: R(n, n + 1, n + 2, n + 3).to(memory_format=torch.channels_last), |
| ) |
| |
| with texpr_enable_strategy([("DYNAMIC", 20)]): |
| |
| def foo(x, y, z): |
| return torch.sigmoid(torch.tanh(x)) |
| |
| foo.__disable_jit_function_caching__ = True |
| |
| def fi(x, y, z): |
| return torch.tanh(x + y) |
| |
| fi.__disable_jit_function_caching__ = True |
| |
| def fum(x, y, z): |
| return torch.tanh(x + y) + z |
| |
| fum.__disable_jit_function_caching__ = True |
| |
| funcs = [foo, fi, fum] |
| with inline_fusion_groups(): |
| for device in self.devices: |
| I = partial(torch.randint, 0, 100, device=device) |
| R = partial(torch.randn, device=device) |
| |
| for i, func in enumerate(funcs): |
| num_args = i + 1 |
| for j, gen in enumerate(gen_tensor): |
| inps = (gen(n), gen(n), gen(n)) |
| func_s = torch.jit.trace(func, inps, check_trace=False) |
| torch._C._jit_pass_erase_shape_information(func_s.graph) |
| for _ in range(2): |
| x, y, z = gen(n), gen(n), gen(n) |
| func_s(x, y, z) |
| |
| for incr in range(3): |
| func_s(*[gen(n + 1) for _ in range(3)]) |
| |
| g = torch.jit.last_executed_optimized_graph() |
| torch._C._jit_pass_inline(g) |
| torch._C._jit_pass_dce(g) |
| |
| # We should see only one optimized kernel |
| FileCheck().check_count( |
| "TensorExprDynamicGuard", 1, exactly=True |
| ).run(g) |
| self.assertEqual(func(*inps), func_s(*inps)) |
| |
| gen = gen_tensor[0] |
| inps = (gen(n), gen(n), gen(n)) |
| foo_s = torch.jit.trace(foo, inps) |
| torch._C._jit_pass_erase_shape_information(foo_s.graph) |
| g_prev = None |
| for gen in gen_tensor: |
| for i in range(3): |
| foo_s(*[gen(n + i) for _ in range(3)]) |
| inps = (gen(n), gen(n), gen(n)) |
| self.assertEqual(foo_s(*inps), foo(*inps)) |
| g = torch.jit.last_executed_optimized_graph() |
| torch._C._jit_pass_inline(g) |
| torch._C._jit_pass_dce(g) |
| FileCheck().check_count( |
| "TensorExprDynamicGuard", len(gen_tensor), exactly=True |
| ).run(g) |
| |
| @unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA") |
| def test_autocast_up(self): |
| def f(x): |
| y = x._autocast_to_full_precision(True, True) |
| z = torch.exp(y) |
| return z |
| |
| x = torch.rand((2, 2), dtype=torch.half, device="cuda") |
| scr = torch.jit.script(f) |
| scr(x) |
| scr(x) |
| self.assertLastGraphAllFused() |
| |
| @unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA") |
| def test_autocast_down(self): |
| def f(x): |
| y = torch.sigmoid(x) |
| z = y._autocast_to_reduced_precision(True, True, torch.half, torch.half) |
| return z |
| |
| x = torch.rand((2, 2), dtype=torch.float, device="cuda") |
| scr = torch.jit.script(f) |
| scr(x) |
| scr(x) |
| self.assertLastGraphAllFused() |
| |
| @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") |
| def test_to_dtype(self): |
| def f(x): |
| y = torch.sigmoid(x) |
| z = y._autocast_to_reduced_precision(True, True, torch.half, torch.bfloat16) |
| h = z._autocast_to_full_precision(True, True) |
| i = h.to(dtype=torch.bfloat16) |
| j = i.to(dtype=torch.float32) |
| return j |
| |
| x = torch.rand((2, 2), dtype=torch.float32) |
| scr = torch.jit.trace(f, x) |
| scr(x) |
| scr(x) |
| self.assertLastGraphAllFused() |
| self.assertEqual(f(x), scr(x), atol=4e-3, rtol=4e-3) |
| |
| bf_x = torch.rand((2, 2), dtype=torch.bfloat16) |
| bf_scr = torch.jit.trace(f, bf_x) |
| bf_scr(bf_x) |
| bf_scr(bf_x) |
| graph = bf_scr.graph_for(bf_x) |
| fusion_groups = self.findFusionGroups(graph) |
| self.assertEqual(len(fusion_groups), 2) |
| self.assertEqual(f(bf_x), bf_scr(bf_x), atol=4e-3, rtol=4e-3) |
| |
| def test_with_strict_fusion(self): |
| def success(x): |
| with torch.jit.strict_fusion(): |
| return x + x + x |
| |
| scripted = self.checkScript(success, (torch.rand([4]),)) |
| g = torch.jit.last_executed_optimized_graph() |
| FileCheck().check_not("aten::add").check("prim::TensorExprGroup").run(g) |
| |
| def foo(x): |
| with torch.jit.strict_fusion(): |
| return x + x + torch.rand([4]) + 3 |
| |
| with self.assertRaises(Exception) as error_out: |
| foo_s = torch.jit.script(foo) |
| foo_s(torch.rand([4])) |
| foo_s(torch.rand([4])) |
| print(torch.jit.last_executed_optimized_graph()) |
| fc = FileCheck().check("Found unfused operators") |
| fc.check("aten::rand(SymInt[] size") |
| fc.check("torch.rand([4]").run(str(error_out.exception)) |
| |
| with warnings.catch_warnings(record=True) as warns: |
| foo(torch.rand([4])) |
| |
| FileCheck().check("Only works in script mode").run(str(warns[0])) |
| |
| def test_autodiff(x): |
| with torch.jit.strict_fusion(): |
| return torch.rand([4]) + x + x + x |
| |
| foo_s = torch.jit.script(test_autodiff) |
| inp = torch.rand([4], requires_grad=True) |
| with self.assertRaises(Exception) as error_out: |
| for _ in range(3): |
| foo_s(inp) |
| f = FileCheck().check("unfused operators").check("aten::rand") |
| f.run(str(error_out.exception)) |
| |
| def test_separate_fusions(x, y): |
| with torch.jit.strict_fusion(): |
| return x + x + x, y + y + y |
| |
| inp = torch.rand([4], requires_grad=True) |
| with self.assertRaises(Exception) as error_out: |
| for _ in range(3): |
| foo_s = torch.jit.script(test_separate_fusions) |
| foo_s(inp, inp) |
| |
| f = FileCheck().check("Found multiple fusions") |
| f.run(str(error_out.exception)) |
| |
| def test_constant_chunk_shapes(self): |
| # We had an issue where buildShapeExpressions would fail as show below: |
| # |
| # %1 : Tensor = Constant[..] # not supported, we don't build this shape |
| # %2 : Tensor = Constant[..] # not supported |
| # %3 : Tensor = aten::add(%1, %2) # inputs not supported, we don't build shape |
| # ... = prim::ConstantChunk[..](%3) # it forgets to check whether input shapes exist, and fails |
| if self.dynamic_shapes: |
| self.skipTest("TODO: chunk dynamic shapes") |
| |
| for device in self.devices: |
| |
| def f(x, y): |
| r = torch.tensor(4) |
| z1, z2 = (x + y + r).chunk(2, dim=1) |
| return z1 * z2 |
| |
| x = torch.randn(4, 4, dtype=torch.float, device=device) |
| y = torch.randn(4, 4, dtype=torch.float, device=device) |
| |
| ge = self.checkTrace(f, (x, y)) |
| graph = ge.graph_for(x, y) |
| |
| # make sure that we are actually testing the right scenario |
| FileCheck().check("with " + FUSION_GROUP + "_").check_count( |
| "ConstantChunk", 1, exactly=True |
| ).run(str(graph)) |
| |
| f_traced = torch.jit.trace(f, (x, y)) |
| |
| for i in range(4): |
| # make sure this doesn't error out |
| res = f_traced(x, y) |
| |
| self.assertEqual(res, f(x, y)) |
| |
| @unittest.skipIf(not RUN_CUDA_HALF, "half-precision NNC fusion requires CUDA") |
| def test_pow_multiple_dtype(self): |
| # https://github.com/pytorch/pytorch/issues/75476 |
| def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor: |
| p = torch.sigmoid(p) |
| result = p**gamma |
| return result |
| |
| x = torch.rand((2, 2), dtype=torch.half, device="cuda") |
| |
| ref = fn(x) |
| |
| script_fn = torch.jit.script(fn) |
| for i in range(4): |
| res = script_fn(x) |
| |
| self.assertEqual(ref, res) |
| |
| |
| class TestTEFuserStatic(TestTEFuser): |
| dynamic_shapes = False |
| |
| |
| class TestTEFuserDynamic(TestTEFuser): |
| dynamic_shapes = True |
| |
| |
| del TestTEFuser |
| |
| works_list = [ |
| "__radd__", |
| "__rdiv__", |
| "__rmul__", |
| "__rmod__", |
| "abs", |
| "acos", |
| "add", |
| "addcmul", |
| "addmm.decomposed", |
| "asin", |
| "atan", |
| "atan2", |
| "ceil", |
| "clamp", |
| "clamp.scalar", |
| "contiguous", |
| "cos", |
| "cosh", |
| "div.no_rounding_mode", |
| "div.true_rounding", |
| "div.floor_rounding", |
| "div.trunc_rounding", |
| "eq", |
| "erf", |
| "erfc", |
| "exp", |
| "expand", |
| "expand_as", |
| "expm1", |
| "floor", |
| "fmod", |
| "fmod.autodiffed", |
| "ge", |
| "gt", |
| "isnan", |
| "le", |
| "lerp", |
| "lgamma", |
| "log", |
| "log10", |
| "log1p", |
| "log2", |
| "lt", |
| "masked_fill", |
| "max.binary", |
| "mean", |
| "min.binary", |
| "mm", |
| "mul", |
| "ne", |
| "neg", |
| "nn.functional.hardshrink", |
| "nn.functional.hardsigmoid", |
| "nn.functional.hardswish", |
| "nn.functional.softplus", |
| "nn.functional.hardtanh", |
| "nn.functional.leaky_relu", |
| "nn.functional.relu", |
| "nn.functional.relu6", |
| "nn.functional.softsign", |
| "nn.functional.tanhshrink", |
| "nn.functional.threshold", |
| "permute", |
| "pow", |
| "reciprocal", |
| "remainder", |
| "remainder.autodiffed", |
| "reshape", |
| "reshape_as", |
| "round", |
| "rsub", |
| "rsub.rsub_tensor", |
| "rsqrt", |
| "sigmoid", |
| "sign", |
| "sin", |
| "sinh", |
| "sqrt", |
| "sub", |
| "sum", |
| "t", |
| "tan", |
| "tanh", |
| "transpose", |
| "true_divide", |
| "trunc", |
| "unsqueeze", |
| "view", |
| "view_as", |
| "where", |
| "bool", |
| "byte", |
| "char", |
| "double", |
| "float", |
| "half", |
| "int", |
| "long", |
| "short", |
| "bool.channels_last", |
| "byte.channels_last", |
| "char.channels_last", |
| "double.channels_last", |
| "float.channels_last", |
| "half.channels_last", |
| "int.channels_last", |
| "long.channels_last", |
| "short.channels_last", |
| ] |
| |
| known_failures = [ |
| "__rmatmul__", |
| "frac", |
| "matmul", |
| ] |
| |
| # If your OpInfo test causes this test to fail, add it here |
| skip_ops = ["conj"] |
| |
| |
| def get_name(op): |
| l = [op.name] |
| if op.variant_test_name != "": |
| l.append(op.variant_test_name) |
| return ".".join(l) |
| |
| |
| # Purpose of this class is to allow super() calls. |
| # super() [with no arguments] fails, presumably because of how instantiate_device_type_tests works. |
| # super(TestNNCOpInfo, self) fails because TestNNCOpInfo gets deleted from global scope. |
| # super(JitCommonTestCase, self).fn() would skip JitCommonTestCase.fn() implementation |
| class TestNNCOpInfoParent(JitCommonTestCase): |
| pass |
| |
| |
| class TestNNCOpInfo(TestNNCOpInfoParent): |
| def setUp(self): |
| super(TestNNCOpInfoParent, self).setUp() |
| self.tensorexpr_options = TensorExprTestOptions() |
| |
| def tearDown(self): |
| self.tensorexpr_options.restore() |
| super(TestNNCOpInfoParent, self).tearDown() |
| |
| def te_compile(self, device, dtype, op): |
| if op.name in skip_ops: |
| return |
| sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) |
| for sample_input in sample_inputs_itr: |
| arg_values = [sample_input.input] + list(sample_input.args) |
| kwarg_values = sample_input.kwargs |
| param_names = [] |
| param_values = [] |
| fx_args = [] |
| for idx, v in enumerate(arg_values): |
| if isinstance(v, torch.Tensor): |
| param_names.append(f"arg_{idx}") |
| param_values.append(v) |
| fx_args.append(param_names[-1]) |
| else: |
| fx_args.append(f"{repr(v)}") |
| |
| for k, v in kwarg_values.items(): |
| if isinstance(v, torch.Tensor): |
| param_names.append(k) |
| param_values.append(v) |
| fx_args.append(f"{k} = {k}") |
| else: |
| fx_args.append(f"{k} = {repr(v)}") |
| |
| code = f""" |
| def f({', '.join(param_names)}): |
| return op.op({', '.join(fx_args)})""" |
| g = {"torch": torch, "inf": math.inf, "op": op} |
| exec(code, g) |
| f = g["f"] |
| f.__module__ = "test" |
| out = f(*param_values) |
| |
| ts_g = torch.jit.trace(f, param_values) |
| kernel = torch._C._te.TensorExprKernel(ts_g.graph) |
| correct_val = f(*param_values) |
| self.assertEqual(kernel.run(tuple(param_values)), correct_val) |
| self.assertEqual(kernel.fallback(tuple(param_values)), correct_val) |
| |
| @onlyCPU |
| @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") |
| @ops( |
| [op for op in op_db if get_name(op) in works_list], |
| allowed_dtypes=(torch.float,), |
| ) |
| def test_working(self, device, dtype, op): |
| self.te_compile(device, dtype, op) |
| |
| @onlyCPU |
| @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") |
| @ops( |
| [op for op in op_db if get_name(op) in known_failures], |
| allowed_dtypes=(torch.float,), |
| ) |
| def test_failures(self, device, dtype, op): |
| try: |
| self.te_compile(device, dtype, op) |
| except Exception as e: |
| pass |
| else: |
| raise RuntimeError( |
| "Expected test to fail. If it now works, move op into works_list" |
| ) |
| |
| @onlyCPU |
| @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") |
| @ops( |
| [op for op in op_db if get_name(op) not in works_list + known_failures], |
| allowed_dtypes=(torch.float,), |
| ) |
| def test_unsupported(self, device, dtype, op): |
| if get_name(op) in skip_ops: |
| return |
| try: |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore", TracerWarning) # noqa: F821 |
| self.te_compile(device, dtype, op) |
| except Exception as e: |
| pass |
| else: |
| raise RuntimeError( |
| "Expected test to fail. If it now works, move op into works_list" |
| ) |
| |
| @slowTest |
| @onlyCPU |
| @ops(op_db, dtypes=OpDTypes.supported) |
| def test_nnc_correctness(self, device, dtype, op): |
| if not op.supports_tracing: |
| self.skipTest("Requires tracing support") |
| |
| with NoTracerWarnContextManager() as no_warn: |
| variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op) |
| |
| for variant, sample in variant_sample_pairs: |
| trace = create_traced_fn(self, variant, cache_traced_fn=True) |
| ref = variant( |
| *clone_inputs((sample.input, *sample.args)), **sample.kwargs |
| ) |
| |
| trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) |
| val = trace( |
| *clone_inputs((sample.input, *sample.args)), **sample.kwargs |
| ) |
| |
| atol = 2e-1 if dtype == torch.bfloat16 else 1e-5 |
| rtol = 2e-1 if dtype == torch.bfloat16 else 1e-5 |
| self.assertEqual(ref, val, atol=atol, rtol=rtol) |
| |
| # https://github.com/pytorch/pytorch/issues/35600 |
| # each torch.jit.trace adds state to the _python_cu compilation unit |
| # since this test traces a lot of functions, out-of-memory can occur |
| # if the CU is not cleared. |
| torch.jit._state._python_cu.drop_all_functions() |
| |
| |
| # CPU fuser not currently used in fbcode |
| only_for = ("cuda") if IS_FBCODE else ("cpu", "cuda") |
| instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for) |
| |
| |
| # Purpose of this class is to allow super() calls. (See TestNNCOpInfoParent) |
| class TestLoopnestRandomizationParent(JitTestCase): |
| pass |
| |
| |
| class TestLoopnestRandomization(TestLoopnestRandomizationParent): |
| def setUp(self): |
| super(TestLoopnestRandomizationParent, self).setUp() |
| self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() |
| self.old_must_use_cpu_state = torch._C._jit_get_te_must_use_llvm_cpu() |
| self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu() |
| |
| torch._C._jit_override_can_fuse_on_cpu(True) |
| # TODO: force LLVM. need to add it to asan, mac, windows builds + sandcastle |
| # torch._C._jit_set_te_must_use_llvm_cpu(True) |
| torch._C._jit_override_can_fuse_on_gpu(True) |
| |
| self.old_profiling_executor = torch._C._jit_set_profiling_executor(True) |
| self.old_profiling_mode = torch._C._get_graph_executor_optimize(True) |
| |
| self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() |
| torch._C._debug_set_fusion_group_inlining(False) |
| |
| self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() |
| torch._C._jit_set_texpr_fuser_enabled(True) |
| |
| self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu() |
| torch._C._jit_set_te_must_use_llvm_cpu(False) |
| |
| # Set the seed to 1. This tests the codepath through random |
| # transformation. |
| os.environ["PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED"] = "1" |
| |
| def tearDown(self): |
| torch._C._jit_set_profiling_executor(self.old_profiling_executor) |
| torch._C._get_graph_executor_optimize(self.old_profiling_mode) |
| |
| torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state) |
| torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state) |
| torch._C._jit_set_te_must_use_llvm_cpu(self.old_must_use_cpu_state) |
| torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) |
| |
| torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) |
| torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu) |
| |
| # Set it back to 0. |
| os.environ["PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED"] = "0" |
| super(TestLoopnestRandomizationParent, self).tearDown() |
| |
| @onlyCPU |
| @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel") |
| def test_relu(self, device): |
| def fn_test_relu(x, y): |
| return F.relu(x + 0.5 * y) |
| |
| x = torch.randn(4, 4, dtype=torch.float, device=device) |
| y = torch.randn(4, 4, dtype=torch.float, device=device) |
| |
| fn = fn_test_relu |
| traced_fn = torch.jit.trace(fn, (x, y)) |
| |
| ref = fn(x, y) |
| res = traced_fn(x, y) |
| assert torch.allclose(ref, res) |
| |
| |
| instantiate_device_type_tests(TestLoopnestRandomization, globals(), only_for=("cpu")) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |