| # Owner(s): ["oncall: pt2"] |
| import functools |
| import itertools |
| import os |
| import sys |
| import textwrap |
| import unittest |
| |
| import torch |
| import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools |
| from torch._inductor import config |
| from torch._inductor.codecache import HalideCodeCache |
| from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta |
| from torch._inductor.test_case import run_tests, TestCase |
| from torch._inductor.utils import parallel_num_threads |
| from torch.testing._internal.common_utils import IS_CI, IS_MACOS, IS_WINDOWS |
| from torch.testing._internal.inductor_utils import HAS_CPU |
| from torch.utils._triton import has_triton |
| |
| |
| if IS_WINDOWS and IS_CI: |
| sys.stderr.write( |
| "Windows CI does not have necessary dependencies for test_torchinductor_dynamic_shapes yet\n" |
| ) |
| if __name__ == "__main__": |
| sys.exit(0) |
| raise unittest.SkipTest("requires sympy/functorch/filelock") |
| |
| try: |
| import halide |
| |
| HAS_HALIDE = halide is not None |
| except ImportError: |
| HAS_HALIDE = False |
| |
| |
| try: |
| from . import test_torchinductor |
| except ImportError: |
| import test_torchinductor |
| |
| |
| make_halide = config.patch( |
| { |
| "halide.scan_kernels": True, |
| "cpu_backend": "halide", |
| "cuda_backend": "halide", |
| } |
| ) |
| |
| |
| @unittest.skipUnless(HAS_HALIDE, "requires halide") |
| class HalideTests(TestCase): |
| def test_codecache(self): |
| fn = HalideCodeCache.generate_halide( |
| HalideMeta( |
| argtypes=[ |
| HalideInputSpec( |
| ctype="float*", |
| name="in_ptr0", |
| shape=["1024L"], |
| stride=["1L"], |
| offset="0", |
| ), |
| HalideInputSpec( |
| ctype="float*", |
| name="in_ptr1", |
| shape=["1024L"], |
| stride=["1L"], |
| offset="0", |
| ), |
| HalideInputSpec( |
| ctype="float*", |
| name="out_ptr0", |
| shape=["1024L"], |
| stride=["1L"], |
| offset="0", |
| ), |
| ], |
| target="host-no_runtime", |
| scheduler="Mullapudi2016", |
| scheduler_flags={ |
| "parallelism": parallel_num_threads(), |
| }, |
| ), |
| textwrap.dedent( |
| """ |
| import halide as hl |
| |
| @hl.generator(name="kernel") |
| class Kernel: |
| in_ptr0 = hl.InputBuffer(hl.Float(32), 1) |
| in_ptr1 = hl.InputBuffer(hl.Float(32), 1) |
| out_ptr0 = hl.OutputBuffer(hl.Float(32), 1) |
| |
| def generate(g): |
| in_ptr0 = g.in_ptr0 |
| in_ptr1 = g.in_ptr1 |
| out_ptr0 = g.out_ptr0 |
| xindex = hl.Var('xindex') |
| x0 = xindex |
| tmp0 = hl.Func() |
| tmp0[xindex] = in_ptr0[x0] |
| tmp1 = hl.Func() |
| tmp1[xindex] = in_ptr1[x0] |
| tmp2 = hl.Func() |
| tmp2[xindex] = tmp0[xindex] + tmp1[xindex] |
| out_ptr0[x0] = tmp2[xindex] |
| |
| assert g.using_autoscheduler() |
| in_ptr0.set_estimates([hl.Range(1024, 1024)]) |
| in_ptr1.set_estimates([hl.Range(1024, 1024)]) |
| out_ptr0.set_estimates([hl.Range(1024, 1024)]) |
| |
| __name__ == '__main__' and hl.main() |
| """ |
| ), |
| ) |
| a = torch.randn(1024) |
| b = torch.randn(1024) |
| c = torch.randn(1024) |
| fn(a, b, c) |
| self.assertEqual(c, a + b) |
| |
| def test_manual_schedule(self): |
| fn = HalideCodeCache.generate_halide( |
| HalideMeta( |
| argtypes=[ |
| HalideInputSpec( |
| ctype="float*", |
| name="in_ptr0", |
| shape=["1024L"], |
| stride=["1L"], |
| offset="0", |
| ), |
| HalideInputSpec( |
| ctype="float*", |
| name="in_ptr1", |
| shape=["1024L"], |
| stride=["1L"], |
| offset="0", |
| ), |
| HalideInputSpec( |
| ctype="float*", |
| name="out_ptr0", |
| shape=["1024L"], |
| stride=["1L"], |
| offset="0", |
| ), |
| ], |
| target="host-no_runtime", |
| scheduler=None, |
| ), |
| textwrap.dedent( |
| """ |
| import halide as hl |
| |
| @hl.generator(name="kernel") |
| class Kernel: |
| in_ptr0 = hl.InputBuffer(hl.Float(32), 1) |
| in_ptr1 = hl.InputBuffer(hl.Float(32), 1) |
| out_ptr0 = hl.OutputBuffer(hl.Float(32), 1) |
| |
| def generate(g): |
| in_ptr0 = g.in_ptr0 |
| in_ptr1 = g.in_ptr1 |
| out_ptr0 = g.out_ptr0 |
| xindex = hl.Var('xindex') |
| x0 = xindex |
| tmp0 = hl.Func() |
| tmp0[xindex] = in_ptr0[x0] |
| tmp1 = hl.Func() |
| tmp1[xindex] = in_ptr1[x0] |
| tmp2 = hl.Func() |
| tmp2[xindex] = tmp0[xindex] + tmp1[xindex] |
| out_ptr0[x0] = tmp2[xindex] |
| |
| assert not g.using_autoscheduler() |
| i = hl.Var() |
| j = hl.Var() |
| out_ptr0.compute_root() |
| out_ptr0.split(xindex, i, j, 32) |
| out_ptr0.parallel(i) |
| out_ptr0.vectorize(j) |
| tmp2.compute_at(out_ptr0, i) |
| tmp2.store_at(out_ptr0, i) |
| tmp1.compute_inline() |
| |
| __name__ == '__main__' and hl.main() |
| """ |
| ), |
| ) |
| a = torch.randn(1024) |
| b = torch.randn(1024) |
| c = torch.randn(1024) |
| fn(a, b, c) |
| self.assertEqual(c, a + b) |
| |
| @unittest.skipUnless(has_triton(), "requires triton") |
| def test_random_consistency(self): |
| seed = 1234 |
| shape = (3, 3) |
| dtype = torch.float32 |
| |
| for (rand_fn,) in itertools.product( |
| ( |
| functools.partial(torch.rand, shape, dtype=dtype, device="cuda"), |
| functools.partial(torch.randn, shape, dtype=dtype, device="cuda"), |
| functools.partial( |
| torch.randint, |
| -1000, |
| 1000, |
| size=shape, |
| dtype=torch.int64, |
| device="cuda", |
| ), |
| ) |
| ): |
| |
| @torch.compile(backend="inductor", options={"cuda_backend": "halide"}) |
| def get_rand_halide(): |
| return rand_fn() |
| |
| @torch.compile(backend="inductor", options={"cuda_backend": "triton"}) |
| def get_rand_triton(): |
| return rand_fn() |
| |
| torch.manual_seed(seed) |
| halide_output = get_rand_halide() |
| torch.manual_seed(seed) |
| triton_output = get_rand_triton() |
| |
| self.assertEqual(halide_output, triton_output) |
| |
| |
| if test_torchinductor.HAS_CPU and HAS_HALIDE: |
| SweepInputsCpuHalideTest = make_halide(test_torchinductor.SweepInputsCpuTest) |
| CpuHalideTests = make_halide(test_torchinductor.CpuTests) |
| |
| if ( |
| test_torchinductor.HAS_GPU |
| and HAS_HALIDE |
| and os.environ.get("TEST_HALIDE_GPU") == "1" |
| ): |
| SweepInputsGPUHalideTest = make_halide(test_torchinductor.SweepInputsGPUTest) |
| GPUHalideTests = make_halide(test_torchinductor.GPUTests) |
| |
| if __name__ == "__main__": |
| if HAS_CPU and not IS_MACOS and HAS_HALIDE: |
| run_tests(needs="filelock") |