blob: 3d6cae8830893a9b6d412d5520f9a6311b0926c6 [file] [log] [blame]
Jon Janzen03847802022-06-03 14:19:26 -07001load("@fbcode_macros//build_defs:native_rules.bzl", "buck_genrule")
2load(
3 "//caffe2/caffe2:defs_hip.bzl",
4 "get_caffe2_hip_headers",
5 "get_caffe2_hip_srcs",
6)
7load(":ufunc_defs.bzl", "aten_ufunc_names")
8
9ATEN_CUDA_H_PATTERN = [
10 "aten/src/ATen/cuda/*.h",
11 "aten/src/ATen/cuda/detail/*.h",
12 "aten/src/ATen/cuda/nvrtc_stub/*.h",
13 "aten/src/ATen/cuda/*.cuh",
14 "aten/src/ATen/cuda/detail/*.cuh",
15]
16
17ATEN_CUDA_CPP_PATTERN = [
18 "aten/src/ATen/cuda/*.cpp",
19 "aten/src/ATen/cuda/detail/*.cpp",
20 "aten/src/ATen/cuda/nvrtc_stub/*.cpp",
21]
22
23ATEN_CUDA_CU_PATTERN = [
24 "aten/src/ATen/cuda/*.cu",
25 "aten/src/ATen/cuda/detail/*.cu",
26]
27
28ATEN_CUDNN_H_PATTERN = [
29 "aten/src/ATen/cudnn/*.h",
30 "aten/src/ATen/cudnn/*.cuh",
31]
32
33ATEN_CUDNN_CPP_PATTERN = ["aten/src/ATen/cudnn/*.cpp"]
34
35ATEN_MIOPEN_H_PATTERN = [
36 "aten/src/ATen/miopen/*.h",
37 "aten/src/ATen/miopen/*.cuh",
38]
39
40ATEN_MIOPEN_CPP_PATTERN = ["aten/src/ATen/miopen/*.cpp"]
41
42ATEN_NATIVE_CUDNN_CPP_PATTERN = ["aten/src/ATen/native/cudnn/*.cpp"]
43
44ATEN_NATIVE_MIOPEN_CPP_PATTERN = ["aten/src/ATen/native/miopen/*.cpp"]
45
46ATEN_NATIVE_CUDA_CU_PATTERN = [
47 "aten/src/ATen/native/cuda/*.cu",
48 "aten/src/ATen/native/nested/cuda/*.cu",
49 "aten/src/ATen/native/quantized/cuda/*.cu",
50 "aten/src/ATen/native/sparse/cuda/*.cu",
51 "aten/src/ATen/native/transformers/**/*.cu",
52]
53
54ATEN_NATIVE_CUDA_CPP_PATTERN = [
55 "aten/src/ATen/native/cuda/*.cpp",
56 "aten/src/ATen/native/cuda/linalg/*.cpp",
57 "aten/src/ATen/native/nested/cuda/*.cpp",
58 "aten/src/ATen/native/sparse/cuda/*.cpp",
59 "aten/src/ATen/native/transformers/cuda/*.cpp",
60]
61
62ATEN_NATIVE_CUDA_H_PATTERN = [
63 "aten/src/ATen/native/cudnn/**/*.h",
64 "aten/src/ATen/native/cuda/**/*.h",
65 "aten/src/ATen/native/cuda/**/*.cuh",
66 "aten/src/ATen/native/sparse/cuda/*.h",
67 "aten/src/ATen/native/sparse/cuda/*.cuh",
68 "aten/src/ATen/native/quantized/cuda/*.h",
69 "aten/src/ATen/native/transformers/cuda/*.h",
70 "aten/src/ATen/native/transformers/**/*.cuh",
71]
72
73# T66678203: Clang CUDA rollout
74ATEN_CUDA_CLANG_CU_PATTERN = [
75 "aten/src/ATen/native/cuda/DistributionBernoulli.cu",
76]
77
78### Cuda Files
79def get_aten_cuda_headers():
80 ATEN_CUDA_H = native.glob(ATEN_CUDA_H_PATTERN)
81 ATEN_NATIVE_CUDA_H = native.glob(ATEN_NATIVE_CUDA_H_PATTERN)
82 ATEN_CUDNN_H = native.glob(ATEN_CUDNN_H_PATTERN)
83 return ATEN_CUDA_H + ATEN_NATIVE_CUDA_H + ATEN_CUDNN_H
84
85def get_aten_cuda_srcs():
86 ATEN_CUDA_CU = native.glob(ATEN_CUDA_CU_PATTERN)
87 ATEN_NATIVE_CUDA_CU = native.glob(
88 ATEN_NATIVE_CUDA_CU_PATTERN,
89 exclude = ATEN_CUDA_CLANG_CU_PATTERN,
90 )
91 return ATEN_CUDA_CU + ATEN_NATIVE_CUDA_CU
92
93def get_aten_cuda_clang_srcs():
94 return native.glob(ATEN_CUDA_CLANG_CU_PATTERN)
95
96# CPU+CUDA file
97# Note that these sources and headers include the CPU lists too
98def get_all_cuda_srcs():
99 ATEN_NATIVE_CUDNN_CPP = native.glob(ATEN_NATIVE_CUDNN_CPP_PATTERN)
100 ATEN_CUDNN_CPP = native.glob(ATEN_CUDNN_CPP_PATTERN)
101 ATEN_NATIVE_MIOPEN_CPP = native.glob(ATEN_NATIVE_MIOPEN_CPP_PATTERN)
102 ATEN_CUDA_CPP = native.glob(ATEN_CUDA_CPP_PATTERN)
103 ATEN_NATIVE_CUDA_CPP = native.glob(ATEN_NATIVE_CUDA_CPP_PATTERN)
104
105 return ATEN_NATIVE_CUDNN_CPP + ATEN_CUDNN_CPP + ATEN_NATIVE_MIOPEN_CPP + ATEN_CUDA_CPP + ATEN_NATIVE_CUDA_CPP + get_aten_cuda_srcs()
106
107### HIP files
108# Files that must be hipified
109def get_aten_hip_srcs():
110 ## CU -> HIP files
111 ATEN_CUDA_CU = native.glob(ATEN_CUDA_CU_PATTERN)
112
113 # HIP does not use clang for ATEN_CUDA_CLANG_CU_PATTERN
114 ATEN_NATIVE_CUDA_CU = native.glob(ATEN_NATIVE_CUDA_CU_PATTERN)
115
116 ## CPU files
117 ATEN_NATIVE_CUDNN_CPP = native.glob(ATEN_NATIVE_CUDNN_CPP_PATTERN)
118 ATEN_CUDNN_CPP = native.glob(ATEN_CUDNN_CPP_PATTERN)
119 ATEN_CUDA_CPP = native.glob(ATEN_CUDA_CPP_PATTERN)
120 ATEN_NATIVE_CUDA_CPP = native.glob(ATEN_NATIVE_CUDA_CPP_PATTERN)
121
122 # Get hipified file names (before, after)
123 srcs = ATEN_CUDA_CU + ATEN_NATIVE_CUDA_CU + ATEN_NATIVE_CUDNN_CPP + ATEN_CUDNN_CPP + ATEN_CUDA_CPP + ATEN_NATIVE_CUDA_CPP
124 ret = get_caffe2_hip_srcs(include_patterns = [], include_files = srcs, project_dir = "")
125 return (ret[0], [f.replace("aten/src/", "") for f in ret[1]])
126
127def get_aten_hip_headers():
128 ATEN_CUDA_H = native.glob(ATEN_CUDA_H_PATTERN)
129 ATEN_NATIVE_CUDA_H = native.glob(ATEN_NATIVE_CUDA_H_PATTERN)
130 ATEN_CUDNN_H = [] # native.glob(ATEN_CUDNN_H_PATTERN)
131
132 # Get hipified file names (before, after)
133 srcs = ATEN_CUDA_H + ATEN_NATIVE_CUDA_H + ATEN_CUDNN_H
134 ret = get_caffe2_hip_headers(include_patterns = [], include_files = ATEN_CUDA_H + ATEN_NATIVE_CUDA_H + ATEN_CUDNN_H, project_dir = "")
135 return ret[0], [f.replace("aten/src/", "") for f in ret[1]]
136
137# Native HIP-aware files
138def get_aten_hip_native_srcs():
139 HIP_IMPL_CPP = native.glob(["aten/src/ATen/hip/impl/*.cpp"])
140 ATEN_MIOPEN_CPP = native.glob(ATEN_MIOPEN_CPP_PATTERN)
141 ATEN_NATIVE_MIOPEN_CPP = native.glob(ATEN_NATIVE_MIOPEN_CPP_PATTERN)
142 return HIP_IMPL_CPP + ATEN_MIOPEN_CPP + ATEN_NATIVE_MIOPEN_CPP
143
144def get_aten_hip_native_headers():
145 HIP_IMPL_H = native.glob(["aten/src/ATen/hip/impl/*.h"])
146 ATEN_MIOPEN_H = native.glob(ATEN_MIOPEN_H_PATTERN)
147 return HIP_IMPL_H + ATEN_MIOPEN_H
148
149def get_aten_hip_ufunc_generated_cuda_sources(gencode_pattern = "{}"):
150 # Contents of these CUDA files do not need to be hipified at this point,
151 # but they must be renamed from ".cu" to ".hip" because, unlike OSS, a compiler
152 # is selected based on a file extension.
153
154 renamed_rules = []
155 for n in aten_ufunc_names:
156 cuda_name = "UfuncCUDA_{}.cu".format(n)
157 hip_name = "UfuncCUDA_{}.hip".format(n)
158 buck_genrule(
159 name = "aten_ufunc_hip_renamed_{}".format(n),
160 srcs = [gencode_pattern.format(cuda_name)],
161 bash = 'cp "$SRCDIR/{}" "$OUT"'.format(cuda_name),
162 out = hip_name,
163 default_outs = [],
164 )
165 renamed_rules.append(":aten_ufunc_hip_renamed_{}".format(n))
166 return renamed_rules