[TensorExpr] Add CUDA codegen. (#34227)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34227

This PR adds a CUDA support to tensor expressions.

Differential Revision: D20251836

Test Plan: Imported from OSS

Pulled By: ZolotukhinM

fbshipit-source-id: ab36a55834cceff30c8371fef6cca1054a32f017
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index e7b8e7a..776ea8d 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -557,6 +557,7 @@
       ${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp
       ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp
       ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
+      ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp
     )
     add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS})
     target_link_libraries(caffe2_nvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB})
@@ -574,6 +575,7 @@
       ${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp
       ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp
       ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
+      ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp
     )
     if (USE_NCCL)
       list(APPEND Caffe2_HIP_SRCS
diff --git a/test/cpp/tensorexpr/gtest.cpp b/test/cpp/tensorexpr/gtest.cpp
index 7c660ee..62e2998 100644
--- a/test/cpp/tensorexpr/gtest.cpp
+++ b/test/cpp/tensorexpr/gtest.cpp
@@ -12,5 +12,14 @@
 TH_FORALL_TESTS(TENSOREXPR_GTEST)
 #undef TENSOREXPR_GTEST
 
+#ifdef USE_CUDA
+#define TENSOREXPR_GTEST_CUDA(name)   \
+  TEST(TensorExprTest, name##_CUDA) { \
+    test##name();                     \
+  }
+TH_FORALL_TESTS_CUDA(TENSOREXPR_GTEST_CUDA)
+#undef TENSOREXPR_GTEST_CUDA
+#endif
+
 } // namespace jit
 } // namespace torch
diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp
new file mode 100644
index 0000000..7854d02
--- /dev/null
+++ b/test/cpp/tensorexpr/test_cuda.cpp
@@ -0,0 +1,333 @@
+#ifdef USE_CUDA
+
+#include <sstream>
+#include <stdexcept>
+#include "test/cpp/tensorexpr/test_base.h"
+
+#include <cmath>
+
+#include "test/cpp/tensorexpr/padded_buffer.h"
+#include "torch/csrc/jit/tensorexpr/buffer.h"
+#include "torch/csrc/jit/tensorexpr/cuda_codegen.h"
+#include "torch/csrc/jit/tensorexpr/schedule.h"
+#include "torch/csrc/jit/tensorexpr/tensor.h"
+
+#include <c10/cuda/CUDACachingAllocator.h>
+#include <c10/util/Half.h>
+
+namespace torch {
+namespace jit {
+using namespace torch::jit::tensorexpr;
+using namespace torch::jit::tensorexpr::schedule;
+
+template <typename ctype>
+void testCudaTestVectorAdd01_impl() {
+  KernelScope kernel_scope;
+  const int num_iter = 3;
+  const int block_count = 16;
+  const int block_size = 128;
+  Dtype dtype = ToDtype<ctype>();
+  Buffer a_buf("a", dtype, {num_iter, block_count, block_size});
+  Buffer b_buf("b", dtype, {num_iter, block_count, block_size});
+  Tensor* c = Compute(
+      "c",
+      {
+          {num_iter, "n"},
+          {block_count, "b_id"},
+          {block_size, "t_id"},
+      },
+      [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) {
+        return a_buf(n, b_id, t_id) + b_buf(n, b_id, t_id);
+      });
+  LoopNest l({c});
+  std::vector<Stmt*> loops = l.getLoopStmtsFor(c);
+  l.SetGPUBlockIndex(loops[1], 0);
+  l.SetGPUThreadIndex(loops[2], 0);
+  Stmt* stmt = l.root_stmt();
+  CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
+  const int N = block_count * block_size * num_iter;
+  PaddedBuffer<ctype> a_v(N);
+  PaddedBuffer<ctype> b_v(N);
+  PaddedBuffer<ctype> c_v(N);
+  PaddedBuffer<ctype> c_ref(N);
+
+  for (int i = 0; i < N; i++) {
+    a_v(i) = ctype(i);
+    b_v(i) = ctype(i * 3 + 7);
+    c_ref(i) = a_v(i) + b_v(i);
+  }
+
+  // TODO: move gpu support into PaddedBuffer
+  ctype* a_dev = nullptr;
+  cudaMalloc(&a_dev, N * sizeof(ctype));
+  ctype* b_dev = nullptr;
+  cudaMalloc(&b_dev, N * sizeof(ctype));
+  ctype* c_dev = nullptr;
+  cudaMalloc(&c_dev, N * sizeof(ctype));
+  cudaMemcpy(a_dev, a_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice);
+  cudaMemcpy(b_dev, b_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice);
+  cudaMemcpy(c_dev, c_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice);
+  cudaDeviceSynchronize();
+
+  cuda_cg(c_dev, a_dev, b_dev);
+
+  cudaDeviceSynchronize();
+  cudaMemcpy(c_v.data(), c_dev, N * sizeof(ctype), cudaMemcpyDeviceToHost);
+  cudaDeviceSynchronize();
+
+  ExpectAllNear(c_v, c_ref, 1e-5);
+
+  cudaFree(a_dev);
+  cudaFree(b_dev);
+  cudaFree(c_dev);
+}
+
+void testCudaTestVectorAdd01() {
+  // floating types.
+  testCudaTestVectorAdd01_impl<float>();
+  testCudaTestVectorAdd01_impl<at::Half>();
+  testCudaTestVectorAdd01_impl<double>();
+
+  // integer types.
+  testCudaTestVectorAdd01_impl<int8_t>();
+  testCudaTestVectorAdd01_impl<uint8_t>();
+  testCudaTestVectorAdd01_impl<int16_t>();
+  testCudaTestVectorAdd01_impl<int32_t>();
+  testCudaTestVectorAdd01_impl<int64_t>();
+}
+
+static void testCudaTestVectorAdd02_impl(int N, int block_size) {
+  KernelScope kernel_scope;
+  Buffer a_buf("a", kFloat, {N});
+  Buffer b_buf("b", kFloat, {N});
+  Tensor* c = Compute(
+      "c",
+      {
+          {N, "N"},
+      },
+      [&](const VarHandle& n) { return a_buf(n) + b_buf(n); });
+  LoopNest l({c});
+  Stmt* n_outer;
+  Stmt* n_inner;
+  std::vector<Stmt*> loops = l.getLoopStmtsFor(c);
+  l.SplitWithMask(loops[0], block_size, &n_outer, &n_inner);
+  l.SetGPUBlockIndex(n_outer, 0);
+  l.SetGPUThreadIndex(n_inner, 0);
+  Stmt* stmt = l.root_stmt();
+  CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
+  PaddedBuffer<float> a_v(N);
+  PaddedBuffer<float> b_v(N);
+  PaddedBuffer<float> c_v(N);
+  PaddedBuffer<float> c_ref(N);
+
+  for (int i = 0; i < N; i++) {
+    a_v(i) = i;
+    b_v(i) = i * 3 + 7;
+    c_ref(i) = a_v(i) + b_v(i);
+  }
+
+  // TODO: move gpu support into PaddedBuffer
+  float* a_dev = nullptr;
+  cudaMalloc(&a_dev, N * sizeof(float));
+  float* b_dev = nullptr;
+  cudaMalloc(&b_dev, N * sizeof(float));
+  float* c_dev = nullptr;
+  cudaMalloc(&c_dev, N * sizeof(float));
+  cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice);
+  cudaMemcpy(b_dev, b_v.data(), N * sizeof(float), cudaMemcpyHostToDevice);
+  cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice);
+  cudaDeviceSynchronize();
+
+  cuda_cg(c_dev, a_dev, b_dev);
+
+  cudaDeviceSynchronize();
+  cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost);
+  cudaDeviceSynchronize();
+
+  ExpectAllNear(c_v, c_ref, 1e-5);
+
+  cudaFree(a_dev);
+  cudaFree(b_dev);
+  cudaFree(c_dev);
+}
+
+void testCudaTestVectorAdd02() {
+  testCudaTestVectorAdd02_impl(1024, 128);
+  testCudaTestVectorAdd02_impl(1030, 128);
+}
+
+void testCudaDynamicShape2D() {
+  KernelScope kernel_scope;
+  auto testWithSize = [](int32_t M, int32_t N) {
+    VarHandle m("m", kInt);
+    VarHandle n("n", kInt);
+    Buffer a(VarHandle("a", kHandle), kFloat, {m, n});
+    Buffer b(VarHandle("b", kHandle), kFloat, {m, n});
+    Tensor* c = Compute(
+        "c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) {
+          return a(i, j) + b(i, j);
+        });
+    LoopNest l({c});
+    Stmt* s = l.root_stmt();
+    CudaCodeGen cg(s, {a, b, c, m, n});
+
+    std::vector<float> aData(M * N, 1.0f);
+    std::vector<float> bData(M * N, 2.0f);
+    std::vector<float> cData(M * N, 0.0f);
+    float* aDev = nullptr;
+    float* bDev = nullptr;
+    float* cDev = nullptr;
+    cudaMalloc(&aDev, aData.size() * sizeof(aData[0]));
+    cudaMalloc(&bDev, bData.size() * sizeof(bData[0]));
+    cudaMalloc(&cDev, cData.size() * sizeof(cData[0]));
+    cudaMemcpy(
+        aDev,
+        aData.data(),
+        aData.size() * sizeof(aData[0]),
+        cudaMemcpyHostToDevice);
+    cudaMemcpy(
+        bDev,
+        bData.data(),
+        bData.size() * sizeof(bData[0]),
+        cudaMemcpyHostToDevice);
+    cudaMemcpy(
+        cDev,
+        cData.data(),
+        cData.size() * sizeof(cData[0]),
+        cudaMemcpyHostToDevice);
+    cudaDeviceSynchronize();
+
+    cg.call({aDev, bDev, cDev, M, N});
+    cudaDeviceSynchronize();
+
+    cudaMemcpy(
+        cData.data(),
+        cDev,
+        cData.size() * sizeof(cData[0]),
+        cudaMemcpyDeviceToHost);
+    cudaDeviceSynchronize();
+
+    ExpectAllNear(cData, std::vector<float>(M * N, 3.0f), 1e-7);
+
+    cudaFree(aDev);
+    cudaFree(bDev);
+    cudaFree(cDev);
+  };
+  testWithSize(32, 32);
+  testWithSize(1, 16);
+  testWithSize(27, 13);
+}
+
+void testCudaTestRand01() {
+  KernelScope kernel_scope;
+  const int num_iter = 3;
+  const int block_count = 16;
+  const int block_size = 128;
+  Tensor* c = Compute(
+      "c",
+      {
+          {num_iter, "n"},
+          {block_count, "b_id"},
+          {block_size, "t_id"},
+      },
+      [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) {
+        return Intrinsics::make(IntrinsicsOp::kRand, kFloat);
+      });
+  LoopNest l({c});
+  std::vector<Stmt*> loops = l.getLoopStmtsFor(c);
+  l.SetGPUBlockIndex(loops[1], 0);
+  l.SetGPUThreadIndex(loops[2], 0);
+  Stmt* stmt = l.root_stmt();
+  CudaCodeGen cuda_cg(stmt, c);
+  const int N = block_count * block_size * num_iter;
+  PaddedBuffer<float> c_v(N);
+
+  // TODO: move gpu support into PaddedBuffer
+  float* c_dev = nullptr;
+  cudaMalloc(&c_dev, N * sizeof(float));
+  cudaDeviceSynchronize();
+
+  cuda_cg(c_dev);
+
+  cudaDeviceSynchronize();
+  cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost);
+  cudaDeviceSynchronize();
+
+  float sum1 = 0;
+  float sum2 = 0;
+  float sum3 = 0;
+  for (int i = 0; i < N; i++) {
+    float v = c_v.data()[i];
+    sum1 += v;
+    sum2 += v * v;
+    sum3 += v * v * v;
+    EXPECT_TRUE(v >= 0 && v < 1) << "invalid value: " << i << ", " << v;
+  }
+  sum1 /= N;
+  sum2 /= N;
+  sum3 /= N;
+  float sum1_mean = 1.f / 2;
+  float sum2_mean = 1.f / 3;
+  float sum3_mean = 1.f / 4;
+
+  EXPECT_NEAR(sum1, sum1_mean, 2e-2);
+  EXPECT_NEAR(sum2, sum2_mean, 2e-2);
+  EXPECT_NEAR(sum3, sum3_mean, 2e-2);
+  cudaFree(c_dev);
+}
+
+void testCudaDynamicShapeSplit() {
+  KernelScope ks;
+  constexpr int N = 4096;
+  VarHandle n("n", kInt);
+  Buffer a(VarHandle("a", kHandle), kFloat, {n});
+  Tensor* b =
+      Compute("b", {{n, "n"}}, [&](const VarHandle& i) { return a(i) * 2.0f; });
+  LoopNest l({b});
+  Stmt* outer;
+  Stmt* inner;
+  std::vector<Stmt*> loops = l.getLoopStmtsFor(b);
+  l.SplitWithMask(loops[0], 1024, &outer, &inner);
+  l.SetGPUBlockIndex(outer, 0);
+  l.SetGPUThreadIndex(inner, 0);
+  Stmt* s = l.root_stmt();
+  CudaCodeGen cg(s, {a, b, n});
+
+  std::vector<float> aData(N, 1.0f);
+  std::vector<float> bData(N, 1.0f);
+  float* aDev = nullptr;
+  float* bDev = nullptr;
+  cudaMalloc(&aDev, aData.size() * sizeof(aData[0]));
+  cudaMalloc(&bDev, bData.size() * sizeof(bData[0]));
+  cudaMemcpy(
+      aDev,
+      aData.data(),
+      aData.size() * sizeof(aData[0]),
+      cudaMemcpyHostToDevice);
+  cudaMemcpy(
+      bDev,
+      bData.data(),
+      bData.size() * sizeof(aData[0]),
+      cudaMemcpyHostToDevice);
+  cudaDeviceSynchronize();
+
+  cg.call({aDev, bDev, N});
+  cudaDeviceSynchronize();
+
+  cudaMemcpy(
+      bData.data(),
+      bDev,
+      bData.size() * sizeof(aData[0]),
+      cudaMemcpyDeviceToHost);
+  cudaDeviceSynchronize();
+
+  ExpectAllNear(bData, std::vector<float>(N, 2.0f), 1e-7);
+
+  cudaFree(aDev);
+  cudaFree(bDev);
+}
+
+} // namespace jit
+} // namespace torch
+
+#endif
diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h
index 17aff26..19642d6 100644
--- a/test/cpp/tensorexpr/tests.h
+++ b/test/cpp/tensorexpr/tests.h
@@ -87,6 +87,11 @@
   _(ATenltInt)
 
 #define TH_FORALL_TESTS_CUDA(_) \
