|  | # Owner(s): ["module: cuda"] | 
|  |  | 
|  | import torch | 
|  | from torch.cuda.jiterator import _create_jit_fn as create_jit_fn | 
|  | from torch.cuda.jiterator import _create_multi_output_jit_fn as create_multi_output_jit_fn | 
|  | import sys | 
|  | from itertools import product | 
|  | from torch.testing._internal.common_utils import TestCase, parametrize, run_tests, TEST_CUDA | 
|  | from torch.testing._internal.common_dtype import all_types_and_complex_and | 
|  | from torch.testing._internal.common_device_type import ( | 
|  | skipCUDAIfRocm, skipCUDAIf, instantiate_device_type_tests, dtypes, toleranceOverride, tol) | 
|  | from torch.testing._internal.common_cuda import _get_torch_cuda_version | 
|  |  | 
|  | if not TEST_CUDA: | 
|  | print('CUDA not available, skipping tests', file=sys.stderr) | 
|  | TestCase = object  # noqa: F811 | 
|  |  | 
|  |  | 
|  | code_string = "template <typename T> T my_fused_kernel(T x, T y, T alpha, T beta) { return alpha * x + beta * y; }" | 
|  | jitted_fn = create_jit_fn(code_string, alpha=1, beta=1) | 
|  |  | 
|  | def ref_fn(x, y, alpha=1, beta=1): | 
|  | return alpha * x + beta * y | 
|  |  | 
|  | class TestPythonJiterator(TestCase): | 
|  | @parametrize("shape_strides", [ | 
|  | (([3, 3], [3, 1]), ([3, 3], [3, 1])),  # contiguous | 
|  | ]) | 
|  | @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16), | 
|  | all_types_and_complex_and(torch.half, torch.bfloat16))) | 
|  | def test_all_dtype_contiguous(self, device, dtypes, shape_strides): | 
|  | a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0]) | 
|  | b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1]) | 
|  |  | 
|  | a = a_buffer.as_strided(*shape_strides[0]) | 
|  | b = b_buffer.as_strided(*shape_strides[1]) | 
|  |  | 
|  | expected = ref_fn(a, b) | 
|  | result = jitted_fn(a, b) | 
|  |  | 
|  | self.assertEqual(expected, result) | 
|  |  | 
|  | @skipCUDAIfRocm | 
|  | # See https://github.com/pytorch/pytorch/pull/76394#issuecomment-1118018287 for details | 
|  | @skipCUDAIf(_get_torch_cuda_version() < (11, 6), "On cuda 11.3, nvrtcCompileProgram is taking too long to " | 
|  | "compile jiterator generated kernels for non-contiguous input that requires dynamic-casting.") | 
|  | @parametrize("shape_strides", [ | 
|  | (([3, 3], [1, 3]), ([3, 1], [1, 3])),  # non-contiguous | 
|  | ]) | 
|  | @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16), | 
|  | all_types_and_complex_and(torch.half, torch.bfloat16))) | 
|  | def test_all_dtype_noncontiguous(self, device, dtypes, shape_strides): | 
|  | a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0]) | 
|  | b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1]) | 
|  |  | 
|  | a = a_buffer.as_strided(*shape_strides[0]) | 
|  | b = b_buffer.as_strided(*shape_strides[1]) | 
|  |  | 
|  | expected = ref_fn(a, b) | 
|  | result = jitted_fn(a, b) | 
|  |  | 
|  | self.assertEqual(expected, result) | 
|  |  | 
|  | @dtypes(torch.float, torch.double, torch.float16, torch.bfloat16) | 
|  | @parametrize("alpha", [-1, 2.0, None]) | 
|  | @parametrize("beta", [3, -4.2, None]) | 
|  | @toleranceOverride({torch.float16 : tol(atol=1e-2, rtol=1e-3)}) | 
|  | def test_extra_args(self, device, dtype, alpha, beta): | 
|  | a = torch.rand(3, device=device).mul(10).type(dtype) | 
|  | b = torch.rand(3, device=device).mul(10).type(dtype) | 
|  |  | 
|  | extra_args = {} | 
|  | if alpha is not None: | 
|  | extra_args["alpha"] = alpha | 
|  | if beta is not None: | 
|  | extra_args["beta"] = beta | 
|  |  | 
|  | expected = ref_fn(a, b, **extra_args) | 
|  | result = jitted_fn(a, b, **extra_args) | 
|  |  | 
|  | self.assertEqual(expected, result) | 
|  |  | 
|  | @parametrize("is_train", [True, False]) | 
|  | def test_bool_extra_args(self, device, is_train): | 
|  | code_string = "template <typename T> T conditional(T x, T mask, bool is_train) { return is_train ? x * mask : x; }" | 
|  | jitted_fn = create_jit_fn(code_string, is_train=False) | 
|  |  | 
|  | def ref_fn(x, mask, is_train): | 
|  | return x * mask if is_train else x | 
|  |  | 
|  | a = torch.rand(3, device=device) | 
|  | b = torch.rand(3, device=device) | 
|  |  | 
|  | expected = ref_fn(a, b, is_train=is_train) | 
|  | result = jitted_fn(a, b, is_train=is_train) | 
|  | self.assertEqual(expected, result) | 
|  |  | 
|  | def test_multiple_functors(self, device): | 
|  | code_string = ''' | 
|  | template <typename T> T fn(T x, T mask) { return x * mask; } | 
|  | template <typename T> T main_fn(T x, T mask, T y) { return fn(x, mask) + y; } | 
|  | ''' | 
|  | jitted_fn = create_jit_fn(code_string) | 
|  |  | 
|  | def ref_fn(x, mask, y): | 
|  | return x * mask + y | 
|  |  | 
|  | a = torch.rand(3, device=device) | 
|  | b = torch.rand(3, device=device) | 
|  | c = torch.rand(3, device=device) | 
|  |  | 
|  | expected = ref_fn(a, b, c) | 
|  | result = jitted_fn(a, b, c) | 
|  | self.assertEqual(expected, result) | 
|  |  | 
|  | @parametrize("num_inputs", [1, 5, 8]) | 
|  | def test_various_num_inputs(self, num_inputs): | 
|  | inputs = [] | 
|  | for i in range(num_inputs): | 
|  | inputs.append(torch.rand(3, device='cuda').mul(10)) | 
|  |  | 
|  | input_string = ",".join([f"T i{i}" for i in range(num_inputs)]) | 
|  | function_body = "+".join([f"i{i}" for i in range(num_inputs)]) | 
|  | code_string = f"template <typename T> T my_kernel({input_string}) {{ return {function_body}; }}" | 
|  | jitted_fn = create_jit_fn(code_string) | 
|  |  | 
|  | def ref_fn(*inputs): | 
|  | return torch.sum(torch.stack(inputs), dim=0) | 
|  |  | 
|  | expected = ref_fn(*inputs) | 
|  | result = jitted_fn(*inputs) | 
|  |  | 
|  | self.assertEqual(expected, result) | 
|  |  | 
|  | @parametrize("num_outputs", [1, 4, 8]) | 
|  | def test_various_num_outputs(self, num_outputs): | 
|  | input = torch.rand(3, device='cuda') | 
|  |  | 
|  | output_string = ", ".join([f"T& out{i}" for i in range(num_outputs)]) | 
|  | function_body = "" | 
|  | for i in range(num_outputs): | 
|  | function_body += f"out{i} = input + {i};\n" | 
|  | # NB: return type must be void, otherwise ROCm silently fails | 
|  | code_string = f"template <typename T> void my_kernel(T input, {output_string}) {{ {function_body} }}" | 
|  |  | 
|  | jitted_fn = create_multi_output_jit_fn(code_string, num_outputs) | 
|  |  | 
|  | def ref_fn(input): | 
|  | outputs = [] | 
|  | for i in range(num_outputs): | 
|  | outputs.append(input + i) | 
|  |  | 
|  | if num_outputs == 1: | 
|  | return outputs[0] | 
|  | return tuple(outputs) | 
|  |  | 
|  | expected = ref_fn(input) | 
|  | result = jitted_fn(input) | 
|  |  | 
|  | for i in range(num_outputs): | 
|  | self.assertEqual(expected[i], result[i]) | 
|  |  | 
|  | @parametrize("code_string", [ | 
|  | "template <typename T> T my _kernel(T x) { return x; }", | 
|  | "template <typename T> Tmy_kernel(T x) { return x; }", | 
|  | ]) | 
|  | def test_invalid_function_name(self, code_string): | 
|  | with self.assertRaises(Exception): | 
|  | jitted_fn = create_jit_fn(code_string) | 
|  |  | 
|  |  | 
|  | instantiate_device_type_tests(TestPythonJiterator, globals(), only_for="cuda") | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | run_tests() |