|  | # Owner(s): ["module: inductor"] | 
|  |  | 
|  | import ctypes | 
|  | import unittest | 
|  |  | 
|  | import torch | 
|  |  | 
|  | from torch._inductor import config | 
|  | from torch._inductor.codecache import AsyncCompile, CUDACodeCache | 
|  | from torch._inductor.codegen.cuda.cuda_env import nvcc_exist | 
|  | from torch._inductor.exc import CUDACompileError | 
|  | from torch._inductor.test_case import TestCase as InductorTestCase | 
|  |  | 
|  | _SOURCE_CODE = r""" | 
|  |  | 
|  | #include <stdio.h> | 
|  |  | 
|  | __global__ | 
|  | void saxpy_device(int n, float a, float *x, float *y) | 
|  | { | 
|  | int i = blockIdx.x*blockDim.x + threadIdx.x; | 
|  | if (i < n) y[i] = a*x[i] + y[i]; | 
|  | } | 
|  |  | 
|  | extern "C" { | 
|  |  | 
|  | __attribute__((__visibility__("default"))) | 
|  | int saxpy(int n, float a, float *x, float *y) { | 
|  | // Perform SAXPY | 
|  | saxpy_device<<<(n+255)/256, 256>>>(n, a, x, y); | 
|  | return 0; | 
|  | } | 
|  |  | 
|  | } | 
|  | """ | 
|  |  | 
|  |  | 
|  | @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUDA_HOME setup") | 
|  | class TestCUDACodeCache(InductorTestCase): | 
|  | def test_cuda_load(self): | 
|  | # Test both .o and .so compilation. | 
|  | object_file_path, object_hash_key, source_code_path0 = CUDACodeCache.compile( | 
|  | _SOURCE_CODE, "o" | 
|  | ) | 
|  | dll_wrapper, so_hash_key, source_code_path1 = CUDACodeCache.load( | 
|  | _SOURCE_CODE, "so" | 
|  | ) | 
|  | self.assertNotEqual(source_code_path0, source_code_path1) | 
|  | self.assertNotEqual(object_hash_key, so_hash_key) | 
|  |  | 
|  | # Test load and call functions in .so. | 
|  | x = torch.rand(10).float().cuda() | 
|  | y = torch.rand(10).float().cuda() | 
|  | a = 5.0 | 
|  | expected_y = a * x + y | 
|  | res = dll_wrapper.saxpy( | 
|  | ctypes.c_int(10), | 
|  | ctypes.c_float(a), | 
|  | ctypes.c_void_p(x.data_ptr()), | 
|  | ctypes.c_void_p(y.data_ptr()), | 
|  | ) | 
|  | torch.testing.assert_close(y, expected_y) | 
|  |  | 
|  | def test_compilation_error(self): | 
|  | error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1) | 
|  | with self.assertRaises(CUDACompileError): | 
|  | CUDACodeCache.compile(error_source_code, "o") | 
|  |  | 
|  | def test_async_compile(self): | 
|  | async_compile = AsyncCompile() | 
|  | compiled_res = async_compile.cuda(_SOURCE_CODE, "so") | 
|  | async_compile.wait(globals()) | 
|  |  | 
|  | # Test load and call functions in .so. | 
|  | x = torch.rand(5).float().cuda() | 
|  | y = torch.rand(5).float().cuda() | 
|  | a = 2.0 | 
|  | expected_y = a * x + y | 
|  | res = compiled_res.result().saxpy( | 
|  | ctypes.c_int(5), | 
|  | ctypes.c_float(a), | 
|  | ctypes.c_void_p(x.data_ptr()), | 
|  | ctypes.c_void_p(y.data_ptr()), | 
|  | ) | 
|  | torch.testing.assert_close(y, expected_y) | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | from torch._inductor.test_case import run_tests | 
|  |  | 
|  | if nvcc_exist(): | 
|  | run_tests("cuda") |