+  _(CudaTestVectorAdd01)        \
+  _(CudaTestVectorAdd02)        \
+  _(CudaDynamicShape2D)         \
+  _(CudaTestRand01)             \
+  _(CudaDynamicShapeSplit)
 
 #define DECLARE_TENSOREXPR_TEST(name) void test##name();
 TH_FORALL_TESTS(DECLARE_TENSOREXPR_TEST)
diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py
index 381d68a..361340d 100644
--- a/test/test_tensorexpr.py
+++ b/test/test_tensorexpr.py
@@ -34,6 +34,16 @@
         return value - self.start_value
 
 
+class CudaCodeGenCreated(ExecutionCounter):
+    def __init__(self):
+        super(CudaCodeGenCreated, self).__init__("cuda_codegen_created")
+
+
+class CudaCodeGenExecuted(ExecutionCounter):
+    def __init__(self):
+        super(CudaCodeGenExecuted, self).__init__("cuda_codegen_executed")
+
+
 class SimpleIREvalExecuted(ExecutionCounter):
     def __init__(self):
         super(SimpleIREvalExecuted, self).__init__("simple_ir_eval_executed")
@@ -80,7 +90,7 @@
             c = torch.addcmul(torch.add(x, y), z, w)
             return c
 
-        device_options = ["cpu"]
+        device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu']
         for dev in device_options:
             rand_a = torch.rand(1024, dtype=torch.float, device=dev)
             rand_b = torch.rand(1024, dtype=torch.float, device=dev)
@@ -102,6 +112,79 @@
             np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6)
 
 
