| # Owner(s): ["module: inductor"] |
| import contextlib |
| import itertools |
| import sys |
| import unittest |
| from typing import Callable |
| from unittest.mock import patch |
| |
| import numpy as np |
| import sympy |
| import torch |
| import torch._dynamo |
| from torch._C import FileCheck |
| from torch._dynamo.utils import same |
| from torch._inductor import codecache, config, metrics |
| from torch._inductor.codegen.cpp import ( |
| CppOverrides, |
| CppVecKernelChecker, |
| CppVecOverrides, |
| ) |
| from torch._inductor.compile_fx import compile_fx_inner, complex_memory_overlap |
| from torch._inductor.graph import GraphLowering |
| from torch._inductor.ir import InterpreterShim |
| from torch._inductor.utils import timed |
| from torch._inductor.virtualized import V |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch.nn import functional as F |
| from torch.testing._internal.common_utils import IS_MACOS, slowTest |
| from torch.utils._python_dispatch import TorchDispatchMode |
| |
| try: |
| try: |
| from . import test_torchinductor |
| except ImportError: |
| import test_torchinductor |
| except unittest.SkipTest: |
| if __name__ == "__main__": |
| sys.exit(0) |
| raise |
| |
| |
| vec_dtypes = test_torchinductor.vec_dtypes |
| run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code |
| TestCase = test_torchinductor.TestCase |
| aten = torch.ops.aten |
| |
| |
| class CPUReproTests(TestCase): |
| def test_conv_stride_constraints(self): |
| for fmt in [torch.channels_last, torch.contiguous_format]: |
| # TorchDispatch doesn't work in our cuda invocation for some reason |
| m = torch.nn.Conv2d(5, 6, [3, 3]) |
| |
| def fn(inp, weight): |
| return ( |
| F.conv2d( |
| inp, weight, None, m.stride, m.padding, m.dilation, m.groups |
| ), |
| ) |
| |
| inp = torch.randn([2, 5, 16, 16]) |
| inps = [inp, m.weight.to(memory_format=fmt)] |
| fn_fx = make_fx(fn)(*inps) |
| fn_compiled = compile_fx_inner(fn_fx, inps) |
| test_self = self |
| conv_seen = False |
| |
| class RecordFunctions(TorchDispatchMode): |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| kwargs = kwargs if kwargs else {} |
| if func == torch.ops.aten.convolution.default: |
| test_self.assertTrue(args[0].is_contiguous(memory_format=fmt)) |
| test_self.assertTrue(args[1].is_contiguous(memory_format=fmt)) |
| nonlocal conv_seen |
| conv_seen = True |
| |
| return func(*args, **kwargs) |
| |
| with RecordFunctions(): |
| out = fn_compiled(inps) |
| |
| self.assertTrue(conv_seen) |
| |
| def test_inplace_squeeze_needed(self): |
| mod = torch.nn.Sequential( |
| torch.nn.Linear(10, 10), |
| torch.nn.LayerNorm(10), |
| torch.nn.ReLU(), |
| ).eval() |
| |
| @torch._dynamo.optimize("inductor") |
| def fn(x): |
| return mod(x) |
| |
| v = torch.randn(10) |
| result = fn(v) |
| # TODO: OMP parallel reduction order is not deterministic. |
| # Hence, the accurarcy might vary up and down. For short term, |
| # we increase the tolerance and will fix it later by using |
| # aten parallel. |
| assert same(result, mod(v), tol=5e-1) |
| |
| def test_cat_mul(self): |
| # https://github.com/pytorch/pytorch/issues/93365 |
| def fn(p0, p1): |
| y1 = torch.cat([p0, p1], dim=0) |
| y2 = torch.mul(y1, y1) |
| return y1, y2 |
| |
| p0 = torch.randn(3, 4) |
| p1 = torch.randn(3, 4) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| opt_fn(p0, p1) |
| real_out = fn(p0, p1) |
| compiled_out = opt_fn(p0, p1) |
| assert same(real_out, compiled_out) |
| |
| def test_pow_cos(self): |
| # https://github.com/pytorch/pytorch/issues/98149 |
| def fn(x): |
| t = x.pow(5) |
| return torch.cos(t) |
| |
| x = torch.tensor([4], dtype=torch.uint8) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| opt_fn(x) |
| real_out = fn(x) |
| compiled_out = opt_fn(x) |
| assert same(real_out, compiled_out) |
| |
| def test_reduce_with_masked(self): |
| # https://github.com/pytorch/pytorch/issues/96484 |
| def fn(a, b): |
| a = torch.nn.functional.pad(a, (0, -1)) |
| c = a + b |
| return c.min(0).values |
| |
| a = torch.randn([2]) |
| b = torch.randn([2]) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| opt_fn(a, b) |
| real_out = fn(a, b) |
| compiled_out = opt_fn(a, b) |
| assert same(real_out, compiled_out) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_sigmoid_with_reduction(self): |
| def fn(x): |
| x = torch.ops.aten.sigmoid.default(x) |
| return torch.ops.aten.mean.dim(x, [-1, -2], True) |
| |
| x = torch.randn((1, 8, 8, 8)) |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| opt_fn(x) |
| |
| real_out = fn(x) |
| compiled_out = opt_fn(x) |
| assert same(real_out, compiled_out, equal_nan=True) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_decomposed_dequant_relu_quant(self): |
| def fn(x, scale, zero_point, use_dequant, use_quant): |
| # For quantized_decomposed.dequantize_per_tensor |
| # Refer to torch/ao/quantization/fx/_decomposed.py |
| if use_dequant: |
| x = (x.to(torch.float32) - zero_point) * scale |
| |
| x = torch.relu(x) |
| |
| # For quantized_decomposed.quantize_per_tensor |
| # Refer to torch/ao/quantization/fx/_decomposed.py |
| if use_quant: |
| inv_scale = 1.0 / scale |
| x = torch.clamp(torch.round(x * inv_scale) + zero_point, 0, 255).to( |
| torch.uint8 |
| ) |
| return x |
| |
| use_dequant_list = [False, True] |
| use_quant_list = [False, True] |
| for use_dequant, use_quant in itertools.product( |
| use_dequant_list, use_quant_list |
| ): |
| x = torch.clamp( |
| torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 0, 255 |
| ) |
| if use_dequant: |
| x = x.to(torch.uint8) |
| zero_point = 100 |
| scale = 0.01 |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| opt_fn(x, scale, zero_point, use_dequant, use_quant) |
| |
| real_out = fn(x, scale, zero_point, use_dequant, use_quant) |
| compiled_out = opt_fn(x, scale, zero_point, use_dequant, use_quant) |
| assert same(real_out, compiled_out, equal_nan=True) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| def test_inplace_add_alpha(self): |
| def fn(x, y): |
| aten.add_.Tensor(x, y, alpha=0.55) |
| return (x,) |
| |
| x1 = torch.zeros(10) |
| x2 = torch.zeros(10) |
| x3 = torch.zeros(10) |
| y = torch.randn(10) |
| fn_fx = make_fx(fn)(x1, y) |
| fn_compiled = compile_fx_inner(fn_fx, [x1, y]) |
| fn(x2, y) |
| fn_compiled([x3, y]) |
| assert same(x2, x3) |
| |
| def test_int_div(self): |
| def fn(x, y): |
| s3 = x.size(1) |
| a = torch.zeros((1 + s3) // 2) |
| a += y |
| return a, s3 |
| |
| p0 = torch.randint(5, (1, 8)) |
| p1 = torch.randn(1) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| opt_fn(p0, p1) |
| real_out = fn(p0, p1) |
| compiled_out = opt_fn(p0, p1) |
| assert same(real_out, compiled_out) |
| |
| def test_no_op_squeeze(self): |
| @torch._dynamo.optimize("inductor") |
| def forward(arg0_1): |
| return torch.ops.aten.squeeze.dim(arg0_1, 1) |
| |
| x = torch.randn((10, 20)) |
| assert same(x, forward(x)) |
| |
| def test_parallel_num_threads(self): |
| @torch._dynamo.optimize("inductor") |
| def fn(x1, x2): |
| return x1 + x2 |
| |
| @contextlib.contextmanager |
| def set_num_threads(num_threads): |
| orig_num_threads = torch.get_num_threads() |
| torch.set_num_threads(num_threads) |
| yield |
| torch.set_num_threads(orig_num_threads) |
| |
| x1 = torch.randn((10, 20)) |
| x2 = torch.randn((10, 20)) |
| with set_num_threads(1): |
| assert same(x1 + x2, fn(x1, x2)) |
| with set_num_threads(4): |
| assert same(x1 + x2, fn(x1, x2)) |
| |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_timed_cpu_only(self): |
| timed(lambda: torch.randn(10), ()) |
| |
| def test_complex_memory_overlap(self): |
| dense = torch.zeros(64, 32) |
| self.assertFalse(complex_memory_overlap(dense)) |
| self.assertFalse(complex_memory_overlap(dense.t())) |
| |
| strided = dense.split(4, dim=1) |
| self.assertFalse(complex_memory_overlap(strided[0])) |
| self.assertFalse(complex_memory_overlap(strided[0].t())) |
| |
| unsqueezed = dense.unsqueeze(1) |
| self.assertFalse(complex_memory_overlap(unsqueezed)) |
| self.assertFalse(complex_memory_overlap(unsqueezed.permute(1, 2, 0))) |
| |
| gathered = dense.index_select(0, torch.IntTensor([1, 0, 1])) |
| self.assertFalse(complex_memory_overlap(gathered)) |
| self.assertFalse(complex_memory_overlap(gathered.t())) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @torch._dynamo.config.patch(dynamic_shapes=True) |
| def test_vec_dynamic_shapes(self): |
| def fn(x): |
| return torch.softmax(x, -1) |
| |
| value = torch.randn((2, 10)) |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| opt_fn(value) |
| |
| real_out = fn(value) |
| compiled_out = opt_fn(value) |
| assert same(real_out, compiled_out, equal_nan=True) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_auto_simd(self): |
| vec_avx512 = codecache.supported_vec_isa_list[0] |
| vec_avx2 = codecache.supported_vec_isa_list[1] |
| self.assertTrue(vec_avx512.bit_width() == 512) |
| self.assertTrue(vec_avx2.bit_width() == 256) |
| self.assertTrue(vec_avx512.nelements() == 16) |
| self.assertTrue(vec_avx2.nelements() == 8) |
| self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32) |
| self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16) |
| |
| with config.patch({"cpp.simdlen": None}): |
| isa = codecache.pick_vec_isa() |
| if vec_avx512 in codecache.valid_vec_isa_list(): |
| self.assertTrue(isa == vec_avx512) |
| else: |
| self.assertTrue(isa == vec_avx2) |
| |
| with config.patch({"cpp.simdlen": 0}): |
| isa = codecache.pick_vec_isa() |
| self.assertFalse(isa) |
| |
| with config.patch({"cpp.simdlen": 1}): |
| isa = codecache.pick_vec_isa() |
| self.assertFalse(isa) |
| |
| with config.patch({"cpp.simdlen": 257}): |
| isa = codecache.pick_vec_isa() |
| self.assertFalse(isa) |
| |
| with config.patch({"cpp.simdlen": 513}): |
| isa_list = codecache.valid_vec_isa_list() |
| if vec_avx512 in isa_list: |
| self.assertFalse(isa) |
| |
| with config.patch({"cpp.simdlen": 512}): |
| isa_list = codecache.valid_vec_isa_list() |
| if vec_avx512 in isa_list: |
| isa = codecache.pick_vec_isa() |
| self.assertTrue(isa == vec_avx512) |
| |
| with config.patch({"cpp.simdlen": 256}): |
| isa_list = codecache.valid_vec_isa_list() |
| if vec_avx2 in isa_list: |
| isa = codecache.pick_vec_isa() |
| self.assertTrue(isa == vec_avx2) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_masked_fill_softmax(self): |
| def fn(value, mask): |
| mask = mask.to(torch.bool) |
| x = torch.masked_fill(value, mask, -33.0) |
| return torch.softmax(x, -1) |
| |
| for dtype in vec_dtypes: |
| value = torch.randn((2, 17), dtype=dtype) |
| mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8) |
| with config.patch({"cpp.simdlen": None}): |
| for cpp_wrapper_flag in [True, False]: |
| with config.patch({"cpp_wrapper": cpp_wrapper_flag}): |
| torch._dynamo.reset() |
| metrics.reset() |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| opt_fn(value, mask) |
| |
| real_out = fn(value, mask) |
| compiled_out = opt_fn(value, mask) |
| assert same(real_out, compiled_out, equal_nan=True) |
| assert metrics.generated_cpp_vec_kernel_count >= 1 |
| |
| def test_load_same_bool_tensor_twice(self): |
| @torch._dynamo.optimize("inductor") |
| def fn(a, b): |
| x = torch.masked_fill(a, b, -33.0) |
| y = torch.masked_fill(a, b, -33.0) |
| return x, y |
| |
| value = torch.randn((2, 17)) |
| mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8).to(torch.bool) |
| fn(value, mask) |
| |
| def test_cpu_vec_cosim(self): |
| cpp_vec_op_list = [] |
| cpp_op_list = [] |
| |
| for k, v in CppVecOverrides.__dict__.items(): |
| if isinstance(v, staticmethod): |
| cpp_vec_op_list.append(k) |
| for k, v in CppOverrides.__dict__.items(): |
| if isinstance(v, staticmethod): |
| cpp_op_list.append(k) |
| |
| diff = [ |
| "index_expr", |
| "signbit", |
| "isinf", |
| "mod", |
| "masked", |
| "randn", |
| "isnan", |
| "rand", |
| "bitwise_and", |
| "bitwise_or", |
| "bitwise_xor", |
| ] |
| union = {*cpp_vec_op_list, *diff} |
| self.assertTrue(set(cpp_op_list).issubset(union)) |
| |
| def test_atomic_add_bf16(self): |
| def fn(test_args): |
| res = torch.gather(**test_args) |
| return res |
| |
| input_tensor_for_ref = torch.tensor( |
| [[3.0, -5.0]], dtype=torch.bfloat16, requires_grad=True |
| ) |
| input_tensor_for_opt = torch.tensor( |
| [[3.0, -5.0]], dtype=torch.bfloat16, requires_grad=True |
| ) |
| |
| test_args_for_ref = { |
| "input": input_tensor_for_ref, |
| "dim": 1, |
| "index": torch.tensor([[1]]), |
| } |
| test_args_for_opt = { |
| "input": input_tensor_for_opt, |
| "dim": 1, |
| "index": torch.tensor([[1]]), |
| } |
| |
| opt_fn = torch.compile(fn) |
| |
| ref_fwd = fn(test_args_for_ref) |
| res_fwd = opt_fn(test_args_for_opt) |
| self.assertEqual(res_fwd, ref_fwd) |
| |
| torch.manual_seed(1) |
| bwd_tensor_for_ref = torch.randn(ref_fwd.shape, dtype=torch.bfloat16) |
| torch.manual_seed(1) |
| bwd_tensor_for_opt = torch.randn(res_fwd.shape, dtype=torch.bfloat16) |
| self.assertEqual(bwd_tensor_for_ref, bwd_tensor_for_opt) |
| |
| ref_fwd.backward(bwd_tensor_for_ref) |
| res_fwd.backward(bwd_tensor_for_opt) |
| |
| ref_grad = test_args_for_ref["input"].grad |
| res_grad = test_args_for_opt["input"].grad |
| self.assertEqual(ref_grad, res_grad) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_new_vec_op_cpu_only(self): |
| def fn(x): |
| return (torch.log1p(torch.expm1(torch.erf(x))),) |
| |
| for dtype in vec_dtypes: |
| torch.manual_seed(0) |
| x = torch.randn((2, 9), dtype=dtype) |
| x[0, 0] = torch.nan |
| x[1, -1] = torch.nan |
| |
| tol = 1e-2 if dtype == torch.bfloat16 else 1e-4 |
| |
| with config.patch({"cpp.simdlen": None}): |
| for cpp_wrapper_flag in [True, False]: |
| with config.patch({"cpp_wrapper": cpp_wrapper_flag}): |
| torch._dynamo.reset() |
| metrics.reset() |
| traced = make_fx(fn)(x) |
| compiled = compile_fx_inner(traced, [x]) |
| assert same(fn(x)[0], compiled([x])[0], equal_nan=True, tol=tol) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_vec_cpu_only_for_all_available_isa(self): |
| def fn(x): |
| return (torch.sin(torch.cos(torch.erf(x))),) |
| |
| x = torch.randn((2, 9)) |
| x[0, 0] = torch.nan |
| x[1, -1] = torch.nan |
| |
| bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()] + [None] |
| for item in bit_widths: |
| with config.patch({"cpp.simdlen": item}): |
| torch._dynamo.reset() |
| metrics.reset() |
| traced = make_fx(fn)(x) |
| compiled = compile_fx_inner(traced, [x]) |
| assert same(fn(x)[0], compiled([x])[0], equal_nan=True) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @slowTest |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test__adaptive_avg_pool2d(self): |
| def wrap_fn(oh, ow): |
| def fn(x): |
| return torch._adaptive_avg_pool2d(x, (oh, ow)) |
| |
| return fn |
| |
| bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()] |
| ih = [16, 65] |
| iw = ih |
| oh = ih |
| ow = ih |
| for _ih, _iw, _oh, _ow, _simd_len, dtype in itertools.product( |
| ih, iw, oh, ow, bit_widths, vec_dtypes |
| ): |
| x = torch.randn(2, 3, _ih, _iw, dtype=dtype).to( |
| memory_format=torch.channels_last |
| ) |
| _fn = wrap_fn(_oh, _ow) |
| with config.patch({"cpp.simdlen": _simd_len}): |
| torch._dynamo.reset() |
| metrics.reset() |
| compiled = torch.compile(_fn) |
| compiled(x) |
| assert same(_fn(x), compiled(x), equal_nan=True) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_vec_logical_and_or(self): |
| def wrap_fn(op: Callable): |
| def fn(x: torch.Tensor, y: torch.Tensor): |
| return torch.where(op(x, y), 1.0, 0.0) |
| |
| return fn |
| |
| for dtype in vec_dtypes: |
| x = torch.randn(64, dtype=dtype) |
| y = torch.randn(64, dtype=dtype) |
| logical_fns = [torch.logical_and, torch.logical_or] |
| for logical_fn in logical_fns: |
| _fn = wrap_fn(logical_fn) |
| torch._dynamo.reset() |
| metrics.reset() |
| compiled = torch.compile(_fn) |
| |
| compiled(x, y) |
| assert same(_fn(x, y), compiled(x, y), equal_nan=True) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_vec_compare_op_cpu_only(self): |
| def fn(x): |
| y1 = torch.eq(x, 1.0) |
| x = torch.where(y1, x, -x) |
| y2 = torch.ne(x, 0.0) |
| x = torch.where(y2, x, -x) |
| y3 = torch.lt(x, 5.0) |
| x = torch.where(y3, x, x - 1.0) |
| y4 = torch.gt(x, -2.0) |
| x = torch.where(y4, x, x + 1.0) |
| y5 = torch.le(x, 8.0) |
| x = torch.where(y5, x, x - 1.0) |
| y6 = torch.ge(x, -3.0) |
| x = torch.where(y6, x, x + 1.0) |
| y7 = x == 1.0 |
| x = torch.where(y7, x, -x) |
| y8 = x != 0.0 |
| x = torch.where(y8, x, -x) |
| y9 = x < 5.0 |
| x = torch.where(y9, x, x - 1.0) |
| y10 = x > -2.0 |
| x = torch.where(y10, x, x + 1.0) |
| y11 = x <= 8.0 |
| x = torch.where(y11, x, x - 1.0) |
| y12 = x >= -3.0 |
| x = torch.where(y12, x, x + 1.0) |
| return (x,) |
| |
| for dtype in vec_dtypes: |
| x = torch.randn((2, 9), dtype=dtype) |
| |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| traced = make_fx(fn)(x) |
| compiled = compile_fx_inner(traced, [x]) |
| assert same(fn(x)[0], compiled([x])[0], equal_nan=True) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| assert ( |
| metrics.generated_kernel_count |
| - metrics.generated_cpp_vec_kernel_count |
| ) == 0 |
| |
| def test_skip_cpp_codegen(self): |
| with config.patch({"disable_cpp_codegen": True}): |
| inps = torch.ones([20]), torch.rand([20]) |
| |
| def f(x, y): |
| return x + y + torch.tensor(1) |
| |
| f_opt = torch.compile()(f) |
| |
| code = run_and_get_cpp_code(f_opt, inps[0], inps[1]) |
| FileCheck().check_not("void kernel").run(code) |
| |
| self.assertEqual( |
| f(*inps), |
| f_opt(*inps), |
| ) |
| |
| # constant needs to be propagated on fallback |
| def f(x): |
| return x[torch.tensor(1) :] * 2 |
| |
| f_opt = torch.compile()(f) |
| code = run_and_get_cpp_code(f_opt, inps[0]) |
| FileCheck().check_not("void kernel").run(code) |
| self.assertEqual(f_opt(inps[0]), f(inps[0])) |
| |
| class Model(torch.nn.Module): |
| def __init__( |
| self, |
| ): |
| super().__init__() |
| |
| def forward(self, v1: torch.Tensor): |
| vx = v1.min(dim=1).values |
| v2 = torch.randn_like(vx) |
| return v2 |
| |
| model = Model() |
| x = torch.rand(10, 3, 0) |
| model_f = torch.compile()(model) |
| |
| self.assertEqual(model(x), model_f(x)) |
| |
| def test_redundant_to_node_elimination_bf16(self): |
| def fn(x, y): |
| res = x + y |
| res = torch.mean(res) |
| return (res,) |
| |
| x = torch.randn((2, 9), dtype=torch.bfloat16) |
| y = torch.randn((2, 9), dtype=torch.bfloat16) |
| |
| for torch_compile_debug in [True, False]: |
| with config.patch( |
| {"trace.enabled": torch_compile_debug, "cpp.simdlen": None} |
| ): |
| torch._dynamo.reset() |
| metrics.reset() |
| traced = make_fx(fn)(x, y) |
| compiled = compile_fx_inner(traced, [x, y]) |
| assert same(fn(x, y)[0], compiled([x, y])[0], equal_nan=True, tol=1e-2) |
| if codecache.valid_vec_isa_list(): |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| def test_do_not_insert_to_dtype_for_memory_copy_only_kernel(self): |
| def fn(x): |
| res = x.clone() |
| return (res,) |
| |
| x = torch.randn((100, 100), dtype=torch.bfloat16) |
| |
| torch._dynamo.reset() |
| metrics.reset() |
| traced = make_fx(fn)(x) |
| compiled = compile_fx_inner(traced, [x]) |
| assert same(fn(x)[0], compiled([x])[0]) |
| assert metrics.cpp_to_dtype_count == 0 |
| if codecache.valid_vec_isa_list(): |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| def test_insert_to_dtype_count(self): |
| def fn(x): |
| res = x.relu() |
| return (res,) |
| |
| x = torch.randn((100, 100), dtype=torch.bfloat16) |
| |
| torch._dynamo.reset() |
| metrics.reset() |
| traced = make_fx(fn)(x) |
| compiled = compile_fx_inner(traced, [x]) |
| assert same(fn(x)[0], compiled([x])[0]) |
| assert metrics.cpp_to_dtype_count == 2 |
| if codecache.valid_vec_isa_list(): |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_cpp_vec_constant_checker(self): |
| _graph: torch.fx.Graph = torch.fx.Graph() |
| a: torch.fx.Node = _graph.create_node("placeholder", "ops") |
| iv: torch.fx.Node = _graph.create_node("placeholder", "iv") |
| fv: torch.fx.Node = _graph.create_node("placeholder", "fv") |
| b: torch.fx.Node = _graph.create_node( |
| "call_method", |
| "constant", |
| args=( |
| a, |
| iv, |
| torch.int64, |
| ), |
| ) |
| c: torch.fx.Node = _graph.create_node( |
| "call_method", |
| "constant", |
| args=( |
| a, |
| fv, |
| torch.double, |
| ), |
| ) |
| d: torch.fx.Node = _graph.create_node( |
| "call_method", |
| "ge", |
| args=( |
| a, |
| b, |
| b, |
| ), |
| ) |
| _graph.output((d, c)) |
| |
| def get_index(): |
| return "" |
| |
| submodules = {"get_index": get_index} |
| |
| graph_lowering = GraphLowering( |
| torch.fx.GraphModule(submodules, _graph), |
| shape_env=None, |
| num_static_inputs=0, |
| ) |
| with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler( |
| graph_lowering |
| ): |
| # The moset inner loop variable is used in the index_expr |
| tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float) |
| with CppVecKernelChecker( |
| args=None, num_threads=1, tiling_factor=tiling_factor |
| ) as vec_checker: |
| i32_iinfo = np.iinfo(np.int32) |
| f32_iinfo = np.finfo(np.float32) |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max |
| ) |
| self.assertTrue(vec_checker.simd_vec) |
| |
| vec_checker.simd_vec = True |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min |
| ) |
| self.assertTrue(vec_checker.simd_vec) |
| |
| vec_checker.simd_vec = True |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.min, np.inf |
| ) |
| self.assertTrue(vec_checker.simd_vec) |
| |
| vec_checker.simd_vec = True |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.min, -np.inf |
| ) |
| self.assertTrue(vec_checker.simd_vec) |
| |
| vec_checker.simd_vec = True |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.min - 1, f32_iinfo.min |
| ) |
| self.assertFalse(vec_checker.simd_vec) |
| |
| vec_checker.simd_vec = True |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.max + 1, f32_iinfo.max |
| ) |
| self.assertFalse(vec_checker.simd_vec) |
| |
| vec_checker.simd_vec = True |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min * (1 + 1e-5) |
| ) |
| self.assertFalse(vec_checker.simd_vec) |
| |
| vec_checker.simd_vec = True |
| InterpreterShim(_graph, submodules).run( |
| V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max * (1 + 1e-5) |
| ) |
| self.assertFalse(vec_checker.simd_vec) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_cpp_vec_index_expr_checker(self): |
| _graph: torch.fx.Graph = torch.fx.Graph() |
| a: torch.fx.Node = _graph.create_node("placeholder", "ops") |
| b: torch.fx.Node = _graph.create_node("call_module", "get_index", args=()) |
| c: torch.fx.Node = _graph.create_node( |
| "call_method", |
| "index_expr", |
| args=( |
| a, |
| b, |
| torch.int64, |
| ), |
| ) |
| d: torch.fx.Node = _graph.create_node( |
| "call_method", |
| "ge", |
| args=( |
| a, |
| c, |
| c, |
| ), |
| ) |
| _graph.output(d) |
| |
| def get_index(): |
| return "" |
| |
| submodules = {"get_index": get_index} |
| graph_lowering = GraphLowering( |
| torch.fx.GraphModule(submodules, _graph), |
| shape_env=None, |
| num_static_inputs=0, |
| ) |
| with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler( |
| graph_lowering |
| ): |
| itervars = [sympy.Symbol("i"), sympy.Symbol("j"), sympy.Symbol("k")] |
| |
| tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float) |
| # The moset inner loop variable is used in the index_expr |
| with CppVecKernelChecker( |
| args=None, num_threads=1, tiling_factor=tiling_factor |
| ) as vec_checker: |
| |
| def get_index(): |
| return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1] |
| |
| ranges = [0, 100, 200] |
| vec_checker.itervars = itervars[:2] |
| vec_checker.ranges = ranges[:2] |
| submodules = {"get_index": get_index} |
| InterpreterShim(_graph, submodules).run(V.get_ops_handler()) |
| self.assertFalse(vec_checker.simd_vec) |
| |
| # Most inner loop variable irrevalant |
| with CppVecKernelChecker( |
| args=None, num_threads=1, tiling_factor=tiling_factor |
| ) as vec_checker: |
| |
| def get_index(): |
| return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1] |
| |
| ranges = [0, 100, 200] |
| vec_checker.itervars = itervars |
| vec_checker.ranges = ranges |
| submodules = {"get_index": get_index} |
| InterpreterShim(_graph, submodules).run(V.get_ops_handler()) |
| self.assertTrue(vec_checker.simd_vec) |
| |
| i32_iinfo = np.iinfo(np.int32) |
| _max_value = i32_iinfo.max + 1 |
| ranges = [_max_value, _max_value, _max_value] |
| # Most inner loop variable irrevalant but max value is greater than |
| # the max value of INT32 |
| with CppVecKernelChecker( |
| args=None, num_threads=1, tiling_factor=tiling_factor |
| ) as vec_checker: |
| |
| def get_index(): |
| return itervars[0] |
| |
| submodules = {"get_index": get_index} |
| vec_checker.itervars = itervars |
| vec_checker.ranges = ranges |
| InterpreterShim(_graph, submodules).run(V.get_ops_handler()) |
| self.assertFalse(vec_checker.simd_vec) |
| |
| # Most inner loop variable irrevalant but min value is greater than |
| # the min value of INT32 |
| with CppVecKernelChecker( |
| args=None, num_threads=1, tiling_factor=tiling_factor |
| ) as vec_checker: |
| |
| def get_index(): |
| return -itervars[0] - 2 |
| |
| submodules = {"get_index": get_index} |
| vec_checker.itervars = itervars |
| vec_checker.ranges = ranges |
| InterpreterShim(_graph, submodules).run(V.get_ops_handler()) |
| self.assertFalse(vec_checker.simd_vec) |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_maxpool2d_cpu_only(self): |
| for dtype in vec_dtypes: |
| input = torch.randn(10, 32, 20, 20, dtype=dtype).to( |
| memory_format=torch.channels_last |
| ) |
| maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
| |
| def func(x): |
| return maxpool(x) |
| |
| with patch.object(config.cpp, "simdlen", None): |
| torch._dynamo.reset() |
| metrics.reset() |
| graph = torch.compile(func, backend="inductor") |
| graph(input) |
| assert same(graph(input), func(input), equal_nan=True) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_maxpool2d_with_pre_loop_collapse_cpu_only(self): |
| x1 = torch.randn(2, 3, 20, 20).to(memory_format=torch.channels_last) |
| x2 = torch.randn(2, 3, 20, 20).to(memory_format=torch.channels_last) |
| maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) |
| |
| def func(x1, x2): |
| y = x1 + x2 |
| return maxpool(y) |
| |
| with patch.object(config.cpp, "simdlen", None): |
| torch._dynamo.reset() |
| metrics.reset() |
| graph = torch.compile(func, backend="inductor") |
| graph(x1, x2) |
| assert same(graph(x1, x2), func(x1, x2), equal_nan=True) |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_sign_cpu_only(self): |
| def fn(x): |
| return (torch.sign(x),) |
| |
| for dtype in vec_dtypes: |
| x = torch.randn((2, 9), dtype=dtype) |
| x[0, 0] = torch.nan |
| x[1, -1] = torch.nan |
| |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| traced = make_fx(fn)(x) |
| compiled = compile_fx_inner(traced, [x]) |
| assert same(fn(x)[0], compiled([x])[0], equal_nan=True) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_reduction_cpu_only(self): |
| def fn(x): |
| return (torch.argmax(x, -1),) |
| |
| for dtype in vec_dtypes: |
| x = torch.randn((10, 10), dtype=dtype) |
| |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| traced = make_fx(fn)(x) |
| compiled = compile_fx_inner(traced, [x]) |
| assert same(fn(x)[0], compiled([x])[0], equal_nan=True) |
| assert metrics.generated_cpp_vec_kernel_count == 0 |
| |
| # Currently, we enabled AVX2 and AVX512 for vectorization. If the platform is not |
| # supported, the vectorization will not work and skip this test case. For ARM or |
| # other platforms support, we just need to add the ISA info to the supported_vector_isa |
| # and include proper aten vectorization head file. |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| def test_vec_kernel_cpu_only(self): |
| def fn(x1, x2): |
| # Current, there are some limitations as follows. |
| # rsqrt: |
| # assert [both a fallback and a decomp for same kernel: aten.rsqrt.default] |
| # round: |
| # couldn't find symbolic meta function/decomposition |
| # fmod/logical_and/logic_or: |
| # vec kernel has not support to_type |
| x = torch.abs(x1) |
| x = torch.sin(x) |
| x = torch.neg(x) |
| x = torch.square(x) |
| x = torch.sigmoid(x) |
| x = torch.relu(x) |
| x = torch.cos(x) |
| x = torch.exp(x) |
| x = torch.sqrt(x) |
| x = torch.add(x, x1) |
| x = torch.sub(x, x2) |
| x = torch.mul(x, x1) |
| x = torch.div(x, x1) |
| x = torch.pow(x, 10) |
| x = torch.log(x) |
| x = torch.floor(x) |
| x = torch.ceil(x) |
| x = torch.trunc(x) |
| x = torch.lgamma(x) |
| x = torch.fmod(x, x2) |
| x = torch.sign(x) |
| res = x + x2 |
| return (res,) |
| |
| for dtype in vec_dtypes: |
| torch.manual_seed(0) |
| x1 = torch.randn((5, 20), dtype=dtype) |
| x2 = torch.randn((5, 20), dtype=dtype) |
| |
| tol = 1e-2 if dtype == torch.bfloat16 else 1e-4 |
| with config.patch({"cpp.simdlen": 1}): |
| torch._dynamo.reset() |
| metrics.reset() |
| traced = make_fx(fn)(x1, x2) |
| compiled = compile_fx_inner(traced, [x1, x2]) |
| assert same( |
| fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True, tol=tol |
| ) |
| assert metrics.generated_cpp_vec_kernel_count == 0 |
| |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| traced = make_fx(fn)(x1, x2) |
| compiled = compile_fx_inner(traced, [x1, x2]) |
| assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| with config.patch({"cpp.simdlen": None}): |
| torch._dynamo.reset() |
| metrics.reset() |
| x1 = torch.randn(10, 20).permute(1, 0) |
| x2 = torch.randn((20, 10)) |
| traced = make_fx(fn)(x1, x2) |
| compiled = compile_fx_inner(traced, [x1, x2]) |
| assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True) |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| torch._dynamo.reset() |
| metrics.reset() |
| x1 = torch.randn((10, 7)) |
| x2 = torch.randn((10, 7)) |
| traced = make_fx(fn)(x1, x2) |
| compiled = compile_fx_inner(traced, ([x1, x2])) |
| assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| @unittest.skipIf( |
| sys.platform != "linux", "cpp kernel profile only support linux now" |
| ) |
| @patch("torch.cuda.is_available", lambda: False) |
| @config.patch({"cpp.enable_kernel_profile": True}) |
| def test_cpp_kernel_profile(self): |
| from torch.profiler import profile |
| |
| @torch._dynamo.optimize("inductor", nopython=True) |
| def fn(a, b): |
| return a + b |
| |
| a = torch.rand((100,)) |
| b = torch.rand((100,)) |
| with profile() as prof: |
| fn(a, b) |
| |
| kernel_profile_events = [] |
| for e in prof.profiler.function_events: |
| if "kernel_cpp_0" in e.name: |
| kernel_profile_events.append(e.name) |
| assert len(kernel_profile_events) > 0 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| def test_channel_shuffle_cl_output(self): |
| """code and shape extracted from shufflenet_v2_x1_0""" |
| |
| def channel_shuffle(x, groups): |
| batchsize, num_channels, height, width = x.size() |
| channels_per_group = num_channels // groups |
| x = x.view(batchsize, groups, channels_per_group, height, width) |
| x = torch.transpose(x, 1, 2).contiguous() |
| x = x.view(batchsize, -1, height, width) |
| return x.contiguous(memory_format=torch.channels_last) |
| |
| for simdlen in (None, 256, 1): |
| with config.patch({"cpp.simdlen": simdlen}): |
| torch._dynamo.reset() |
| metrics.reset() |
| x = torch.randn(64, 58, 28, 28) |
| opt_fn = torch._dynamo.optimize("inductor")(channel_shuffle) |
| self.assertTrue(same(channel_shuffle(x, 2), opt_fn(x, 2))) |
| if simdlen != 1: |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| @slowTest |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| def test_transpose_with_norm(self): |
| """a sub-module from TIMM gmlp_s16_224""" |
| |
| class Model(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear( |
| in_features=256, out_features=1536, bias=True |
| ) |
| self.act = torch.nn.GELU() |
| self.norm = torch.nn.LayerNorm(768) |
| self.proj = torch.nn.Linear(196, 196) |
| self.fc = torch.nn.Linear(in_features=768, out_features=256, bias=True) |
| |
| def forward(self, x): |
| x = self.linear(x) |
| x = self.act(x) |
| u, v = x.chunk(2, dim=-1) |
| v = self.norm(v) |
| v = self.proj(v.transpose(-1, -2)) |
| y = u * v.transpose(-1, -2) |
| return self.fc(y) |
| |
| x = torch.randn(128, 196, 256) |
| for simdlen in (None, 256, 1): |
| with config.patch({"cpp.simdlen": simdlen}): |
| for eval_mode in [True, False]: |
| torch._dynamo.reset() |
| metrics.reset() |
| m = Model().eval() if eval_mode else Model() |
| opt_fn = torch._dynamo.optimize("inductor")(m) |
| same(m(x), opt_fn(x)) |
| if simdlen != 1: |
| assert metrics.generated_cpp_vec_kernel_count == 6 |
| |
| @unittest.skipIf( |
| not codecache.valid_vec_isa_list(), "Does not support vectorization" |
| ) |
| def test_transpose_copy(self): |
| def fn(a): |
| return a.t().contiguous() |
| |
| for simdlen in (None, 256, 1): |
| with config.patch({"cpp.simdlen": simdlen}): |
| for dtype in (torch.float, torch.bfloat16): |
| for shape in ( |
| (7, 7), |
| (8, 8), |
| (9, 9), |
| (16, 16), |
| (17, 17), |
| (32, 32), |
| (33, 33), |
| ): |
| torch._dynamo.reset() |
| metrics.reset() |
| x = torch.randn(shape, dtype=dtype) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| self.assertTrue(same(fn(x), opt_fn(x))) |
| if simdlen != 1: |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| def test_transpose_non_contiguous(self): |
| def fn(a): |
| # From part of timm HaloAttn: |
| # (https://github.com/rwightman/pytorch-image-models/blob/main/timm/layers/halo_attn.py#L97). |
| # Fixed https://github.com/pytorch/pytorch/issues/94269 accuracy issue. |
| as_strided = torch.ops.aten.as_strided.default( |
| a, [1, 384, 2, 20, 12], [153600, 1, 61440, 384, 7680] |
| ) |
| as_strided_1 = torch.ops.aten.as_strided.default( |
| as_strided, |
| [1, 384, 2, 2, 12, 12], |
| [153600, 1, 61440, 3072, 7680, 384], |
| ) |
| clone_1 = torch.ops.aten.clone.default( |
| as_strided_1, memory_format=torch.contiguous_format |
| ) |
| _unsafe_view_1 = torch.ops.aten._unsafe_view.default( |
| clone_1, [8, 48, 4, 144] |
| ) |
| permute_2 = torch.ops.aten.permute.default(_unsafe_view_1, [0, 2, 3, 1]) |
| split_with_sizes = torch.ops.aten.split_with_sizes.default( |
| permute_2, [16, 32], -1 |
| ) |
| getitem = split_with_sizes[0] |
| getitem_1 = split_with_sizes[1] |
| permute_3 = torch.ops.aten.permute.default(getitem, [0, 1, 3, 2]) |
| expand_1 = torch.ops.aten.expand.default(permute_3, [8, 4, 16, 144]) |
| clone_3 = torch.ops.aten.clone.default( |
| expand_1, memory_format=torch.contiguous_format |
| ) |
| return clone_3 |
| |
| metrics.reset() |
| x = torch.randn(1, 384, 20, 20).to(memory_format=torch.channels_last) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| self.assertTrue(same(fn(x), opt_fn(x))) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| def test_invalid_index_of_empty_tensor(self): |
| def fn(a): |
| b = a[[0]] |
| return b |
| |
| a = torch.tensor([]) |
| with self.assertRaises(RuntimeError): |
| torch.compile(fn)(a) |
| |
| def test_ir_node_str(self): |
| @torch.compile |
| def fn(x: torch.Tensor) -> torch.Tensor: |
| return x.sin(), torch.nn.Softmax(dim=1)(x.cos()) |
| |
| def run_node_alt(*args, **kwargs): |
| rv = run_node(*args, **kwargs) |
| strings.append(str(rv)) |
| return rv |
| |
| strings = [] |
| run_node = GraphLowering.run_node |
| with patch.object(GraphLowering, "run_node", run_node_alt): |
| fn(torch.randn([8, 128])) |
| self.assertGreater(len(strings), 3) |
| |
| def test_vertical_sum_cpu_only(self): |
| def fn1(a): |
| return a.sum(dim=0) |
| |
| def fn2(a): |
| return a.sum(dim=1) |
| |
| metrics.reset() |
| x = torch.randn(100, 100) |
| opt_fn1 = torch._dynamo.optimize("inductor")(fn1) |
| self.assertTrue(same(fn1(x), opt_fn1(x))) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| metrics.reset() |
| x = torch.randn(100, 100, 100) |
| opt_fn2 = torch._dynamo.optimize("inductor")(fn2) |
| self.assertTrue(same(fn2(x), opt_fn2(x))) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| def test_transpose_vertical_sum_cpu_only(self): |
| def fn(a, b): |
| c = a * b |
| return c.sum(dim=1) |
| |
| metrics.reset() |
| x = torch.randn(100, 50, 50) |
| y = torch.randn(100, 50, 50).transpose(1, 2) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| self.assertTrue(same(fn(x, y), opt_fn(x, y))) |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| def test_transpose_sum2d_cpu_only(self): |
| def fn(a, b): |
| c = a * b |
| return c.sum() |
| |
| metrics.reset() |
| x = torch.randn(50, 50) |
| y = torch.randn(50, 50).transpose(0, 1) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| self.assertTrue(same(fn(x, y), opt_fn(x, y))) |
| assert metrics.generated_cpp_vec_kernel_count == 2 |
| |
| def test_transpose_sum_outer(self): |
| # https://github.com/pytorch/pytorch/issues/98573 |
| def fn(a): |
| return a.transpose(2, 3).sum(dim=1).contiguous() |
| |
| metrics.reset() |
| x = torch.randn(10, 50, 50, 50) |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| self.assertTrue(same(fn(x), opt_fn(x))) |
| assert metrics.generated_cpp_vec_kernel_count == 1 |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| from torch.testing._internal.inductor_utils import HAS_CPU |
| |
| if HAS_CPU and not IS_MACOS: |
| run_tests(needs="filelock") |