Jon Janzen | 0384780 | 2022-06-03 14:19:26 -0700 | [diff] [blame] | 1 | load("@fbcode_macros//build_defs:native_rules.bzl", "buck_genrule") |
| 2 | load( |
| 3 | "//caffe2/caffe2:defs_hip.bzl", |
| 4 | "get_caffe2_hip_headers", |
| 5 | "get_caffe2_hip_srcs", |
| 6 | ) |
| 7 | load(":ufunc_defs.bzl", "aten_ufunc_names") |
| 8 | |
| 9 | ATEN_CUDA_H_PATTERN = [ |
| 10 | "aten/src/ATen/cuda/*.h", |
| 11 | "aten/src/ATen/cuda/detail/*.h", |
| 12 | "aten/src/ATen/cuda/nvrtc_stub/*.h", |
| 13 | "aten/src/ATen/cuda/*.cuh", |
| 14 | "aten/src/ATen/cuda/detail/*.cuh", |
| 15 | ] |
| 16 | |
| 17 | ATEN_CUDA_CPP_PATTERN = [ |
| 18 | "aten/src/ATen/cuda/*.cpp", |
| 19 | "aten/src/ATen/cuda/detail/*.cpp", |
| 20 | "aten/src/ATen/cuda/nvrtc_stub/*.cpp", |
| 21 | ] |
| 22 | |
| 23 | ATEN_CUDA_CU_PATTERN = [ |
| 24 | "aten/src/ATen/cuda/*.cu", |
| 25 | "aten/src/ATen/cuda/detail/*.cu", |
| 26 | ] |
| 27 | |
| 28 | ATEN_CUDNN_H_PATTERN = [ |
| 29 | "aten/src/ATen/cudnn/*.h", |
| 30 | "aten/src/ATen/cudnn/*.cuh", |
| 31 | ] |
| 32 | |
| 33 | ATEN_CUDNN_CPP_PATTERN = ["aten/src/ATen/cudnn/*.cpp"] |
| 34 | |
| 35 | ATEN_MIOPEN_H_PATTERN = [ |
| 36 | "aten/src/ATen/miopen/*.h", |
| 37 | "aten/src/ATen/miopen/*.cuh", |
| 38 | ] |
| 39 | |
| 40 | ATEN_MIOPEN_CPP_PATTERN = ["aten/src/ATen/miopen/*.cpp"] |
| 41 | |
| 42 | ATEN_NATIVE_CUDNN_CPP_PATTERN = ["aten/src/ATen/native/cudnn/*.cpp"] |
| 43 | |
| 44 | ATEN_NATIVE_MIOPEN_CPP_PATTERN = ["aten/src/ATen/native/miopen/*.cpp"] |
| 45 | |
| 46 | ATEN_NATIVE_CUDA_CU_PATTERN = [ |
| 47 | "aten/src/ATen/native/cuda/*.cu", |
| 48 | "aten/src/ATen/native/nested/cuda/*.cu", |
| 49 | "aten/src/ATen/native/quantized/cuda/*.cu", |
| 50 | "aten/src/ATen/native/sparse/cuda/*.cu", |
| 51 | "aten/src/ATen/native/transformers/**/*.cu", |
| 52 | ] |
| 53 | |
| 54 | ATEN_NATIVE_CUDA_CPP_PATTERN = [ |
| 55 | "aten/src/ATen/native/cuda/*.cpp", |
| 56 | "aten/src/ATen/native/cuda/linalg/*.cpp", |
| 57 | "aten/src/ATen/native/nested/cuda/*.cpp", |
| 58 | "aten/src/ATen/native/sparse/cuda/*.cpp", |
| 59 | "aten/src/ATen/native/transformers/cuda/*.cpp", |
| 60 | ] |
| 61 | |
| 62 | ATEN_NATIVE_CUDA_H_PATTERN = [ |
| 63 | "aten/src/ATen/native/cudnn/**/*.h", |
| 64 | "aten/src/ATen/native/cuda/**/*.h", |
| 65 | "aten/src/ATen/native/cuda/**/*.cuh", |
| 66 | "aten/src/ATen/native/sparse/cuda/*.h", |
| 67 | "aten/src/ATen/native/sparse/cuda/*.cuh", |
| 68 | "aten/src/ATen/native/quantized/cuda/*.h", |
| 69 | "aten/src/ATen/native/transformers/cuda/*.h", |
| 70 | "aten/src/ATen/native/transformers/**/*.cuh", |
| 71 | ] |
| 72 | |
| 73 | # T66678203: Clang CUDA rollout |
| 74 | ATEN_CUDA_CLANG_CU_PATTERN = [ |
| 75 | "aten/src/ATen/native/cuda/DistributionBernoulli.cu", |
| 76 | ] |
| 77 | |
| 78 | ### Cuda Files |
| 79 | def get_aten_cuda_headers(): |
| 80 | ATEN_CUDA_H = native.glob(ATEN_CUDA_H_PATTERN) |
| 81 | ATEN_NATIVE_CUDA_H = native.glob(ATEN_NATIVE_CUDA_H_PATTERN) |
| 82 | ATEN_CUDNN_H = native.glob(ATEN_CUDNN_H_PATTERN) |
| 83 | return ATEN_CUDA_H + ATEN_NATIVE_CUDA_H + ATEN_CUDNN_H |
| 84 | |
| 85 | def get_aten_cuda_srcs(): |
| 86 | ATEN_CUDA_CU = native.glob(ATEN_CUDA_CU_PATTERN) |
| 87 | ATEN_NATIVE_CUDA_CU = native.glob( |
| 88 | ATEN_NATIVE_CUDA_CU_PATTERN, |
| 89 | exclude = ATEN_CUDA_CLANG_CU_PATTERN, |
| 90 | ) |
| 91 | return ATEN_CUDA_CU + ATEN_NATIVE_CUDA_CU |
| 92 | |
| 93 | def get_aten_cuda_clang_srcs(): |
| 94 | return native.glob(ATEN_CUDA_CLANG_CU_PATTERN) |
| 95 | |
| 96 | # CPU+CUDA file |
| 97 | # Note that these sources and headers include the CPU lists too |
| 98 | def get_all_cuda_srcs(): |
| 99 | ATEN_NATIVE_CUDNN_CPP = native.glob(ATEN_NATIVE_CUDNN_CPP_PATTERN) |
| 100 | ATEN_CUDNN_CPP = native.glob(ATEN_CUDNN_CPP_PATTERN) |
| 101 | ATEN_NATIVE_MIOPEN_CPP = native.glob(ATEN_NATIVE_MIOPEN_CPP_PATTERN) |
| 102 | ATEN_CUDA_CPP = native.glob(ATEN_CUDA_CPP_PATTERN) |
| 103 | ATEN_NATIVE_CUDA_CPP = native.glob(ATEN_NATIVE_CUDA_CPP_PATTERN) |
| 104 | |
| 105 | return ATEN_NATIVE_CUDNN_CPP + ATEN_CUDNN_CPP + ATEN_NATIVE_MIOPEN_CPP + ATEN_CUDA_CPP + ATEN_NATIVE_CUDA_CPP + get_aten_cuda_srcs() |
| 106 | |
| 107 | ### HIP files |
| 108 | # Files that must be hipified |
| 109 | def get_aten_hip_srcs(): |
| 110 | ## CU -> HIP files |
| 111 | ATEN_CUDA_CU = native.glob(ATEN_CUDA_CU_PATTERN) |
| 112 | |
| 113 | # HIP does not use clang for ATEN_CUDA_CLANG_CU_PATTERN |
| 114 | ATEN_NATIVE_CUDA_CU = native.glob(ATEN_NATIVE_CUDA_CU_PATTERN) |
| 115 | |
| 116 | ## CPU files |
| 117 | ATEN_NATIVE_CUDNN_CPP = native.glob(ATEN_NATIVE_CUDNN_CPP_PATTERN) |
| 118 | ATEN_CUDNN_CPP = native.glob(ATEN_CUDNN_CPP_PATTERN) |
| 119 | ATEN_CUDA_CPP = native.glob(ATEN_CUDA_CPP_PATTERN) |
| 120 | ATEN_NATIVE_CUDA_CPP = native.glob(ATEN_NATIVE_CUDA_CPP_PATTERN) |
| 121 | |
| 122 | # Get hipified file names (before, after) |
| 123 | srcs = ATEN_CUDA_CU + ATEN_NATIVE_CUDA_CU + ATEN_NATIVE_CUDNN_CPP + ATEN_CUDNN_CPP + ATEN_CUDA_CPP + ATEN_NATIVE_CUDA_CPP |
| 124 | ret = get_caffe2_hip_srcs(include_patterns = [], include_files = srcs, project_dir = "") |
| 125 | return (ret[0], [f.replace("aten/src/", "") for f in ret[1]]) |
| 126 | |
| 127 | def get_aten_hip_headers(): |
| 128 | ATEN_CUDA_H = native.glob(ATEN_CUDA_H_PATTERN) |
| 129 | ATEN_NATIVE_CUDA_H = native.glob(ATEN_NATIVE_CUDA_H_PATTERN) |
| 130 | ATEN_CUDNN_H = [] # native.glob(ATEN_CUDNN_H_PATTERN) |
| 131 | |
| 132 | # Get hipified file names (before, after) |
| 133 | srcs = ATEN_CUDA_H + ATEN_NATIVE_CUDA_H + ATEN_CUDNN_H |
| 134 | ret = get_caffe2_hip_headers(include_patterns = [], include_files = ATEN_CUDA_H + ATEN_NATIVE_CUDA_H + ATEN_CUDNN_H, project_dir = "") |
| 135 | return ret[0], [f.replace("aten/src/", "") for f in ret[1]] |
| 136 | |
| 137 | # Native HIP-aware files |
| 138 | def get_aten_hip_native_srcs(): |
| 139 | HIP_IMPL_CPP = native.glob(["aten/src/ATen/hip/impl/*.cpp"]) |
| 140 | ATEN_MIOPEN_CPP = native.glob(ATEN_MIOPEN_CPP_PATTERN) |
| 141 | ATEN_NATIVE_MIOPEN_CPP = native.glob(ATEN_NATIVE_MIOPEN_CPP_PATTERN) |
| 142 | return HIP_IMPL_CPP + ATEN_MIOPEN_CPP + ATEN_NATIVE_MIOPEN_CPP |
| 143 | |
| 144 | def get_aten_hip_native_headers(): |
| 145 | HIP_IMPL_H = native.glob(["aten/src/ATen/hip/impl/*.h"]) |
| 146 | ATEN_MIOPEN_H = native.glob(ATEN_MIOPEN_H_PATTERN) |
| 147 | return HIP_IMPL_H + ATEN_MIOPEN_H |
| 148 | |
| 149 | def get_aten_hip_ufunc_generated_cuda_sources(gencode_pattern = "{}"): |
| 150 | # Contents of these CUDA files do not need to be hipified at this point, |
| 151 | # but they must be renamed from ".cu" to ".hip" because, unlike OSS, a compiler |
| 152 | # is selected based on a file extension. |
| 153 | |
| 154 | renamed_rules = [] |
| 155 | for n in aten_ufunc_names: |
| 156 | cuda_name = "UfuncCUDA_{}.cu".format(n) |
| 157 | hip_name = "UfuncCUDA_{}.hip".format(n) |
| 158 | buck_genrule( |
| 159 | name = "aten_ufunc_hip_renamed_{}".format(n), |
| 160 | srcs = [gencode_pattern.format(cuda_name)], |
| 161 | bash = 'cp "$SRCDIR/{}" "$OUT"'.format(cuda_name), |
| 162 | out = hip_name, |
| 163 | default_outs = [], |
| 164 | ) |
| 165 | renamed_rules.append(":aten_ufunc_hip_renamed_{}".format(n)) |
| 166 | return renamed_rules |