+    def test_three_arg_cuda(self):
+        if not torch.cuda.is_available():
+            return
+        cuda_cg_executed = CudaCodeGenExecuted()
+        cuda_cg_created = CudaCodeGenCreated()
+
+        def test(x, y, z):
+            aaa = torch.add(x, y)
+            bbb = torch.add(aaa, z)
+            return bbb
+
+        M = 32
+        N = 32
+        traced = torch.jit.trace(
+            test,
+            (
+                torch.rand(M, N, device="cuda"),
+                torch.rand(M, N, device="cuda"),
+                torch.rand(M, N, device="cuda"),
+            ),
+        )
+
+        a = torch.rand(M, N, device="cuda")
+        b = torch.rand(M, N, device="cuda")
+        c = torch.rand(M, N, device="cuda")
+        x = traced(a, b, c)
+        npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
+        np.testing.assert_allclose(npr, x.cpu().numpy())
+        assert cuda_cg_executed.elapsed_value() >= 1
+        assert cuda_cg_created.elapsed_value() >= 1
+
+
+    def test_broadcast_cuda(self):
+        if not torch.cuda.is_available():
+            return
+
+        def test_body(M, N, L, K):
+            if not torch.cuda.is_available():
+                return
+            cuda_cg_executed = CudaCodeGenExecuted()
+            cuda_cg_created = CudaCodeGenCreated()
+
+            def test(x, y, z):
+                v1 = torch.add(x, y)
+                v2 = torch.add(v1, z)
+                return v2
+
+            a_shape = [M, N]
+            b_shape = [L, M, 1]
+            c_shape = [K, L, 1, 1]
+            traced = torch.jit.trace(
+                test,
+                (
+                    torch.rand(*a_shape, device="cuda"),
+                    torch.rand(*b_shape, device="cuda"),
+                    torch.rand(*c_shape, device="cuda"),
+                ),
+            )
+
+            a = torch.rand(*a_shape, device="cuda")
+            b = torch.rand(*b_shape, device="cuda")
+            c = torch.rand(*c_shape, device="cuda")
+            x = traced(a, b, c)
+            npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
+            np.testing.assert_allclose(npr, x.cpu().numpy())
+            assert cuda_cg_executed.elapsed_value() >= 1
+            assert cuda_cg_created.elapsed_value() >= 1
+
+        test_configs = [[36, 17, 63, 33], [32, 32, 32, 32]]
+        for test_config in test_configs:
+            test_body(*test_config)
+
+
     def test_all_combos(self):
         def easy(x, y, z):
             a = torch.add(x, y)
@@ -426,7 +509,7 @@
             c = torch.lt(x, y)
             return c
 
-        device_options = ["cpu"]
+        device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
         for dev in device_options:
             traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev)))
             a = torch.ones(1024, dtype=torch.int32, device=dev)
@@ -451,7 +534,7 @@
         def test(x):
             return torch.clamp(x + 3.0, 0.0, 6.0)
 
-        device_options = ["cpu"]
+        device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
 
         for dev in device_options:
             traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
@@ -463,7 +546,7 @@
         def test(x):
             return torch.clamp(F.relu(x), 0, 0.5)
 
-        device_options = ["cpu"]
+        device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
         for dev in device_options:
             traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
             a = 20.0 * torch.rand(1024, device=dev) - 10.0
@@ -598,7 +681,7 @@
             # test_tanh_backward,
             test_type_as,
         }
-        device_options = ["cpu"]
+        device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu']
         for torch_fn in fns:
             for dev in device_options:
                 rand_a = torch.rand(1024, device=dev)
@@ -776,7 +859,7 @@
             test_neg,
             test_relu,
         }
-        device_options = ["cpu"]
+        device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu']
 
         for torch_fn in fns:
             for dev in device_options:
@@ -797,6 +880,26 @@
                 np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
 
 
+    def test_rand_like(self):
+        devices = ["cuda"] if torch.cuda.is_available() else []
+        N = 1 << 16
+
+        def run_rand_like(x, y):
+            return torch.rand_like(torch.add(x, y))
+
+        for device in devices:
+            x = torch.rand(N, device=device)
+            traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False)
+            x_v = traced(x, x)
+            x_np = x.cpu().numpy()
+            x1_mean = np.mean(x_np)
+            x2_mean = np.mean(x_np ** 2)
+            x3_mean = np.mean(x_np ** 3)
+            np.testing.assert_allclose(x1_mean, 1. / 2, rtol=2e-2)
+            np.testing.assert_allclose(x2_mean, 1. / 3, rtol=2e-2)
+            np.testing.assert_allclose(x3_mean, 1. / 4, rtol=2e-2)
+
+
     def test_nans(self):
         def test_max(x, y):
             return torch.max(2 * x, 2 * y)
@@ -898,6 +1001,10 @@
     def test_cat_cpu(self):
         self._test_cat('cpu')
 
+    @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
+    def test_cat_cuda(self):
+        self._test_cat('cuda')
+
     def test_scalar(self):
         @torch.jit.script
         def test_float(x, y, z, a, b):
@@ -1001,8 +1108,66 @@
         assert interp.elapsed_value() == 1
 
 
+    @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
+    @unittest.skip("dynamic shapes are not quite there yet")
+    def test_dynamic_shape(self):
+        with num_profiled_runs(2):
+            @torch.jit.script
+            def test(x, y, z):
+                return x * y * z
+            cuda = CudaCodeGenCreated()
+            x, y, z = [torch.rand(4, 8).cuda() for _ in range(3)]
+            ref = test(x, y, z)
+            _ = test(*[torch.rand(6, 8).cuda() for _ in range(3)])
+            res = test(x, y, z)
+            np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy())
+            assert cuda.elapsed_value() == 1
+
+            # A wild broadcast appears.
+            x = torch.rand(4, 8).cuda()
+            y = torch.rand(1, 8).cuda()
+            z = torch.rand(4, 1).cuda()
+            res = test(x, y, z)
+            xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)]
+            np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
+            assert cuda.elapsed_value() == 1
+
+            # Mismatched shapes shouldn't reach codegen.
+            x = torch.rand(4, 8).cuda()
+            y = torch.rand(4, 8).cuda()
+            z = torch.rand(5, 8).cuda()
+            try:
+                res = test(x, y, z)
+            except RuntimeError as e:
+                assert "The size of tensor a (4) must match" in e.args[0]
+            assert cuda.elapsed_value() == 1
+
+            # Changing a static dimension fails guards.
+            # x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)]
+            # xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)]
+            # res = test(x, y, z)
+            # print(test.graph_for(x, y, z))
+            # np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
+            # assert cuda.elapsed_value() == 1
+
+    @unittest.skip("guarding on static shapes is not working")
+    def test_guard_fails(self):
+        @torch.jit.script
+        def test(x, y, z):
+            return x * y * z
+        cuda = CudaCodeGenExecuted()
+        r1 = test(*[torch.rand(4).cuda() for _ in range(3)])
+        assert cuda.elapsed_value() == 0
+        r2 = test(*[torch.rand(4).cuda() for _ in range(3)])
+        assert cuda.elapsed_value() == 1
+        r3 = test(*[torch.rand(4).cuda() for _ in range(3)])
+        assert cuda.elapsed_value() == 2
+        r4 = test(*[torch.rand(7).cuda() for _ in range(3)])
+        print(test.graph_for(*[torch.rand(7).cuda() for _ in range(3)]))
+        assert cuda.elapsed_value() == 2
+
     def test_bitwise_ops(self):
-        devices = ["cpu"]
+        devices = ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]
 
         def run_and(x, y):
             return x & (x & y)
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index 11ecc78..d548a35 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -219,6 +219,7 @@
     "torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp",
     "torch/csrc/autograd/profiler_cuda.cpp",
     "torch/csrc/autograd/functions/comm.cpp",
+    "torch/csrc/jit/tensorexpr/cuda_codegen.cpp",
 ]
 
 torch_cpp_srcs = [
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 02a652d..722c2c7 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -60,6 +60,7 @@
 #include <torch/csrc/jit/python/python_tree_views.h>
 #include <torch/csrc/jit/frontend/tracer.h>
 #include <torch/csrc/jit/tensorexpr/execution_counter.h>
+#include <torch/csrc/jit/tensorexpr/kernel.h>
 
 #include <c10/macros/Export.h>
 #include <caffe2/serialize/inline_container.h>
@@ -414,6 +415,42 @@
                 ExecutionTriggerList::GetInstance().FindByName(trigger_name);
             return trigger->value();
           })
