[TensorExpr] Add CUDA codegen. (#34227)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34227
This PR adds a CUDA support to tensor expressions.
Differential Revision: D20251836
Test Plan: Imported from OSS
Pulled By: ZolotukhinM
fbshipit-source-id: ab36a55834cceff30c8371fef6cca1054a32f017
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index e7b8e7a..776ea8d 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -557,6 +557,7 @@
${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp
${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp
)
add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS})
target_link_libraries(caffe2_nvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB})
@@ -574,6 +575,7 @@
${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp
${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp
)
if (USE_NCCL)
list(APPEND Caffe2_HIP_SRCS
diff --git a/test/cpp/tensorexpr/gtest.cpp b/test/cpp/tensorexpr/gtest.cpp
index 7c660ee..62e2998 100644
--- a/test/cpp/tensorexpr/gtest.cpp
+++ b/test/cpp/tensorexpr/gtest.cpp
@@ -12,5 +12,14 @@
TH_FORALL_TESTS(TENSOREXPR_GTEST)
#undef TENSOREXPR_GTEST
+#ifdef USE_CUDA
+#define TENSOREXPR_GTEST_CUDA(name) \
+ TEST(TensorExprTest, name##_CUDA) { \
+ test##name(); \
+ }
+TH_FORALL_TESTS_CUDA(TENSOREXPR_GTEST_CUDA)
+#undef TENSOREXPR_GTEST_CUDA
+#endif
+
} // namespace jit
} // namespace torch
diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp
new file mode 100644
index 0000000..7854d02
--- /dev/null
+++ b/test/cpp/tensorexpr/test_cuda.cpp
@@ -0,0 +1,333 @@
+#ifdef USE_CUDA
+
+#include <sstream>
+#include <stdexcept>
+#include "test/cpp/tensorexpr/test_base.h"
+
+#include <cmath>
+
+#include "test/cpp/tensorexpr/padded_buffer.h"
+#include "torch/csrc/jit/tensorexpr/buffer.h"
+#include "torch/csrc/jit/tensorexpr/cuda_codegen.h"
+#include "torch/csrc/jit/tensorexpr/schedule.h"
+#include "torch/csrc/jit/tensorexpr/tensor.h"
+
+#include <c10/cuda/CUDACachingAllocator.h>
+#include <c10/util/Half.h>
+
+namespace torch {
+namespace jit {
+using namespace torch::jit::tensorexpr;
+using namespace torch::jit::tensorexpr::schedule;
+
+template <typename ctype>
+void testCudaTestVectorAdd01_impl() {
+ KernelScope kernel_scope;
+ const int num_iter = 3;
+ const int block_count = 16;
+ const int block_size = 128;
+ Dtype dtype = ToDtype<ctype>();
+ Buffer a_buf("a", dtype, {num_iter, block_count, block_size});
+ Buffer b_buf("b", dtype, {num_iter, block_count, block_size});
+ Tensor* c = Compute(
+ "c",
+ {
+ {num_iter, "n"},
+ {block_count, "b_id"},
+ {block_size, "t_id"},
+ },
+ [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) {
+ return a_buf(n, b_id, t_id) + b_buf(n, b_id, t_id);
+ });
+ LoopNest l({c});
+ std::vector<Stmt*> loops = l.getLoopStmtsFor(c);
+ l.SetGPUBlockIndex(loops[1], 0);
+ l.SetGPUThreadIndex(loops[2], 0);
+ Stmt* stmt = l.root_stmt();
+ CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
+ const int N = block_count * block_size * num_iter;
+ PaddedBuffer<ctype> a_v(N);
+ PaddedBuffer<ctype> b_v(N);
+ PaddedBuffer<ctype> c_v(N);
+ PaddedBuffer<ctype> c_ref(N);
+
+ for (int i = 0; i < N; i++) {
+ a_v(i) = ctype(i);
+ b_v(i) = ctype(i * 3 + 7);
+ c_ref(i) = a_v(i) + b_v(i);
+ }
+
+ // TODO: move gpu support into PaddedBuffer
+ ctype* a_dev = nullptr;
+ cudaMalloc(&a_dev, N * sizeof(ctype));
+ ctype* b_dev = nullptr;
+ cudaMalloc(&b_dev, N * sizeof(ctype));
+ ctype* c_dev = nullptr;
+ cudaMalloc(&c_dev, N * sizeof(ctype));
+ cudaMemcpy(a_dev, a_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice);
+ cudaMemcpy(b_dev, b_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice);
+ cudaMemcpy(c_dev, c_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice);
+ cudaDeviceSynchronize();
+
+ cuda_cg(c_dev, a_dev, b_dev);
+
+ cudaDeviceSynchronize();
+ cudaMemcpy(c_v.data(), c_dev, N * sizeof(ctype), cudaMemcpyDeviceToHost);
+ cudaDeviceSynchronize();
+
+ ExpectAllNear(c_v, c_ref, 1e-5);
+
+ cudaFree(a_dev);
+ cudaFree(b_dev);
+ cudaFree(c_dev);
+}
+
+void testCudaTestVectorAdd01() {
+ // floating types.
+ testCudaTestVectorAdd01_impl<float>();
+ testCudaTestVectorAdd01_impl<at::Half>();
+ testCudaTestVectorAdd01_impl<double>();
+
+ // integer types.
+ testCudaTestVectorAdd01_impl<int8_t>();
+ testCudaTestVectorAdd01_impl<uint8_t>();
+ testCudaTestVectorAdd01_impl<int16_t>();
+ testCudaTestVectorAdd01_impl<int32_t>();
+ testCudaTestVectorAdd01_impl<int64_t>();
+}
+
+static void testCudaTestVectorAdd02_impl(int N, int block_size) {
+ KernelScope kernel_scope;
+ Buffer a_buf("a", kFloat, {N});
+ Buffer b_buf("b", kFloat, {N});
+ Tensor* c = Compute(
+ "c",
+ {
+ {N, "N"},
+ },
+ [&](const VarHandle& n) { return a_buf(n) + b_buf(n); });
+ LoopNest l({c});
+ Stmt* n_outer;
+ Stmt* n_inner;
+ std::vector<Stmt*> loops = l.getLoopStmtsFor(c);
+ l.SplitWithMask(loops[0], block_size, &n_outer, &n_inner);
+ l.SetGPUBlockIndex(n_outer, 0);
+ l.SetGPUThreadIndex(n_inner, 0);
+ Stmt* stmt = l.root_stmt();
+ CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
+ PaddedBuffer<float> a_v(N);
+ PaddedBuffer<float> b_v(N);
+ PaddedBuffer<float> c_v(N);
+ PaddedBuffer<float> c_ref(N);
+
+ for (int i = 0; i < N; i++) {
+ a_v(i) = i;
+ b_v(i) = i * 3 + 7;
+ c_ref(i) = a_v(i) + b_v(i);
+ }
+
+ // TODO: move gpu support into PaddedBuffer
+ float* a_dev = nullptr;
+ cudaMalloc(&a_dev, N * sizeof(float));
+ float* b_dev = nullptr;
+ cudaMalloc(&b_dev, N * sizeof(float));
+ float* c_dev = nullptr;
+ cudaMalloc(&c_dev, N * sizeof(float));
+ cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice);
+ cudaMemcpy(b_dev, b_v.data(), N * sizeof(float), cudaMemcpyHostToDevice);
+ cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice);
+ cudaDeviceSynchronize();
+
+ cuda_cg(c_dev, a_dev, b_dev);
+
+ cudaDeviceSynchronize();
+ cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost);
+ cudaDeviceSynchronize();
+
+ ExpectAllNear(c_v, c_ref, 1e-5);
+
+ cudaFree(a_dev);
+ cudaFree(b_dev);
+ cudaFree(c_dev);
+}
+
+void testCudaTestVectorAdd02() {
+ testCudaTestVectorAdd02_impl(1024, 128);
+ testCudaTestVectorAdd02_impl(1030, 128);
+}
+
+void testCudaDynamicShape2D() {
+ KernelScope kernel_scope;
+ auto testWithSize = [](int32_t M, int32_t N) {
+ VarHandle m("m", kInt);
+ VarHandle n("n", kInt);
+ Buffer a(VarHandle("a", kHandle), kFloat, {m, n});
+ Buffer b(VarHandle("b", kHandle), kFloat, {m, n});
+ Tensor* c = Compute(
+ "c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) {
+ return a(i, j) + b(i, j);
+ });
+ LoopNest l({c});
+ Stmt* s = l.root_stmt();
+ CudaCodeGen cg(s, {a, b, c, m, n});
+
+ std::vector<float> aData(M * N, 1.0f);
+ std::vector<float> bData(M * N, 2.0f);
+ std::vector<float> cData(M * N, 0.0f);
+ float* aDev = nullptr;
+ float* bDev = nullptr;
+ float* cDev = nullptr;
+ cudaMalloc(&aDev, aData.size() * sizeof(aData[0]));
+ cudaMalloc(&bDev, bData.size() * sizeof(bData[0]));
+ cudaMalloc(&cDev, cData.size() * sizeof(cData[0]));
+ cudaMemcpy(
+ aDev,
+ aData.data(),
+ aData.size() * sizeof(aData[0]),
+ cudaMemcpyHostToDevice);
+ cudaMemcpy(
+ bDev,
+ bData.data(),
+ bData.size() * sizeof(bData[0]),
+ cudaMemcpyHostToDevice);
+ cudaMemcpy(
+ cDev,
+ cData.data(),
+ cData.size() * sizeof(cData[0]),
+ cudaMemcpyHostToDevice);
+ cudaDeviceSynchronize();
+
+ cg.call({aDev, bDev, cDev, M, N});
+ cudaDeviceSynchronize();
+
+ cudaMemcpy(
+ cData.data(),
+ cDev,
+ cData.size() * sizeof(cData[0]),
+ cudaMemcpyDeviceToHost);
+ cudaDeviceSynchronize();
+
+ ExpectAllNear(cData, std::vector<float>(M * N, 3.0f), 1e-7);
+
+ cudaFree(aDev);
+ cudaFree(bDev);
+ cudaFree(cDev);
+ };
+ testWithSize(32, 32);
+ testWithSize(1, 16);
+ testWithSize(27, 13);
+}
+
+void testCudaTestRand01() {
+ KernelScope kernel_scope;
+ const int num_iter = 3;
+ const int block_count = 16;
+ const int block_size = 128;
+ Tensor* c = Compute(
+ "c",
+ {
+ {num_iter, "n"},
+ {block_count, "b_id"},
+ {block_size, "t_id"},
+ },
+ [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) {
+ return Intrinsics::make(IntrinsicsOp::kRand, kFloat);
+ });
+ LoopNest l({c});
+ std::vector<Stmt*> loops = l.getLoopStmtsFor(c);
+ l.SetGPUBlockIndex(loops[1], 0);
+ l.SetGPUThreadIndex(loops[2], 0);
+ Stmt* stmt = l.root_stmt();
+ CudaCodeGen cuda_cg(stmt, c);
+ const int N = block_count * block_size * num_iter;
+ PaddedBuffer<float> c_v(N);
+
+ // TODO: move gpu support into PaddedBuffer
+ float* c_dev = nullptr;
+ cudaMalloc(&c_dev, N * sizeof(float));
+ cudaDeviceSynchronize();
+
+ cuda_cg(c_dev);
+
+ cudaDeviceSynchronize();
+ cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost);
+ cudaDeviceSynchronize();
+
+ float sum1 = 0;
+ float sum2 = 0;
+ float sum3 = 0;
+ for (int i = 0; i < N; i++) {
+ float v = c_v.data()[i];
+ sum1 += v;
+ sum2 += v * v;
+ sum3 += v * v * v;
+ EXPECT_TRUE(v >= 0 && v < 1) << "invalid value: " << i << ", " << v;
+ }
+ sum1 /= N;
+ sum2 /= N;
+ sum3 /= N;
+ float sum1_mean = 1.f / 2;
+ float sum2_mean = 1.f / 3;
+ float sum3_mean = 1.f / 4;
+
+ EXPECT_NEAR(sum1, sum1_mean, 2e-2);
+ EXPECT_NEAR(sum2, sum2_mean, 2e-2);
+ EXPECT_NEAR(sum3, sum3_mean, 2e-2);
+ cudaFree(c_dev);
+}
+
+void testCudaDynamicShapeSplit() {
+ KernelScope ks;
+ constexpr int N = 4096;
+ VarHandle n("n", kInt);
+ Buffer a(VarHandle("a", kHandle), kFloat, {n});
+ Tensor* b =
+ Compute("b", {{n, "n"}}, [&](const VarHandle& i) { return a(i) * 2.0f; });
+ LoopNest l({b});
+ Stmt* outer;
+ Stmt* inner;
+ std::vector<Stmt*> loops = l.getLoopStmtsFor(b);
+ l.SplitWithMask(loops[0], 1024, &outer, &inner);
+ l.SetGPUBlockIndex(outer, 0);
+ l.SetGPUThreadIndex(inner, 0);
+ Stmt* s = l.root_stmt();
+ CudaCodeGen cg(s, {a, b, n});
+
+ std::vector<float> aData(N, 1.0f);
+ std::vector<float> bData(N, 1.0f);
+ float* aDev = nullptr;
+ float* bDev = nullptr;
+ cudaMalloc(&aDev, aData.size() * sizeof(aData[0]));
+ cudaMalloc(&bDev, bData.size() * sizeof(bData[0]));
+ cudaMemcpy(
+ aDev,
+ aData.data(),
+ aData.size() * sizeof(aData[0]),
+ cudaMemcpyHostToDevice);
+ cudaMemcpy(
+ bDev,
+ bData.data(),
+ bData.size() * sizeof(aData[0]),
+ cudaMemcpyHostToDevice);
+ cudaDeviceSynchronize();
+
+ cg.call({aDev, bDev, N});
+ cudaDeviceSynchronize();
+
+ cudaMemcpy(
+ bData.data(),
+ bDev,
+ bData.size() * sizeof(aData[0]),
+ cudaMemcpyDeviceToHost);
+ cudaDeviceSynchronize();
+
+ ExpectAllNear(bData, std::vector<float>(N, 2.0f), 1e-7);
+
+ cudaFree(aDev);
+ cudaFree(bDev);
+}
+
+} // namespace jit
+} // namespace torch
+
+#endif
diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h
index 17aff26..19642d6 100644
--- a/test/cpp/tensorexpr/tests.h
+++ b/test/cpp/tensorexpr/tests.h
@@ -87,6 +87,11 @@
_(ATenltInt)
#define TH_FORALL_TESTS_CUDA(_) \
+ _(CudaTestVectorAdd01) \
+ _(CudaTestVectorAdd02) \
+ _(CudaDynamicShape2D) \
+ _(CudaTestRand01) \
+ _(CudaDynamicShapeSplit)
#define DECLARE_TENSOREXPR_TEST(name) void test##name();
TH_FORALL_TESTS(DECLARE_TENSOREXPR_TEST)
diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py
index 381d68a..361340d 100644
--- a/test/test_tensorexpr.py
+++ b/test/test_tensorexpr.py
@@ -34,6 +34,16 @@
return value - self.start_value
+class CudaCodeGenCreated(ExecutionCounter):
+ def __init__(self):
+ super(CudaCodeGenCreated, self).__init__("cuda_codegen_created")
+
+
+class CudaCodeGenExecuted(ExecutionCounter):
+ def __init__(self):
+ super(CudaCodeGenExecuted, self).__init__("cuda_codegen_executed")
+
+
class SimpleIREvalExecuted(ExecutionCounter):
def __init__(self):
super(SimpleIREvalExecuted, self).__init__("simple_ir_eval_executed")
@@ -80,7 +90,7 @@
c = torch.addcmul(torch.add(x, y), z, w)
return c
- device_options = ["cpu"]
+ device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu']
for dev in device_options:
rand_a = torch.rand(1024, dtype=torch.float, device=dev)
rand_b = torch.rand(1024, dtype=torch.float, device=dev)
@@ -102,6 +112,79 @@
np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6)
+ def test_three_arg_cuda(self):
+ if not torch.cuda.is_available():
+ return
+ cuda_cg_executed = CudaCodeGenExecuted()
+ cuda_cg_created = CudaCodeGenCreated()
+
+ def test(x, y, z):
+ aaa = torch.add(x, y)
+ bbb = torch.add(aaa, z)
+ return bbb
+
+ M = 32
+ N = 32
+ traced = torch.jit.trace(
+ test,
+ (
+ torch.rand(M, N, device="cuda"),
+ torch.rand(M, N, device="cuda"),
+ torch.rand(M, N, device="cuda"),
+ ),
+ )
+
+ a = torch.rand(M, N, device="cuda")
+ b = torch.rand(M, N, device="cuda")
+ c = torch.rand(M, N, device="cuda")
+ x = traced(a, b, c)
+ npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
+ np.testing.assert_allclose(npr, x.cpu().numpy())
+ assert cuda_cg_executed.elapsed_value() >= 1
+ assert cuda_cg_created.elapsed_value() >= 1
+
+
+ def test_broadcast_cuda(self):
+ if not torch.cuda.is_available():
+ return
+
+ def test_body(M, N, L, K):
+ if not torch.cuda.is_available():
+ return
+ cuda_cg_executed = CudaCodeGenExecuted()
+ cuda_cg_created = CudaCodeGenCreated()
+
+ def test(x, y, z):
+ v1 = torch.add(x, y)
+ v2 = torch.add(v1, z)
+ return v2
+
+ a_shape = [M, N]
+ b_shape = [L, M, 1]
+ c_shape = [K, L, 1, 1]
+ traced = torch.jit.trace(
+ test,
+ (
+ torch.rand(*a_shape, device="cuda"),
+ torch.rand(*b_shape, device="cuda"),
+ torch.rand(*c_shape, device="cuda"),
+ ),
+ )
+
+ a = torch.rand(*a_shape, device="cuda")
+ b = torch.rand(*b_shape, device="cuda")
+ c = torch.rand(*c_shape, device="cuda")
+ x = traced(a, b, c)
+ npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
+ np.testing.assert_allclose(npr, x.cpu().numpy())
+ assert cuda_cg_executed.elapsed_value() >= 1
+ assert cuda_cg_created.elapsed_value() >= 1
+
+ test_configs = [[36, 17, 63, 33], [32, 32, 32, 32]]
+ for test_config in test_configs:
+ test_body(*test_config)
+
+
def test_all_combos(self):
def easy(x, y, z):
a = torch.add(x, y)
@@ -426,7 +509,7 @@
c = torch.lt(x, y)
return c
- device_options = ["cpu"]
+ device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
for dev in device_options:
traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev)))
a = torch.ones(1024, dtype=torch.int32, device=dev)
@@ -451,7 +534,7 @@
def test(x):
return torch.clamp(x + 3.0, 0.0, 6.0)
- device_options = ["cpu"]
+ device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
for dev in device_options:
traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
@@ -463,7 +546,7 @@
def test(x):
return torch.clamp(F.relu(x), 0, 0.5)
- device_options = ["cpu"]
+ device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
for dev in device_options:
traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
a = 20.0 * torch.rand(1024, device=dev) - 10.0
@@ -598,7 +681,7 @@
# test_tanh_backward,
test_type_as,
}
- device_options = ["cpu"]
+ device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu']
for torch_fn in fns:
for dev in device_options:
rand_a = torch.rand(1024, device=dev)
@@ -776,7 +859,7 @@
test_neg,
test_relu,
}
- device_options = ["cpu"]
+ device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu']
for torch_fn in fns:
for dev in device_options:
@@ -797,6 +880,26 @@
np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
+ def test_rand_like(self):
+ devices = ["cuda"] if torch.cuda.is_available() else []
+ N = 1 << 16
+
+ def run_rand_like(x, y):
+ return torch.rand_like(torch.add(x, y))
+
+ for device in devices:
+ x = torch.rand(N, device=device)
+ traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False)
+ x_v = traced(x, x)
+ x_np = x.cpu().numpy()
+ x1_mean = np.mean(x_np)
+ x2_mean = np.mean(x_np ** 2)
+ x3_mean = np.mean(x_np ** 3)
+ np.testing.assert_allclose(x1_mean, 1. / 2, rtol=2e-2)
+ np.testing.assert_allclose(x2_mean, 1. / 3, rtol=2e-2)
+ np.testing.assert_allclose(x3_mean, 1. / 4, rtol=2e-2)
+
+
def test_nans(self):
def test_max(x, y):
return torch.max(2 * x, 2 * y)
@@ -898,6 +1001,10 @@
def test_cat_cpu(self):
self._test_cat('cpu')
+ @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
+ def test_cat_cuda(self):
+ self._test_cat('cuda')
+
def test_scalar(self):
@torch.jit.script
def test_float(x, y, z, a, b):
@@ -1001,8 +1108,66 @@
assert interp.elapsed_value() == 1
+ @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
+ @unittest.skip("dynamic shapes are not quite there yet")
+ def test_dynamic_shape(self):
+ with num_profiled_runs(2):
+ @torch.jit.script
+ def test(x, y, z):
+ return x * y * z
+ cuda = CudaCodeGenCreated()
+ x, y, z = [torch.rand(4, 8).cuda() for _ in range(3)]
+ ref = test(x, y, z)
+ _ = test(*[torch.rand(6, 8).cuda() for _ in range(3)])
+ res = test(x, y, z)
+ np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy())
+ assert cuda.elapsed_value() == 1
+
+ # A wild broadcast appears.
+ x = torch.rand(4, 8).cuda()
+ y = torch.rand(1, 8).cuda()
+ z = torch.rand(4, 1).cuda()
+ res = test(x, y, z)
+ xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)]
+ np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
+ assert cuda.elapsed_value() == 1
+
+ # Mismatched shapes shouldn't reach codegen.
+ x = torch.rand(4, 8).cuda()
+ y = torch.rand(4, 8).cuda()
+ z = torch.rand(5, 8).cuda()
+ try:
+ res = test(x, y, z)
+ except RuntimeError as e:
+ assert "The size of tensor a (4) must match" in e.args[0]
+ assert cuda.elapsed_value() == 1
+
+ # Changing a static dimension fails guards.
+ # x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)]
+ # xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)]
+ # res = test(x, y, z)
+ # print(test.graph_for(x, y, z))
+ # np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
+ # assert cuda.elapsed_value() == 1
+
+ @unittest.skip("guarding on static shapes is not working")
+ def test_guard_fails(self):
+ @torch.jit.script
+ def test(x, y, z):
+ return x * y * z
+ cuda = CudaCodeGenExecuted()
+ r1 = test(*[torch.rand(4).cuda() for _ in range(3)])
+ assert cuda.elapsed_value() == 0
+ r2 = test(*[torch.rand(4).cuda() for _ in range(3)])
+ assert cuda.elapsed_value() == 1
+ r3 = test(*[torch.rand(4).cuda() for _ in range(3)])
+ assert cuda.elapsed_value() == 2
+ r4 = test(*[torch.rand(7).cuda() for _ in range(3)])
+ print(test.graph_for(*[torch.rand(7).cuda() for _ in range(3)]))
+ assert cuda.elapsed_value() == 2
+
def test_bitwise_ops(self):
- devices = ["cpu"]
+ devices = ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]
def run_and(x, y):
return x & (x & y)
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index 11ecc78..d548a35 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -219,6 +219,7 @@
"torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp",
"torch/csrc/autograd/profiler_cuda.cpp",
"torch/csrc/autograd/functions/comm.cpp",
+ "torch/csrc/jit/tensorexpr/cuda_codegen.cpp",
]
torch_cpp_srcs = [
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 02a652d..722c2c7 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -60,6 +60,7 @@
#include <torch/csrc/jit/python/python_tree_views.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
+#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <c10/macros/Export.h>
#include <caffe2/serialize/inline_container.h>
@@ -414,6 +415,42 @@
ExecutionTriggerList::GetInstance().FindByName(trigger_name);
return trigger->value();
})
+ .def(
+ "_jit_get_te_cuda_pointwise_loop_levels",
+ []() -> int {
+ using namespace torch::jit::tensorexpr;
+ return GetTECudaPointwiseLoopLevels();
+ })
+ .def(
+ "_jit_set_te_cuda_pointwise_loop_levels",
+ [](int level) {
+ using namespace torch::jit::tensorexpr;
+ return GetTECudaPointwiseLoopLevels() = level;
+ })
+ .def(
+ "_jit_get_te_cuda_pointwise_block_count",
+ []() -> int {
+ using namespace torch::jit::tensorexpr;
+ return GetTECudaPointwiseBlockCount();
+ })
+ .def(
+ "_jit_set_te_cuda_pointwise_block_count",
+ [](int block_count) {
+ using namespace torch::jit::tensorexpr;
+ return GetTECudaPointwiseBlockCount() = block_count;
+ })
+ .def(
+ "_jit_get_te_cuda_pointwise_block_size",
+ []() -> int {
+ using namespace torch::jit::tensorexpr;
+ return GetTECudaPointwiseBlockSize();
+ })
+ .def(
+ "_jit_set_te_cuda_pointwise_block_size",
+ [](int block_size) {
+ using namespace torch::jit::tensorexpr;
+ return GetTECudaPointwiseBlockSize() = block_size;
+ })
.def("_jit_set_texpr_fuser_enabled", &setTensorExprFuserEnabled)
.def(
"_jit_fuser_get_fused_kernel_code",
diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
new file mode 100644
index 0000000..c3f1332
--- /dev/null
+++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
@@ -0,0 +1,695 @@
+#include "torch/csrc/jit/tensorexpr/cuda_codegen.h"
+#include "torch/csrc/jit/tensorexpr/cuda_half_support.h"
+
+#include "ATen/CUDAGenerator.h"
+#include "c10/cuda/CUDAFunctions.h"
+#include "torch/csrc/jit/tensorexpr/cuda_random.h"
+#include "torch/csrc/jit/tensorexpr/eval.h"
+#include "torch/csrc/jit/tensorexpr/execution_counter.h"
+
+#define DEBUG_PRINT 0
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+DEFINE_TRIGGER(cuda_codegen_created);
+DEFINE_TRIGGER(cuda_codegen_executed);
+
+// A RAII wrapper to manage a variable and name pair in the look-up table.
+// TODO: move this to a more shared place.
+class ScopedVarName {
+ public:
+ ScopedVarName(VarNameMap* mapping, const Var* var, const std::string& name)
+ : mapping_(mapping), var_(var) {
+ auto iter = mapping->find(var);
+ if (iter != mapping->end()) {
+ throw std::runtime_error("Duplicate var entry: " + var->name_hint());
+ }
+ mapping->insert(std::make_pair(var, name));
+ }
+
+ ScopedVarName(
+ UniqueNameManager* manager,
+ const Var* var,
+ const std::string& name)
+ : ScopedVarName(&manager->unique_name_mapping_, var, name) {}
+
+ ScopedVarName(const ScopedVarName&) = delete;
+ ScopedVarName& operator=(const ScopedVarName&) = delete;
+
+ ~ScopedVarName() noexcept(false) {
+ mapping_->erase(var_);
+ }
+
+ private:
+ VarNameMap* mapping_ = nullptr;
+ const Var* var_ = nullptr;
+};
+
+static int as_int(const Expr* expr) {
+ auto v = dynamic_cast<const IntImm*>(expr);
+ TORCH_CHECK(v, "Expression is not an integer constant");
+ return v->value();
+}
+
+static bool is_zero(const Expr* expr) {
+ return as_int(expr) == 0;
+}
+
+static const at::cuda::NVRTC& nvrtc() {
+ return at::globalContext().getNVRTC();
+}
+
+static void getMajorMinor(
+ const cudaDeviceProp* const prop,
+ int& major,
+ int& minor) {
+ using CudaVersion = std::pair<int, int>;
+ CudaVersion nvrtc_version;
+ AT_CUDA_NVRTC_CHECK(
+ nvrtc().nvrtcVersion(&nvrtc_version.first, &nvrtc_version.second));
+
+ AT_ASSERT(nvrtc_version.first >= 6);
+
+ CudaVersion dev_version = CudaVersion(prop->major, prop->minor);
+ CudaVersion max_dev_version(dev_version);
+ if (nvrtc_version.first <= 7) { // 7 supports 2-5.x
+ max_dev_version = CudaVersion(5, 0);
+ } else if (nvrtc_version.first <= 8) { // 8 supports 2-6.x
+ max_dev_version = CudaVersion(6, 0);
+ } else if (nvrtc_version.first <= 9) { // 9 supports 3-7.2
+ max_dev_version = CudaVersion(7, 2);
+ } else if (nvrtc_version.first <= 10) { // 10 supports 3-7.5
+ max_dev_version = CudaVersion(7, 5);
+ }
+ if (dev_version > max_dev_version) {
+ dev_version = max_dev_version;
+ }
+ major = dev_version.first;
+ minor = dev_version.second;
+}
+
+void CudaPrinter::visit(const For* v) {
+ const LoopOptions& loop_options = v->loop_options();
+ if (loop_options.is_gpu_block_index()) {
+ ScopedVarName var_name(
+ name_manager(), v->var(), loop_options.gpu_block_index_str());
+ v->body()->accept(this);
+ int gpu_block_index = loop_options.gpu_block_index();
+ if (gpu_block_extents_.size() <= gpu_block_index) {
+ gpu_block_extents_.resize(gpu_block_index + 1);
+ }
+ if (!is_zero(v->start())) {
+ throw std::runtime_error(
+ "start must be zero for gpu_block_index: " +
+ std::to_string(ExprHandle(v->start())));
+ }
+ gpu_block_extents_[gpu_block_index] = v->stop();
+ } else if (loop_options.is_gpu_thread_index()) {
+ ScopedVarName var_name(
+ name_manager(), v->var(), loop_options.gpu_thread_index_str());
+ v->body()->accept(this);
+ int gpu_thread_index = loop_options.gpu_thread_index();
+ if (gpu_thread_extents_.size() <= gpu_thread_index) {
+ gpu_thread_extents_.resize(gpu_thread_index + 1);
+ }
+ if (!is_zero(v->start())) {
+ throw std::runtime_error(
+ "start must be zero for gpu_block_index: " +
+ std::to_string(ExprHandle(v->start())));
+ }
+ gpu_thread_extents_[gpu_thread_index] = v->stop();
+ } else {
+ IRPrinter::visit(v);
+ }
+}
+
+void CudaPrinter::visit(const Intrinsics* v) {
+ if (v->op_type() == IntrinsicsOp::kRand) {
+ os() << "Uint32ToFloat(" << *rand_func_ << "())";
+ return;
+ }
+
+ std::string func_name = v->func_name();
+
+ // get type of resulting expression.
+ ScalarType returnType = v->param(0)->dtype().scalar_type();
+ for (int i = 1; i < v->nparams(); ++i) {
+ returnType = promoteTypes(returnType, v->param(i)->dtype().scalar_type());
+ }
+
+ if (returnType == ScalarType::Half || returnType == ScalarType::Float) {
+ func_name = func_name + "f";
+ }
+
+ os() << func_name << "(";
+ for (int i = 0; i < v->nparams(); i++) {
+ if (i > 0) {
+ os() << ", ";
+ }
+ os() << *v->param(i);
+ }
+ os() << ")";
+}
+
+void CudaPrinter::visit(const Load* v) {
+ // TODO: find a better metric in using ldg or not. Support different dtypes.
+ if (v->dtype().scalar_type() == ScalarType::Half) {
+ os() << "__half2float(" << *v->base_handle() << "[" << *v->index() << "])";
+ } else {
+ os() << "__ldg(" << *v->base_handle() << " + " << *v->index() << ")";
+ }
+}
+
+void CudaPrinter::visit(const Store* v) {
+ os() << *v->base_handle() << "[" << *v->index() << "] = ";
+ if (v->value()->dtype().scalar_type() == ScalarType::Half) {
+ os() << "__float2half(" << *v->value() << ");";
+ } else {
+ os() << *v->value() << ";";
+ }
+}
+
+void CudaPrinter::visit(const Max* v) {
+ auto dtype = v->dtype().scalar_type();
+ switch (dtype) {
+ case ScalarType::Half:
+ // doing Half math in float.
+ case ScalarType::Float:
+ os() << "fmaxf";
+ break;
+ case ScalarType::Double:
+ os() << "fmax";
+ break;
+ default:
+ os() << "max";
+ break;
+ }
+ os() << "(";
+ v->lhs()->accept(this);
+ os() << ",";
+ v->rhs()->accept(this);
+ os() << ")";
+}
+
+void CudaPrinter::visit(const Min* v) {
+ auto dtype = v->dtype().scalar_type();
+ switch (dtype) {
+ case ScalarType::Half:
+ // doing Half math in float.
+ case ScalarType::Float:
+ os() << "fminf";
+ break;
+ case ScalarType::Double:
+ os() << "fmin";
+ break;
+ default:
+ os() << "min";
+ break;
+ }
+ os() << "(";
+ v->lhs()->accept(this);
+ os() << ",";
+ v->rhs()->accept(this);
+ os() << ")";
+}
+
+std::string cudaDtypeCppString(const Dtype& dtype) {
+ switch (dtype.scalar_type()) {
+ case ScalarType::Half:
+ return "half";
+ case ScalarType::Char:
+ return "char";
+ case ScalarType::Byte:
+ return "unsigned char";
+ case ScalarType::Short:
+ return "short";
+ case ScalarType::Long:
+ return "long";
+ default:; /* nothing */
+ }
+ return dtype.ToCppString();
+}
+
+void CudaPrinter::visit(const LetStmt* v) {
+ const Var* var = v->var();
+ if (var->dtype().scalar_type() == ScalarType::Half) {
+ // we do math in floats so use that.
+ os() << "float";
+ } else {
+ os() << cudaDtypeCppString(var->dtype());
+ }
+ os() << " " << *var << " = " << *v->value() << "; " << std::endl;
+ v->body()->accept(this);
+}
+
+void CudaPrinter::visit(const IfThenElse* v) {
+ os() << "((";
+ v->condition()->accept(this);
+ os() << ") ? ";
+ v->true_value()->accept(this);
+ os() << " : ";
+ v->false_value()->accept(this);
+ os() << ")";
+}
+
+class PrioritizeLoad : public IRMutator {
+ public:
+ const Expr* mutate(const Load* v) override {
+ // Look at the declaration of this variable for more details.
+ if (nested_if_then_else_ > 0) {
+ return IRMutator::mutate(v);
+ }
+ MemLoadList& load_list = load_stack_.back();
+ const Var* load_new_var = new Var("v", v->dtype());
+ const Expr* new_value = IRMutator::mutate(v);
+ load_list.push_back(std::make_pair(load_new_var, new_value));
+ return load_new_var;
+ }
+
+ // TODO: merge this with the IRMutator::mutate version.
+ Stmt* mutate(const For* v) override {
+ const Var* var = v->var();
+ const Expr* start = v->start();
+ const Expr* stop = v->stop();
+ Stmt* body = v->body();
+ LoopOptions loop_options = v->loop_options();
+ const Var* var_new = dynamic_cast<const Var*>(var->accept_mutator(this));
+ const Expr* start_new = start->accept_mutator(this);
+ const Expr* stop_new = stop->accept_mutator(this);
+ PushList();
+ Stmt* body_new = body->accept_mutator(this);
+ if (!body_new) {
+ return nullptr;
+ }
+ Stmt* body_with_loads = AddMemLoadsFromList(body_new);
+ PopList();
+ if (var == var_new && start == start_new && stop == stop_new &&
+ body == body_with_loads) {
+ return (Stmt*)v;
+ }
+ return new For(var_new, start_new, stop_new, body_with_loads, loop_options);
+ }
+
+ Stmt* mutate(const LetStmt* v) override {
+ const Var* var = v->var();
+ const Expr* value = v->value();
+ Stmt* body = v->body();
+ const Var* var_new = dynamic_cast<const Var*>(var->accept_mutator(this));
+ if (var_new == nullptr) {
+ throw std::runtime_error("LetStmt var must be variable");
+ }
+ const Expr* value_new = value->accept_mutator(this);
+ PushList();
+ Stmt* body_new = body->accept_mutator(this);
+ Stmt* body_with_loads = AddMemLoadsFromList(body_new);
+ PopList();
+ if (var == var_new && value == value_new && body == body_with_loads) {
+ return (Stmt*)v;
+ }
+ return new LetStmt(var_new, value_new, body_with_loads);
+ }
+
+ Stmt* mutate(const Cond* v) override {
+ const Expr* cond_old = v->condition();
+ Stmt* true_old = v->true_stmt();
+ Stmt* false_old = v->false_stmt();
+
+ const Expr* cond_new = cond_old->accept_mutator(this);
+ PushList();
+ Stmt* true_new = true_old ? true_old->accept_mutator(this) : true_old;
+ Stmt* true_with_loads = AddMemLoadsFromList(true_new);
+ PopList();
+ PushList();
+ Stmt* false_new = false_old ? false_old->accept_mutator(this) : false_old;
+ Stmt* false_with_loads = AddMemLoadsFromList(false_new);
+ PopList();
+
+ if (cond_old == cond_new && true_old == true_with_loads &&
+ false_old == false_with_loads) {
+ return (Stmt*)v;
+ }
+ return new Cond(cond_new, true_with_loads, false_with_loads);
+ }
+
+ const Expr* mutate(const IfThenElse* v) override {
+ nested_if_then_else_++;
+ const Expr* new_v = IRMutator::mutate(v);
+ nested_if_then_else_--;
+ return new_v;
+ }
+
+ Stmt* Process(Stmt* stmt) {
+ this->PushList();
+ Stmt* stmt_v = stmt;
+ Stmt* stmt_new = stmt_v->accept_mutator(this);
+ Stmt* stmt_with_loads = AddMemLoadsFromList(stmt_new);
+ this->PopList();
+ return stmt_with_loads;
+ }
+
+ private:
+ using MemLoadEntry = std::pair<const Var*, const Expr*>;
+ using MemLoadList = std::vector<MemLoadEntry>;
+ using MemoryLoadStack = std::vector<MemLoadList>;
+
+ void PushList() {
+ load_stack_.push_back(MemLoadList());
+ }
+
+ void PopList() {
+ load_stack_.pop_back();
+ }
+
+ Stmt* AddMemLoadsFromList(Stmt* stmt) {
+ MemLoadList& load_list = load_stack_.back();
+ Stmt* stmt_v = stmt;
+ for (auto iter = load_list.rbegin(); iter != load_list.rend(); iter++) {
+ const MemLoadEntry& entry = *iter;
+ const Var* var_ptr = entry.first;
+ stmt_v = new LetStmt(var_ptr, entry.second, stmt_v);
+ }
+ return stmt_v;
+ }
+
+ MemoryLoadStack load_stack_;
+ // TODO: For now, we are not moving the loads with the IfThenElse.
+ // Eventually, we should switch to a more generic structure like:
+ // int v2 = IfThenElse(cond, true_v, false_v) + 2 ->
+ //
+ // int v;
+ // if (cond) {
+ // v = true_v;
+ // } else {
+ // v = false_v;
+ // }
+ // int v2 = v + 2;
+ int nested_if_then_else_ = 0;
+};
+
+class HasRand : public IRVisitor {
+ public:
+ HasRand(Stmt* stmt) : stmt_(stmt) {
+ stmt_->accept(this);
+ }
+
+ bool has_rand() const {
+ return has_rand_;
+ }
+
+ private:
+ void visit(const Intrinsics* v) override {
+ if (v->op_type() == IntrinsicsOp::kRand) {
+ has_rand_ = true;
+ } else {
+ IRVisitor::visit(v);
+ }
+ }
+ Stmt* stmt_;
+ bool has_rand_ = false;
+};
+
+std::string CudaCodeGen::GetUniqueFuncName(const std::string& func_prefix) {
+ // We are using a global counter here to make sure difference instances within
+ // CudaCodeGen have different names.
+ static int64_t counter = 0;
+ ++counter;
+ int64_t value = counter;
+ return func_prefix + "_" + std::to_string(value);
+}
+
+void CudaCodeGen::Initialize() {
+ // TODO: handle multiple kernels.
+ // TODO: handle dynamic dimension.
+ // TODO: call nvrtc.
+ HasRand has_rand_func(stmt());
+ has_random_ = has_rand_func.has_rand();
+ printer_ = std::make_unique<CudaPrinter>(&oss_, has_random_);
+
+ os() << "#define NAN __int_as_float(0x7fffffff)\n"
+ "#define POS_INFINITY __int_as_float(0x7f800000)\n"
+ "#define NEG_INFINITY __int_as_float(0xff800000)\n";
+ if (has_random_) {
+ os() << philox_random_string << std::endl;
+ }
+
+ // Check whether the statement uses the Half type, if so add the
+ // half_support_literal.
+ CudaHalfChecker halfChecker;
+ stmt()->accept(&halfChecker);
+ if (halfChecker.hasHalf()) {
+ os() << fuser::cuda::half_support_literal << std::endl;
+ }
+
+ std::string func_name = GetUniqueFuncName("func");
+ os() << "extern \"C\" __global__" << std::endl << "void " << func_name << "(";
+ const std::vector<BufferArg> buffer_args = this->buffer_args();
+ for (size_t i = 0; i < buffer_args.size(); i++) {
+ if (i > 0) {
+ os() << ", ";
+ }
+ const BufferArg& buffer_arg = buffer_args[i];
+ const Var* var = buffer_arg.var();
+ Dtype dtype = buffer_arg.dtype();
+
+ os() << cudaDtypeCppString(dtype) << (buffer_arg.isVar() ? " " : "* ")
+ << name_manager()->get_unique_name(var);
+ }
+ const Var* rand_seed;
+ const Var* rand_offset;
+ if (has_random_) {
+ // TODO: switch to kUint64 when it is available.
+ rand_seed = new Var("rand_seed", kInt);
+ rand_offset = new Var("rand_offset", kInt);
+ std::string uint64_str = "unsigned long long";
+ os() << ", " << uint64_str << " " << *rand_seed << ", " << uint64_str << " "
+ << *rand_offset;
+ }
+ os() << ") {";
+ os() << std::endl;
+
+ if (has_random_) {
+ const Var* idx = new Var("idx", kInt);
+ os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;"
+ << std::endl;
+ const Var* rand_func = printer_->rand_func();
+ os() << "Philox " << *rand_func << "(" << *rand_seed << ", " << *idx << ", "
+ << *rand_offset << ");" << std::endl;
+ os() << std::endl;
+ }
+
+ Stmt* stmt_v = stmt();
+ PrioritizeLoad prioritize_load;
+ stmt_v = prioritize_load.Process(stmt_v);
+ stmt_v->accept(printer_.get());
+ os() << std::endl;
+ os() << "}";
+
+ // Check that all block extents had been set.
+ const std::vector<const Expr*>& gpu_block_extents =
+ printer_->gpu_block_extents();
+ const std::vector<const Expr*>& gpu_thread_extents =
+ printer_->gpu_thread_extents();
+ for (size_t i = 0; i < gpu_block_extents.size(); i++) {
+ if (!gpu_block_extents[i]) {
+ throw std::runtime_error("Missing gpu_block_index: " + std::to_string(i));
+ }
+ }
+
+#if DEBUG_PRINT
+ std::cout << "stmt: " << std::endl;
+ std::cout << oss_.str() << std::endl;
+ std::cout << "block(";
+ for (size_t i = 0; i < gpu_block_extents.size(); i++) {
+ if (i > 0) {
+ std::cout << ", ";
+ }
+ std::cout << *gpu_block_extents[i];
+ }
+ std::cout << "), thread(";
+ for (size_t i = 0; i < gpu_thread_extents.size(); i++) {
+ if (i > 0) {
+ std::cout << ", ";
+ }
+ std::cout << *gpu_thread_extents[i];
+ }
+ std::cout << ")" << std::endl;
+ ;
+#endif
+
+ CompileToNVRTC(oss_.str(), func_name);
+ USE_TRIGGER(cuda_codegen_created);
+}
+
+void CudaCodeGen::call(const std::vector<CallArg>& args) {
+ CHECK_EQ(args.size(), buffer_args().size());
+
+ // TODO: move as much of this into the constructors.
+ const std::vector<const Expr*>& gpu_block_extents =
+ printer_->gpu_block_extents();
+ const std::vector<const Expr*>& gpu_thread_extents =
+ printer_->gpu_thread_extents();
+ CHECK(gpu_block_extents.size() <= 3);
+ CHECK(gpu_thread_extents.size() <= 3);
+ std::vector<int> gpu_block_extents_v(3, 1);
+ std::vector<int> gpu_thread_extents_v(3, 1);
+ // evaluate all the block/thread extents into values
+ // TODO: eventually, codegen these calculations and make them part of the
+ // module.
+ for (size_t i = 0; i < gpu_block_extents.size(); i++) {
+ ExprEval<SimpleIREvaluator> eval(
+ ExprHandle(gpu_block_extents[i]), buffer_args());
+ gpu_block_extents_v[i] = eval.value<int>(args);
+ }
+ for (size_t i = 0; i < gpu_thread_extents.size(); i++) {
+ ExprEval<SimpleIREvaluator> eval(
+ ExprHandle(gpu_thread_extents[i]), buffer_args());
+ gpu_thread_extents_v[i] = eval.value<int>(args);
+ }
+
+ // Skip launching the kernel if there are no elements to process.
+ for (int extent : gpu_block_extents_v) {
+ if (extent == 0) {
+ return;
+ }
+ }
+
+ // Bind the buffer addresses into arguments
+ auto const& buffer_args = this->buffer_args();
+ int ptr_count = buffer_args.size();
+ if (has_random_) {
+ ptr_count += 2;
+ }
+ std::vector<void*> args_data(buffer_args.size());
+ std::vector<void*> ptr_to_args(ptr_count);
+ uint64_t rand_seed = uint64_t(-1);
+ uint64_t rand_offset = uint64_t(-1);
+ for (size_t i = 0; i < buffer_args.size(); i++) {
+ auto const& bufferArg = buffer_args[i];
+ if (bufferArg.isVar()) {
+ auto stype = bufferArg.dtype().scalar_type();
+ switch (stype) {
+#define TYPE_CASE(Type, Name) \
+ case ScalarType::Name: \
+ ptr_to_args[i] = args[i].Name##Ptr(); \
+ break;
+ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
+#undef TYPE_CASE
+ default:
+ LOG(FATAL) << "Unhandled dtype in argument";
+ }
+ } else {
+ args_data[i] = args[i].data();
+ ptr_to_args[i] = &args_data[i];
+ }
+ }
+
+ if (has_random_) {
+ auto gen = at::cuda::detail::getDefaultCUDAGenerator();
+ // TODO: total hack. Switch to numel when it is available.
+ int64_t total_elements_per_thread = (1LL << 28);
+ {
+ std::lock_guard<std::mutex> lock(gen->mutex_);
+ auto philox_engine_inputs =
+ gen->philox_engine_inputs(total_elements_per_thread);
+ rand_seed = philox_engine_inputs.first;
+ rand_offset = philox_engine_inputs.second;
+ }
+ ptr_to_args[buffer_args.size()] = &rand_seed;
+ ptr_to_args[buffer_args.size() + 1] = &rand_offset;
+ }
+
+ // Launch the kernels
+ auto stream = at::cuda::getCurrentCUDAStream();
+ AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
+ function_,
+ gpu_block_extents_v[0],
+ gpu_block_extents_v[1],
+ gpu_block_extents_v[2],
+ gpu_thread_extents_v[0],
+ gpu_thread_extents_v[1],
+ gpu_thread_extents_v[2],
+ 0,
+ stream,
+ ptr_to_args.data(),
+ nullptr));
+ USE_TRIGGER(cuda_codegen_executed);
+}
+
+void CudaCodeGen::CompileToNVRTC(
+ const std::string& code,
+ const std::string& func_name) {
+ // Initializes driver's API context (if necessary)
+ CUdevice device = 0;
+ CUcontext pctx = 0;
+ AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
+ if (!pctx) {
+ std::unique_lock<std::mutex> cudaFreeMutexLock(
+ *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
+ cudaFree(0);
+ }
+
+ // Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
+ // properly in some scenarios
+ const auto prior_device = at::cuda::current_device();
+ at::cuda::set_device(device);
+
+ // Acquires device and NVRTC properties (for compile arch and occupancy
+ // calculations)
+ cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+ int major, minor;
+ getMajorMinor(prop, major, minor);
+
+#if DEBUG_PRINT
+ std::cout << "major: " << major << ", "
+ << "minor: " << minor << std::endl;
+#endif
+
+ // Creates the NVRTC program
+ nvrtcProgram program;
+ AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
+ &program, code.c_str(), nullptr, 0, nullptr, nullptr));
+
+#ifdef __HIP_PLATFORM_HCC__
+ std::vector<const char*> args = {};
+#else
+ const std::string compute = "--gpu-architecture=compute_" +
+ std::to_string(major) + std::to_string(minor);
+ const std::vector<const char*> args = {
+ "--std=c++14", compute.c_str(), "-default-device"};
+#endif
+
+ const auto result =
+ nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
+ if (result != NVRTC_SUCCESS) {
+ size_t logsize;
+ AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize));
+ std::vector<char> log(logsize);
+ AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data()));
+ std::stringstream cu;
+ cu << log.data() << std::endl;
+ cu << "nvrtc compilation failed: " << std::endl;
+ cu << code << std::endl;
+ throw std::runtime_error(cu.str());
+ }
+ ResourceGuard holdProgram(
+ [&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
+ AT_CUDA_NVRTC_CHECK(result);
+ size_t ptx_size;
+ AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTXSize(program, &ptx_size));
+ std::vector<char> ptx;
+ ptx.resize(ptx_size);
+ AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx.data()));
+
+ CUmodule module;
+ AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module, ptx.data()));
+ AT_CUDA_DRIVER_CHECK(
+ nvrtc().cuModuleGetFunction(&function_, module, func_name.c_str()));
+}
+
+RegisterCodeGen<CudaCodeGen> cuda_codegen_reg("cuda_codegen");
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.h b/torch/csrc/jit/tensorexpr/cuda_codegen.h
new file mode 100644
index 0000000..7afa9a0
--- /dev/null
+++ b/torch/csrc/jit/tensorexpr/cuda_codegen.h
@@ -0,0 +1,123 @@
+#pragma once
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "ATen/ATen.h"
+#include "ATen/cuda/CUDAContext.h"
+#include "ATen/cuda/nvrtc_stub/ATenNVRTC.h"
+#include "c10/cuda/CUDACachingAllocator.h"
+#include "c10/cuda/CUDAGuard.h"
+#include "torch/csrc/jit/resource_guard.h"
+#include "torch/csrc/jit/tensorexpr/codegen.h"
+#include "torch/csrc/jit/tensorexpr/ir.h"
+#include "torch/csrc/jit/tensorexpr/ir_printer.h"
+#include "torch/csrc/jit/tensorexpr/ir_visitor.h"
+#include "torch/csrc/jit/tensorexpr/unique_name_manager.h"
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+// A class that overrides the underlying IRPrinter to produce Cuda C.
+class CudaPrinter : public IRPrinter {
+ public:
+ explicit CudaPrinter(std::ostream* os, bool has_random) : IRPrinter(*os) {
+ if (has_random) {
+ rand_func_ = new Var("rand", kHandle);
+ }
+ }
+
+ void visit(const Cast* v) override {
+ auto dtype = v->dtype();
+ if (dtype == kHalf) {
+ os() << "half";
+ } else {
+ os() << dtype;
+ }
+ os() << "(";
+ v->src_value()->accept(this);
+ os() << ")";
+ }
+
+ void visit(const Intrinsics* v) override;
+ void visit(const For* v) override;
+
+ void visit(const Load* v) override;
+ void visit(const Store* v) override;
+ void visit(const Max* v) override;
+ void visit(const Min* v) override;
+ void visit(const LetStmt* v) override;
+ void visit(const IfThenElse* v) override;
+
+ const std::vector<const Expr*>& gpu_block_extents() const {
+ return gpu_block_extents_;
+ }
+
+ const std::vector<const Expr*>& gpu_thread_extents() const {
+ return gpu_thread_extents_;
+ }
+
+ const Var* rand_func() const {
+ return rand_func_;
+ }
+
+ using IRPrinter::name_manager;
+ using IRPrinter::visit;
+
+ private:
+ std::vector<const Expr*> gpu_block_extents_;
+ std::vector<const Expr*> gpu_thread_extents_;
+ const Var* rand_func_;
+};
+
+// Construct Cuda C from the buffer and tensor input, and invoke the kernel
+// when real arguments are provided.
+class TORCH_CUDA_API CudaCodeGen : public CodeGen {
+ public:
+ template <typename... Ts>
+ CudaCodeGen(Stmt* stmt, Ts... ts) : CodeGen(stmt, std::forward<Ts>(ts)...) {
+ Initialize();
+ }
+
+ CudaCodeGen(Stmt* stmt, const std::vector<BufferArg>& buffer_args)
+ : CodeGen(stmt, buffer_args) {
+ Initialize();
+ }
+
+ ~CudaCodeGen() override {}
+
+ void call(const std::vector<CallArg>& args) override;
+
+ template <typename... Ts>
+ void operator()(const Ts&... ts) {
+ call(std::vector<CallArg>({CallArg(ts)...}));
+ }
+
+ private:
+ void Initialize();
+
+ void CompileToNVRTC(const std::string& code, const std::string& func_name);
+
+ UniqueNameManager* name_manager() {
+ if (!printer_) {
+ throw std::runtime_error("Null IRPrinter is not expected");
+ }
+ return printer_->name_manager();
+ }
+
+ std::ostream& os() {
+ return printer_->os();
+ }
+
+ std::ostringstream oss_;
+ std::unique_ptr<CudaPrinter> printer_;
+ CUfunction function_;
+ bool has_random_ = false;
+
+ std::string GetUniqueFuncName(const std::string& func_prefix);
+};
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/cuda_half_support.h b/torch/csrc/jit/tensorexpr/cuda_half_support.h
new file mode 100644
index 0000000..91d4a5f
--- /dev/null
+++ b/torch/csrc/jit/tensorexpr/cuda_half_support.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include "torch/csrc/jit/codegen/fuser/cuda/resource_strings.h"
+#include "torch/csrc/jit/tensorexpr/cuda_codegen.h"
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+// Walk the Statment looking for Half size loads/stores.
+class CudaHalfChecker : public IRVisitor {
+ public:
+ bool hasHalf() {
+ return hasHalf_;
+ }
+
+ void visit(const Load* v) override {
+ hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
+ }
+ void visit(const Store* v) override {
+ hasHalf_ |= v->value()->dtype().scalar_type() == ScalarType::Half;
+ }
+
+ private:
+ bool hasHalf_{false};
+};
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/cuda_random.h b/torch/csrc/jit/tensorexpr/cuda_random.h
new file mode 100644
index 0000000..987ac52
--- /dev/null
+++ b/torch/csrc/jit/tensorexpr/cuda_random.h
@@ -0,0 +1,104 @@
+#pragma once
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+constexpr auto philox_random_string = R"(
+
+class Philox {
+public:
+ __device__ inline Philox(unsigned long long seed,
+ unsigned long long subsequence,
+ unsigned long long offset) {
+ key.x = (unsigned int)seed;
+ key.y = (unsigned int)(seed >> 32);
+ counter = make_uint4(0, 0, 0, 0);
+ counter.z = (unsigned int)(subsequence);
+ counter.w = (unsigned int)(subsequence >> 32);
+ STATE = 0;
+ incr_n(offset / 4);
+ }
+
+ __device__ inline unsigned long operator()() {
+ if(STATE == 0) {
+ uint4 counter_ = counter;
+ uint2 key_ = key;
+ for(int i = 0; i < 9; i++) {
+ counter_ = single_round(counter_, key_);
+ key_.x += (kPhilox10A); key_.y += (kPhilox10B);
+ }
+ output = single_round(counter_, key_);
+ incr();
+ }
+ unsigned long ret;
+ switch(STATE) {
+ case 0: ret = output.x; break;
+ case 1: ret = output.y; break;
+ case 2: ret = output.z; break;
+ case 3: ret = output.w; break;
+ }
+ STATE = (STATE + 1) % 4;
+ return ret;
+ }
+
+private:
+ uint4 counter;
+ uint4 output;
+ uint2 key;
+ unsigned int STATE;
+ __device__ inline void incr_n(unsigned long long n) {
+ unsigned int nlo = (unsigned int)(n);
+ unsigned int nhi = (unsigned int)(n >> 32);
+ counter.x += nlo;
+ if (counter.x < nlo)
+ nhi++;
+ counter.y += nhi;
+ if (nhi <= counter.y)
+ return;
+ if (++counter.z)
+ return;
+ ++counter.w;
+ }
+ __device__ inline void incr() {
+ if (++counter.x)
+ return;
+ if (++counter.y)
+ return;
+ if (++counter.z)
+ return;
+ ++counter.w;
+ }
+ __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
+ unsigned int *result_high) {
+ *result_high = __umulhi(a, b);
+ return a*b;
+ }
+
+ __device__ inline uint4 single_round(uint4 ctr, uint2 key) {
+ unsigned int hi0;
+ unsigned int hi1;
+ unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
+ unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
+
+ uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
+ return ret;
+ }
+
+ static const unsigned long kPhilox10A = 0x9E3779B9;
+ static const unsigned long kPhilox10B = 0xBB67AE85;
+ static const unsigned long kPhiloxSA = 0xD2511F53;
+ static const unsigned long kPhiloxSB = 0xCD9E8D57;
+};
+
+// Inverse of 2^32.
+#define M_RAN_INVM32 2.3283064e-10f
+__device__ __inline__ float Uint32ToFloat(unsigned int x) {
+ return x * M_RAN_INVM32;
+}
+
+)";
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index f4c5adc..f26ba0b 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -5,6 +5,30 @@
using namespace torch::jit;
using namespace torch::jit::tensorexpr;
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+static int te_cuda_pointwise_loop_levels = -1;
+static int te_cuda_pointwise_block_count = -1;
+static int te_cuda_pointwise_block_size = -1;
+
+int& GetTECudaPointwiseLoopLevels() {
+ return te_cuda_pointwise_loop_levels;
+}
+
+int& GetTECudaPointwiseBlockCount() {
+ return te_cuda_pointwise_block_count;
+}
+
+int& GetTECudaPointwiseBlockSize() {
+ return te_cuda_pointwise_block_size;
+}
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
+
static at::ScalarType tensorType(Tensor* t) {
return static_cast<at::ScalarType>(t->body()->dtype().scalar_type());
}
@@ -883,12 +907,96 @@
void TensorExprKernel::LowerToBackend(BackendType backend_type) {
std::vector<Tensor*> tensor_outputs(tensor_outputs_);
+ if (backend_type == BackendType::kCudaCodeGen) {
+ for (size_t tensor_idx = 0; tensor_idx < tensor_outputs_.size();
+ tensor_idx++) {
+ Tensor* tensor = tensor_outputs_[tensor_idx];
+ ExprHandle total_count = ExprHandle(tensor->dim(0));
+ for (int i = 1; i < tensor->ndim(); i++) {
+ const IntImm* total_count_i = total_count.AsNode<IntImm>();
+ const IntImm* tensor_dim_i =
+ dynamic_cast<const IntImm*>(tensor->dim(i));
+ if (total_count_i && tensor_dim_i) {
+ // TODO: switch to real constant folding when it is available.
+ total_count =
+ ExprHandle(total_count_i->value() * tensor_dim_i->value());
+ } else {
+ total_count = total_count * ExprHandle(tensor->dim(i));
+ }
+ }
+ // Flatten the index for GPU kernels.
+ // TODO: move this to fusing axis when it is ready.
+ Tensor* new_out = Compute(
+ tensor->func_var()->name_hint() + "_flat",
+ {total_count},
+ [tensor](const VarHandle& index) -> ExprHandle {
+ std::vector<ExprHandle> dims;
+ ExprHandle value = index;
+ for (int i = tensor->ndim() - 1; i >= 0; i--) {
+ ExprHandle idx = value;
+ if (i > 0) {
+ idx = Mod::make(value, ExprHandle(tensor->dim(i)));
+ }
+ dims.push_back(idx);
+ value = value / ExprHandle(tensor->dim(i));
+ }
+ std::reverse(dims.begin(), dims.end());
+ return tensor->call(dims);
+ });
+ tensor_outputs[tensor_idx] = new_out;
+ }
+ }
+
torch::jit::tensorexpr::schedule::LoopNest l(tensor_outputs);
// Compute non-output tensors_ inline
for (auto& p : tensors_) {
l.ComputeInline(l.getLoopBodyFor(p.second));
}
+ if (backend_type == kCudaCodeGen) {
+ for (size_t i = 0; i < tensor_outputs_.size(); i++) {
+ l.ComputeInline(l.getLoopBodyFor(tensor_outputs_[i]));
+
+ Tensor* tensor = tensor_outputs[i];
+ const Var* index = tensor->arg(0);
+ int loop_levels = GetTECudaPointwiseLoopLevels();
+ const int kDefaultLoopLevels = 2;
+ loop_levels = (loop_levels > 0) ? loop_levels : kDefaultLoopLevels;
+ int block_count = GetTECudaPointwiseBlockCount();
+ int block_size = GetTECudaPointwiseBlockSize();
+
+ if (loop_levels == 2) {
+ Stmt* outer;
+ Stmt* inner;
+ const int kDefaultBlockSize = 512;
+ if (block_size < 0) {
+ block_size = kDefaultBlockSize;
+ }
+ std::vector<Stmt*> loops = l.getLoopStmtsFor(tensor);
+ l.SplitWithMask(loops[0], block_size, &outer, &inner);
+ l.SetGPUBlockIndex(outer, 0);
+ l.SetGPUThreadIndex(inner, 0);
+ } else if (loop_levels == 3) {
+ Stmt* outer;
+ Stmt* inner;
+ Stmt* inner_1;
+ Stmt* inner_2;
+ // TODO: change the number of microprocessors
+ const int kDefaultBlockCount = 1280;
+ const int kDefaultBlockSize = 256;
+ block_count = (block_count > 0) ? block_count : kDefaultBlockCount;
+ block_size = (block_size > 0) ? block_size : kDefaultBlockSize;
+ std::vector<Stmt*> loops = l.getLoopStmtsFor(tensor);
+ l.SplitWithMask(loops[0], block_count * block_size, &outer, &inner);
+ l.SplitWithMask(inner, block_size, &inner_1, &inner_2);
+ l.SetGPUBlockIndex(inner_1, 0);
+ l.SetGPUThreadIndex(inner_2, 0);
+ } else {
+ throw std::runtime_error(
+ "Invalid loop-level: " + std::to_string(loop_levels));
+ }
+ }
+ }
l.ApplyInlines();
Stmt* stmt = l.root_stmt();
@@ -911,6 +1019,9 @@
// Generate code.
std::string codegen_name;
switch (backend_type_) {
+ case kCudaCodeGen:
+ codegen_name = "cuda_codegen";
+ break;
case kSimpleIREval:
codegen_name = "simple_ir_eval";
break;
@@ -933,7 +1044,9 @@
throw std::runtime_error("No tensor inputs");
}();
BackendType backend_type = BackendType::kUninitialized;
- if (device.type() == at::kCPU) {
+ if (device.type() == at::kCUDA) {
+ backend_type = kCudaCodeGen;
+ } else if (device.type() == at::kCPU) {
backend_type = kSimpleIREval;
} else {
throw std::runtime_error("Invalid device type");
@@ -956,6 +1069,7 @@
const std::vector<CodeGen::CallArg>& run_args) {
switch (backend_type_) {
case kSimpleIREval:
+ case kCudaCodeGen:
codegen_->call(run_args);
break;
default:
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index bbaf212..f3dcf65 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -52,6 +52,7 @@
enum BackendType {
kUninitialized,
kSimpleIREval,
+ kCudaCodeGen,
};
ExprHandle constant(const torch::jit::Value* v);
@@ -205,6 +206,10 @@
at::Device device_ = at::kCPU;
};
+TORCH_API int& GetTECudaPointwiseLoopLevels();
+TORCH_API int& GetTECudaPointwiseBlockCount();
+TORCH_API int& GetTECudaPointwiseBlockSize();
+
} // namespace tensorexpr
} // namespace jit
} // namespace torch