+      .def(
+          "_jit_get_te_cuda_pointwise_loop_levels",
+          []() -> int {
+            using namespace torch::jit::tensorexpr;
+            return GetTECudaPointwiseLoopLevels();
+          })
+      .def(
+          "_jit_set_te_cuda_pointwise_loop_levels",
+          [](int level) {
+            using namespace torch::jit::tensorexpr;
+            return GetTECudaPointwiseLoopLevels() = level;
+          })
+      .def(
+          "_jit_get_te_cuda_pointwise_block_count",
+          []() -> int {
+            using namespace torch::jit::tensorexpr;
+            return GetTECudaPointwiseBlockCount();
+          })
+      .def(
+          "_jit_set_te_cuda_pointwise_block_count",
+          [](int block_count) {
+            using namespace torch::jit::tensorexpr;
+            return GetTECudaPointwiseBlockCount() = block_count;
+          })
+      .def(
+          "_jit_get_te_cuda_pointwise_block_size",
+          []() -> int {
+            using namespace torch::jit::tensorexpr;
+            return GetTECudaPointwiseBlockSize();
+          })
+      .def(
+          "_jit_set_te_cuda_pointwise_block_size",
+          [](int block_size) {
+            using namespace torch::jit::tensorexpr;
+            return GetTECudaPointwiseBlockSize() = block_size;
+          })
       .def("_jit_set_texpr_fuser_enabled", &setTensorExprFuserEnabled)
       .def(
           "_jit_fuser_get_fused_kernel_code",
diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
new file mode 100644
index 0000000..c3f1332
--- /dev/null
+++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
@@ -0,0 +1,695 @@
+#include "torch/csrc/jit/tensorexpr/cuda_codegen.h"
+#include "torch/csrc/jit/tensorexpr/cuda_half_support.h"
+
+#include "ATen/CUDAGenerator.h"
+#include "c10/cuda/CUDAFunctions.h"
+#include "torch/csrc/jit/tensorexpr/cuda_random.h"
+#include "torch/csrc/jit/tensorexpr/eval.h"
+#include "torch/csrc/jit/tensorexpr/execution_counter.h"
+
+#define DEBUG_PRINT 0
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+DEFINE_TRIGGER(cuda_codegen_created);
+DEFINE_TRIGGER(cuda_codegen_executed);
+
+// A RAII wrapper to manage a variable and name pair in the look-up table.
+// TODO: move this to a more shared place.
+class ScopedVarName {
+ public:
+  ScopedVarName(VarNameMap* mapping, const Var* var, const std::string& name)
+      : mapping_(mapping), var_(var) {
+    auto iter = mapping->find(var);
+    if (iter != mapping->end()) {
+      throw std::runtime_error("Duplicate var entry: " + var->name_hint());
+    }
+    mapping->insert(std::make_pair(var, name));
+  }
+
+  ScopedVarName(
+      UniqueNameManager* manager,
+      const Var* var,
+      const std::string& name)
+      : ScopedVarName(&manager->unique_name_mapping_, var, name) {}
+
+  ScopedVarName(const ScopedVarName&) = delete;
+  ScopedVarName& operator=(const ScopedVarName&) = delete;
+
+  ~ScopedVarName() noexcept(false) {
+    mapping_->erase(var_);
+  }
+
+ private:
+  VarNameMap* mapping_ = nullptr;
+  const Var* var_ = nullptr;
+};
+
+static int as_int(const Expr* expr) {
+  auto v = dynamic_cast<const IntImm*>(expr);
+  TORCH_CHECK(v, "Expression is not an integer constant");
+  return v->value();
+}
+
+static bool is_zero(const Expr* expr) {
+  return as_int(expr) == 0;
+}
+
+static const at::cuda::NVRTC& nvrtc() {
+  return at::globalContext().getNVRTC();
+}
+
+static void getMajorMinor(
+    const cudaDeviceProp* const prop,
+    int& major,
+    int& minor) {
+  using CudaVersion = std::pair<int, int>;
+  CudaVersion nvrtc_version;
+  AT_CUDA_NVRTC_CHECK(
+      nvrtc().nvrtcVersion(&nvrtc_version.first, &nvrtc_version.second));
+
+  AT_ASSERT(nvrtc_version.first >= 6);
+
+  CudaVersion dev_version = CudaVersion(prop->major, prop->minor);
+  CudaVersion max_dev_version(dev_version);
+  if (nvrtc_version.first <= 7) { // 7 supports 2-5.x
+    max_dev_version = CudaVersion(5, 0);
+  } else if (nvrtc_version.first <= 8) { // 8 supports 2-6.x
+    max_dev_version = CudaVersion(6, 0);
+  } else if (nvrtc_version.first <= 9) { // 9 supports 3-7.2
+    max_dev_version = CudaVersion(7, 2);
+  } else if (nvrtc_version.first <= 10) { // 10 supports 3-7.5
+    max_dev_version = CudaVersion(7, 5);
+  }
+  if (dev_version > max_dev_version) {
+    dev_version = max_dev_version;
+  }
+  major = dev_version.first;
+  minor = dev_version.second;
+}
+
+void CudaPrinter::visit(const For* v) {
+  const LoopOptions& loop_options = v->loop_options();
+  if (loop_options.is_gpu_block_index()) {
+    ScopedVarName var_name(
+        name_manager(), v->var(), loop_options.gpu_block_index_str());
+    v->body()->accept(this);
+    int gpu_block_index = loop_options.gpu_block_index();
+    if (gpu_block_extents_.size() <= gpu_block_index) {
+      gpu_block_extents_.resize(gpu_block_index + 1);
+    }
+    if (!is_zero(v->start())) {
+      throw std::runtime_error(
+          "start must be zero for gpu_block_index: " +
+          std::to_string(ExprHandle(v->start())));
+    }
+    gpu_block_extents_[gpu_block_index] = v->stop();
+  } else if (loop_options.is_gpu_thread_index()) {
+    ScopedVarName var_name(
+        name_manager(), v->var(), loop_options.gpu_thread_index_str());
+    v->body()->accept(this);
+    int gpu_thread_index = loop_options.gpu_thread_index();
+    if (gpu_thread_extents_.size() <= gpu_thread_index) {
+      gpu_thread_extents_.resize(gpu_thread_index + 1);
+    }
+    if (!is_zero(v->start())) {
+      throw std::runtime_error(
+          "start must be zero for gpu_block_index: " +
+          std::to_string(ExprHandle(v->start())));
+    }
+    gpu_thread_extents_[gpu_thread_index] = v->stop();
+  } else {
+    IRPrinter::visit(v);
+  }
+}
+
+void CudaPrinter::visit(const Intrinsics* v) {
+  if (v->op_type() == IntrinsicsOp::kRand) {
+    os() << "Uint32ToFloat(" << *rand_func_ << "())";
+    return;
+  }
+
+  std::string func_name = v->func_name();
+
+  // get type of resulting expression.
+  ScalarType returnType = v->param(0)->dtype().scalar_type();
+  for (int i = 1; i < v->nparams(); ++i) {
+    returnType = promoteTypes(returnType, v->param(i)->dtype().scalar_type());
+  }
+
+  if (returnType == ScalarType::Half || returnType == ScalarType::Float) {
+    func_name = func_name + "f";
+  }
+
+  os() << func_name << "(";
+  for (int i = 0; i < v->nparams(); i++) {
+    if (i > 0) {
+      os() << ", ";
+    }
+    os() << *v->param(i);
+  }
+  os() << ")";
+}
+
+void CudaPrinter::visit(const Load* v) {
+  // TODO: find a better metric in using ldg or not. Support different dtypes.
+  if (v->dtype().scalar_type() == ScalarType::Half) {
+    os() << "__half2float(" << *v->base_handle() << "[" << *v->index() << "])";
+  } else {
+    os() << "__ldg(" << *v->base_handle() << " + " << *v->index() << ")";
+  }
+}
+
+void CudaPrinter::visit(const Store* v) {
+  os() << *v->base_handle() << "[" << *v->index() << "] = ";
+  if (v->value()->dtype().scalar_type() == ScalarType::Half) {
+    os() << "__float2half(" << *v->value() << ");";
+  } else {
+    os() << *v->value() << ";";
+  }
+}
+
+void CudaPrinter::visit(const Max* v) {
+  auto dtype = v->dtype().scalar_type();
+  switch (dtype) {
+    case ScalarType::Half:
+      // doing Half math in float.
+    case ScalarType::Float:
+      os() << "fmaxf";
+      break;
+    case ScalarType::Double:
+      os() << "fmax";
+      break;
+    default:
+      os() << "max";
+      break;
+  }
+  os() << "(";
+  v->lhs()->accept(this);
+  os() << ",";
+  v->rhs()->accept(this);
+  os() << ")";
+}
+
+void CudaPrinter::visit(const Min* v) {
+  auto dtype = v->dtype().scalar_type();
+  switch (dtype) {
+    case ScalarType::Half:
+      // doing Half math in float.
+    case ScalarType::Float:
+      os() << "fminf";
+      break;
+    case ScalarType::Double:
+      os() << "fmin";
+      break;
+    default:
+      os() << "min";
+      break;
+  }
+  os() << "(";
+  v->lhs()->accept(this);
+  os() << ",";
+  v->rhs()->accept(this);
+  os() << ")";
+}
+
+std::string cudaDtypeCppString(const Dtype& dtype) {
+  switch (dtype.scalar_type()) {
+    case ScalarType::Half:
+      return "half";
+    case ScalarType::Char:
+      return "char";
+    case ScalarType::Byte:
+      return "unsigned char";
+    case ScalarType::Short:
+      return "short";
+    case ScalarType::Long:
+      return "long";
+    default:; /* nothing */
+  }
+  return dtype.ToCppString();
+}
+
+void CudaPrinter::visit(const LetStmt* v) {
+  const Var* var = v->var();
+  if (var->dtype().scalar_type() == ScalarType::Half) {
+    // we do math in floats so use that.
+    os() << "float";
+  } else {
+    os() << cudaDtypeCppString(var->dtype());
+  }
+  os() << " " << *var << " = " << *v->value() << "; " << std::endl;
+  v->body()->accept(this);
+}
+
+void CudaPrinter::visit(const IfThenElse* v) {
+  os() << "((";
+  v->condition()->accept(this);
+  os() << ") ? ";
+  v->true_value()->accept(this);
+  os() << " : ";
+  v->false_value()->accept(this);
+  os() << ")";
+}
+
+class PrioritizeLoad : public IRMutator {
+ public:
+  const Expr* mutate(const Load* v) override {
+    // Look at the declaration of this variable for more details.
+    if (nested_if_then_else_ > 0) {
+      return IRMutator::mutate(v);
+    }
+    MemLoadList& load_list = load_stack_.back();
+    const Var* load_new_var = new Var("v", v->dtype());
+    const Expr* new_value = IRMutator::mutate(v);
+    load_list.push_back(std::make_pair(load_new_var, new_value));
+    return load_new_var;
+  }
+
+  // TODO: merge this with the IRMutator::mutate version.
+  Stmt* mutate(const For* v) override {
+    const Var* var = v->var();
+    const Expr* start = v->start();
+    const Expr* stop = v->stop();
+    Stmt* body = v->body();
+    LoopOptions loop_options = v->loop_options();
+    const Var* var_new = dynamic_cast<const Var*>(var->accept_mutator(this));
+    const Expr* start_new = start->accept_mutator(this);
+    const Expr* stop_new = stop->accept_mutator(this);
+    PushList();
+    Stmt* body_new = body->accept_mutator(this);
+    if (!body_new) {
+      return nullptr;
+    }
+    Stmt* body_with_loads = AddMemLoadsFromList(body_new);
+    PopList();
+    if (var == var_new && start == start_new && stop == stop_new &&
+        body == body_with_loads) {
+      return (Stmt*)v;
+    }
+    return new For(var_new, start_new, stop_new, body_with_loads, loop_options);
+  }
+
+  Stmt* mutate(const LetStmt* v) override {
+    const Var* var = v->var();
+    const Expr* value = v->value();
+    Stmt* body = v->body();
+    const Var* var_new = dynamic_cast<const Var*>(var->accept_mutator(this));
+    if (var_new == nullptr) {
+      throw std::runtime_error("LetStmt var must be variable");
+    }
+    const Expr* value_new = value->accept_mutator(this);
+    PushList();
+    Stmt* body_new = body->accept_mutator(this);
+    Stmt* body_with_loads = AddMemLoadsFromList(body_new);
+    PopList();
+    if (var == var_new && value == value_new && body == body_with_loads) {
+      return (Stmt*)v;
+    }
+    return new LetStmt(var_new, value_new, body_with_loads);
+  }
+
+  Stmt* mutate(const Cond* v) override {
+    const Expr* cond_old = v->condition();
+    Stmt* true_old = v->true_stmt();
+    Stmt* false_old = v->false_stmt();
+
+    const Expr* cond_new = cond_old->accept_mutator(this);
+    PushList();
+    Stmt* true_new = true_old ? true_old->accept_mutator(this) : true_old;
+    Stmt* true_with_loads = AddMemLoadsFromList(true_new);
+    PopList();
+    PushList();
+    Stmt* false_new = false_old ? false_old->accept_mutator(this) : false_old;
+    Stmt* false_with_loads = AddMemLoadsFromList(false_new);
+    PopList();
+
+    if (cond_old == cond_new && true_old == true_with_loads &&
+        false_old == false_with_loads) {
+      return (Stmt*)v;
+    }
+    return new Cond(cond_new, true_with_loads, false_with_loads);
+  }
+
+  const Expr* mutate(const IfThenElse* v) override {
+    nested_if_then_else_++;
+    const Expr* new_v = IRMutator::mutate(v);
+    nested_if_then_else_--;
+    return new_v;
+  }
+
+  Stmt* Process(Stmt* stmt) {
+    this->PushList();
+    Stmt* stmt_v = stmt;
+    Stmt* stmt_new = stmt_v->accept_mutator(this);
+    Stmt* stmt_with_loads = AddMemLoadsFromList(stmt_new);
+    this->PopList();
+    return stmt_with_loads;
+  }
+
+ private:
+  using MemLoadEntry = std::pair<const Var*, const Expr*>;
+  using MemLoadList = std::vector<MemLoadEntry>;
+  using MemoryLoadStack = std::vector<MemLoadList>;
+
+  void PushList() {
+    load_stack_.push_back(MemLoadList());
+  }
+
+  void PopList() {
+    load_stack_.pop_back();
+  }
+
+  Stmt* AddMemLoadsFromList(Stmt* stmt) {
+    MemLoadList& load_list = load_stack_.back();
+    Stmt* stmt_v = stmt;
+    for (auto iter = load_list.rbegin(); iter != load_list.rend(); iter++) {
+      const MemLoadEntry& entry = *iter;
+      const Var* var_ptr = entry.first;
+      stmt_v = new LetStmt(var_ptr, entry.second, stmt_v);
+    }
+    return stmt_v;
+  }
+
+  MemoryLoadStack load_stack_;
+  // TODO: For now, we are not moving the loads with the IfThenElse.
+  // Eventually, we should switch to a more generic structure like:
+  // int v2 = IfThenElse(cond, true_v, false_v) + 2 ->
+  //
+  // int v;
+  // if (cond) {
+  //   v = true_v;
+  // } else {
+  //   v = false_v;
+  // }
+  // int v2 = v + 2;
+  int nested_if_then_else_ = 0;
+};
+
+class HasRand : public IRVisitor {
+ public:
+  HasRand(Stmt* stmt) : stmt_(stmt) {
+    stmt_->accept(this);
+  }
+
+  bool has_rand() const {
+    return has_rand_;
+  }
+
+ private:
+  void visit(const Intrinsics* v) override {
+    if (v->op_type() == IntrinsicsOp::kRand) {
+      has_rand_ = true;
+    } else {
+      IRVisitor::visit(v);
+    }
+  }
+  Stmt* stmt_;
+  bool has_rand_ = false;
+};
+
+std::string CudaCodeGen::GetUniqueFuncName(const std::string& func_prefix) {
+  // We are using a global counter here to make sure difference instances within
+  // CudaCodeGen have different names.
+  static int64_t counter = 0;
+  ++counter;
+  int64_t value = counter;
+  return func_prefix + "_" + std::to_string(value);
+}
+
+void CudaCodeGen::Initialize() {
+  // TODO: handle multiple kernels.
+  // TODO: handle dynamic dimension.
+  // TODO: call nvrtc.
+  HasRand has_rand_func(stmt());
+  has_random_ = has_rand_func.has_rand();
+  printer_ = std::make_unique<CudaPrinter>(&oss_, has_random_);
+
+  os() << "#define NAN __int_as_float(0x7fffffff)\n"
+          "#define POS_INFINITY __int_as_float(0x7f800000)\n"
+          "#define NEG_INFINITY __int_as_float(0xff800000)\n";
+  if (has_random_) {
+    os() << philox_random_string << std::endl;
+  }
+
+  // Check whether the statement uses the Half type, if so add the
+  // half_support_literal.
+  CudaHalfChecker halfChecker;
+  stmt()->accept(&halfChecker);
+  if (halfChecker.hasHalf()) {
+    os() << fuser::cuda::half_support_literal << std::endl;
+  }
+
+  std::string func_name = GetUniqueFuncName("func");
+  os() << "extern \"C\" __global__" << std::endl << "void " << func_name << "(";
+  const std::vector<BufferArg> buffer_args = this->buffer_args();
+  for (size_t i = 0; i < buffer_args.size(); i++) {
+    if (i > 0) {
+      os() << ", ";
+    }
+    const BufferArg& buffer_arg = buffer_args[i];
+    const Var* var = buffer_arg.var();
+    Dtype dtype = buffer_arg.dtype();
+
+    os() << cudaDtypeCppString(dtype) << (buffer_arg.isVar() ? " " : "* ")
+         << name_manager()->get_unique_name(var);
+  }
+  const Var* rand_seed;
+  const Var* rand_offset;
+  if (has_random_) {
+    // TODO: switch to kUint64 when it is available.
+    rand_seed = new Var("rand_seed", kInt);
+    rand_offset = new Var("rand_offset", kInt);
+    std::string uint64_str = "unsigned long long";
+    os() << ", " << uint64_str << " " << *rand_seed << ", " << uint64_str << " "
+         << *rand_offset;
+  }
+  os() << ") {";
+  os() << std::endl;
+
+  if (has_random_) {
+    const Var* idx = new Var("idx", kInt);
+    os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;"
+         << std::endl;
+    const Var* rand_func = printer_->rand_func();
+    os() << "Philox " << *rand_func << "(" << *rand_seed << ", " << *idx << ", "
+         << *rand_offset << ");" << std::endl;
+    os() << std::endl;
+  }
+
+  Stmt* stmt_v = stmt();
+  PrioritizeLoad prioritize_load;
+  stmt_v = prioritize_load.Process(stmt_v);
+  stmt_v->accept(printer_.get());
+  os() << std::endl;
+  os() << "}";
+
+  // Check that all block extents had been set.
+  const std::vector<const Expr*>& gpu_block_extents =
+      printer_->gpu_block_extents();
+  const std::vector<const Expr*>& gpu_thread_extents =
+      printer_->gpu_thread_extents();
+  for (size_t i = 0; i < gpu_block_extents.size(); i++) {
+    if (!gpu_block_extents[i]) {
+      throw std::runtime_error("Missing gpu_block_index: " + std::to_string(i));
+    }
+  }
+
+#if DEBUG_PRINT
+  std::cout << "stmt: " << std::endl;
+  std::cout << oss_.str() << std::endl;
+  std::cout << "block(";
+  for (size_t i = 0; i < gpu_block_extents.size(); i++) {
+    if (i > 0) {
+      std::cout << ", ";
+    }
+    std::cout << *gpu_block_extents[i];
+  }
+  std::cout << "), thread(";
+  for (size_t i = 0; i < gpu_thread_extents.size(); i++) {
+    if (i > 0) {
+      std::cout << ", ";
+    }
+    std::cout << *gpu_thread_extents[i];
+  }
+  std::cout << ")" << std::endl;
+  ;
+#endif
+
+  CompileToNVRTC(oss_.str(), func_name);
+  USE_TRIGGER(cuda_codegen_created);
+}
+
+void CudaCodeGen::call(const std::vector<CallArg>& args) {
+  CHECK_EQ(args.size(), buffer_args().size());
+
+  // TODO: move as much of this into the constructors.
+  const std::vector<const Expr*>& gpu_block_extents =
+      printer_->gpu_block_extents();
+  const std::vector<const Expr*>& gpu_thread_extents =
+      printer_->gpu_thread_extents();
+  CHECK(gpu_block_extents.size() <= 3);
+  CHECK(gpu_thread_extents.size() <= 3);
+  std::vector<int> gpu_block_extents_v(3, 1);
+  std::vector<int> gpu_thread_extents_v(3, 1);
+  // evaluate all the block/thread extents into values
+  // TODO: eventually, codegen these calculations and make them part of the
+  // module.
+  for (size_t i = 0; i < gpu_block_extents.size(); i++) {
+    ExprEval<SimpleIREvaluator> eval(
+        ExprHandle(gpu_block_extents[i]), buffer_args());
+    gpu_block_extents_v[i] = eval.value<int>(args);
+  }
+  for (size_t i = 0; i < gpu_thread_extents.size(); i++) {
+    ExprEval<SimpleIREvaluator> eval(
+        ExprHandle(gpu_thread_extents[i]), buffer_args());
+    gpu_thread_extents_v[i] = eval.value<int>(args);
+  }
+
+  // Skip launching the kernel if there are no elements to process.
+  for (int extent : gpu_block_extents_v) {
+    if (extent == 0) {
+      return;
+    }
+  }
+
+  // Bind the buffer addresses into arguments
+  auto const& buffer_args = this->buffer_args();
+  int ptr_count = buffer_args.size();
+  if (has_random_) {
+    ptr_count += 2;
+  }
+  std::vector<void*> args_data(buffer_args.size());
+  std::vector<void*> ptr_to_args(ptr_count);
+  uint64_t rand_seed = uint64_t(-1);
+  uint64_t rand_offset = uint64_t(-1);
+  for (size_t i = 0; i < buffer_args.size(); i++) {
+    auto const& bufferArg = buffer_args[i];
+    if (bufferArg.isVar()) {
+      auto stype = bufferArg.dtype().scalar_type();
+      switch (stype) {
+#define TYPE_CASE(Type, Name)             \
+  case ScalarType::Name:                  \
+    ptr_to_args[i] = args[i].Name##Ptr(); \
+    break;
+        AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
+#undef TYPE_CASE
+        default:
+          LOG(FATAL) << "Unhandled dtype in argument";
+      }
+    } else {
+      args_data[i] = args[i].data();
+      ptr_to_args[i] = &args_data[i];
+    }
+  }
+
+  if (has_random_) {
+    auto gen = at::cuda::detail::getDefaultCUDAGenerator();
+    // TODO: total hack. Switch to numel when it is available.
+    int64_t total_elements_per_thread = (1LL << 28);
+    {
+      std::lock_guard<std::mutex> lock(gen->mutex_);
+      auto philox_engine_inputs =
+          gen->philox_engine_inputs(total_elements_per_thread);
+      rand_seed = philox_engine_inputs.first;
+      rand_offset = philox_engine_inputs.second;
+    }
+    ptr_to_args[buffer_args.size()] = &rand_seed;
+    ptr_to_args[buffer_args.size() + 1] = &rand_offset;
+  }
+
+  // Launch the kernels
+  auto stream = at::cuda::getCurrentCUDAStream();
+  AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
+      function_,
+      gpu_block_extents_v[0],
+      gpu_block_extents_v[1],
+      gpu_block_extents_v[2],
+      gpu_thread_extents_v[0],
+      gpu_thread_extents_v[1],
+      gpu_thread_extents_v[2],
+      0,
+      stream,
+      ptr_to_args.data(),
+      nullptr));
+  USE_TRIGGER(cuda_codegen_executed);
+}
+
+void CudaCodeGen::CompileToNVRTC(
+    const std::string& code,
+    const std::string& func_name) {
+  // Initializes driver's API context (if necessary)
+  CUdevice device = 0;
+  CUcontext pctx = 0;
+  AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
+  if (!pctx) {
+    std::unique_lock<std::mutex> cudaFreeMutexLock(
+        *(c10::cuda::CUDACachingAllocator::getFreeMutex()));
+    cudaFree(0);
+  }
+
+  // Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
+  // properly in some scenarios
+  const auto prior_device = at::cuda::current_device();
+  at::cuda::set_device(device);
+
+  // Acquires device and NVRTC properties (for compile arch and occupancy
+  // calculations)
+  cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+  int major, minor;
+  getMajorMinor(prop, major, minor);
+
+#if DEBUG_PRINT
+  std::cout << "major: " << major << ", "
+            << "minor: " << minor << std::endl;
+#endif
+
+  // Creates the NVRTC program
+  nvrtcProgram program;
+  AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
+      &program, code.c_str(), nullptr, 0, nullptr, nullptr));
+
+#ifdef __HIP_PLATFORM_HCC__
+  std::vector<const char*> args = {};
+#else
+  const std::string compute = "--gpu-architecture=compute_" +
+      std::to_string(major) + std::to_string(minor);
+  const std::vector<const char*> args = {
+      "--std=c++14", compute.c_str(), "-default-device"};
+#endif
+
+  const auto result =
+      nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
+  if (result != NVRTC_SUCCESS) {
+    size_t logsize;
+    AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize));
+    std::vector<char> log(logsize);
+    AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data()));
+    std::stringstream cu;
+    cu << log.data() << std::endl;
+    cu << "nvrtc compilation failed: " << std::endl;
+    cu << code << std::endl;
+    throw std::runtime_error(cu.str());
+  }
+  ResourceGuard holdProgram(
+      [&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
+  AT_CUDA_NVRTC_CHECK(result);
+  size_t ptx_size;
+  AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTXSize(program, &ptx_size));
+  std::vector<char> ptx;
+  ptx.resize(ptx_size);
+  AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx.data()));
+
+  CUmodule module;
+  AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module, ptx.data()));
+  AT_CUDA_DRIVER_CHECK(
+      nvrtc().cuModuleGetFunction(&function_, module, func_name.c_str()));
+}
+
+RegisterCodeGen<CudaCodeGen> cuda_codegen_reg("cuda_codegen");
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.h b/torch/csrc/jit/tensorexpr/cuda_codegen.h
new file mode 100644
index 0000000..7afa9a0
--- /dev/null
+++ b/torch/csrc/jit/tensorexpr/cuda_codegen.h
@@ -0,0 +1,123 @@
+#pragma once
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "ATen/ATen.h"
+#include "ATen/cuda/CUDAContext.h"
+#include "ATen/cuda/nvrtc_stub/ATenNVRTC.h"
+#include "c10/cuda/CUDACachingAllocator.h"
+#include "c10/cuda/CUDAGuard.h"
+#include "torch/csrc/jit/resource_guard.h"
+#include "torch/csrc/jit/tensorexpr/codegen.h"
+#include "torch/csrc/jit/tensorexpr/ir.h"
+#include "torch/csrc/jit/tensorexpr/ir_printer.h"
+#include "torch/csrc/jit/tensorexpr/ir_visitor.h"
+#include "torch/csrc/jit/tensorexpr/unique_name_manager.h"
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+// A class that overrides the underlying IRPrinter to produce Cuda C.
+class CudaPrinter : public IRPrinter {
+ public:
+  explicit CudaPrinter(std::ostream* os, bool has_random) : IRPrinter(*os) {
+    if (has_random) {
+      rand_func_ = new Var("rand", kHandle);
+    }
+  }
+
+  void visit(const Cast* v) override {
+    auto dtype = v->dtype();
+    if (dtype == kHalf) {
+      os() << "half";
+    } else {
+      os() << dtype;
+    }
+    os() << "(";
+    v->src_value()->accept(this);
+    os() << ")";
+  }
+
+  void visit(const Intrinsics* v) override;
+  void visit(const For* v) override;
+
+  void visit(const Load* v) override;
+  void visit(const Store* v) override;
+  void visit(const Max* v) override;
+  void visit(const Min* v) override;
+  void visit(const LetStmt* v) override;
+  void visit(const IfThenElse* v) override;
+
+  const std::vector<const Expr*>& gpu_block_extents() const {
+    return gpu_block_extents_;
+  }
+
+  const std::vector<const Expr*>& gpu_thread_extents() const {
+    return gpu_thread_extents_;
+  }
+
+  const Var* rand_func() const {
+    return rand_func_;
+  }
+
+  using IRPrinter::name_manager;
+  using IRPrinter::visit;
+
+ private:
+  std::vector<const Expr*> gpu_block_extents_;
+  std::vector<const Expr*> gpu_thread_extents_;
+  const Var* rand_func_;
+};
+
+// Construct Cuda C from the buffer and tensor input, and invoke the kernel
+// when real arguments are provided.
+class TORCH_CUDA_API CudaCodeGen : public CodeGen {
+ public:
+  template <typename... Ts>
+  CudaCodeGen(Stmt* stmt, Ts... ts) : CodeGen(stmt, std::forward<Ts>(ts)...) {
+    Initialize();
+  }
+
+  CudaCodeGen(Stmt* stmt, const std::vector<BufferArg>& buffer_args)
+      : CodeGen(stmt, buffer_args) {
+    Initialize();
+  }
+
+  ~CudaCodeGen() override {}
+
+  void call(const std::vector<CallArg>& args) override;
+
+  template <typename... Ts>
+  void operator()(const Ts&... ts) {
+    call(std::vector<CallArg>({CallArg(ts)...}));
+  }
+
+ private:
+  void Initialize();
+
+  void CompileToNVRTC(const std::string& code, const std::string& func_name);
+
+  UniqueNameManager* name_manager() {
+    if (!printer_) {
+      throw std::runtime_error("Null IRPrinter is not expected");
+    }
+    return printer_->name_manager();
+  }
+
+  std::ostream& os() {
+    return printer_->os();
+  }
+
+  std::ostringstream oss_;
+  std::unique_ptr<CudaPrinter> printer_;
+  CUfunction function_;
+  bool has_random_ = false;
+
+  std::string GetUniqueFuncName(const std::string& func_prefix);
+};
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/cuda_half_support.h b/torch/csrc/jit/tensorexpr/cuda_half_support.h
new file mode 100644
index 0000000..91d4a5f
--- /dev/null
+++ b/torch/csrc/jit/tensorexpr/cuda_half_support.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include "torch/csrc/jit/codegen/fuser/cuda/resource_strings.h"
+#include "torch/csrc/jit/tensorexpr/cuda_codegen.h"
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+// Walk the Statment looking for Half size loads/stores.
+class CudaHalfChecker : public IRVisitor {
+ public:
+  bool hasHalf() {
+    return hasHalf_;
+  }
+
+  void visit(const Load* v) override {
+    hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
+  }
+  void visit(const Store* v) override {
+    hasHalf_ |= v->value()->dtype().scalar_type() == ScalarType::Half;
+  }
+
+ private:
+  bool hasHalf_{false};
+};
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/cuda_random.h b/torch/csrc/jit/tensorexpr/cuda_random.h
new file mode 100644
index 0000000..987ac52
--- /dev/null
+++ b/torch/csrc/jit/tensorexpr/cuda_random.h
@@ -0,0 +1,104 @@
+#pragma once
+
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+constexpr auto philox_random_string = R"(
+
+class Philox {
+public:
+  __device__ inline Philox(unsigned long long seed,
+                           unsigned long long subsequence,
+                           unsigned long long offset) {
+    key.x = (unsigned int)seed;
+    key.y = (unsigned int)(seed >> 32);
+    counter = make_uint4(0, 0, 0, 0);
+    counter.z = (unsigned int)(subsequence);
+    counter.w = (unsigned int)(subsequence >> 32);
+    STATE = 0;
+    incr_n(offset / 4);
+  }
+
+  __device__ inline unsigned long operator()() {
+    if(STATE == 0) {
+      uint4 counter_ = counter;
+      uint2 key_ = key;
+      for(int i = 0; i < 9; i++) {
+        counter_ = single_round(counter_, key_);
+        key_.x += (kPhilox10A); key_.y += (kPhilox10B);
+      }
+      output = single_round(counter_, key_);
+      incr();
+    }
+    unsigned long ret;
+    switch(STATE) {
+      case 0: ret = output.x; break;
+      case 1: ret = output.y; break;
+      case 2: ret = output.z; break;
+      case 3: ret = output.w; break;
+    }
+    STATE = (STATE + 1) % 4;
+    return ret;
+  }
+
+private:
+  uint4 counter;
+  uint4 output;
+  uint2 key;
+  unsigned int STATE;
+  __device__ inline void incr_n(unsigned long long n) {
+    unsigned int nlo = (unsigned int)(n);
+    unsigned int nhi = (unsigned int)(n >> 32);
+    counter.x += nlo;
+    if (counter.x < nlo)
+      nhi++;
+    counter.y += nhi;
+    if (nhi <= counter.y)
+      return;
+    if (++counter.z)
+      return;
+    ++counter.w;
+  }
+  __device__ inline void incr() {
+    if (++counter.x)
+      return;
+    if (++counter.y)
+      return;
+    if (++counter.z)
+      return;
+    ++counter.w;
+  }
+  __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
+                                    unsigned int *result_high) {
+    *result_high = __umulhi(a, b);
+    return a*b;
+  }
+
+  __device__ inline uint4 single_round(uint4 ctr, uint2 key) {
+    unsigned int hi0;
+    unsigned int hi1;
+    unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
+    unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
+
+    uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
+    return ret;
+  }
+
+  static const unsigned long kPhilox10A = 0x9E3779B9;
+  static const unsigned long kPhilox10B = 0xBB67AE85;
+  static const unsigned long kPhiloxSA = 0xD2511F53;
+  static const unsigned long kPhiloxSB = 0xCD9E8D57;
+};
+
+// Inverse of 2^32.
+#define M_RAN_INVM32 2.3283064e-10f
+__device__  __inline__ float Uint32ToFloat(unsigned int x) {
+  return x * M_RAN_INVM32;
+}
+
+)";
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index f4c5adc..f26ba0b 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -5,6 +5,30 @@
 using namespace torch::jit;
 using namespace torch::jit::tensorexpr;
 
+namespace torch {
+namespace jit {
+namespace tensorexpr {
+
+static int te_cuda_pointwise_loop_levels = -1;
+static int te_cuda_pointwise_block_count = -1;
+static int te_cuda_pointwise_block_size = -1;
+
+int& GetTECudaPointwiseLoopLevels() {
+  return te_cuda_pointwise_loop_levels;
+}
+
+int& GetTECudaPointwiseBlockCount() {
+  return te_cuda_pointwise_block_count;
+}
+
+int& GetTECudaPointwiseBlockSize() {
+  return te_cuda_pointwise_block_size;
+}
+
+} // namespace tensorexpr
+} // namespace jit
+} // namespace torch
+
 static at::ScalarType tensorType(Tensor* t) {
   return static_cast<at::ScalarType>(t->body()->dtype().scalar_type());
 }
@@ -883,12 +907,96 @@
 void TensorExprKernel::LowerToBackend(BackendType backend_type) {
   std::vector<Tensor*> tensor_outputs(tensor_outputs_);
 
+  if (backend_type == BackendType::kCudaCodeGen) {
+    for (size_t tensor_idx = 0; tensor_idx < tensor_outputs_.size();
+         tensor_idx++) {
+      Tensor* tensor = tensor_outputs_[tensor_idx];
+      ExprHandle total_count = ExprHandle(tensor->dim(0));
+      for (int i = 1; i < tensor->ndim(); i++) {
+        const IntImm* total_count_i = total_count.AsNode<IntImm>();
+        const IntImm* tensor_dim_i =
+            dynamic_cast<const IntImm*>(tensor->dim(i));
+        if (total_count_i && tensor_dim_i) {
+          // TODO: switch to real constant folding when it is available.
+          total_count =
+              ExprHandle(total_count_i->value() * tensor_dim_i->value());
+        } else {
+          total_count = total_count * ExprHandle(tensor->dim(i));
+        }
+      }
+      // Flatten the index for GPU kernels.
+      // TODO: move this to fusing axis when it is ready.
+      Tensor* new_out = Compute(
+          tensor->func_var()->name_hint() + "_flat",
+          {total_count},
+          [tensor](const VarHandle& index) -> ExprHandle {
+            std::vector<ExprHandle> dims;
+            ExprHandle value = index;
+            for (int i = tensor->ndim() - 1; i >= 0; i--) {
+              ExprHandle idx = value;
+              if (i > 0) {
+                idx = Mod::make(value, ExprHandle(tensor->dim(i)));
+              }
+              dims.push_back(idx);
+              value = value / ExprHandle(tensor->dim(i));
+            }
+            std::reverse(dims.begin(), dims.end());
+            return tensor->call(dims);
+          });
+      tensor_outputs[tensor_idx] = new_out;
+    }
+  }
+
   torch::jit::tensorexpr::schedule::LoopNest l(tensor_outputs);
 
   // Compute non-output tensors_ inline
   for (auto& p : tensors_) {
     l.ComputeInline(l.getLoopBodyFor(p.second));
   }
+  if (backend_type == kCudaCodeGen) {
+    for (size_t i = 0; i < tensor_outputs_.size(); i++) {
+      l.ComputeInline(l.getLoopBodyFor(tensor_outputs_[i]));
+
+      Tensor* tensor = tensor_outputs[i];
+      const Var* index = tensor->arg(0);
+      int loop_levels = GetTECudaPointwiseLoopLevels();
+      const int kDefaultLoopLevels = 2;
+      loop_levels = (loop_levels > 0) ? loop_levels : kDefaultLoopLevels;
+      int block_count = GetTECudaPointwiseBlockCount();
+      int block_size = GetTECudaPointwiseBlockSize();
+
+      if (loop_levels == 2) {
+        Stmt* outer;
+        Stmt* inner;
+        const int kDefaultBlockSize = 512;
+        if (block_size < 0) {
+          block_size = kDefaultBlockSize;
+        }
+        std::vector<Stmt*> loops = l.getLoopStmtsFor(tensor);
+        l.SplitWithMask(loops[0], block_size, &outer, &inner);
+        l.SetGPUBlockIndex(outer, 0);
+        l.SetGPUThreadIndex(inner, 0);
+      } else if (loop_levels == 3) {
+        Stmt* outer;
+        Stmt* inner;
+        Stmt* inner_1;
+        Stmt* inner_2;
+        // TODO: change the number of microprocessors
+        const int kDefaultBlockCount = 1280;
+        const int kDefaultBlockSize = 256;
+        block_count = (block_count > 0) ? block_count : kDefaultBlockCount;
+        block_size = (block_size > 0) ? block_size : kDefaultBlockSize;
+        std::vector<Stmt*> loops = l.getLoopStmtsFor(tensor);
+        l.SplitWithMask(loops[0], block_count * block_size, &outer, &inner);
+        l.SplitWithMask(inner, block_size, &inner_1, &inner_2);
+        l.SetGPUBlockIndex(inner_1, 0);
+        l.SetGPUThreadIndex(inner_2, 0);
+      } else {
+        throw std::runtime_error(
+            "Invalid loop-level: " + std::to_string(loop_levels));
+      }
+    }
+  }
 
   l.ApplyInlines();
   Stmt* stmt = l.root_stmt();
@@ -911,6 +1019,9 @@
   // Generate code.
   std::string codegen_name;
   switch (backend_type_) {
+    case kCudaCodeGen:
+      codegen_name = "cuda_codegen";
+      break;
     case kSimpleIREval:
       codegen_name = "simple_ir_eval";
       break;
@@ -933,7 +1044,9 @@
     throw std::runtime_error("No tensor inputs");
   }();
   BackendType backend_type = BackendType::kUninitialized;
-  if (device.type() == at::kCPU) {
+  if (device.type() == at::kCUDA) {
+    backend_type = kCudaCodeGen;
+  } else if (device.type() == at::kCPU) {
     backend_type = kSimpleIREval;
   } else {
     throw std::runtime_error("Invalid device type");
@@ -956,6 +1069,7 @@
     const std::vector<CodeGen::CallArg>& run_args) {
   switch (backend_type_) {
     case kSimpleIREval:
+    case kCudaCodeGen:
       codegen_->call(run_args);
       break;
     default:
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index bbaf212..f3dcf65 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -52,6 +52,7 @@
   enum BackendType {
     kUninitialized,
     kSimpleIREval,
+    kCudaCodeGen,
   };
 
   ExprHandle constant(const torch::jit::Value* v);
@@ -205,6 +206,10 @@
   at::Device device_ = at::kCPU;
 };
 
+TORCH_API int& GetTECudaPointwiseLoopLevels();
+TORCH_API int& GetTECudaPointwiseBlockCount();
+TORCH_API int& GetTECudaPointwiseBlockSize();
+
 } // namespace tensorexpr
 } // namespace jit
 } // namespace torch