[nvFuser] Working towards reductions, codegen improvements (#40864)

Summary:
Have basic reduction fusion working, and have improved code generator to approach performance of eager mode reductions. Coming soon will be pointwise-reduction fusions in a way that should prevent the possibility of hitting regressions. Also working on performant softmax kernels in the code generator which may be our next fusion target.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/40864

Reviewed By: ngimel

Differential Revision: D22392877

Pulled By: soumith

fbshipit-source-id: 457448a807d628b1035f6d90bc0abe8a87bf8447
diff --git a/.jenkins/pytorch/macos-test.sh b/.jenkins/pytorch/macos-test.sh
index 1fffbe2..7ec76d2 100755
--- a/.jenkins/pytorch/macos-test.sh
+++ b/.jenkins/pytorch/macos-test.sh
@@ -63,7 +63,7 @@
   # Increase default limit on open file handles from 256 to 1024
   ulimit -n 1024
 
-  python test/run_test.py --verbose --exclude test_jit_profiling test_jit_legacy test_jit_fuser_legacy test_jit_fuser_te test_tensorexpr --determine-from="$DETERMINE_FROM"
+  python test/run_test.py --verbose --exclude test_jit_cuda_fuser_profiling test_jit_cuda_fuser_legacy test_jit_profiling test_jit_legacy test_jit_fuser_legacy test_jit_fuser_te test_tensorexpr --determine-from="$DETERMINE_FROM"
 
   assert_git_not_dirty
 }
diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh
index c1e67b8..72d1cd4 100755
--- a/.jenkins/pytorch/test.sh
+++ b/.jenkins/pytorch/test.sh
@@ -150,17 +150,17 @@
 }
 
 test_python_ge_config_profiling() {
-  time python test/run_test.py --include test_jit_profiling test_jit_fuser_te --verbose --determine-from="$DETERMINE_FROM"
+  time python test/run_test.py --include test_jit_cuda_fuser_profiling test_jit_profiling test_jit_fuser_te --verbose --determine-from="$DETERMINE_FROM"
   assert_git_not_dirty
 }
 
 test_python_ge_config_legacy() {
-  time python test/run_test.py --include test_jit_legacy test_jit_fuser_legacy --verbose --determine-from="$DETERMINE_FROM"
+  time python test/run_test.py --include test_jit_cuda_fuser_legacy test_jit_legacy test_jit_fuser_legacy --verbose --determine-from="$DETERMINE_FROM"
   assert_git_not_dirty
 }
 
 test_python_all_except_nn_and_cpp_extensions() {
-  time python test/run_test.py --exclude test_nn test_jit_profiling test_jit_legacy test_jit_fuser_legacy test_jit_fuser_te test_tensorexpr --verbose --determine-from="$DETERMINE_FROM"
+  time python test/run_test.py --exclude test_jit_cuda_fuser_profiling test_jit_cuda_fuser_legacy test_nn test_jit_profiling test_jit_legacy test_jit_fuser_legacy test_jit_fuser_te test_tensorexpr --verbose --determine-from="$DETERMINE_FROM"
   assert_git_not_dirty
 }
 
diff --git a/.jenkins/pytorch/win-test-helpers/test_python_all_except_nn.bat b/.jenkins/pytorch/win-test-helpers/test_python_all_except_nn.bat
index dd4339b..4bfb5bc 100644
--- a/.jenkins/pytorch/win-test-helpers/test_python_all_except_nn.bat
+++ b/.jenkins/pytorch/win-test-helpers/test_python_all_except_nn.bat
@@ -1,3 +1,3 @@
 call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat
-cd test && python run_test.py --exclude test_jit_profiling test_jit_legacy test_jit_fuser_legacy test_jit_fuser_te test_tensorexpr --verbose --determine-from="%1" && cd ..
+cd test && python run_test.py --exclude test_jit_cuda_fuser_profiling test_jit_cuda_fuser_legacy test_jit_profiling test_jit_legacy test_jit_fuser_legacy test_jit_fuser_te test_tensorexpr --verbose --determine-from="%1" && cd ..
 if ERRORLEVEL 1 exit /b 1
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 7693f62..2e36a97 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -448,12 +448,14 @@
       ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp
       ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/arith.cpp
+      ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/compute_at.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/dispatch.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/expr_evaluator.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/fusion.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/graph_fuser.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/index_compute.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_base_nodes.cpp
+      ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_cloner.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_graphviz.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_nodes.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_iostream.cpp
@@ -463,13 +465,16 @@
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/manager.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/shape_inference.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/mutator.cpp
+      ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_index.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_loops.cpp
+      ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_thread_predicate.cpp
+      ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_unroll.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_utils.cpp
+      ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_validation.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower2device.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/parser.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/partition.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/predicate_compute.cpp
-      ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/tensor_meta.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/tensor_view.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_iter.cpp
       ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_replay.cpp
diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp
index 84fa69d..0edcac5 100644
--- a/test/cpp/jit/test_gpu.cpp
+++ b/test/cpp/jit/test_gpu.cpp
@@ -10,7 +10,6 @@
 #include <torch/csrc/jit/codegen/cuda/kernel.h>
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
 #include <torch/csrc/jit/codegen/cuda/mutator.h>
-#include <torch/csrc/jit/codegen/cuda/tensor_meta.h>
 #include <torch/csrc/jit/codegen/cuda/transform_replay.h>
 #include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
 
@@ -38,7 +37,7 @@
 
 static void checkIntValue(
     const EvaluationContext* eval_context,
-    const Val* val,
+    Val* val,
     Int::ScalarType expected_value) {
   TORCH_CHECK(val->isAnInt());
   const auto actual_value = ExpressionEvaluator::evaluate(val, eval_context);
@@ -66,27 +65,24 @@
   TensorView* tv0 = makeDummyTensor(2);
   fusion.addInput(tv0);
 
-  TensorView* tv1 = mul(tv0, new Float(-1.0));
-  TensorView* tv2 = add(tv0, new Float(3.0));
-  TensorView* tv3 = mul(tv0, new Float(2.5));
-  TensorView* tv4 = add(tv2, tv1);
-  TensorView* tv5 = add(tv4, tv3);
-  TensorView* tv6 = add(tv0, tv3);
+  TensorView* tv2 = add(tv0, new Float(3.141));
+  TensorView* tv3 = broadcast(tv0, {false, true, false, true});
+  TensorView* tv4 = reductionOp(BinaryOpType::Add, {2}, new Float(0), tv3);
+  TensorView* tv5 = clamp(tv4, new Float(0.f), new Float(1.f));
+  TensorView* tv6 = add(tv2, tv2);
 
   // Another checkpoint before adding outputs
   TORCH_CHECK(!IrGraphGenerator::toGraphviz(
                    &fusion, IrGraphGenerator::DetailLevel::Explicit)
                    .empty());
 
-  fusion.addOutput(tv5);
   fusion.addOutput(tv6);
 
   tv6->merge(0);
   tv6->split(0, 4);
   tv6->axis(0)->parallelize(ParallelType::BIDx);
   tv5->reorder({{-1, 0}});
-  tv0->computeAt(tv3, 1);
-  tv0->computeAt(tv6, 1);
+  tv2->computeAt(tv6, 1);
 
   // Another checkpoint with more node types
   TORCH_CHECK(!IrGraphGenerator::toGraphviz(
@@ -148,11 +144,22 @@
   auto* a = new Int();
   auto* b = new Int();
   auto* c = add(a, b);
-  auto* d = neg(ceilDiv(add(a, b), b));
+  auto* d = neg(ceilDiv(c, b));
+  auto* e = new Int(0);
+
+  // trying to evaluate before binding should give empty results
+  TORCH_CHECK(!ExpressionEvaluator::evaluate(a, &eval_context).has_value());
+  TORCH_CHECK(!ExpressionEvaluator::evaluate(d, &eval_context).has_value());
 
   eval_context.bind(a, 7);
   eval_context.bind(b, 3);
 
+  // can't bind to the results of expressions
+  ASSERT_ANY_THROW(eval_context.bind(c, 100));
+
+  // can't bind to concrete values
+  ASSERT_ANY_THROW(eval_context.bind(e, 100));
+
   checkIntValue(&eval_context, c, 10);
   checkIntValue(&eval_context, sub(a, b), 4);
   checkIntValue(&eval_context, mod(a, b), 1);
@@ -277,6 +284,290 @@
   checkIntValue(&eval_context, tv6->axis(2)->rawExtent(), 127);
 }
 
+// Evaluate expressions post lowering
+void testGPU_FusionExprEvalPostLower() {
+  Fusion fusion;
+  FusionGuard fg(&fusion);
+
+  // Create a non-trivial IR
+  TensorView* tv0 = makeDummyTensor(2);
+  TensorView* tv1 = makeDummyTensor(2);
+
+  fusion.addInput(tv0);
+  fusion.addInput(tv1);
+
+  TensorView* tv2 = add(tv1, new Float(2.0));
+  TensorView* tv3 = add(tv0, tv2);
+
+  fusion.addOutput(tv3);
+
+  tv3->split(0, 4);
+
+  tv0->computeAt(tv3, 1);
+  tv1->computeAt(tv3, 1);
+
+  tv3->axis(0)->parallelize(ParallelType::BIDx);
+  tv2->axis(1)->parallelize(ParallelType::Unroll);
+  tv3->axis(1)->parallelize(ParallelType::Unroll);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+
+  auto* bid_x = add(tv3->axis(0)->rawExtent(), new Int(0));
+  auto* tid_x = add(tv3->axis(-1)->rawExtent(), new Int(0));
+
+  // Lower
+  GPULower gpulw(&fusion);
+  std::stringstream kernel;
+  gpulw.printKernel(kernel);
+
+  // 1. Create an evaluation context
+  EvaluationContext eval_context(&fusion);
+
+  // 2. Bind values
+  eval_context.bind(tv0->getRootDomain()[0]->extent(), 6);
+  eval_context.bind(tv0->getRootDomain()[1]->extent(), 128);
+  eval_context.bind(tv1->getRootDomain()[0]->extent(), 6);
+  eval_context.bind(tv1->getRootDomain()[1]->extent(), 128);
+
+  // 3. Evaluate and check result values
+  TORCH_CHECK(tv2->domain()->nDims() == 3);
+  checkIntValue(&eval_context, tv2->axis(0)->rawExtent(), 2);
+  checkIntValue(&eval_context, tv2->axis(1)->rawExtent(), 4);
+  checkIntValue(&eval_context, tv2->axis(2)->rawExtent(), 128);
+
+  TORCH_CHECK(tv3->domain()->nDims() == 3);
+  checkIntValue(&eval_context, tv3->axis(0)->rawExtent(), 2);
+  checkIntValue(&eval_context, tv3->axis(1)->rawExtent(), 4);
+  checkIntValue(&eval_context, tv3->axis(2)->rawExtent(), 128);
+
+  checkIntValue(&eval_context, bid_x, 2);
+  checkIntValue(&eval_context, tid_x, 128);
+}
+
+void testGPU_FusionClear() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  // 1. Create a dummy IR
+
+  {
+    TensorView* tv0 = makeDummyTensor(2);
+    TensorView* tv1 = makeDummyTensor(2);
+
+    fusion.addInput(tv0);
+    fusion.addInput(tv1);
+
+    TensorView* tv2 = add(tv1, new Float(2.0));
+    TensorView* tv3 = add(tv0, tv2);
+
+    fusion.addOutput(tv3);
+
+    tv3->split(0, 4);
+    tv0->computeAt(tv3, 1);
+    tv1->computeAt(tv3, 1);
+
+    tv3->axis(0)->parallelize(ParallelType::BIDx);
+    tv2->axis(1)->parallelize(ParallelType::Unroll);
+    tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  }
+
+  // 2. Clear the IR
+
+  fusion.clear();
+
+  TORCH_CHECK(fusion.exprs().empty());
+  TORCH_CHECK(fusion.vals().empty());
+
+  TORCH_CHECK(fusion.inputs().empty());
+  TORCH_CHECK(fusion.outputs().empty());
+
+  TORCH_CHECK(!fusion.hasReduction());
+  TORCH_CHECK(!fusion.hasBlockReduction());
+  TORCH_CHECK(!fusion.hasGridReduction());
+
+  // 3. Rebuild the IR
+
+  {
+    TensorView* tv0 = makeDummyTensor(3);
+    TensorView* tv1 = makeDummyTensor(3);
+    TensorView* tv2 = add(tv1, new Float(2.0));
+    TensorView* tv3 = add(tv0, tv2);
+
+    fusion.addInput(tv0);
+    fusion.addInput(tv1);
+    fusion.addOutput(tv3);
+
+    tv3->reorder({{0, 2}, {2, 0}});
+    tv3->split(-1, 4);
+    tv3->reorder({{2, 0}, {3, 1}, {0, 3}});
+    tv0->computeAt(tv3, -1);
+    tv1->computeAt(tv3, -1);
+  }
+
+  prog.device_ = 0;
+  prog.grid(4);
+  prog.block(8);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+
+  at::Tensor input1 = at::randn({16, 8, 8}, options);
+  at::Tensor input2 = at::randn_like(input1);
+  at::Tensor output = at::empty_like(input1);
+
+  torch::jit::fuser::cuda::compileKernel(&prog);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input1, input2}, {output});
+
+  at::Tensor tv2_ref = input2 + 2.0;
+  at::Tensor output_ref = input1 + tv2_ref;
+
+  TORCH_CHECK(output_ref.equal(output));
+}
+
+void testGPU_FusionCopy() {
+  Fusion original_fusion;
+
+  // Create the test IR
+  {
+    FusionGuard fg(&original_fusion);
+
+    auto tv0 = makeDummyTensor(3);
+    auto tv1 = makeDummyTensor(3);
+    auto tv2 = add(tv1, new Float(2.0));
+    auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2);
+
+    original_fusion.addInput(tv0);
+    original_fusion.addInput(tv1);
+    original_fusion.addOutput(tv3);
+
+    tv3->reorder({{0, 2}, {2, 0}});
+    tv3->split(-1, 4);
+    tv3->reorder({{2, 0}, {3, 1}, {0, 3}});
+
+    tv0->computeAt(tv3, -1);
+    tv1->computeAt(tv3, -1);
+
+    tv3->axis(0)->parallelize(ParallelType::BIDx);
+    tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  }
+
+  // Test copy before lowering
+  Fusion clone = original_fusion;
+
+  // Compare IR dumps
+  std::stringstream original_ir;
+  std::stringstream clone_ir;
+  original_ir << original_fusion;
+  clone_ir << clone;
+  ASSERT_EQ(original_ir.str(), clone_ir.str());
+
+  // Lower original fusion
+  std::stringstream original_kernel;
+  {
+    GPULower lower(&original_fusion);
+    lower.printKernel(original_kernel);
+  }
+
+  // Make sure the "before lowering" clone was not mutated
+  // while lowering the original fusion IR
+  std::stringstream before_lowering_ir;
+  before_lowering_ir << clone;
+  ASSERT_EQ(original_ir.str(), before_lowering_ir.str());
+
+  // Test copy after lowering (including assignment operator)
+  Fusion before_lowering = clone;
+  clone = original_fusion;
+
+  // Compare IR dumps
+  std::stringstream original_lowered_ir;
+  std::stringstream clone_lowered_ir;
+  original_lowered_ir << original_fusion;
+  clone_lowered_ir << clone;
+  ASSERT_EQ(original_lowered_ir.str(), clone_lowered_ir.str());
+
+  // Lower the "before lowering" and compare kernels
+  std::stringstream clone_kernel;
+  {
+    GPULower lower(&before_lowering);
+    lower.printKernel(clone_kernel);
+  }
+  ASSERT_EQ(original_kernel.str(), clone_kernel.str());
+}
+
+void testGPU_FusionMove() {
+  Fusion fusion;
+
+  // Create the test IR
+  {
+    FusionGuard fg(&fusion);
+
+    auto tv0 = makeDummyTensor(3);
+    auto tv1 = makeDummyTensor(3);
+    auto tv2 = add(tv1, new Float(2.0));
+    auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2);
+
+    fusion.addInput(tv0);
+    fusion.addInput(tv1);
+    fusion.addOutput(tv3);
+
+    tv3->reorder({{0, 2}, {2, 0}});
+    tv3->split(-1, 4);
+    tv3->reorder({{2, 0}, {3, 1}, {0, 3}});
+
+    tv0->computeAt(tv3, -1);
+    tv1->computeAt(tv3, -1);
+
+    tv3->axis(0)->parallelize(ParallelType::BIDx);
+    tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  }
+
+  std::stringstream original_ir;
+  original_ir << fusion;
+
+  // Test move before lowering
+  Fusion another_fusion = std::move(fusion);
+
+  // Check that the original fusion is "empty"
+  //
+  // IMPORTANT: these checks assume knowledge of the internal
+  //    implementation of the move operations. General uses
+  //    should only assume that the moved-from object is in
+  //    a valid, but unspecified state. This is similar to the
+  //    standard library containers:
+  //    https://en.cppreference.com/w/cpp/utility/move
+  //
+  TORCH_CHECK(fusion.exprs().empty());
+  TORCH_CHECK(fusion.vals().empty());
+  TORCH_CHECK(fusion.inputs().empty());
+  TORCH_CHECK(fusion.outputs().empty());
+
+  // clear() has no pre-conditions so it's valid to call on a moved-from object
+  fusion.clear();
+
+  // Compare IR dumps
+  std::stringstream another_ir;
+  another_ir << another_fusion;
+  ASSERT_EQ(original_ir.str(), another_ir.str());
+
+  // Lower the fusion IR
+  std::stringstream kernel;
+  {
+    GPULower lower(&another_fusion);
+    lower.printKernel(kernel);
+  }
+
+  std::stringstream lowered_ir;
+  lowered_ir << another_fusion;
+
+  // Test move assignment after lowering
+  fusion = std::move(another_fusion);
+
+  // Compare IR dumps
+  std::stringstream moved_lowered_ir;
+  moved_lowered_ir << fusion;
+  ASSERT_EQ(lowered_ir.str(), moved_lowered_ir.str());
+}
+
 void testGPU_FusionSimpleArith() {
   std::stringstream ss1, ss2;
 
@@ -460,139 +751,6 @@
   TORCH_CHECK(fuser_tensor->domain() != nullptr);
 }
 
-void testGPU_FusionTensorContiguity() {
-  {
-    // NCHW memory layout
-    auto tensor = at::randn({2, 3, 4, 5});
-    auto sizes = tensor.sizes().vec();
-    auto strides = tensor.strides().vec();
-    TensorContiguity t_c(sizes, strides);
-    TORCH_CHECK(t_c.rank() == 4);
-    TORCH_CHECK(t_c.getBroadcastDims().size() == 0);
-    for (int i = 0; i < 4; i++) {
-      TORCH_CHECK(!t_c.isBroadcastDim(i));
-      if (i < 3) {
-        TORCH_CHECK(t_c.canCollapseToHigher(i));
-      }
-    }
-  }
-
-  {
-    // NHWC memory layout
-    TensorContiguity t_c({2, 3, 4, 5}, {60, 1, 15, 3});
-    TORCH_CHECK(t_c.rank() == 4);
-    TORCH_CHECK(t_c.getBroadcastDims().size() == 0);
-    for (int i = 0; i < 4; i++) {
-      TORCH_CHECK(!t_c.isBroadcastDim(i));
-      if (i < 3) {
-        TORCH_CHECK((t_c.canCollapseToHigher(i) ^ (i != 2)));
-      }
-    }
-  }
-
-  {
-    // NHWC memory layout with broadcast
-    TensorContiguity t_c({2, 3, 4, 5}, {120, 0, 30, 3});
-    TORCH_CHECK(t_c.rank() == 4);
-    auto b_dims = t_c.getBroadcastDims();
-    TORCH_CHECK(b_dims.size() == 1 && b_dims[0] == 1);
-    for (int i = 0; i < 4; i++) {
-      TORCH_CHECK(!(t_c.isBroadcastDim(i)) ^ (i == 1));
-      if (i < 3) {
-        TORCH_CHECK(!(t_c.canCollapseToHigher(i)));
-      }
-    }
-  }
-
-  {
-    // contiguity across size-1 dimension
-    auto tensor = at::randn({4, 1, 4});
-    auto sizes = tensor.sizes().vec();
-    auto strides = tensor.strides().vec();
-    auto dim = sizes.size();
-    TensorContiguity t_c(sizes, strides);
-    TORCH_CHECK(t_c.rank() == (int)sizes.size());
-    auto b_dims = t_c.getBroadcastDims();
-    TORCH_CHECK(b_dims.size() == 0);
-    TORCH_CHECK(t_c.getFCD() == 2);
-    TORCH_CHECK(t_c.hasContiguousFCD());
-    for (decltype(dim) i = 0; i < dim; i++) {
-      TORCH_CHECK(!t_c.isBroadcastDim(i));
-      if (i < dim - 1) {
-        TORCH_CHECK(t_c.canCollapseToHigher(i));
-      }
-    }
-  }
-
-  {
-    // no contiguity across size-1 dimension
-    auto tensor = at::randn({4, 4, 4}).split(1, 1)[0];
-    auto sizes = tensor.sizes().vec();
-    auto strides = tensor.strides().vec();
-    TensorContiguity t_c(sizes, strides);
-    TORCH_CHECK(!(t_c.canCollapseToHigher(0)));
-    TORCH_CHECK((t_c.canCollapseToHigher(1)));
-  }
-
-  {
-    // no contiguity across size-1 dimension
-    auto tensor = at::randn({4, 1, 8}).split(4, 2)[0];
-    auto sizes = tensor.sizes().vec();
-    auto strides = tensor.strides().vec();
-    TensorContiguity t_c(sizes, strides);
-    TORCH_CHECK((t_c.canCollapseToHigher(0)));
-    TORCH_CHECK((!t_c.canCollapseToHigher(1)));
-  }
-
-  {
-    // no contiguity across size-1 dimension
-    auto tensor = at::randn({8, 1, 4}).split(4, 0)[0];
-    auto sizes = tensor.sizes().vec();
-    auto strides = tensor.strides().vec();
-    TensorContiguity t_c(sizes, strides);
-    TORCH_CHECK((t_c.canCollapseToHigher(0)));
-    TORCH_CHECK((t_c.canCollapseToHigher(1)));
-  }
-
-  {
-    // test merge
-    TensorContiguity t_c_l({4, 4, 4}, {16, 4, 1});
-    TensorContiguity t_c_r({4, 4, 4}, {16, 4, 1});
-    t_c_l.merge(t_c_r);
-    TORCH_CHECK((t_c_l.isIdentical(t_c_r)));
-  }
-
-  {
-    TensorContiguity t_c_l({4, 4, 4, 4}, {16, 0, 4, 1});
-    TensorContiguity t_c_r({4, 4, 4, 4}, {64, 16, 4, 1});
-    t_c_l.merge(t_c_r);
-    TORCH_CHECK(t_c_l.getFCD() == 3);
-    TORCH_CHECK(t_c_l.getAxisByStride(0) == 0);
-  }
-
-  {
-    // NHWC + NCHW
-    TensorContiguity t_c_l({4, 4, 4, 4}, {64, 16, 4, 1});
-    TensorContiguity t_c_r({4, 4, 4, 4}, {64, 1, 16, 4});
-    t_c_l.merge(t_c_r);
-    TORCH_CHECK(!t_c_l.hasContiguousFCD());
-    TORCH_CHECK(t_c_l.getFCD() == -1);
-    TORCH_CHECK(t_c_l.getAxisByStride(0) == 0);
-    TORCH_CHECK(t_c_l.getAxisByStride(1) == -1);
-    TORCH_CHECK(t_c_l.getAxisByStride(2) == -1);
-    TORCH_CHECK(t_c_l.getAxisByStride(3) == -1);
-  }
-
-  {
-    // NCHW + NCHW with broadcasting
-    TensorContiguity t_c_l({4, 4, 4, 4}, {4, 1, 4, 0});
-    TensorContiguity t_c_r({4, 4, 4, 4}, {64, 1, 16, 4});
-    t_c_l.merge(t_c_r);
-    TORCH_CHECK(t_c_l.getFCD() == 1);
-    TORCH_CHECK(t_c_l.getAxisByStride(0) == 0);
-  }
-}
-
 void testGPU_FusionTVSplit() {
   Fusion fusion;
   FusionGuard fg(&fusion);
@@ -853,51 +1011,56 @@
   prog.device_ = 0;
   fuser::cuda::parseJitIR(g, &prog);
 
-  std::stringstream ref;
-  ref << "__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3){\n"
-      << "  float T2[4];\n"
-      << "  if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n"
-      << "    for(size_t i60 = 0; i60 < 4; ++i60 ) {\n"
-      << "      T2[ i60 ]\n"
-      << "         = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]\n"
-      << "         * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];\n"
-      << "    }\n"
-      << "  } else { \n"
-      << "    for(size_t i60 = 0; i60 < 4; ++i60 ) {\n"
-      << "      if ( ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n"
-      << "        T2[ i60 ]\n"
-      << "           = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]\n"
-      << "           * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];\n"
-      << "      }\n"
-      << "    }\n"
-      << "  }\n"
-      << "  if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n"
-      << "    for(size_t i61 = 0; i61 < 4; ++i61 ) {\n"
-      << "      T3[ ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]\n"
-      << "         = T2[ i61 ]\n"
-      << "         * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];\n"
-      << "    }\n"
-      << "  } else { \n"
-      << "    for(size_t i61 = 0; i61 < 4; ++i61 ) {\n"
-      << "      if ( ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n"
-      << "        T3[ ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]\n"
-      << "           = T2[ i61 ]\n"
-      << "           * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];\n"
-      << "      }\n"
-      << "    }\n"
-      << "  }\n"
-      << "}\n";
+  // CONSIDER:
+  // 1. this can be moved to a dedicated "golden" file
+  // 2. use a fuzzy compare (ignore non-significant whitespaces for example)
+  const std::string expected_kernel = R"(
+__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3){
+  float T2[4];
+  if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { 
+    for(size_t i40 = 0; i40 < 4; ++i40 ) {
+      T2[ i40 ]
+         = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]
+         * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];
+    }
+  } else { 
+    for(size_t i40 = 0; i40 < 4; ++i40 ) {
+      if ( ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { 
+        T2[ i40 ]
+           = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]
+           * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];
+      }
+    }
+  }
+  if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { 
+    for(size_t i41 = 0; i41 < 4; ++i41 ) {
+      T3[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]
+         = T2[ i41 ]
+         * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];
+    }
+  } else { 
+    for(size_t i41 = 0; i41 < 4; ++i41 ) {
+      if ( ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { 
+        T3[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]
+           = T2[ i41 ]
+           * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];
+      }
+    }
+  }
+}
+)";
 
   GPULower gpulw(&fusion);
-  std::stringstream cdg;
-  gpulw.printKernel(cdg);
-  if (ref.str().size() != cdg.str().size() ||
-      ref.str().compare(cdg.str()) != 0) {
+  std::stringstream actual_kernel;
+  actual_kernel << "\n";
+  gpulw.printKernel(actual_kernel);
+  if (expected_kernel.size() != actual_kernel.str().size() ||
+      expected_kernel.compare(actual_kernel.str()) != 0) {
     std::cerr
         << " Codegen mismatch, codegen possibly changed, or is incorrect. "
-        << " \n ========= REF ========= \n"
-        << ref.str() << "\n========= RESULT ========== \n"
-        << cdg.str() << "\n=================" << std::endl;
+        << " \n ========= EXPECTED ========= \n"
+        << expected_kernel << "\n========= ACTUAL ========== \n"
+        << actual_kernel.str() << "\n=================" << std::endl;
     TORCH_CHECK(false);
   }
 }
@@ -1198,7 +1361,7 @@
 
     tv0->computeAt(tv6, 1);
 
-    TORCH_CHECK(tv0->getComputeAtView() == tv3 && tv0->nDims() == 3);
+    TORCH_CHECK(tv0->getComputeAtView() == tv6 && tv0->nDims() == 3);
     TORCH_CHECK(tv1->getComputeAtView() == tv4 && tv1->nDims() == 3);
     TORCH_CHECK(tv2->getComputeAtView() == tv4 && tv2->nDims() == 3);
     TORCH_CHECK(tv3->getComputeAtView() == tv6 && tv3->nDims() == 3);
@@ -1272,6 +1435,7 @@
     fusion.addOutput(tv6);
 
     tv2->computeAt(tv4, 1);
+
     TORCH_CHECK(!tv0->hasComputeAt());
     TORCH_CHECK(!tv1->hasComputeAt());
     TORCH_CHECK(tv2->getComputeAtView() == tv4);
@@ -1323,10 +1487,10 @@
         &prog, {t0}, {kernel_tv5, kernel_tv6});
 
     GPULower gpulw(&fusion);
-    std::stringstream cdg;
-    gpulw.printKernel(cdg);
+    std::stringstream actual_kernel;
+    gpulw.printKernel(actual_kernel);
 
-    TORCH_CHECK(at::allclose(kernel_tv5, t5), cdg.str());
+    TORCH_CHECK(at::allclose(kernel_tv5, t5), actual_kernel.str());
     TORCH_CHECK(at::allclose(kernel_tv6, t6));
   }
 
@@ -1390,10 +1554,10 @@
     torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {kernel_tv3});
 
     GPULower gpulw(&fusion);
-    std::stringstream cdg;
-    gpulw.printKernel(cdg);
+    std::stringstream actual_kernel;
+    gpulw.printKernel(actual_kernel);
 
-    TORCH_CHECK(at::allclose(kernel_tv3, t3), cdg.str());
+    TORCH_CHECK(at::allclose(kernel_tv3, t3), actual_kernel.str());
   }
 
   // Case 4
@@ -1470,10 +1634,10 @@
         &prog, {t0, t1, t2, t3}, {kernel_tv6});
 
     GPULower gpulw(&fusion);
-    std::stringstream cdg;
-    gpulw.printKernel(cdg);
+    std::stringstream actual_kernel;
+    gpulw.printKernel(actual_kernel);
 
-    TORCH_CHECK(at::allclose(kernel_tv6, t6), cdg.str());
+    TORCH_CHECK(at::allclose(kernel_tv6, t6), actual_kernel.str());
   }
 }
 
@@ -1570,10 +1734,10 @@
       {kernel_tv4});
 
   GPULower gpulw(&fusion);
-  std::stringstream cdg;
-  gpulw.printKernel(cdg);
+  std::stringstream actual_kernel;
+  gpulw.printKernel(actual_kernel);
 
-  TORCH_CHECK(at::allclose(kernel_tv4, t4), cdg.str());
+  TORCH_CHECK(at::allclose(kernel_tv4, t4), actual_kernel.str());
 }
 
 void testGPU_FusionLoopUnroll() {
@@ -1619,9 +1783,6 @@
 
   int inp_size = 129 * 13 * 3;
 
-  // GPULower lower(&fusion);
-  // lower.printKernel(std::cout);
-
   prog.device_ = 0;
   prog.grid((inp_size + 63) / 64);
   prog.block(block_size);
@@ -2192,7 +2353,8 @@
 
   // Replay casp, replay new_domain2 as new_domain
   // reordered_new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf]
-  TensorDomain* casp = TransformReplay::replayCasP(new_domain2, new_domain, 2);
+  auto replay_casp = TransformReplay::replayCasP(new_domain2, new_domain, 2);
+  TensorDomain* casp = replay_casp.first;
   // new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf]
   //       casp[I0oi{16}, I0oo*I0i{32},  R1oi{4}]
 
@@ -2201,7 +2363,8 @@
   // new_domain[I0oi{16},  I0oo*I0i{32}  ,                 ir1oi{4}rf,
   // R(R1oo*R1i{8})rf]
 
-  TensorDomain* pasc = TransformReplay::replayPasC(new_domain, casp, 2);
+  auto replay_pasc = TransformReplay::replayPasC(new_domain, casp, 2);
+  TensorDomain* pasc = replay_pasc.first;
   // pasc      [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4}rf,
   // R(R1oo*R1i{8})rf]
 
@@ -2287,11 +2450,6 @@
   tv2->axis(-1)->parallelize(ParallelType::TIDx);
   tv3->axis(-1)->parallelize(ParallelType::TIDx);
 
-  // for(auto expr : fusion.exprs(true))
-  //   std::cout<<expr<<std::endl;
-  // GPULower lower(&fusion);
-  // lower.printKernel(std::cout);
-
   int numel_x = 65000;
   int numel_y = 1025;
 
@@ -2527,28 +2685,189 @@
   }
 }
 
-void testGPU_FusionSimpleBCast() {
-  {
-    Fusion fusion;
-    FusionGuard fg(&fusion);
+void testGPU_FusionReduction4() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
 
-    // Set up your input tensor views
-    TensorView* tv0 = makeDummyTensor(2);
-    TensorView* tv1 = makeDummyTensor(2);
-    fusion.addInput(tv0);
-    fusion.addInput(tv1);
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(3);
 
-    TensorView* tv2 = add(tv0, tv1);
+  fusion.addInput(tv0);
 
-    // tv1[I0, R1] = tv0[I0, I1]
-    TensorView* tv3 = broadcast(tv2, {false, true, true, false});
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
 
-    Val* tv4 = mul(tv3, makeDummyTensor(4));
-    fusion.addOutput(tv4);
+  fusion.addOutput(tv1);
+
+  int bidy = 2;
+  int tidy = 4;
+  int tidx = 5;
+
+  int dim1 = 11;
+
+  tv1->split(-2, tidy);
+
+  TensorView* tv2 = tv1->rFactor({-3});
+
+  tv0->computeAt(tv1, 1);
+
+  tv1->axis(0)->parallelize(ParallelType::BIDy);
+
+  for (auto* val : fusion.vals()) {
+    if (val->getValType().value() == ValType::TensorView)
+      val->as<TensorView>()->axis(-1)->parallelize(ParallelType::TIDx);
   }
 
+  tv2->axis(-2)->parallelize(ParallelType::TIDy);
+  tv1->axis(-2)->parallelize(ParallelType::TIDy);
+
+  prog.device_ = 0;
+  prog.grid(1, bidy);
+  prog.block(tidx, tidy);
+  torch::jit::fuser::cuda::compileKernel(&prog);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::randn({bidy, dim1, tidx}, options);
+
+  at::Tensor cg_output = at::empty({bidy, tidx}, options);
+
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+void testGPU_FusionReduction5() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  const int bdimx = 64;
+  const int bdimy = 8;
+
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(3);
+  fusion.addInput(tv0);
+
+  // tv1[I0, R1, R2] = tv0[I0, I1, I2]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Float(0), tv0);
+  fusion.addOutput(tv1);
+
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+  tv1->split(2, bdimx);
+  // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2]
+  tv1->split(1, bdimy);
+  // tv1[I0, R1o, R1i{8}, R2o, R2i{128}] = tv0[I0, I1, I2]
+
+  TensorView* tv2 = tv1->rFactor({3});
+  // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2]
+  // tv1[I0, R1o, R1i{8},      R2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}]
+
+  TensorView* tv3 = tv1->rFactor({1});
+  // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2]
+  // tv3[I0, R1o, I1i{8},      I2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}]
+  // tv1[I0,      R1i{8},      R2i{128}] = tv3[I0, R1o, I1i{8},      I2i{128}]
+
+  tv3->computeAt(tv1, 1);
+  tv2->computeAt(tv3, 2);
+
+  tv1->axis(0)->parallelize(ParallelType::BIDx);
+  tv2->axis(0)->parallelize(ParallelType::BIDx);
+  tv3->axis(0)->parallelize(ParallelType::BIDx);
+
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+
+  tv1->axis(-2)->parallelize(ParallelType::TIDy);
+  tv3->axis(-2)->parallelize(ParallelType::TIDy);
+  tv2->axis(-3)->parallelize(ParallelType::TIDy);
+
+  int numel_x = 650;
+  int numel_y = 1000;
+  int numel_z = 1000;
+
+  prog.device_ = 0;
+  prog.grid(numel_x);
+  prog.block(bdimx, bdimy);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y, numel_z}, options);
+  at::Tensor cg_output = at::empty({numel_x}, options);
+
+  torch::jit::fuser::cuda::compileKernel(&prog);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+  auto aten_output = input.sum({1, 2});
+  TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+void testGPU_FusionReductionTFT() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
+
+  // tv1[I0, R1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+
+  fusion.addOutput(tv1);
+
+  int numel_x = 1025;
+  int numel_y = 129;
+  int tidx = 16;
+  int tidy = 8;
+  int tidz = 8;
+
+  tv1->split(1, tidx);
+  // tv1[I0, R1o, R1i{tidx}]
+
+  tv1->split(1, tidz);
+  // tv1[I0, R1oo, R1Oi{tidz}, R1R1i{tidx}]
+
+  tv1->split(0, tidy);
+  // tv1[I0o, I0i, R1oo, R1Oi{tidz}, R1R1i{tidx}]
+
+  TensorView* tv2 = tv1->rFactor({2});
+  // tv2[I0o, I0i, R1oo, I1Oi{tidz}, I11i{tidx}]
+  // tv1[I0o, I0i,       R1Oi{tidz}, R1R1i{tidx}]
+
+  tv2->computeAt(tv1, 2);
+
+  tv1->axis(1)->parallelize(ParallelType::TIDy);
+
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+
+  tv1->axis(-2)->parallelize(ParallelType::TIDz);
+  tv2->axis(-2)->parallelize(ParallelType::TIDz);
+
+  prog.device_ = 0;
+  prog.grid(1);
+  prog.block(tidx, tidy, tidz);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
+  at::Tensor cg_output = at::empty({numel_x}, options);
+
+  torch::jit::fuser::cuda::compileKernel(&prog);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+  c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();
+  AT_CUDA_CHECK(cudaStreamSynchronize(stream));
+
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+void testGPU_FusionSimpleBCast() {
   {
-    Fusion fusion;
+    torch::jit::fuser::cuda::CudaKernel prog;
+    Fusion& fusion = *prog.fusion_;
     FusionGuard fg(&fusion);
 
     // Set up your input tensor views
@@ -2557,18 +2876,1008 @@
     fusion.addInput(tv0);
     fusion.addInput(tv1);
 
-    TensorView* tv2 = broadcast(tv0, {true, false, false});
-    TensorView* tv3 = broadcast(tv1, {false, false, true});
+    TensorView* tv2 = broadcast(tv0, {false, false, true});
+    TensorView* tv3 = broadcast(tv1, {true, false, false});
 
-    TensorView* tv4 = mul(tv3, tv2);
+    TensorView* tv4 = add(tv2, tv3);
+    tv4->split(-1, 4);
+    tv4->split(0, 8);
     fusion.addOutput(tv4);
 
     tv0->computeAt(tv4, -1);
     tv1->computeAt(tv4, -1);
 
-    // GPULower lower(&fusion);
-    // lower.printKernel(std::cout);
+    tv4->axis(0)->parallelize(ParallelType::BIDx);
+    tv4->axis(-1)->parallelize(ParallelType::TIDx);
+
+    constexpr int x = 63, y = 33, z = 15;
+
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+
+    at::Tensor t0 = at::randn({x, y}, options);
+    at::Tensor t1 = at::randn({y, z}, options);
+
+    at::Tensor cg_output = at::empty({x, y, z}, options);
+
+    prog.device_ = 0;
+    prog.grid(ceilDiv_(x, 8));
+    prog.block(4);
+    torch::jit::fuser::cuda::compileKernel(&prog);
+    torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});
+
+    auto t2 = t0.unsqueeze(-1).expand({x, y, z});
+    auto t3 = t1.expand({x, y, z});
+    auto t4 = t2.add(t3);
+
+    TORCH_CHECK(t4.allclose(cg_output));
   }
+
+  {
+    torch::jit::fuser::cuda::CudaKernel prog;
+    Fusion& fusion = *prog.fusion_;
+    FusionGuard fg(&fusion);
+
+    // Set up your input tensor views
+    TensorView* tv0 = makeDummyTensor(2);
+    TensorView* tv1 = makeDummyTensor(2);
+    fusion.addInput(tv0);
+    fusion.addInput(tv1);
+
+    // TODO add pointwise ops on the begining before the bcast.
+
+    TensorView* tv2 = broadcast(tv0, {false, false, true});
+    TensorView* tv3 = broadcast(tv1, {true, false, false});
+
+    TensorView* tv4 = add(tv2, tv3);
+
+    tv4->merge(0, 1);
+
+    fusion.addOutput(tv4);
+
+    tv0->computeAt(tv4, -1);
+    tv1->computeAt(tv4, -1);
+
+    tv4->axis(0)->parallelize(ParallelType::BIDx);
+
+    constexpr int x = 63, y = 33, z = 15;
+
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+
+    at::Tensor t0 = at::randn({x, y}, options);
+    at::Tensor t1 = at::randn({y, z}, options);
+
+    at::Tensor cg_output = at::empty({x, y, z}, options);
+
+    prog.device_ = 0;
+    prog.grid(x * y);
+    prog.block(1);
+    torch::jit::fuser::cuda::compileKernel(&prog);
+    torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});
+
+    auto t2 = t0.unsqueeze(-1).expand({x, y, z});
+    auto t3 = t1.expand({x, y, z});
+    auto t4 = t2.add(t3);
+
+    TORCH_CHECK(t4.allclose(cg_output));
+  }
+}
+
+void testGPU_FusionSimpleGemm() {
+  {
+    torch::jit::fuser::cuda::CudaKernel prog;
+    Fusion& fusion = *prog.fusion_;
+    FusionGuard fg(&fusion);
+
+    // Set up your input tensor views
+    TensorView* tv0 = makeDummyTensor(2); // M, K
+    TensorView* tv1 = makeDummyTensor(2); // K, N
+    fusion.addInput(tv0);
+    fusion.addInput(tv1);
+
+    TensorView* tv2 = broadcast(tv0, {false, false, true});
+    // tv2[I0, I1, B] = tv0[I0, I1]
+
+    TensorView* tv3 = broadcast(tv1, {true, false, false});
+    // tv3[B, I1, I2] = tv1[I1, I2]
+
+    // tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2]
+    TensorView* tv4 = mul(tv2, tv3);
+    // tv5[I0, R1, I2] = tv4[I0, I1, I2]
+    TensorView* tv5 = sum(tv4, {1});
+    fusion.addOutput(tv5);
+
+    tv5->split(1, 32);
+    // tv5[I0, R1o, R1i{32}, I2]
+
+    auto tv6 = tv5->rFactor({1});
+    // tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2]
+    // tv5[I0,    , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2]
+
+    tv5->split(0, 4);
+    tv5->split(-1, 4);
+    // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
+    // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
+
+    tv0->computeAt(tv5, -1);
+    tv1->computeAt(tv5, -1);
+
+    // tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
+    // tv5[I0o, I0i{4},    , R1i{32}, I2o, I2i{4}]
+    //--> (line symbolizes compute at location)
+    // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o]
+    // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
+    // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
+
+    tv0->computeAt(tv6, -1);
+    tv1->computeAt(tv6, -1);
+    // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
+    // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
+    // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
+
+    tv5->axis(0)->parallelize(ParallelType::BIDz);
+    tv5->axis(1)->parallelize(ParallelType::TIDz);
+
+    tv5->axis(-2)->parallelize(ParallelType::BIDy);
+    tv5->axis(-1)->parallelize(ParallelType::TIDy);
+
+    tv5->axis(2)->parallelize(ParallelType::TIDx);
+    tv6->axis(2)->parallelize(ParallelType::TIDx);
+
+    constexpr int M = 65, K = 33, N = 17;
+
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+
+    at::Tensor t0 = at::randn({M, K}, options);
+    at::Tensor t1 = at::randn({K, N}, options);
+
+    at::Tensor cg_output = at::empty({M, N}, options);
+
+    prog.device_ = 0;
+    prog.grid(1, ceilDiv_(N, 4), ceilDiv_(M, 4));
+
+    prog.block(32, 4, 4);
+    torch::jit::fuser::cuda::compileKernel(&prog);
+    torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});
+
+    auto t2 = t0.matmul(t1);
+    TORCH_CHECK(
+        t2.allclose(cg_output, 1e-5, 1e-5),
+        "Error of: ",
+        t2.sub(cg_output).abs().max());
+  }
+}
+
+// This test currently requires a combination of broadcast and reduction
+// operations and parellelization strategy that is currently not supported.
+// It is a goal to get this example working and this test is added so we
+// can continue working on getting this example fixed. Right now it
+// produces an incorrect result. Either we need to error coherently on the
+// optimization strategy we don't support and set this test to one we do support
+// or we need to get this schedule working correctly.
+void testGPU_FusionSoftmax() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  // Set up your input tensor views
+  TensorView* input_tv0 = makeDummyTensor(3);
+  fusion.addInput(input_tv0);
+
+  TensorView* max_val_tv1 =
+      reductionOp(BinaryOpType::Max, {2}, new Float(0), input_tv0);
+  TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true});
+  TensorView* exp_tv3 = sub(input_tv0, bcast_max_tv2);
+  TensorView* sum_exp_tv4 =
+      reductionOp(BinaryOpType::Add, {2}, new Float(0), exp_tv3);
+  TensorView* bcast_sum_tv5 = broadcast(sum_exp_tv4, {false, false, true});
+  TensorView* output_tv6 = div(exp_tv3, bcast_sum_tv5);
+
+  max_val_tv1->split(-1, 32);
+  TensorView* max_val_rf_tv7 = max_val_tv1->rFactor({-2});
+  sum_exp_tv4->split(-1, 32);
+  TensorView* sum_exp_rf_tv8 = sum_exp_tv4->rFactor({-2});
+
+  exp_tv3->computeAt(sum_exp_rf_tv8, 2);
+
+  max_val_rf_tv7->axis(0)->parallelize(ParallelType::BIDx);
+  max_val_tv1->axis(0)->parallelize(ParallelType::BIDx);
+  bcast_max_tv2->axis(0)->parallelize(ParallelType::BIDx);
+  sum_exp_rf_tv8->axis(0)->parallelize(ParallelType::BIDx);
+  sum_exp_tv4->axis(0)->parallelize(ParallelType::BIDx);
+  bcast_sum_tv5->axis(0)->parallelize(ParallelType::BIDx);
+  output_tv6->axis(0)->parallelize(ParallelType::BIDx);
+
+  max_val_rf_tv7->axis(1)->parallelize(ParallelType::BIDy);
+  max_val_tv1->axis(1)->parallelize(ParallelType::BIDy);
+  bcast_max_tv2->axis(1)->parallelize(ParallelType::BIDy);
+  sum_exp_rf_tv8->axis(1)->parallelize(ParallelType::BIDy);
+  sum_exp_tv4->axis(1)->parallelize(ParallelType::BIDy);
+  bcast_sum_tv5->axis(1)->parallelize(ParallelType::BIDy);
+  output_tv6->axis(1)->parallelize(ParallelType::BIDy);
+
+  max_val_rf_tv7->axis(-1)->parallelize(ParallelType::TIDx);
+  max_val_tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  bcast_max_tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  exp_tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  sum_exp_rf_tv8->axis(-1)->parallelize(ParallelType::TIDx);
+  sum_exp_tv4->axis(-1)->parallelize(ParallelType::TIDx);
+  bcast_sum_tv5->axis(-1)->parallelize(ParallelType::TIDx);
+  output_tv6->axis(-1)->parallelize(ParallelType::TIDx);
+
+  fusion.addOutput(output_tv6);
+
+  prog.device_ = 0;
+  prog.grid(32, 32);
+  prog.block(32);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor t0 = at::randn({32, 32, 128}, options);
+  at::Tensor cg_output = at::empty({32, 32, 128}, options);
+  torch::jit::fuser::cuda::compileKernel(&prog);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {t0}, {cg_output});
+
+  auto t2 = at::_softmax(t0, -1, false);
+  // TORCH_CHECK(
+  //     t2.allclose(cg_output, 1e-5, 1e-5),
+  //     "Error of: ",
+  //     t2.sub(cg_output).abs().max());
+}
+// Similar to FusionReduction but uses grid reduction
+void testGPU_FusionGridReduction1() {
+  const int gdimx = 32;
+  const int bdimx = 128;
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
+
+  // tv1[I0, R1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+  fusion.addOutput(tv1);
+
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+  tv1->split(1, bdimx);
+  // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
+  tv1->split(1, gdimx);
+  // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
+
+  TensorView* tv2 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
+  // tv1[I0,        R1oi{32},  R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
+
+  // Incrementally, can print in between for debugging
+  tv0->computeAt(tv2, 1);
+  tv2->computeAt(tv1, 1);
+
+  // Re do it all at once, because why not.
+  tv0->computeAt(tv1, 1);
+
+  tv1->axis(0)->parallelize(ParallelType::BIDy);
+  tv1->axis(1)->parallelize(ParallelType::BIDx);
+  tv2->axis(2)->parallelize(ParallelType::BIDx);
+
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+
+  int numel_x = 10000;
+  int numel_y = 65000;
+
+  prog.device_ = 0;
+  prog.grid(gdimx, numel_x);
+  prog.block(bdimx);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
+  at::Tensor cg_output = at::empty({numel_x}, options);
+
+  torch::jit::fuser::cuda::compileKernel(&prog);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+// Same test as the above but uses BIDy and TIDx for reduction
+void testGPU_FusionGridReduction2() {
+  const int gdimy = 32;
+  const int bdimx = 128;
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
+
+  // tv1[I0, R1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+  fusion.addOutput(tv1);
+
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+  tv1->split(1, bdimx);
+  // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
+  tv1->split(1, gdimy);
+  // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
+
+  TensorView* tv2 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
+  // tv1[I0,        R1oi{32},  R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
+
+  // Incrementally, can print in between for debugging
+  tv0->computeAt(tv2, 1);
+  tv2->computeAt(tv1, 1);
+
+  // Re do it all at once, because why not.
+  tv0->computeAt(tv1, 1);
+
+  tv1->axis(0)->parallelize(ParallelType::BIDx);
+  tv1->axis(1)->parallelize(ParallelType::BIDy);
+  tv2->axis(2)->parallelize(ParallelType::BIDy);
+
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+
+  int numel_x = 10000;
+  int numel_y = 65000;
+
+  prog.device_ = 0;
+  prog.grid(numel_x, gdimy);
+  prog.block(bdimx);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
+  at::Tensor cg_output = at::empty({numel_x}, options);
+
+  torch::jit::fuser::cuda::compileKernel(&prog);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+// Same test but uses BIDy and BIDz for reduction. No TID used.
+void testGPU_FusionGridReduction3dim1() {
+  const int gdimz = 32;
+  const int gdimy = 128;
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
+
+  // tv1[I0, R1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+  fusion.addOutput(tv1);
+
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+  tv1->split(1, gdimy);
+  // tv1[I0, R1o, R1i{128}] = tv0[I0, I1]
+  tv1->split(1, gdimz);
+  // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1]
+
+  TensorView* tv2 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1]
+  // tv1[I0,        R1oi{32},  R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}]
+
+  // Incrementally, can print in between for debugging
+  tv0->computeAt(tv2, 1);
+  tv2->computeAt(tv1, 1);
+
+  // Re do it all at once, because why not.
+  tv0->computeAt(tv1, 1);
+
+  tv1->axis(0)->parallelize(ParallelType::BIDx);
+  tv1->axis(1)->parallelize(ParallelType::BIDz);
+  tv2->axis(2)->parallelize(ParallelType::BIDz);
+
+  tv1->axis(-1)->parallelize(ParallelType::BIDy);
+  tv2->axis(-1)->parallelize(ParallelType::BIDy);
+
+  int numel_x = 100;
+  int numel_y = 6500;
+
+  prog.device_ = 0;
+  prog.grid(numel_x, gdimy, gdimz);
+  // This number should not affect the output as TIDx is not
+  // used. All threads in a thread block redundantly computes the
+  // same value.
+  prog.block(128);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
+  at::Tensor cg_output = at::empty({numel_x}, options);
+
+  torch::jit::fuser::cuda::compileKernel(&prog);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+// Same as testGPU_FusionGridReduction3dim1 but reduces dimension 0
+void testGPU_FusionGridReduction3dim0() {
+  const int rdim = 0;
+  const int gdimy = 128;
+  const int gdimz = 32;
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
+
+  // tv1[R0, I1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {rdim}, new Float(0), tv0);
+  fusion.addOutput(tv1);
+
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+  tv1->split(rdim, gdimy);
+  // tv1[R0o, R0i{128}, I1] = tv0[I0, I1]
+  tv1->split(rdim, gdimz);
+  // tv1[R0oo, R0oi{32}, R0i{128}, I1] = tv0[I0, I1]
+
+  TensorView* tv2 = tv1->rFactor({rdim});
+  // tv2[R0oo, I0oi{32}, I0i{128}, I1] = tv0[I0, I1]
+  // tv1[      R0oi{32}, R0i{128}, I1] = tv2[R0oo, I0oi{32}, I0i{128}, I1]
+
+  // Note that computeAt isn't going to make anything better as there
+  // is no dynamically sized dimension.
+
+  // Map parallelism as [Serial, BIDz, BIDy, BIDx]
+  tv1->axis(-1)->parallelize(ParallelType::BIDx);
+  tv2->axis(-1)->parallelize(ParallelType::BIDx);
+  tv1->axis(-2)->parallelize(ParallelType::BIDy);
+  tv2->axis(-2)->parallelize(ParallelType::BIDy);
+  tv1->axis(-3)->parallelize(ParallelType::BIDz);
+  tv2->axis(-3)->parallelize(ParallelType::BIDz);
+
+  int numel_x = 6500;
+  int numel_y = 100;
+
+  prog.device_ = 0;
+  prog.grid(numel_y, gdimy, gdimz);
+  // This number should not affect the output as TIDx is not
+  // used. All threads in a thread block redundantly computes the
+  // same value.
+  prog.block(1);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
+  at::Tensor cg_output = at::empty({numel_y}, options);
+
+  torch::jit::fuser::cuda::compileKernel(&prog);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+  auto aten_output = input.sum({0});
+  TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+// This is similar to the FusionReduction, but swaps BIDx and TIDx
+void testGPU_FusionGridReduction4() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  const int bdimx = 128;
+  const int gdimx = 1024;
+
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
+
+  // tv1[I0, R1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+  fusion.addOutput(tv1);
+
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+  tv1->split(1, gdimx);
+  // tv1[I0, R1o, R1i{1024}] = tv0[I0, I1]
+  tv1->split(1, 4);
+  // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1]
+
+  TensorView* tv2 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1]
+  // tv1[I0,        R1oi{4},  R1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}]
+
+  TensorView* tv3 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1]
+  // tv3[I0,        R1oi{4}, Ir1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}]
+  // tv1[I0,                  R1i{1024}] = tv3[I0,        R1oi{4}, Ir1i{1024}]
+
+  // Incrementally, can print in between for debugging
+  tv0->computeAt(tv2, 1);
+  tv2->computeAt(tv3, 1);
+  tv3->computeAt(tv1, 1);
+
+  // Re do it all at once, because why not.
+  tv0->computeAt(tv1, 1);
+
+  tv2->axis(2)->parallelize(ParallelType::Unroll);
+  tv1->axis(0)->parallelize(ParallelType::TIDx);
+
+  tv1->axis(-1)->parallelize(ParallelType::BIDx);
+  tv2->axis(-1)->parallelize(ParallelType::BIDx);
+  tv3->axis(-1)->parallelize(ParallelType::BIDx);
+
+  int numel_x = bdimx;
+  int numel_y = 65000;
+
+  prog.device_ = 0;
+  prog.grid(gdimx);
+  prog.block(bdimx);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
+  at::Tensor cg_output = at::empty({numel_x}, options);
+
+  torch::jit::fuser::cuda::compileKernel(&prog);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+// Grid reduction with 2D thread blocks but only TIDx and BIDx are
+// mapped to a reduction dim
+void testGPU_FusionGridReduction5() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  const int bdimx = 64;
+  const int bdimy = 16;
+  const int gdimx = 4;
+
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
+
+  // tv1[I0, R1] = tv0[I0, I1]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
+  fusion.addOutput(tv1);
+
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+  tv1->split(1, bdimx);
+  // tv1[I0, R1o, R1i{64}] = tv0[I0, I1]
+  tv1->split(1, gdimx);
+  // tv1[I0, R1oo, R1oi{4}, R1i{64}] = tv0[I0, I1]
+
+  TensorView* tv2 = tv1->rFactor({1});
+  // tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}] = tv0[I0, I1]
+  // tv1[I0,        R1oi{4},  R1i{64}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}]
+
+  tv0->computeAt(tv1, 1);
+
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+
+  tv1->axis(-2)->parallelize(ParallelType::BIDx);
+  tv2->axis(-2)->parallelize(ParallelType::BIDx);
+
+  tv1->axis(0)->parallelize(ParallelType::TIDy);
+
+  int numel_x = bdimy;
+  int numel_y = 6500;
+
+  prog.device_ = 0;
+  prog.grid(gdimx);
+  prog.block(bdimx, bdimy);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y}, options);
+  at::Tensor cg_output = at::empty({numel_x}, options);
+
+  torch::jit::fuser::cuda::compileKernel(&prog);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+  auto aten_output = input.sum({1});
+  TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+// Similar to FusionGridReduction1 but with 3D tensors
+void testGPU_FusionGridReduction6() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(3);
+  fusion.addInput(tv0);
+
+  // tv1[I0, R1, R2] = tv0[I0, I1, I2]
+  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Float(0), tv0);
+  fusion.addOutput(tv1);
+
+  TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion.");
+
+  // Splitting for TID
+  tv1->split(2, 128);
+  // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2]
+
+  // Splitting for BID
+  tv1->split(1, 128);
+
+  // tv1[I0, R1o, R1i{128}, R2o, R2i{128}] = tv0[I0, I1, I2]
+
+  TensorView* tv2 = tv1->rFactor({3});
+  // tv2[I0, I1o, I1i{128}, R2o, I2i{128}]
+  // tv1[I0, R1o, R1i{128},      R2i{128}]
+
+  TensorView* tv3 = tv1->rFactor({1});
+  // tv2[I0, I1o, I1i{128}, R2o, I2i{128}]
+  // tv3[I0, R1o, I1i{128},      I2i{128}]
+  // tv1[I0,      R1i{128},      R2i{128}]
+
+  tv3->computeAt(tv1, 1);
+  tv2->computeAt(tv3, 3);
+
+  tv1->axis(0)->parallelize(ParallelType::BIDy);
+
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv3->axis(-1)->parallelize(ParallelType::TIDx);
+
+  tv1->axis(-2)->parallelize(ParallelType::BIDx);
+  tv2->axis(-3)->parallelize(ParallelType::BIDx);
+  tv3->axis(-2)->parallelize(ParallelType::BIDx);
+
+  int numel_x = 6500;
+  int numel_y = 200;
+  int numel_z = numel_y;
+
+  prog.device_ = 0;
+  prog.grid(128, numel_x);
+  prog.block(128);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({numel_x, numel_y, numel_z}, options);
+  at::Tensor cg_output = at::empty({numel_x}, options);
+
+  torch::jit::fuser::cuda::compileKernel(&prog);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+  auto aten_output = input.sum({1, 2});
+  TORCH_CHECK(aten_output.allclose(cg_output));
+}
+
+void testGPU_FusionNonRedAxisBind() {
+  int bid_x = 3;
+  int tid_x = 2;
+  int red_dim = 0;
+
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
+
+  TensorView* tv1 =
+      reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv0);
+  fusion.addOutput(tv1);
+
+  tv1->split(-1, tid_x);
+  tv1->axis(-2)->parallelize(ParallelType::BIDx);
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+
+  prog.device_ = 0;
+  prog.grid(bid_x);
+  prog.block(tid_x);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({16, bid_x * tid_x}, options);
+  at::Tensor cg_output = at::empty({bid_x * tid_x}, options);
+
+  torch::jit::fuser::cuda::compileKernel(&prog);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {cg_output});
+
+  auto aten_output = input.sum({red_dim});
+
+  TORCH_CHECK(
+      aten_output.allclose(cg_output),
+      "Error of: ",
+      aten_output.sub(cg_output).abs().max());
+}
+
+void testGPU_FusionSplitBCast() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  // Set up your input tensor views
+  TensorView* input_tv0 = makeDummyTensor(3);
+  TensorView* input_tv1 = makeDummyTensor(3);
+  fusion.addInput(input_tv0);
+  fusion.addInput(input_tv1);
+
+  TensorView* sum_tv2 =
+      reductionOp(BinaryOpType::Add, {2}, new Float(0), input_tv0);
+  TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true});
+  TensorView* output_tv4 = div(input_tv1, bcast_tv3);
+
+  sum_tv2->split(-1, 32);
+  TensorView* sum_rf_tv5 = sum_tv2->rFactor({-2});
+
+  bcast_tv3->split(-1, 32);
+  output_tv4->split(-1, 32);
+
+  sum_rf_tv5->axis(0)->parallelize(ParallelType::BIDx);
+  sum_tv2->axis(0)->parallelize(ParallelType::BIDx);
+  bcast_tv3->axis(0)->parallelize(ParallelType::BIDx);
+  output_tv4->axis(0)->parallelize(ParallelType::BIDx);
+
+  sum_rf_tv5->axis(1)->parallelize(ParallelType::BIDy);
+  sum_tv2->axis(1)->parallelize(ParallelType::BIDy);
+  bcast_tv3->axis(1)->parallelize(ParallelType::BIDy);
+  output_tv4->axis(1)->parallelize(ParallelType::BIDy);
+
+  sum_rf_tv5->axis(-1)->parallelize(ParallelType::TIDx);
+  sum_tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  bcast_tv3->axis(-1)->parallelize(ParallelType::TIDx);
+  output_tv4->axis(-1)->parallelize(ParallelType::TIDx);
+
+  fusion.addOutput(output_tv4);
+
+  prog.device_ = 0;
+  prog.grid(32, 32);
+  prog.block(32);
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor t0 = at::randn({32, 32, 128}, options);
+  at::Tensor t1 = at::randn({32, 32, 128}, options);
+  at::Tensor cg_output = at::empty({32, 32, 128}, options);
+  torch::jit::fuser::cuda::compileKernel(&prog);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});
+}
+
+void testGPU_FusionBCastInnerDim() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
+
+  // reduce then broadcast
+  auto tv1 = sum(tv0, {0});
+  auto tv2 = broadcast(tv1, {false, true});
+
+  TORCH_CHECK(!tv2->axis(0)->isReduction() && tv2->axis(1)->isBroadcast());
+}
+
+void testGPU_FusionBCastReduce() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  // Set up your input tensor views
+  TensorView* tv0 = makeDummyTensor(2);
+
+  auto tv1 = broadcast(tv0, {true, false, false});
+  auto tv2 = sum(tv1, {1});
+  TORCH_CHECK(
+      tv2->axis(0)->isBroadcast() && tv2->axis(1)->isReduction() &&
+      !tv2->axis(2)->isBroadcast() && !tv2->axis(2)->isReduction());
+}
+
+// Multiple consumer reduction with computeAt
+// https://github.com/csarofeen/pytorch/issues/110
+void testGPU_FusionReductionMultiConsumer() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+  TensorView* tv0 = makeDummyTensor(2);
+  fusion.addInput(tv0);
+  auto tv1 = unaryOp(UnaryOpType::Exp, tv0);
+  auto tv2 = reductionOp(BinaryOpType::Max, {-1}, new Float(0), tv1);
+  auto tv3 = reductionOp(BinaryOpType::Min, {-1}, new Float(0), tv1);
+  auto tv4 = add(tv2, tv3);
+  fusion.addOutput(tv4);
+  tv1->computeAt(tv2, -1);
+
+  TORCH_CHECK(
+      (tv1->getComputeAtView() == tv2 || tv1->getComputeAtView() == tv3) &&
+      tv1->getThisComputeAtAxis() == 2 && tv1->getRelativeComputeAtAxis() == 2);
+}
+
+void testGPU_FusionComputeAtExprOrder() {
+  {
+    for (int i = 0; i < 2; ++i) {
+      torch::jit::fuser::cuda::CudaKernel prog;
+      Fusion& fusion = *prog.fusion_;
+      FusionGuard fg(&fusion);
+
+      // Set up your input tensor views
+      TensorView* tv0 = makeDummyTensor(1);
+      fusion.addInput(tv0);
+
+      auto tv1 = add(tv0, new Float(1));
+      auto tv2 = add(tv0, new Float(1));
+      TensorView* tv3 = add(tv1, tv2);
+      if (i == 0) {
+        tv1->computeAt(tv3, -1);
+        fusion.addOutput(tv2);
+      } else {
+        tv2->computeAt(tv3, -1);
+        fusion.addOutput(tv1);
+      }
+      fusion.addOutput(tv3);
+
+      prog.device_ = 0;
+      prog.grid(1);
+      prog.block(1);
+
+      torch::jit::fuser::cuda::compileKernel(&prog);
+
+      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+      at::Tensor input = at::rand({100}, options);
+      at::Tensor output2 = at::empty_like(input, options);
+      at::Tensor output3 = at::empty_like(input, options);
+      torch::jit::fuser::cuda::runTestKernel(
+          &prog, {input}, {output2, output3});
+      auto aten_output = (input + 1) * 2;
+      TORCH_CHECK(
+          aten_output.allclose(output3),
+          "Error of: ",
+          aten_output.sub(output3).abs().max());
+    }
+  }
+  {
+    torch::jit::fuser::cuda::CudaKernel prog;
+    Fusion& fusion = *prog.fusion_;
+    FusionGuard fg(&fusion);
+
+    // Set up your input tensor views
+    TensorView* tv0 = makeDummyTensor(2);
+    fusion.addInput(tv0);
+
+    auto tv1 = add(tv0, new Float(1));
+    auto tv2 = add(tv0, new Float(1));
+    TensorView* tv3 = add(tv1, tv2);
+    fusion.addOutput(tv3);
+
+    tv3->split(-1, 32);
+
+    tv1->computeAt(tv3, -1);
+    tv2->computeAt(tv3, -2);
+
+    prog.device_ = 0;
+    prog.grid(1);
+    prog.block(1);
+
+    torch::jit::fuser::cuda::compileKernel(&prog);
+
+    auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+    at::Tensor input = at::rand({100, 100}, options);
+    at::Tensor output = at::empty_like(input, options);
+    torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {output});
+    auto aten_output = (input + 1) * 2;
+    TORCH_CHECK(
+        aten_output.allclose(output),
+        "Error of: ",
+        aten_output.sub(output).abs().max());
+  }
+}
+
+void testGPU_FusionZeroDimComputeAt() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  TensorView* tv0 = makeDummyTensor(1);
+  fusion.addInput(tv0);
+
+  auto tv1 = sum(tv0, {0});
+  auto tv2 = add(tv1, new Float(1));
+  fusion.addOutput(tv2);
+  TORCH_CHECK(tv2->nDims() == 0);
+  tv1->computeAt(tv2, 0);
+
+  prog.device_ = 0;
+  prog.grid(1);
+  prog.block(1);
+
+  torch::jit::fuser::cuda::compileKernel(&prog);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({100}, options);
+  at::Tensor output = at::empty({}, options);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {output});
+  auto aten_output = input.sum() + 1;
+  TORCH_CHECK(
+      aten_output.allclose(output),
+      "Error of: ",
+      aten_output.sub(output).abs().max());
+}
+
+void testGPU_FusionZeroDimBroadcast() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  TensorView* tv0 = makeDummyTensor(0);
+  fusion.addInput(tv0);
+
+  auto tv1 = broadcast(tv0, {true, true});
+  TORCH_CHECK(tv1->nDims() == 2);
+
+  TensorView* tv2 = makeDummyTensor(2);
+  fusion.addInput(tv2);
+
+  auto tv3 = add(tv1, tv2);
+  auto tv4 = sum(tv3, {0, 1});
+  fusion.addOutput(tv4);
+
+  tv3->computeAt(tv4, -1);
+
+  prog.device_ = 0;
+  prog.grid(1);
+  prog.block(1);
+
+  torch::jit::fuser::cuda::compileKernel(&prog);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input1 = at::rand({}, options);
+  at::Tensor input2 = at::rand({10, 10}, options);
+  at::Tensor output = at::empty({}, options);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input1, input2}, {output});
+  auto aten_output =
+      (input1.unsqueeze(-1).unsqueeze(-1).expand({10, 10}) + input2).sum();
+  TORCH_CHECK(
+      aten_output.allclose(output),
+      "Error of: ",
+      aten_output.sub(output).abs().max());
+}
+
+void testGPU_FusionZeroDimReduction() {
+  torch::jit::fuser::cuda::CudaKernel prog;
+  Fusion& fusion = *prog.fusion_;
+  FusionGuard fg(&fusion);
+
+  const int bdimx = 32;
+  const int gdimx = 32;
+
+  TensorView* tv0 = makeDummyTensor(1);
+  fusion.addInput(tv0);
+
+  auto tv1 = sum(tv0, {0});
+  fusion.addOutput(tv1);
+
+  tv1->split(0, bdimx);
+  tv1->split(0, gdimx);
+  auto tv2 = tv1->rFactor({0});
+
+  tv1->axis(-1)->parallelize(ParallelType::TIDx);
+  tv2->axis(-1)->parallelize(ParallelType::TIDx);
+  tv1->axis(-2)->parallelize(ParallelType::BIDx);
+  tv2->axis(-2)->parallelize(ParallelType::BIDx);
+
+  prog.device_ = 0;
+  prog.grid(gdimx);
+  prog.block(bdimx);
+
+  torch::jit::fuser::cuda::compileKernel(&prog);
+
+  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
+  at::Tensor input = at::rand({1000}, options);
+  at::Tensor output = at::empty({}, options);
+  torch::jit::fuser::cuda::runTestKernel(&prog, {input}, {output});
+  auto aten_output = input.sum();
+  TORCH_CHECK(
+      aten_output.allclose(output),
+      "Error of: ",
+      aten_output.sub(output).abs().max());
 }
 
 } // namespace jit
diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h
index 381108f..b6f464f 100644
--- a/test/cpp/jit/tests.h
+++ b/test/cpp/jit/tests.h
@@ -97,50 +97,75 @@
   _(FusionAliasing)
 
 #if defined(USE_CUDA)
-#define TH_FORALL_TESTS_CUDA(_)  \
-  _(ArgumentSpec)                \
-  _(CompleteArgumentSpec)        \
-  _(Fusion)                      \
-  _(GraphExecutor)               \
-  _(ModuleConversion)            \
-  _(Interp)                      \
-  _(GPU_IrGraphGenerator)        \
-  _(GPU_FusionDispatch)          \
-  _(GPU_FusionSimpleArith)       \
-  _(GPU_FusionExprEvalConstants) \
-  _(GPU_FusionExprEvalBindings)  \
-  _(GPU_FusionExprEvalBasic)     \
-  _(GPU_FusionExprEvalComplex)   \
-  _(GPU_FusionSimpleTypePromote) \
-  _(GPU_FusionMutator)           \
-  _(GPU_FusionRegister)          \
-  _(GPU_FusionTopoSort)          \
-  _(GPU_FusionTensor)            \
-  _(GPU_FusionTensorContiguity)  \
-  _(GPU_FusionTVSplit)           \
-  _(GPU_FusionTVMerge)           \
-  _(GPU_FusionTVReorder)         \
-  _(GPU_FusionEquality)          \
-  _(GPU_FusionReplaceAll)        \
-  _(GPU_FusionParser)            \
-  _(GPU_FusionDependency)        \
-  _(GPU_FusionCodeGen)           \
-  _(GPU_FusionCodeGen2)          \
-  _(GPU_FusionSimplePWise)       \
-  _(GPU_FusionExecKernel)        \
-  _(GPU_FusionForLoop)           \
-  _(GPU_FusionLoopUnroll)        \
-  _(GPU_FusionUnaryOps)          \
-  _(GPU_FusionBinaryOps)         \
-  _(GPU_FusionTernaryOps)        \
-  _(GPU_FusionCompoundOps)       \
-  _(GPU_FusionCastOps)           \
-  _(GPU_FusionAdvancedComputeAt) \
-  _(GPU_FusionScalarInputs)      \
-  _(GPU_FusionRFactorReplay)     \
-  _(GPU_FusionReduction)         \
-  _(GPU_FusionReduction2)        \
-  _(GPU_FusionSimpleBCast)
+#define TH_FORALL_TESTS_CUDA(_)   \
+  _(ArgumentSpec)                 \
+  _(CompleteArgumentSpec)         \
+  _(Fusion)                       \
+  _(GraphExecutor)                \
+  _(ModuleConversion)             \
+  _(Interp)                       \
+  _(GPU_IrGraphGenerator)         \
+  _(GPU_FusionDispatch)           \
+  _(GPU_FusionClear)              \
+  _(GPU_FusionCopy)               \
+  _(GPU_FusionMove)               \
+  _(GPU_FusionSimpleArith)        \
+  _(GPU_FusionExprEvalConstants)  \
+  _(GPU_FusionExprEvalBindings)   \
+  _(GPU_FusionExprEvalBasic)      \
+  _(GPU_FusionExprEvalComplex)    \
+  _(GPU_FusionExprEvalPostLower)  \
+  _(GPU_FusionSimpleTypePromote)  \
+  _(GPU_FusionMutator)            \
+  _(GPU_FusionRegister)           \
+  _(GPU_FusionTopoSort)           \
+  _(GPU_FusionTensor)             \
+  _(GPU_FusionTVSplit)            \
+  _(GPU_FusionTVMerge)            \
+  _(GPU_FusionTVReorder)          \
+  _(GPU_FusionEquality)           \
+  _(GPU_FusionReplaceAll)         \
+  _(GPU_FusionParser)             \
+  _(GPU_FusionDependency)         \
+  _(GPU_FusionCodeGen)            \
+  _(GPU_FusionCodeGen2)           \
+  _(GPU_FusionSimplePWise)        \
+  _(GPU_FusionExecKernel)         \
+  _(GPU_FusionForLoop)            \
+  _(GPU_FusionLoopUnroll)         \
+  _(GPU_FusionUnaryOps)           \
+  _(GPU_FusionBinaryOps)          \
+  _(GPU_FusionTernaryOps)         \
+  _(GPU_FusionCompoundOps)        \
+  _(GPU_FusionCastOps)            \
+  _(GPU_FusionAdvancedComputeAt)  \
+  _(GPU_FusionScalarInputs)       \
+  _(GPU_FusionRFactorReplay)      \
+  _(GPU_FusionReduction)          \
+  _(GPU_FusionReduction2)         \
+  _(GPU_FusionReduction3)         \
+  _(GPU_FusionReduction4)         \
+  _(GPU_FusionReduction5)         \
+  _(GPU_FusionReductionTFT)       \
+  _(GPU_FusionSimpleBCast)        \
+  _(GPU_FusionSimpleGemm)         \
+  _(GPU_FusionSoftmax)            \
+  _(GPU_FusionGridReduction1)     \
+  _(GPU_FusionGridReduction2)     \
+  _(GPU_FusionGridReduction3dim1) \
+  _(GPU_FusionGridReduction3dim0) \
+  _(GPU_FusionGridReduction4)     \
+  _(GPU_FusionGridReduction5)     \
+  _(GPU_FusionGridReduction6)     \
+  _(GPU_FusionNonRedAxisBind)     \
+  _(GPU_FusionBCastInnerDim)      \
+  _(GPU_FusionBCastReduce)        \
+  _(GPU_FusionSplitBCast)         \
+  _(GPU_FusionComputeAtExprOrder) \
+  _(GPU_FusionZeroDimComputeAt)   \
+  _(GPU_FusionZeroDimBroadcast)   \
+  _(GPU_FusionZeroDimReduction)   \
+  _(GPU_FusionReductionMultiConsumer)
 #else
 #define TH_FORALL_TESTS_CUDA(_) \
   _(ArgumentSpec)               \
diff --git a/test/run_test.py b/test/run_test.py
index f912c51..53c155b 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -30,6 +30,8 @@
     'distributed/test_c10d_spawn',
     'test_cuda',
     'test_jit_cuda_fuser',
+    'test_jit_cuda_fuser_legacy',
+    'test_jit_cuda_fuser_profiling',
     'test_cuda_primary_ctx',
     'test_dataloader',
     'distributed/test_data_parallel',
diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py
index 2a7421f..cb74a74 100644
--- a/test/test_jit_cuda_fuser.py
+++ b/test/test_jit_cuda_fuser.py
@@ -8,10 +8,12 @@
 
 import torch
 
-from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR
+from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, skipIfRocm
 from torch.testing._internal.codegen.random_topo_test import runDefaultTestWithSeed
 
 from test_jit import JitTestCase, RUN_CUDA
+import itertools
+import numpy as np
 
 if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
     torch._C._jit_set_profiling_executor(True)
@@ -321,14 +323,14 @@
         where_jit = torch.jit.script(where)
         self._run_helper(where_jit, where, x, y, cond)
 
-        def lerp(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
+        def lerp(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
             o = torch.rand_like(x)
             o = o * torch.lerp(x, y, z)
             return o
         lerp_jit = torch.jit.script(lerp)
         self._run_helper(lerp_jit, lerp, x, y, z)
 
-        def lerp_scale(x : torch.Tensor, y : torch.Tensor, z: float):
+        def lerp_scale(x: torch.Tensor, y: torch.Tensor, z: float):
             o = torch.rand_like(x)
             o = o * torch.lerp(x, y, z)
             return o
@@ -342,21 +344,21 @@
         y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
         z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
 
-        def addcmul(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor, value : float):
+        def addcmul(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, value: float):
             o = torch.add(x, 0.5)
             o = torch.addcmul(o, y, z, value=value)
             return o
         addcmul_jit = torch.jit.script(addcmul)
         self._run_helper(addcmul_jit, addcmul, x, y, z, 2.0)
 
-        def addcmul_no_alpha(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
+        def addcmul_no_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
             o = torch.add(x, 0.5)
             o = torch.addcmul(o, y, z)
             return o
         addcmul_no_alpha_jit = torch.jit.script(addcmul_no_alpha)
         self._run_helper(addcmul_no_alpha_jit, addcmul_no_alpha, x, y, z)
 
-        def addcmul_const_alpha(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
+        def addcmul_const_alpha(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
             o = torch.add(x, 0.5)
             o = torch.addcmul(o, y, z, value=0.75)
             return o
@@ -393,6 +395,109 @@
         os.environ["PYTORCH_CUDA_FUSER_DISABLE_FALLBACK"] = "1"
         self.assertTrue(runDefaultTestWithSeed(28449))
 
+    def _compare(self, desc, inp1, inp2, error):
+        a = inp1.clone().detach().cpu().numpy()
+        b = inp2.clone().detach().cpu().numpy()
+        close = np.allclose(a, b, error, error)
+        if not close:
+            print(desc, close)
+            z = a - b
+            index = (np.abs(z) >= error + error * np.abs(b)).nonzero()
+            print("dif    : ", z[index])
+            print("inp1   : ", a[index])
+            print("inp2   : ", b[index])
+        return close
+
+    def _reduction_helper(self, sizes, reduction_axis, dtype, device):
+        class MyReduction(torch.nn.Module):
+            __constants__ = ['reduction_axis']
+
+            def __init__(self):
+                super(MyReduction, self).__init__()
+                self.reduction_axis = reduction_axis
+
+            def forward(self, x: torch.Tensor, y: torch.Tensor):
+                o = torch.add(x, y)
+                o = torch.sum(o, dim=self.reduction_axis)
+                return o
+
+        t = MyReduction()
+        x = torch.randn(sizes, dtype=dtype, device=device)
+        y = torch.randn(sizes, dtype=dtype, device=device)
+        t_jit = torch.jit.script(t)
+        jit_o = t_jit(x, y)
+        jit_o = t_jit(x, y)
+        o = t(x, y)
+        for oo, jit_oo in zip(o, jit_o):
+            self.assertEqual(oo.dtype, jit_oo.dtype)
+            # numerical issues here due to our scheduling.
+            # can't use `self.assertEqual(oo, jit_oo)`
+            self.assertTrue(self._compare("comparing output failed", oo, jit_oo, 1e-4))
+        self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GROUP)
+
+    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
+                     ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
+    @skipIfRocm
+    def test_reduction(self):
+        for x in ([7, 8, 12], [12, 8, 7, 9, 15], [128, 16, 8, 32]):
+            # note that num_dim is exclusive from len(x), so we are not reducing
+            # to single element (codegen limitation at this moment)
+            for num_reduce_dim in range(1, len(x)):
+                for axes in itertools.combinations(range(len(x)), num_reduce_dim):
+                    self._reduction_helper((12, 8, 7, 4, 8), axes, torch.float32, "cuda")
+
+    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
+                     ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
+    @skipIfRocm
+    def test_pw_single_reduction_partition(self):
+        sizes = [8, 8, 8]
+        dtype = torch.float
+        device = "cuda"
+        x = torch.randn(sizes, dtype=dtype, device=device)
+        y = torch.randn(sizes, dtype=dtype, device=device)
+        z = torch.randn(sizes, dtype=dtype, device=device)
+
+        def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
+            o = torch.add(x, y)
+            o = torch.sum(o, dim=[0])
+            o = torch.add(o, z)
+            return o
+        t_jit = torch.jit.script(t)
+        jit_o = t_jit(x, y, z)
+        jit_o = t_jit(x, y, z)
+        o = t(x, y, z)
+        for oo, jit_oo in zip(o, jit_o):
+            self.assertEqual(oo.dtype, jit_oo.dtype)
+            self.assertEqual(oo, jit_oo)
+        self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP)
+
+    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
+                     ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
+    @skipIfRocm
+    def test_single_reduction_broadcast(self):
+        dtype = torch.float
+        device = "cuda"
+        x = torch.randn([7, 4, 8], dtype=dtype, device=device)
+        y = torch.randn([4, 8], dtype=dtype, device=device)
+        z = torch.randn([1, 4, 8], dtype=dtype, device=device)
+
+        def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
+            o = torch.add(x, y)
+            o = torch.add(o, z)
+            o = torch.sum(o, dim=[0])
+            return o
+        t_jit = torch.jit.script(t)
+        jit_o = t_jit(x, y, z)
+        jit_o = t_jit(x, y, z)
+        o = t(x, y, z)
+        for oo, jit_oo in zip(o, jit_o):
+            self.assertEqual(oo.dtype, jit_oo.dtype)
+            self.assertEqual(oo, jit_oo)
+        self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP)
+
 
 class TestPassManagerCudaFuser(JitTestCase):
 
@@ -441,5 +546,6 @@
         self.assertTrue(torch._C._jit_set_nvfuser_enabled(False))
         self.assertFalse(torch._C._jit_nvfuser_enabled())
 
+
 if __name__ == '__main__':
     run_tests()
diff --git a/test/test_jit_cuda_fuser_legacy.py b/test/test_jit_cuda_fuser_legacy.py
new file mode 100644
index 0000000..4b9959c
--- /dev/null
+++ b/test/test_jit_cuda_fuser_legacy.py
@@ -0,0 +1,6 @@
+import sys
+sys.argv.append("--ge_config=legacy")
+from test_jit_cuda_fuser import *
+
+if __name__ == '__main__':
+    run_tests()
diff --git a/test/test_jit_cuda_fuser_profiling.py b/test/test_jit_cuda_fuser_profiling.py
new file mode 100644
index 0000000..e2869ec
--- /dev/null
+++ b/test/test_jit_cuda_fuser_profiling.py
@@ -0,0 +1,6 @@
+import sys
+sys.argv.append("--ge_config=profiling")
+from test_jit_cuda_fuser import *
+
+if __name__ == '__main__':
+    run_tests()
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index bdc0b1c..b194e4f 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -331,20 +331,26 @@
     "torch/csrc/autograd/profiler_cuda.cpp",
     "torch/csrc/autograd/functions/comm.cpp",
     "torch/csrc/jit/codegen/cuda/arith.cpp",
+    "torch/csrc/jit/codegen/cuda/compute_at.cpp",
     "torch/csrc/jit/codegen/cuda/dispatch.cpp",
     "torch/csrc/jit/codegen/cuda/expr_evaluator.cpp",
     "torch/csrc/jit/codegen/cuda/fusion.cpp",
     "torch/csrc/jit/codegen/cuda/graph_fuser.cpp",
     "torch/csrc/jit/codegen/cuda/index_compute.cpp",
     "torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp",
+    "torch/csrc/jit/codegen/cuda/ir_cloner.cpp",
     "torch/csrc/jit/codegen/cuda/ir_graphviz.cpp",
     "torch/csrc/jit/codegen/cuda/ir_nodes.cpp",
     "torch/csrc/jit/codegen/cuda/ir_iostream.cpp",
     "torch/csrc/jit/codegen/cuda/iter_visitor.cpp",
     "torch/csrc/jit/codegen/cuda/kernel.cpp",
     "torch/csrc/jit/codegen/cuda/kernel_cache.cpp",
+    "torch/csrc/jit/codegen/cuda/lower_index.cpp",
     "torch/csrc/jit/codegen/cuda/lower_loops.cpp",
+    "torch/csrc/jit/codegen/cuda/lower_unroll.cpp",
+    "torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp",
     "torch/csrc/jit/codegen/cuda/lower_utils.cpp",
+    "torch/csrc/jit/codegen/cuda/lower_validation.cpp",
     "torch/csrc/jit/codegen/cuda/lower2device.cpp",
     "torch/csrc/jit/codegen/cuda/manager.cpp",
     "torch/csrc/jit/codegen/cuda/shape_inference.cpp",
@@ -352,7 +358,6 @@
     "torch/csrc/jit/codegen/cuda/parser.cpp",
     "torch/csrc/jit/codegen/cuda/partition.cpp",
     "torch/csrc/jit/codegen/cuda/predicate_compute.cpp",
-    "torch/csrc/jit/codegen/cuda/tensor_meta.cpp",
     "torch/csrc/jit/codegen/cuda/tensor_view.cpp",
     "torch/csrc/jit/codegen/cuda/transform_iter.cpp",
     "torch/csrc/jit/codegen/cuda/transform_replay.cpp",
diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp
index 0b3785b..8cd7ab9 100644
--- a/torch/csrc/jit/codegen/cuda/arith.cpp
+++ b/torch/csrc/jit/codegen/cuda/arith.cpp
@@ -311,6 +311,19 @@
 TORCH_CUDA_API TensorView* lt(TensorView* v1, TensorView* v2) {
   return arithOpOverloads(lt, v1, v2);
 }
+// eq
+TORCH_CUDA_API Val* eq(Val* v1, Val* v2) {
+  return binaryOp(BinaryOpType::Eq, v1, v2);
+}
+TORCH_CUDA_API TensorView* eq(TensorView* v1, Val* v2) {
+  return arithOpOverloads(eq, v1, v2);
+}
+TORCH_CUDA_API TensorView* eq(Val* v1, TensorView* v2) {
+  return arithOpOverloads(eq, v1, v2);
+}
+TORCH_CUDA_API TensorView* eq(TensorView* v1, TensorView* v2) {
+  return arithOpOverloads(eq, v1, v2);
+}
 // ceilDiv
 TORCH_CUDA_API Val* ceilDiv(Val* v1, Val* v2) {
   return binaryOp(BinaryOpType::CeilDiv, v1, v2);
@@ -348,9 +361,10 @@
 
 // REDUCTION OPERATIONS
 
-namespace {
 // TODO: How do we adjust this so we can reduce to a single scalar value?
-TensorView* newForReduction(TensorView* tv, std::vector<unsigned int> axes) {
+static TensorView* newForReduction(
+    TensorView* tv,
+    const std::vector<unsigned int>& axes) {
   auto orig_domain = TensorDomain::noReductions(tv->getRootDomain());
   std::set<unsigned int> axes_set(axes.begin(), axes.end());
 
@@ -363,25 +377,35 @@
       (*(axes_set.rbegin())) < orig_domain.size(),
       "Error setting up reduction, reduction axis is outside nDims. Keep in mind reductions are relative to root domains, not modified views.");
 
-  for (decltype(orig_domain.size()) dim = 0; dim < orig_domain.size(); dim++) {
-    IterDomain* id = orig_domain[dim];
-
+  for (size_t dim = 0; dim < orig_domain.size(); dim++) {
     bool isReduction = false;
-    if ((*axes_set.begin()) == dim) {
+    if (!axes_set.empty() && *axes_set.begin() == dim) {
       isReduction = true;
       axes_set.erase(axes_set.begin());
     }
 
+    const IterDomain* id = orig_domain[dim];
+
+    TORCH_CHECK(
+        !(isReduction && id->isBroadcast()),
+        "Cannot reduce an axis that is marked as broadcasted as it has an undetermined size. Tried to reduce ID = ",
+        id,
+        " of tensor ",
+        tv);
+
     new_domain.push_back(new IterDomain(
-        id->start(), id->extent(), ParallelType::Serial, isReduction));
+        id->start(),
+        id->extent(),
+        ParallelType::Serial,
+        isReduction,
+        false,
+        id->isBroadcast()));
   }
 
   TensorDomain* td = new TensorDomain(new_domain);
   return new TensorView(td, tv->getDataType().value());
 }
 
-} // namespace
-
 TensorView* reductionOp(
     BinaryOpType reduction_op_type,
     const std::vector<int>& axes,
@@ -395,6 +419,8 @@
       TensorDomain::sameAs(tv->getRootDomain(), tv->domain()->domain()),
       "Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt.");
 
+  TORCH_CHECK(tv->nDims() > 0, "Tried to reduce a 0-dim tensor");
+
   std::vector<unsigned int> uint_axes;
   for (int axis : axes) {
     if (axis < 0)
@@ -447,9 +473,9 @@
     if (ent)
       n_broadcasts++;
   TORCH_CHECK(
-      nBCastDims - n_broadcasts == inp->nDims(),
+      nBCastDims - n_broadcasts == inp->domain()->noReductions().size(),
       "Invalid broadcast, number of false entries in is_broadcast_dim expected to be ",
-      inp->nDims(),
+      inp->domain()->noReductions().size(),
       " but received ",
       nBCastDims - n_broadcasts);
 
@@ -468,7 +494,8 @@
       out_domain.push_back(new IterDomain(
           new Int(0), new Int(1), ParallelType::Serial, false, false, true));
     } else {
-      out_domain.push_back(inp->axis(iinp));
+      // Don't propagate reduction IDs through arith ops.
+      out_domain.push_back(inp->domain()->noReductions()[iinp]);
       iinp++;
     }
     ibdim++;
diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h
index 4961e41..4c61f2d 100644
--- a/torch/csrc/jit/codegen/cuda/arith.h
+++ b/torch/csrc/jit/codegen/cuda/arith.h
@@ -5,7 +5,7 @@
 #include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/type.h>
 
-struct Val;
+class Val;
 
 /*
  * The operations defined in this header is intended as user facing functions.
@@ -60,7 +60,8 @@
     TensorView* inp,
     const std::vector<bool>& is_broadcast_dim);
 
-// BINARY OPAERATIONS
+// BINARY OPERATIONS
+// add
 TORCH_CUDA_API Val* add(Val* v1, Val* v2);
 TORCH_CUDA_API TensorView* add(TensorView* v1, Val* v2);
 TORCH_CUDA_API TensorView* add(Val* v1, TensorView* v2);
@@ -90,6 +91,11 @@
 TORCH_CUDA_API TensorView* lt(TensorView* v1, Val* v2);
 TORCH_CUDA_API TensorView* lt(Val* v1, TensorView* v2);
 TORCH_CUDA_API TensorView* lt(TensorView* v1, TensorView* v2);
+// eq
+TORCH_CUDA_API Val* eq(Val* v1, Val* v2);
+TORCH_CUDA_API TensorView* eq(TensorView* v1, Val* v2);
+TORCH_CUDA_API TensorView* eq(Val* v1, TensorView* v2);
+TORCH_CUDA_API TensorView* eq(TensorView* v1, TensorView* v2);
 // ceilDiv
 TORCH_CUDA_API Val* ceilDiv(Val* v1, Val* v2);
 TORCH_CUDA_API TensorView* ceilDiv(TensorView* v1, Val* v2);
diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp
new file mode 100644
index 0000000..eea3f4a
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp
@@ -0,0 +1,459 @@
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
+
+#include <torch/csrc/jit/codegen/cuda/compute_at.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+// Actually applies transformation
+void ComputeAt::computeAt_impl(
+    TensorView* producer,
+    TensorView* consumer,
+    unsigned int consumer_compute_at_axis) {
+  // Reset view otherwise will conflict with replay.
+  producer->clearComputeAt();
+  // replay this as consumer / producer as consumer
+  auto replay = TransformReplay::replayPasC(
+      producer, consumer, (int)consumer_compute_at_axis);
+  producer->setComputeAt(consumer, replay.second);
+}
+
+// Runs replay, and checks computeAt position. If higher than that provided,
+// actually applies.
+void ComputeAt::maybe_computeAt_impl(
+    TensorView* producer,
+    TensorView* consumer,
+    unsigned int consumer_compute_at_axis) {
+  unsigned int prev_pos = 0;
+  if (producer->hasComputeAt())
+    prev_pos = producer->getThisComputeAtAxis();
+
+  auto replay = TransformReplay::replayPasC(
+      producer->domain(), consumer->domain(), (int)consumer_compute_at_axis);
+
+  if (replay.second > prev_pos) {
+    producer->setDomain(replay.first);
+    producer->setComputeAt(consumer, replay.second);
+  }
+}
+
+// Actually applies transformation
+void ComputeAt::forwardComputeAt_impl(
+    TensorView* producer,
+    TensorView* consumer,
+    unsigned int producer_compute_at_axis) {
+  // Reset view otherwise will conflict with replay. Don't think this is true
+  // anymore.
+  producer->clearComputeAt();
+  auto replay = TransformReplay::replayCasP(
+      consumer, producer, (int)producer_compute_at_axis);
+  producer->setComputeAt(consumer, replay.second);
+}
+
+namespace {
+// Wrapper around set_intersection
+template <typename T>
+std::set<T> set_intersection(const std::set<T>& set1, const std::set<T>& set2) {
+  std::set<T> intersection;
+  std::set_intersection(
+      set1.begin(),
+      set1.end(),
+      set2.begin(),
+      set2.end(),
+      std::inserter(intersection, intersection.begin()));
+  return intersection;
+}
+
+// convert an iterable of Val* to be an iterable of TensorView*
+template <typename T1, typename T2>
+T1 tv_iterable(const T2& val_iterable) {
+  T1 tv_iterable = T1();
+  std::transform(
+      val_iterable.begin(),
+      val_iterable.end(),
+      std::back_inserter(tv_iterable),
+      [](Val* v) {
+        TORCH_INTERNAL_ASSERT(
+            v->getValType().value() == ValType::TensorView,
+            "When following the computeAt dependency chain, a non TensorView value was found.");
+        return static_cast<TensorView*>(v);
+      });
+  return tv_iterable;
+}
+
+std::deque<std::deque<TensorView*>> getAllTVUseChains(TensorView* tv) {
+  // Grab all paths from producer to  of producer in fusion.
+  auto val_all_use_chains = DependencyCheck::getAllUseChains(tv);
+
+  // Convert dep chains to tensor view chains.
+  std::deque<std::deque<TensorView*>> producer_use_chains_;
+  for (const auto& val_dep_chain : val_all_use_chains)
+    producer_use_chains_.push_back(
+        tv_iterable<std::deque<TensorView*>>(val_dep_chain));
+  return producer_use_chains_;
+}
+} // namespace
+
+void ComputeAt::setCommonConsumer() {
+  // Convert the first chain to a set.
+  std::set<TensorView*> common_consumers(
+      producer_use_chains_.front().begin(), producer_use_chains_.front().end());
+
+  // Run through all use chains of producer, and intersect them to find common
+  // TVs
+  for (auto dep_chain : producer_use_chains_)
+    common_consumers = set_intersection(
+        common_consumers,
+        std::set<TensorView*>(dep_chain.begin(), dep_chain.end()));
+
+  auto all_chains =
+      DependencyCheck::getAllDependencyChains(producer_, consumer_);
+
+  // Right now we only support compute at if at some point in the graph consumer
+  // is dependent on producer.
+  TORCH_CHECK(
+      !all_chains.empty(),
+      "Compute At expects ",
+      producer_,
+      " is a dependency of ",
+      consumer_,
+      ", however it is not.");
+
+  // Remove all TVs from producer to consumer as common consumer must be at or
+  // after consumer
+  for (const auto& dep_chain : all_chains) {
+    auto tv_chain = tv_iterable<std::deque<TensorView*>>(dep_chain);
+    for (auto tv : tv_chain) {
+      if (tv != consumer_)
+        common_consumers.erase(tv);
+    }
+  }
+
+  // If there is a common consumer, grab the first one at or after consumer
+  common_consumer_ = nullptr;
+  if (!common_consumers.empty()) {
+    for (TensorView* tv : producer_use_chains_.front())
+      if (common_consumers.find(tv) != common_consumers.end()) {
+        common_consumer_ = tv;
+        break;
+      }
+    TORCH_INTERNAL_ASSERT(
+        common_consumer_ != nullptr,
+        "Hit a logical inconsistency in the computeAt pass.");
+  }
+}
+
+void ComputeAt::traverseAllKnown() {
+  std::deque<std::deque<Val*>> chains;
+
+  // propagate backwards through all dep chains from producer to consumer
+
+  // Grab all chains from common_consumer to producer
+  chains = DependencyCheck::getAllDependencyChains(producer_, consumer_);
+
+  TORCH_CHECK(
+      !chains.empty(),
+      "Producer and consumer in a computeAt call must have a dependency between them even if indirect.");
+
+  for (const auto& val_chain : chains) {
+    auto tv_chain = tv_iterable<std::deque<TensorView*>>(val_chain);
+    TensorView* running_consumer = nullptr;
+    TensorView* running_producer = tv_chain.back();
+    unsigned int running_consumer_pos = consumer_position_;
+
+    tv_chain.pop_back();
+
+    while (!tv_chain.empty()) {
+      running_consumer = running_producer;
+      running_producer = tv_chain.back();
+      tv_chain.pop_back();
+
+      if (compute_at_ed.find(running_producer) != compute_at_ed.end() &&
+          known_positions.find(running_producer) != known_positions.end()) {
+        running_consumer_pos = known_positions.at(running_producer);
+        continue;
+      }
+
+      computeAt_impl(running_producer, running_consumer, running_consumer_pos);
+      running_consumer_pos = running_producer->getThisComputeAtAxis();
+
+      // Update both compute_at_ed and compute_at_axis_lookup
+      compute_at_ed.emplace(running_producer);
+
+      if (known_positions.find(running_producer) != known_positions.end()) {
+        TORCH_INTERNAL_ASSERT(
+            known_positions.at(running_producer) ==
+                running_producer->getThisComputeAtAxis(),
+            "Hit a logical inconsistency in the computeAt pass.");
+      } else {
+        known_positions[running_producer] =
+            running_producer->getThisComputeAtAxis();
+      }
+    }
+  }
+
+  // propagate forward through all consumer use_chains or from consumer to
+  // common_consumer if common_consumer exists, mark as finished.
+
+  if (common_consumer_ == nullptr) {
+    chains = DependencyCheck::getAllUseChains(consumer_);
+  } else if (common_consumer_ != consumer_) {
+    chains =
+        DependencyCheck::getAllDependencyChains(consumer_, common_consumer_);
+  }
+
+  // propagate forward through all chains
+  unsigned int running_producer_compute_at = consumer_position_;
+
+  for (const auto& dep_chain : chains) {
+    TORCH_INTERNAL_ASSERT(
+        !dep_chain.empty(), "Computed an invalid common_consumer.");
+
+    std::deque<TensorView*> tv_dep_chain =
+        tv_iterable<std::deque<TensorView*>>(dep_chain);
+
+    TensorView* running_consumer = tv_dep_chain.front();
+    tv_dep_chain.pop_front();
+
+    TensorView* running_producer = nullptr;
+
+    while (!tv_dep_chain.empty()) {
+      running_producer = running_consumer;
+      running_consumer = tv_dep_chain.front();
+      tv_dep_chain.pop_front();
+
+      if (compute_at_ed.find(running_producer) != compute_at_ed.end() &&
+          known_positions.find(running_consumer) != known_positions.end()) {
+        running_producer_compute_at = known_positions.at(running_consumer);
+        continue;
+      }
+
+      forwardComputeAt_impl(
+          running_producer, running_consumer, running_producer_compute_at);
+
+      compute_at_ed.emplace(running_producer);
+
+      if (known_positions.find(running_consumer) != known_positions.end()) {
+        TORCH_INTERNAL_ASSERT(
+            known_positions.at(running_consumer) ==
+                running_producer->getRelativeComputeAtAxis(),
+            "Hit a logical inconsistency in computeAt pass.");
+      } else {
+        known_positions[running_consumer] =
+            running_producer->getRelativeComputeAtAxis();
+      }
+    }
+  }
+}
+
+// Similar to forward traversal in traverseAllKnown but we don't know if the
+// positions are actually correct
+void ComputeAt::traverseForward() {
+  // propagate forward through all *producer* use_chains or from *producer* to
+  // common_consumer if common_consumer exists.
+  std::deque<std::deque<Val*>> chains;
+  if (common_consumer_ == nullptr) {
+    chains = DependencyCheck::getAllUseChains(producer_);
+  } else if (common_consumer_ != consumer_) {
+    chains =
+        DependencyCheck::getAllDependencyChains(producer_, common_consumer_);
+  }
+
+  // propagate forward through all chains
+  for (const auto& dep_chain : chains) {
+    int running_producer_compute_at = known_positions.at(producer_);
+    TORCH_INTERNAL_ASSERT(
+        !dep_chain.empty(), "Computed an invalid common_consumer.");
+
+    std::deque<TensorView*> tv_dep_chain =
+        tv_iterable<std::deque<TensorView*>>(dep_chain);
+
+    TensorView* running_consumer = tv_dep_chain.front();
+    tv_dep_chain.pop_front();
+
+    TensorView* running_producer = nullptr;
+
+    while (!tv_dep_chain.empty()) {
+      running_producer = running_consumer;
+      running_consumer = tv_dep_chain.front();
+      tv_dep_chain.pop_front();
+
+      if (compute_at_ed.find(running_producer) != compute_at_ed.end() &&
+          known_positions.find(running_consumer) != known_positions.end()) {
+        running_producer_compute_at = known_positions.at(running_consumer);
+        continue;
+      }
+
+      forwardComputeAt_impl(
+          running_producer, running_consumer, running_producer_compute_at);
+
+      compute_at_ed.emplace(running_producer);
+
+      if (known_positions.find(running_consumer) != known_positions.end()) {
+        TORCH_INTERNAL_ASSERT(
+            known_positions.at(running_consumer) ==
+                running_producer->getRelativeComputeAtAxis(),
+            "Hit a logical inconsistency in computeAt pass.");
+      }
+    }
+  }
+}
+
+// Similar to backward traversal in traverseAllKnown but we should only apply
+// computeAt if it will increase computeAt positions.
+void ComputeAt::traverseBackward() {
+  // propagate *backward* through all *producer* use_chains or from *producer*
+  // to common_consumer if common_consumer exists. Only apply transform if
+  // increases computeAt position.
+  std::deque<std::deque<Val*>> chains;
+  if (common_consumer_ == nullptr) {
+    chains = DependencyCheck::getAllUseChains(producer_);
+  } else if (common_consumer_ != consumer_) {
+    chains =
+        DependencyCheck::getAllDependencyChains(producer_, common_consumer_);
+  }
+
+  for (const auto& val_chain : chains) {
+    auto tv_chain = tv_iterable<std::deque<TensorView*>>(val_chain);
+    TensorView* running_consumer = nullptr;
+    TensorView* running_producer = tv_chain.back();
+    auto it = known_positions.find(running_producer);
+
+    if (it == known_positions.end()) {
+      TORCH_INTERNAL_ASSERT(
+          common_consumer_ == nullptr,
+          "Hit a logical inconsistency in computeAt pass.");
+      continue;
+    }
+
+    unsigned int running_consumer_pos = it->second;
+
+    tv_chain.pop_back();
+
+    while (!tv_chain.empty()) {
+      running_consumer = running_producer;
+      running_producer = tv_chain.back();
+      tv_chain.pop_back();
+
+      if (compute_at_ed.find(running_producer) != compute_at_ed.end() &&
+          known_positions.find(running_producer) != known_positions.end()) {
+        running_consumer_pos = known_positions.at(running_producer);
+        continue;
+      }
+
+      // If we're already at consumer_position_ that's the max position we could
+      // hope for, don't bother running again.
+      if (running_producer->getThisComputeAtAxis() != consumer_position_) {
+        maybe_computeAt_impl(
+            running_producer, running_consumer, running_consumer_pos);
+      }
+      running_consumer_pos = running_producer->getThisComputeAtAxis();
+
+      if (known_positions.find(running_producer) != known_positions.end()) {
+        TORCH_INTERNAL_ASSERT(
+            known_positions.at(running_producer) ==
+                running_producer->getThisComputeAtAxis(),
+            "Hit a logical inconsistency in the computeAt pass.");
+      }
+    }
+  }
+}
+
+void ComputeAt::runPass() {
+  // Make sure the correct fusion is setup between this and consumer.
+  TORCH_CHECK(
+      producer_->fusion() == consumer_->fusion(),
+      producer_,
+      " and ",
+      consumer_,
+      " are not in the same fusion.");
+
+  // Make sure Fusion Guard is set appropriately
+  FusionGuard fg(producer_->fusion());
+
+  // Look through all the use chains of producer. Check if there's a single
+  // consumer for all chains at or after the consumer specified in the computeAt
+  // call.
+  setCommonConsumer();
+
+  // Propagate in a way we know result will be correct, which is forward from
+  // consumer and backward from consumer to producer
+  traverseAllKnown();
+
+  TORCH_INTERNAL_ASSERT(
+      producer_->hasComputeAt(),
+      "Hit a logical inconsistency in the computeAt pass.");
+
+  // Start at producer and traverse forward
+  traverseForward();
+
+  // Propagate backward from consumer or common consumer, check if it increase
+  // computeAt position on tensors, if so take it!
+  traverseBackward();
+}
+
+void ComputeAt::setupOutputs() {
+  if (common_consumer_ != nullptr)
+    return;
+
+  // output and its compute at position
+  std::unordered_map<TensorView*, int> touched_outputs;
+  for (auto tv : compute_at_ed) {
+    TORCH_INTERNAL_ASSERT(
+        tv->hasComputeAt(),
+        "Hit a logical inconsistency in the computeAt pass.");
+    auto ca_view = tv->getComputeAtView();
+    if (FusionGuard::getCurFusion()->hasOutput(ca_view)) {
+      touched_outputs[ca_view] = tv->getRelativeComputeAtAxis();
+    }
+  }
+
+  std::vector<TensorView*> touched_output_order(touched_outputs.size());
+
+  {
+    size_t i = 0;
+    for (auto out : FusionGuard::getCurFusion()->outputs()) {
+      if (out->getValType() == ValType::TensorView) {
+        if (touched_outputs.find(out->as<TensorView>()) !=
+            touched_outputs.end()) {
+          touched_output_order[i++] = out->as<TensorView>();
+        }
+      }
+    }
+    TORCH_INTERNAL_ASSERT(
+        i == touched_output_order.size(),
+        "Hit a logical inconsistency in the computeAt pass.");
+  }
+
+  for (size_t i = 0; i < touched_output_order.size() - 1; i++) {
+    touched_output_order[i]->setComputeAt(
+        touched_output_order[i + 1],
+        touched_outputs.at(touched_output_order[i]),
+        touched_outputs.at(touched_output_order[i + 1]));
+  }
+}
+
+ComputeAt::ComputeAt(
+    TensorView* _producer,
+    TensorView* _consumer,
+    unsigned int _consumer_position)
+    : producer_(_producer),
+      consumer_(_consumer),
+      consumer_position_(_consumer_position) {}
+
+void ComputeAt::run(
+    TensorView* producer,
+    TensorView* consumer,
+    unsigned int consumer_position) {
+  ComputeAt ca(producer, consumer, consumer_position);
+  ca.producer_use_chains_ = getAllTVUseChains(ca.producer_);
+  ca.setCommonConsumer();
+  ca.runPass();
+  ca.setupOutputs();
+}
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h
new file mode 100644
index 0000000..3c2434d
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/compute_at.h
@@ -0,0 +1,98 @@
+#pragma once
+
+#include <c10/util/Exception.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+#include <deque>
+#include <unordered_map>
+#include <vector>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+class TensorView;
+
+class ComputeAt {
+ public:
+  static void run(
+      TensorView* _producer,
+      TensorView* _consumer,
+      unsigned int _consumer_position);
+
+ private:
+  TensorView* producer_;
+  TensorView* consumer_;
+  unsigned int consumer_position_;
+
+  // Only keeping these as member functions as ComputeAt is friend of TensorView
+  // Don't want to keep expanding things that are friends of TV.
+  // Runs replayPasC and sets producer computeAt settings
+  void computeAt_impl(
+      TensorView* producer,
+      TensorView* consumer,
+      unsigned int consumer_compute_at_axis);
+
+  // Runs replay, and checks computeAt position of producer. If new position
+  // would be higher, actually runs operation.
+  void maybe_computeAt_impl(
+      TensorView* producer,
+      TensorView* consumer,
+      unsigned int consumer_compute_at_axis);
+
+  // Runs replayCasP and sets producer computeAt settings
+  void forwardComputeAt_impl(
+      TensorView* producer,
+      TensorView* consumer,
+      unsigned int producer_compute_at_axis);
+
+  // Look through all the use chains of producer. Check if there's a single
+  // consumer for all chains at or after the consumer specified in the computeAt
+  // call.
+  void setCommonConsumer();
+
+  // Propagate in a way we know result will be correct, which is forward from
+  // consumer and backward from consumer to producer
+  void traverseAllKnown();
+
+  // Traverse from producer to common_consumer if exists or through all uses of
+  // producer
+  void traverseForward();
+
+  // Propagate backward from consumer or common consumer, check if it increase
+  // computeAt position on tensors, if so take it!
+  void traverseBackward();
+
+  // Run the computeAt pass
+  void runPass();
+
+  // Set outputs relative to eachother if there is not a common consumer
+  void setupOutputs();
+
+  // Common consumer if it exists
+  TensorView* common_consumer_ = nullptr;
+
+  // Producer use chains set in, used in a few spots.
+  std::deque<std::deque<TensorView*>> producer_use_chains_;
+
+  // Order for forward computeAt pass
+  std::vector<std::pair<TensorView*, TensorView*>> forward_compute_at_order;
+
+  // Order for backward computeAt pass
+  std::vector<std::pair<TensorView*, TensorView*>> backward_compute_at_order;
+
+  // TensorViews we've set computeAt of, in this computeAt pass
+  std::unordered_set<TensorView*> compute_at_ed;
+
+  // TensorViews of which we know their correct computeAt position
+  std::unordered_map<TensorView*, unsigned int> known_positions;
+
+  ComputeAt(
+      TensorView* _producer,
+      TensorView* _consumer,
+      unsigned int _consumer_position);
+};
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp
index a3a5522..197c5e6 100644
--- a/torch/csrc/jit/codegen/cuda/dispatch.cpp
+++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp
@@ -36,7 +36,7 @@
  * }
  *
  * And therefore dispatch should never call:
- * ptr(mutator)->handle(static_cast<Statement*>(this));
+ * ptr(mutator)->handle(this->as<Statement>());
  */
 
 template <typename T>
@@ -45,35 +45,35 @@
     case ValType::Scalar:
       switch (*(val->getDataType())) {
         case DataType::Bool:
-          ptr(handler)->handle(static_cast<Bool*>(val));
+          ptr(handler)->handle(val->as<Bool>());
           return;
         case DataType::Float:
-          ptr(handler)->handle(static_cast<Float*>(val));
+          ptr(handler)->handle(val->as<Float>());
           return;
         case DataType::Half:
-          ptr(handler)->handle(static_cast<Half*>(val));
+          ptr(handler)->handle(val->as<Half>());
           return;
         case DataType::Int:
-          ptr(handler)->handle(static_cast<Int*>(val));
+          ptr(handler)->handle(val->as<Int>());
           return;
         default:
           break;
       }
       break;
     case ValType::IterDomain:
-      ptr(handler)->handle(static_cast<IterDomain*>(val));
+      ptr(handler)->handle(val->as<IterDomain>());
       return;
     case ValType::TensorDomain:
-      ptr(handler)->handle(static_cast<TensorDomain*>(val));
+      ptr(handler)->handle(val->as<TensorDomain>());
       return;
     case ValType::TensorView:
-      ptr(handler)->handle(static_cast<TensorView*>(val));
+      ptr(handler)->handle(val->as<TensorView>());
       return;
     case ValType::TensorIndex:
-      ptr(handler)->handle(static_cast<TensorIndex*>(val));
+      ptr(handler)->handle(val->as<TensorIndex>());
       return;
     case ValType::NamedScalar:
-      ptr(handler)->handle(static_cast<NamedScalar*>(val));
+      ptr(handler)->handle(val->as<NamedScalar>());
       return;
     default:
       break;
@@ -85,34 +85,34 @@
 void Expr::dispatch(T handler, Expr* expr) {
   switch (*(expr->getExprType())) {
     case ExprType::Split:
-      ptr(handler)->handle(static_cast<Split*>(expr));
+      ptr(handler)->handle(expr->as<Split>());
       return;
     case ExprType::Merge:
-      ptr(handler)->handle(static_cast<Merge*>(expr));
+      ptr(handler)->handle(expr->as<Merge>());
       return;
     case ExprType::UnaryOp:
-      ptr(handler)->handle(static_cast<UnaryOp*>(expr));
+      ptr(handler)->handle(expr->as<UnaryOp>());
       return;
     case ExprType::BinaryOp:
-      ptr(handler)->handle(static_cast<BinaryOp*>(expr));
+      ptr(handler)->handle(expr->as<BinaryOp>());
       return;
     case ExprType::TernaryOp:
-      ptr(handler)->handle(static_cast<TernaryOp*>(expr));
+      ptr(handler)->handle(expr->as<TernaryOp>());
       return;
     case ExprType::ReductionOp:
-      ptr(handler)->handle(static_cast<ReductionOp*>(expr));
+      ptr(handler)->handle(expr->as<ReductionOp>());
       return;
     case ExprType::BroadcastOp:
-      ptr(handler)->handle(static_cast<BroadcastOp*>(expr));
+      ptr(handler)->handle(expr->as<BroadcastOp>());
       return;
     case ExprType::ForLoop:
-      ptr(handler)->handle(static_cast<ForLoop*>(expr));
+      ptr(handler)->handle(expr->as<ForLoop>());
       return;
     case ExprType::IfThenElse:
-      ptr(handler)->handle(static_cast<IfThenElse*>(expr));
+      ptr(handler)->handle(expr->as<IfThenElse>());
       return;
     case ExprType::Allocate:
-      ptr(handler)->handle(static_cast<Allocate*>(expr));
+      ptr(handler)->handle(expr->as<Allocate>());
       return;
     default:
       TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
@@ -122,9 +122,9 @@
 template <typename T>
 void Statement::dispatch(T handler, Statement* stmt) {
   if (stmt->isVal()) {
-    ptr(handler)->handle(static_cast<Val*>(stmt));
+    ptr(handler)->handle(stmt->as<Val>());
   } else if (stmt->isExpr()) {
-    ptr(handler)->handle(static_cast<Expr*>(stmt));
+    ptr(handler)->handle(stmt->as<Expr>());
   } else
     TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
 }
@@ -135,35 +135,35 @@
     case ValType::Scalar:
       switch (*(val->getDataType())) {
         case DataType::Bool:
-          ptr(handler)->handle(static_cast<const Bool*>(val));
+          ptr(handler)->handle(val->as<Bool>());
           return;
         case DataType::Float:
-          ptr(handler)->handle(static_cast<const Float*>(val));
+          ptr(handler)->handle(val->as<Float>());
           return;
         case DataType::Half:
-          ptr(handler)->handle(static_cast<const Half*>(val));
+          ptr(handler)->handle(val->as<Half>());
           return;
         case DataType::Int:
-          ptr(handler)->handle(static_cast<const Int*>(val));
+          ptr(handler)->handle(val->as<Int>());
           return;
         default:
           break;
       }
       break;
     case ValType::IterDomain:
-      ptr(handler)->handle(static_cast<const IterDomain*>(val));
+      ptr(handler)->handle(val->as<IterDomain>());
       return;
     case ValType::TensorDomain:
-      ptr(handler)->handle(static_cast<const TensorDomain*>(val));
+      ptr(handler)->handle(val->as<TensorDomain>());
       return;
     case ValType::TensorView:
-      ptr(handler)->handle(static_cast<const TensorView*>(val));
+      ptr(handler)->handle(val->as<TensorView>());
       return;
     case ValType::TensorIndex:
-      ptr(handler)->handle(static_cast<const TensorIndex*>(val));
+      ptr(handler)->handle(val->as<TensorIndex>());
       return;
     case ValType::NamedScalar:
-      ptr(handler)->handle(static_cast<const NamedScalar*>(val));
+      ptr(handler)->handle(val->as<NamedScalar>());
       return;
     default:
       break;
@@ -175,34 +175,34 @@
 void Expr::constDispatch(T handler, const Expr* expr) {
   switch (*(expr->getExprType())) {
     case ExprType::Split:
-      ptr(handler)->handle(static_cast<const Split*>(expr));
+      ptr(handler)->handle(expr->as<Split>());
       return;
     case ExprType::Merge:
-      ptr(handler)->handle(static_cast<const Merge*>(expr));
+      ptr(handler)->handle(expr->as<Merge>());
       return;
     case ExprType::UnaryOp:
-      ptr(handler)->handle(static_cast<const UnaryOp*>(expr));
+      ptr(handler)->handle(expr->as<UnaryOp>());
       return;
     case ExprType::BinaryOp:
-      ptr(handler)->handle(static_cast<const BinaryOp*>(expr));
+      ptr(handler)->handle(expr->as<BinaryOp>());
       return;
     case ExprType::TernaryOp:
-      ptr(handler)->handle(static_cast<const TernaryOp*>(expr));
+      ptr(handler)->handle(expr->as<TernaryOp>());
       return;
     case ExprType::ReductionOp:
-      ptr(handler)->handle(static_cast<const ReductionOp*>(expr));
+      ptr(handler)->handle(expr->as<ReductionOp>());
       return;
     case ExprType::BroadcastOp:
-      ptr(handler)->handle(static_cast<const BroadcastOp*>(expr));
+      ptr(handler)->handle(expr->as<BroadcastOp>());
       return;
     case ExprType::ForLoop:
-      ptr(handler)->handle(static_cast<const ForLoop*>(expr));
+      ptr(handler)->handle(expr->as<ForLoop>());
       return;
     case ExprType::IfThenElse:
-      ptr(handler)->handle(static_cast<const IfThenElse*>(expr));
+      ptr(handler)->handle(expr->as<IfThenElse>());
       return;
     case ExprType::Allocate:
-      ptr(handler)->handle(static_cast<const Allocate*>(expr));
+      ptr(handler)->handle(expr->as<Allocate>());
       return;
     default:
       TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
@@ -212,9 +212,9 @@
 template <typename T>
 void Statement::constDispatch(T handler, const Statement* stmt) {
   if (stmt->isVal()) {
-    ptr(handler)->handle(static_cast<const Val*>(stmt));
+    ptr(handler)->handle(stmt->as<Val>());
   } else if (stmt->isExpr()) {
-    ptr(handler)->handle(static_cast<const Expr*>(stmt));
+    ptr(handler)->handle(stmt->as<Expr>());
   } else
     TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
 }
@@ -228,7 +228,7 @@
  * implement Statement* mutate(Statement* stmt){ stmt->mutatorDispatch(this);
  * }
  * And therefore dispatch should never call:
- *   ptr(mutator)->mutate(static_cast<Statement*>(this));
+ *   ptr(mutator)->mutate(this->as<Statement>());
  */
 template <typename T>
 Statement* Val::mutatorDispatch(T mutator, Val* val) {
@@ -236,27 +236,27 @@
     case ValType::Scalar:
       switch (*(val->getDataType())) {
         case DataType::Bool:
-          return ptr(mutator)->mutate(static_cast<Bool*>(val));
+          return ptr(mutator)->mutate(val->as<Bool>());
         case DataType::Float:
-          return ptr(mutator)->mutate(static_cast<Float*>(val));
+          return ptr(mutator)->mutate(val->as<Float>());
         case DataType::Half:
-          return ptr(mutator)->mutate(static_cast<Half*>(val));
+          return ptr(mutator)->mutate(val->as<Half>());
         case DataType::Int:
-          return ptr(mutator)->mutate(static_cast<Int*>(val));
+          return ptr(mutator)->mutate(val->as<Int>());
         default:
           break;
       }
       break;
     case ValType::IterDomain:
-      return ptr(mutator)->mutate(static_cast<IterDomain*>(val));
+      return ptr(mutator)->mutate(val->as<IterDomain>());
     case ValType::TensorDomain:
-      return ptr(mutator)->mutate(static_cast<TensorDomain*>(val));
+      return ptr(mutator)->mutate(val->as<TensorDomain>());
     case ValType::TensorView:
-      return ptr(mutator)->mutate(static_cast<TensorView*>(val));
+      return ptr(mutator)->mutate(val->as<TensorView>());
     case ValType::TensorIndex:
-      return ptr(mutator)->mutate(static_cast<TensorIndex*>(val));
+      return ptr(mutator)->mutate(val->as<TensorIndex>());
     case ValType::NamedScalar:
-      return ptr(mutator)->mutate(static_cast<NamedScalar*>(val));
+      return ptr(mutator)->mutate(val->as<NamedScalar>());
     default:
       break;
   }
@@ -267,25 +267,25 @@
 Statement* Expr::mutatorDispatch(T mutator, Expr* expr) {
   switch (*(expr->getExprType())) {
     case ExprType::Split:
-      return ptr(mutator)->mutate(static_cast<Split*>(expr));
+      return ptr(mutator)->mutate(expr->as<Split>());
     case ExprType::Merge:
-      return ptr(mutator)->mutate(static_cast<Merge*>(expr));
+      return ptr(mutator)->mutate(expr->as<Merge>());
     case ExprType::UnaryOp:
-      return ptr(mutator)->mutate(static_cast<UnaryOp*>(expr));
+      return ptr(mutator)->mutate(expr->as<UnaryOp>());
     case ExprType::BinaryOp:
-      return ptr(mutator)->mutate(static_cast<BinaryOp*>(expr));
+      return ptr(mutator)->mutate(expr->as<BinaryOp>());
     case ExprType::TernaryOp:
-      return ptr(mutator)->mutate(static_cast<TernaryOp*>(expr));
+      return ptr(mutator)->mutate(expr->as<TernaryOp>());
     case ExprType::ReductionOp:
-      return ptr(mutator)->mutate(static_cast<ReductionOp*>(expr));
+      return ptr(mutator)->mutate(expr->as<ReductionOp>());
     case ExprType::BroadcastOp:
-      return ptr(mutator)->mutate(static_cast<BroadcastOp*>(expr));
+      return ptr(mutator)->mutate(expr->as<BroadcastOp>());
     case ExprType::ForLoop:
-      return ptr(mutator)->mutate(static_cast<ForLoop*>(expr));
+      return ptr(mutator)->mutate(expr->as<ForLoop>());
     case ExprType::IfThenElse:
-      return ptr(mutator)->mutate(static_cast<IfThenElse*>(expr));
+      return ptr(mutator)->mutate(expr->as<IfThenElse>());
     case ExprType::Allocate:
-      return ptr(mutator)->mutate(static_cast<Allocate*>(expr));
+      return ptr(mutator)->mutate(expr->as<Allocate>());
     default:
       TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
   }
@@ -294,10 +294,10 @@
 template <typename T>
 Statement* Statement::mutatorDispatch(T mutator, Statement* stmt) {
   if (stmt->isVal()) {
-    return ptr(mutator)->mutate(static_cast<Val*>(stmt));
+    return ptr(mutator)->mutate(stmt->as<Val>());
   }
   if (stmt->isExpr()) {
-    return ptr(mutator)->mutate(static_cast<Expr*>(stmt));
+    return ptr(mutator)->mutate(stmt->as<Expr>());
   }
   TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
 }
@@ -321,27 +321,19 @@
 template void Expr::dispatch(OptInDispatch, Expr*);
 template void Expr::dispatch(OptInDispatch*, Expr*);
 
-template void Statement::constDispatch(
-    OptOutConstDispatch,
-    const Statement* const);
-template void Statement::constDispatch(
-    OptOutConstDispatch*,
-    const Statement* const);
-template void Val::constDispatch(OptOutConstDispatch, const Val* const);
-template void Val::constDispatch(OptOutConstDispatch*, const Val* const);
-template void Expr::constDispatch(OptOutConstDispatch, const Expr* const);
-template void Expr::constDispatch(OptOutConstDispatch*, const Expr* const);
+template void Statement::constDispatch(OptOutConstDispatch, const Statement*);
+template void Statement::constDispatch(OptOutConstDispatch*, const Statement*);
+template void Val::constDispatch(OptOutConstDispatch, const Val*);
+template void Val::constDispatch(OptOutConstDispatch*, const Val*);
+template void Expr::constDispatch(OptOutConstDispatch, const Expr*);
+template void Expr::constDispatch(OptOutConstDispatch*, const Expr*);
 
-template void Statement::constDispatch(
-    OptInConstDispatch,
-    const Statement* const);
-template void Statement::constDispatch(
-    OptInConstDispatch*,
-    const Statement* const);
-template void Val::constDispatch(OptInConstDispatch, const Val* const);
-template void Val::constDispatch(OptInConstDispatch*, const Val* const);
-template void Expr::constDispatch(OptInConstDispatch, const Expr* const);
-template void Expr::constDispatch(OptInConstDispatch*, const Expr* const);
+template void Statement::constDispatch(OptInConstDispatch, const Statement*);
+template void Statement::constDispatch(OptInConstDispatch*, const Statement*);
+template void Val::constDispatch(OptInConstDispatch, const Val*);
+template void Val::constDispatch(OptInConstDispatch*, const Val*);
+template void Expr::constDispatch(OptInConstDispatch, const Expr*);
+template void Expr::constDispatch(OptInConstDispatch*, const Expr*);
 
 template Statement* Statement::mutatorDispatch(OptOutMutator, Statement*);
 template Statement* Statement::mutatorDispatch(OptOutMutator*, Statement*);
@@ -360,9 +352,11 @@
 void OptOutDispatch::handle(Statement* s) {
   Statement::dispatch(this, s);
 }
+
 void OptOutDispatch::handle(Expr* e) {
   Expr::dispatch(this, e);
 }
+
 void OptOutDispatch::handle(Val* v) {
   Val::dispatch(this, v);
 }
@@ -370,30 +364,36 @@
 void OptInDispatch::handle(Statement* s) {
   Statement::dispatch(this, s);
 }
+
 void OptInDispatch::handle(Expr* e) {
   Expr::dispatch(this, e);
 }
+
 void OptInDispatch::handle(Val* v) {
   Val::dispatch(this, v);
 }
 
-void OptOutConstDispatch::handle(const Statement* const s) {
+void OptOutConstDispatch::handle(const Statement* s) {
   Statement::constDispatch(this, s);
 }
-void OptOutConstDispatch::handle(const Expr* const e) {
+
+void OptOutConstDispatch::handle(const Expr* e) {
   Expr::constDispatch(this, e);
 }
-void OptOutConstDispatch::handle(const Val* const v) {
+
+void OptOutConstDispatch::handle(const Val* v) {
   Val::constDispatch(this, v);
 }
 
-void OptInConstDispatch::handle(const Statement* const s) {
+void OptInConstDispatch::handle(const Statement* s) {
   Statement::constDispatch(this, s);
 }
-void OptInConstDispatch::handle(const Expr* const e) {
+
+void OptInConstDispatch::handle(const Expr* e) {
   Expr::constDispatch(this, e);
 }
-void OptInConstDispatch::handle(const Val* const v) {
+
+void OptInConstDispatch::handle(const Val* v) {
   Val::constDispatch(this, v);
 }
 
@@ -415,9 +415,11 @@
 Statement* OptOutMutator::mutate(Statement* s) {
   return Statement::mutatorDispatch(this, s);
 }
+
 Statement* OptOutMutator::mutate(Expr* e) {
   return Expr::mutatorDispatch(this, e);
 }
+
 Statement* OptOutMutator::mutate(Val* v) {
   // If value is already mutated, return the mutation
   if (mutations.find(v) != mutations.end())
diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h
index 5a3aa89..50f4451 100644
--- a/torch/csrc/jit/codegen/cuda/dispatch.h
+++ b/torch/csrc/jit/codegen/cuda/dispatch.h
@@ -48,41 +48,42 @@
 namespace jit {
 namespace fuser {
 
-struct Fusion;
+class Fusion;
 
 // Hierarchal dispatch functions for handle
-struct Statement;
-struct Expr;
-struct Val;
+class Statement;
+class Expr;
+class Val;
 
 // Vals
-struct IterDomain;
-struct TensorDomain;
-struct TensorView;
-struct TensorIndex;
-struct Bool;
-struct Float;
-struct Half;
-struct Int;
-struct NamedScalar;
+class IterDomain;
+class TensorDomain;
+class TensorView;
+class TensorIndex;
+class Bool;
+class Float;
+class Half;
+class Int;
+class NamedScalar;
 
 // Exprs
-struct Split;
-struct Merge;
-struct UnaryOp;
-struct BinaryOp;
-struct TernaryOp;
-struct ReductionOp;
-struct BroadcastOp;
-struct ForLoop;
-struct IfThenElse;
-struct Allocate;
+class Split;
+class Merge;
+class UnaryOp;
+class BinaryOp;
+class TernaryOp;
+class ReductionOp;
+class BroadcastOp;
+class ForLoop;
+class IfThenElse;
+class Allocate;
 
 /*
  * By default, all IR nodes are handled in this dispatch, and will call an empty
  * function on all nodes.
  */
-struct TORCH_CUDA_API OptOutConstDispatch {
+class TORCH_CUDA_API OptOutConstDispatch {
+ public:
   virtual ~OptOutConstDispatch() = default;
   OptOutConstDispatch() = default;
 
@@ -93,35 +94,36 @@
   OptOutConstDispatch& operator=(OptOutConstDispatch&& other) = default;
 
   // Hierarchal dispatch functions for handle
-  virtual void handle(const Statement* const);
-  virtual void handle(const Expr* const);
-  virtual void handle(const Val* const);
+  virtual void handle(const Statement*);
+  virtual void handle(const Expr*);
+  virtual void handle(const Val*);
 
   // Vals
-  virtual void handle(const IterDomain* const) {}
-  virtual void handle(const TensorDomain* const) {}
-  virtual void handle(const TensorView* const) {}
-  virtual void handle(const TensorIndex* const) {}
-  virtual void handle(const Bool* const) {}
-  virtual void handle(const Float* const) {}
-  virtual void handle(const Half* const) {}
-  virtual void handle(const Int* const) {}
-  virtual void handle(const NamedScalar* const) {}
+  virtual void handle(const IterDomain*) {}
+  virtual void handle(const TensorDomain*) {}
+  virtual void handle(const TensorView*) {}
+  virtual void handle(const TensorIndex*) {}
+  virtual void handle(const Bool*) {}
+  virtual void handle(const Float*) {}
+  virtual void handle(const Half*) {}
+  virtual void handle(const Int*) {}
+  virtual void handle(const NamedScalar*) {}
 
   // Exprs
-  virtual void handle(const Split* const) {}
-  virtual void handle(const Merge* const) {}
-  virtual void handle(const UnaryOp* const) {}
-  virtual void handle(const BinaryOp* const) {}
-  virtual void handle(const TernaryOp* const) {}
-  virtual void handle(const ReductionOp* const) {}
-  virtual void handle(const BroadcastOp* const) {}
-  virtual void handle(const ForLoop* const) {}
-  virtual void handle(const IfThenElse* const) {}
-  virtual void handle(const Allocate* const) {}
+  virtual void handle(const Split*) {}
+  virtual void handle(const Merge*) {}
+  virtual void handle(const UnaryOp*) {}
+  virtual void handle(const BinaryOp*) {}
+  virtual void handle(const TernaryOp*) {}
+  virtual void handle(const ReductionOp*) {}
+  virtual void handle(const BroadcastOp*) {}
+  virtual void handle(const ForLoop*) {}
+  virtual void handle(const IfThenElse*) {}
+  virtual void handle(const Allocate*) {}
 };
 
-struct TORCH_CUDA_API OptOutDispatch {
+class TORCH_CUDA_API OptOutDispatch {
+ public:
   virtual ~OptOutDispatch() = default;
   OptOutDispatch() = default;
 
@@ -160,7 +162,8 @@
   virtual void handle(Allocate*) {}
 };
 
-struct TORCH_CUDA_API OptInConstDispatch {
+class TORCH_CUDA_API OptInConstDispatch {
+ public:
   virtual ~OptInConstDispatch() = default;
   OptInConstDispatch() = default;
 
@@ -171,73 +174,74 @@
   OptInConstDispatch& operator=(OptInConstDispatch&& other) = default;
 
   // Hierarchal dispatch functions for handle
-  virtual void handle(const Statement* const);
-  virtual void handle(const Expr* const);
-  virtual void handle(const Val* const);
+  virtual void handle(const Statement*);
+  virtual void handle(const Expr*);
+  virtual void handle(const Val*);
 
   // Vals
-  virtual void handle(const IterDomain* const) {
+  virtual void handle(const IterDomain*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for IterDomain.");
   }
-  virtual void handle(const TensorDomain* const) {
+  virtual void handle(const TensorDomain*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorDomain.");
   }
-  virtual void handle(const TensorView* const) {
+  virtual void handle(const TensorView*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorView.");
   }
-  virtual void handle(const TensorIndex* const) {
+  virtual void handle(const TensorIndex*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorIndex.");
   }
-  virtual void handle(const Bool* const) {
+  virtual void handle(const Bool*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Bool.");
   }
-  virtual void handle(const Float* const) {
+  virtual void handle(const Float*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Float.");
   }
-  virtual void handle(const Half* const) {
+  virtual void handle(const Half*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Half.");
   }
-  virtual void handle(const Int* const) {
+  virtual void handle(const Int*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int.");
   }
-  virtual void handle(const NamedScalar* const) {
+  virtual void handle(const NamedScalar*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for NamedScalar.");
   }
 
   // Exprs
-  virtual void handle(const Split* const) {
+  virtual void handle(const Split*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Split.");
   }
-  virtual void handle(const Merge* const) {
+  virtual void handle(const Merge*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Merge.");
   }
-  virtual void handle(const UnaryOp* const) {
+  virtual void handle(const UnaryOp*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for UnaryOp.");
   }
-  virtual void handle(const BinaryOp* const) {
+  virtual void handle(const BinaryOp*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BinaryOp.");
   }
-  virtual void handle(const TernaryOp* const) {
+  virtual void handle(const TernaryOp*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TernaryOp.");
   }
-  virtual void handle(const ReductionOp* const) {
+  virtual void handle(const ReductionOp*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ReductionOp.");
   }
-  virtual void handle(const BroadcastOp* const) {
+  virtual void handle(const BroadcastOp*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BroadcastOp.");
   }
-  virtual void handle(const ForLoop* const) {
+  virtual void handle(const ForLoop*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ForLoop.");
   }
-  virtual void handle(const Allocate* const) {
+  virtual void handle(const Allocate*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Allocate.");
   }
-  virtual void handle(const IfThenElse* const) {
+  virtual void handle(const IfThenElse*) {
     TORCH_INTERNAL_ASSERT(false, "Handle not overriden for IfThenElse.");
   }
 };
 
-struct TORCH_CUDA_API OptInDispatch {
+class TORCH_CUDA_API OptInDispatch {
+ public:
   virtual ~OptInDispatch() = default;
   OptInDispatch() = default;
 
@@ -314,7 +318,8 @@
   }
 };
 
-struct TORCH_CUDA_API OptOutMutator {
+class TORCH_CUDA_API OptOutMutator {
+ public:
   virtual ~OptOutMutator() = default;
   OptOutMutator() = default;
 
@@ -331,13 +336,11 @@
   virtual Statement* mutate(Expr* e);
   virtual Statement* mutate(Val* v);
 
-  /*
-   * We always want to dispatch through a Val, so we can capture and dispatch
-   * correctly members of nodes like Split->TensorDomain If we don't call the
-   * below function or manually cast to use mutate(Val* v) we can't intercept
-   * and mutate by capturing mutate(Val* v), which is what we do when we want to
-   * replace all instances of a value.
-   */
+  // We always want to dispatch through a Val, so we can capture and dispatch
+  // correctly members of nodes like Split->TensorDomain If we don't call the
+  // below function or manually cast to use mutate(Val* v) we can't intercept
+  // and mutate by capturing mutate(Val* v), which is what we do when we want to
+  // replace all instances of a value.
   Statement* mutateAsVal(Val* v) {
     return mutate(v);
   }
@@ -352,7 +355,8 @@
 
   std::unordered_map<Val*, Val*> mutations;
 
-  //****Functions below defined in mutator.cpp*****///
+  //****Functions below defined in mutator.cpp*****
+
   // Vals
   virtual Statement* mutate(IterDomain*);
   virtual Statement* mutate(TensorDomain*);
@@ -377,7 +381,8 @@
   virtual Statement* mutate(Allocate*);
 };
 
-struct TORCH_CUDA_API OptInMutator {
+class TORCH_CUDA_API OptInMutator {
+ public:
   virtual ~OptInMutator() = default;
   OptInMutator() = default;
 
diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp
index 4568712..2a333fc 100644
--- a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp
+++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp
@@ -38,40 +38,46 @@
 }
 
 c10::optional<Int::ScalarType> ExpressionEvaluator::evaluate(
-    const Statement* expr,
+    Val* val,
     const EvaluationContext* context) {
   TORCH_CHECK(context != nullptr);
   ExpressionEvaluator evaluator(context);
-  evaluator.OptInConstDispatch::handle(expr);
-  return evaluator.result_;
+  evaluator.traverseFrom(context->fusion(), {val}, false);
+  return evaluator.value(val);
 }
 
-void ExpressionEvaluator::handle(const Int* i) {
+c10::optional<Int::ScalarType> ExpressionEvaluator::value(
+    const Statement* stmt) const {
+  const auto it = values_.find(stmt);
+  return (it != values_.end()) ? c10::optional<Int::ScalarType>(it->second)
+                               : c10::nullopt;
+}
+
+void ExpressionEvaluator::handle(Int* i) {
   if (i->value().has_value()) {
-    result_ = i->value();
+    values_[i] = *i->value();
   } else if (const auto* def = context_->fusion()->origin(i)) {
-    result_ = evaluate(def, context_);
+    const auto& def_result = value(def);
+    if (def_result.has_value()) {
+      values_[i] = *def_result;
+    }
   } else {
     const auto& bound_value = context_->concreteValue(i);
     if (bound_value.has_value()) {
-      result_ = bound_value;
+      values_[i] = *bound_value;
     }
   }
 }
 
-void ExpressionEvaluator::handle(const NamedScalar* i) {
-  // nothing to do, leave the result "unknown"
-}
-
-void ExpressionEvaluator::handle(const UnaryOp* uop) {
-  const auto in = evaluate(uop->in(), context_);
+void ExpressionEvaluator::handle(UnaryOp* uop) {
+  const auto in = value(uop->in());
   if (in.has_value()) {
     switch (uop->getUnaryOpType()) {
       case UnaryOpType::Neg:
-        result_ = -*in;
+        values_[uop] = -*in;
         break;
       case UnaryOpType::Cast:
-        result_ = *in;
+        values_[uop] = *in;
         break;
       default:
         TORCH_CHECK(!"Unexpected operator type");
@@ -79,35 +85,34 @@
   }
 }
 
-void ExpressionEvaluator::handle(const BinaryOp* bop) {
-  TORCH_CHECK(bop->out()->isAnInt()); // not really needed
-  const auto lhs = evaluate(bop->lhs(), context_);
-  const auto rhs = evaluate(bop->rhs(), context_);
+void ExpressionEvaluator::handle(BinaryOp* bop) {
+  const auto lhs = value(bop->lhs());
+  const auto rhs = value(bop->rhs());
   if (lhs.has_value() && rhs.has_value()) {
     switch (bop->getBinaryOpType()) {
       case BinaryOpType::Add:
-        result_ = *lhs + *rhs;
+        values_[bop] = *lhs + *rhs;
         break;
       case BinaryOpType::Sub:
-        result_ = *lhs - *rhs;
+        values_[bop] = *lhs - *rhs;
         break;
       case BinaryOpType::Mul:
-        result_ = *lhs * *rhs;
+        values_[bop] = *lhs * *rhs;
         break;
       case BinaryOpType::Div:
         TORCH_CHECK(*rhs != 0);
-        result_ = *lhs / *rhs;
+        values_[bop] = *lhs / *rhs;
         break;
       case BinaryOpType::Mod:
         TORCH_CHECK(*rhs != 0);
-        result_ = *lhs % *rhs;
+        values_[bop] = *lhs % *rhs;
         break;
       case BinaryOpType::CeilDiv:
         TORCH_CHECK(*rhs != 0);
-        result_ = (*lhs + *rhs - 1) / *rhs;
+        values_[bop] = (*lhs + *rhs - 1) / *rhs;
         break;
       case BinaryOpType::And:
-        result_ = Int::ScalarType(*lhs && *rhs);
+        values_[bop] = Int::ScalarType(*lhs && *rhs);
         break;
       default:
         TORCH_CHECK(!"Unexpected operator type");
diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.h b/torch/csrc/jit/codegen/cuda/expr_evaluator.h
index 5719128..bc29aac 100644
--- a/torch/csrc/jit/codegen/cuda/expr_evaluator.h
+++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.h
@@ -2,8 +2,8 @@
 #pragma once
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/codegen/cuda/dispatch.h>
 #include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
 
 #include <c10/util/Optional.h>
 
@@ -20,7 +20,7 @@
 //
 class TORCH_CUDA_API EvaluationContext {
  public:
-  explicit EvaluationContext(const Fusion* fusion) : fusion_(fusion) {}
+  explicit EvaluationContext(Fusion* fusion) : fusion_(fusion) {}
 
   // Set the concrete value for a Int*
   void bind(const Val* value, Int::ScalarType concrete_value);
@@ -28,7 +28,7 @@
   // Retrieves the concrete value, or nullopt if not set
   c10::optional<Int::ScalarType> concreteValue(const Val* value) const;
 
-  const Fusion* fusion() const {
+  Fusion* fusion() const {
     return fusion_;
   }
 
@@ -37,18 +37,18 @@
 
  private:
   std::unordered_map<const Val*, Int::ScalarType> bindings_;
-  const Fusion* fusion_ = nullptr;
+  Fusion* fusion_ = nullptr;
 };
 
 // Evaluates expressions in a Fusion IR, using the passed in
 // context (EvaluationContext) to query for concrete_values. The
 // evaluation context may override concrete values in the IR as well.
-class TORCH_CUDA_API ExpressionEvaluator : private OptInConstDispatch {
+class TORCH_CUDA_API ExpressionEvaluator : private IterVisitor {
  public:
   // Returns the result of the specified expression, or nullopt if
   // the result cannot be evaluated
   static c10::optional<Int::ScalarType> evaluate(
-      const Statement* expr,
+      Val* val,
       const EvaluationContext* context);
 
  private:
@@ -57,15 +57,17 @@
 
   ~ExpressionEvaluator() override = default;
 
-  void handle(const Int*) override;
-  void handle(const NamedScalar*) override;
+  c10::optional<Int::ScalarType> value(const Statement* stmt) const;
 
-  void handle(const UnaryOp*) override;
-  void handle(const BinaryOp*) override;
+  using IterVisitor::handle;
+
+  void handle(Int*) override;
+  void handle(UnaryOp*) override;
+  void handle(BinaryOp*) override;
 
  private:
   const EvaluationContext* context_ = nullptr;
-  c10::optional<Int::ScalarType> result_;
+  std::unordered_map<const Statement*, Int::ScalarType> values_;
 };
 
 } // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp
index e4faf31..141c12d 100644
--- a/torch/csrc/jit/codegen/cuda/fusion.cpp
+++ b/torch/csrc/jit/codegen/cuda/fusion.cpp
@@ -1,7 +1,10 @@
+
 #include <torch/csrc/jit/codegen/cuda/fusion.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
 #include <torch/csrc/jit/codegen/cuda/ir_printer.h>
 #include <torch/csrc/jit/codegen/cuda/kernel.h>
+#include <torch/csrc/jit/codegen/cuda/lower2device.h>
 
 namespace torch {
 namespace jit {
@@ -34,9 +37,10 @@
 std::vector<Expr*> ExprSort::getExprs(
     Fusion* fusion,
     bool from_outputs_only,
-    bool breadth_first) {
+    bool breadth_first,
+    bool respect_compute_at) {
   ExprSort es;
-  es.traverse(fusion, from_outputs_only, breadth_first);
+  es.traverse(fusion, from_outputs_only, breadth_first, respect_compute_at);
   return es.exprs;
 }
 
@@ -56,22 +60,136 @@
   return io.inputs;
 }
 
-Fusion::~Fusion() {
-  {
-    auto it = val_set_.begin();
-    while (it != val_set_.end()) {
-      auto del = it;
-      it = ++it;
-      delete (*del);
+void swap(Fusion& a, Fusion& b) noexcept {
+  using std::swap;
+
+  // Swap the content
+  swap(a.val_set_, b.val_set_);
+  swap(a.expr_set_, b.expr_set_);
+  swap(a.val_deque_, b.val_deque_);
+
+  swap(a.val_type_name_map_, b.val_type_name_map_);
+  swap(a.val_name_counter_, b.val_name_counter_);
+  swap(a.expr_name_counter_, b.expr_name_counter_);
+
+  swap(a.origin_, b.origin_);
+  swap(a.uses_, b.uses_);
+  swap(a.values_map_, b.values_map_);
+
+  swap(a.inputs_, b.inputs_);
+  swap(a.outputs_, b.outputs_);
+
+  // Fixup the Statement::fusion_ links for a
+  for (auto val : a.val_set_) {
+    val->fusion_ = &a;
+  }
+  for (auto expr : a.expr_set_) {
+    expr->fusion_ = &a;
+  }
+
+  // Fixup the Statement::fusion_ links for b
+  for (auto val : b.val_set_) {
+    val->fusion_ = &b;
+  }
+  for (auto expr : b.expr_set_) {
+    expr->fusion_ = &b;
+  }
+}
+
+Fusion::Fusion(const Fusion& other) {
+  IrCloner ir_cloner(this);
+
+  for (auto val : other.val_set_) {
+    val_set_.insert(ir_cloner.clone(val));
+  }
+
+  for (auto expr : other.expr_set_) {
+    expr_set_.insert(ir_cloner.clone(expr));
+  }
+
+  for (auto val : other.val_deque_) {
+    val_deque_.push_back(ir_cloner.clone(val));
+  }
+
+  val_type_name_map_ = other.val_type_name_map_;
+  val_name_counter_ = other.val_name_counter_;
+  expr_name_counter_ = other.expr_name_counter_;
+
+  for (const auto& kv : other.origin_) {
+    auto val = ir_cloner.clone(kv.first);
+    auto expr = ir_cloner.clone(kv.second);
+    origin_.insert({val, expr});
+  }
+
+  for (const auto& kv : other.uses_) {
+    auto val = ir_cloner.clone(kv.first);
+    std::unordered_set<Expr*> val_uses;
+    for (auto expr : kv.second) {
+      val_uses.insert(ir_cloner.clone(expr));
     }
+    uses_.insert({val, std::move(val_uses)});
   }
-  auto it = expr_set_.begin();
-  while (it != expr_set_.end()) {
-    auto del = it;
-    it = ++it;
-    delete (*del);
+
+  for (const auto& kv : other.values_map_) {
+    auto from_val = ir_cloner.clone(kv.first);
+    auto to_val = ir_cloner.clone(kv.second);
+    values_map_.insert({from_val, to_val});
   }
-};
+
+  inputs_ = ir_cloner.clone(other.inputs_);
+  outputs_ = ir_cloner.clone(other.outputs_);
+}
+
+Fusion::Fusion(Fusion&& other) noexcept {
+  swap(*this, other);
+}
+
+Fusion& Fusion::operator=(const Fusion& other) {
+  Fusion copy(other);
+  clear();
+  swap(*this, copy);
+  return *this;
+}
+
+Fusion& Fusion::operator=(Fusion&& other) noexcept {
+  clear();
+  swap(*this, other);
+  return *this;
+}
+
+Fusion::~Fusion() {
+  clear();
+}
+
+void Fusion::clear() noexcept {
+  // Free the owned values
+  for (auto ptr : val_set_) {
+    delete ptr;
+  }
+
+  // Free the owned expressions
+  for (auto ptr : expr_set_) {
+    delete ptr;
+  }
+
+  val_set_.clear();
+  val_deque_.clear();
+  expr_set_.clear();
+
+  for (auto& kv : val_type_name_map_) {
+    kv.second = 0;
+  }
+
+  val_name_counter_ = 0;
+  expr_name_counter_ = 0;
+
+  origin_.clear();
+  uses_.clear();
+  values_map_.clear();
+
+  inputs_.clear();
+  outputs_.clear();
+}
 
 void Fusion::removeExpr(Expr* expr) {
   assertInFusion(expr, "Cannot remove expr ");
@@ -138,7 +256,14 @@
           " has a reduction axis, but this does nothing in the fusion.");
   }
 
-  IRInputOutput::addInput(input);
+  TORCH_CHECK(
+      input->getOrigin() == nullptr,
+      input,
+      " cannot be registered as an input as it is used as an output of an expression (",
+      input->getOrigin(),
+      ").");
+
+  inputs_.push_back(input);
 }
 
 void Fusion::addOutput(Val* const output) {
@@ -153,7 +278,7 @@
           output,
           " cannot be registered as an output as it has a broadcast axis.");
   }
-  IRInputOutput::addOutput(output);
+  outputs_.push_back(output);
 }
 
 bool Fusion::inFusion(const Statement* stmt) const {
@@ -177,10 +302,14 @@
   TORCH_CHECK(false, msg, " it was not found in the active fusion.");
 }
 
-std::vector<Expr*> Fusion::exprs(bool from_outputs_only, bool breadth_first) {
+std::vector<Expr*> Fusion::exprs(
+    bool from_outputs_only,
+    bool breadth_first,
+    bool respect_compute_at) {
   if (breadth_first)
     TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
-  return ExprSort::getExprs(this, from_outputs_only, breadth_first);
+  return ExprSort::getExprs(
+      this, from_outputs_only, breadth_first, respect_compute_at);
 }
 
 std::unordered_set<Val*> Fusion::inputsOf(Val* val) {
@@ -190,20 +319,16 @@
 void Fusion::validateInputs() {
   std::unordered_set<Val*> all_inputs;
   for (Val* out : outputs()) {
-    auto outs_inputs = inputsOf(out);
-    std::set_union(
-        all_inputs.begin(),
-        all_inputs.end(),
-        outs_inputs.begin(),
-        outs_inputs.end(),
-        std::inserter(all_inputs, all_inputs.begin()));
+    for (Val* input : inputsOf(out)) {
+      all_inputs.insert(input);
+    }
   }
-  for (Val* inp : all_inputs) {
-    if (!inp->isConstScalar())
+  for (Val* input : all_inputs) {
+    if (!input->isConstScalar())
       TORCH_CHECK(
-          hasInput(inp),
+          hasInput(input),
           "Could not figure out how ",
-          inp,
+          input,
           " is generated, however it was not specified as an input.");
   }
 }
@@ -218,10 +343,30 @@
   std::cout << "}\n";
 }
 
+void Fusion::printValuesMap() {
+  IRPrinter ir_printer(std::cout);
+  ir_printer.follow_val_map = false;
+  std::cout << "\nValues map\n";
+  std::cout << "--------------------\n";
+  for (const auto& kv : values_map_) {
+    ir_printer.handle(kv.first);
+    std::cout << " -> ";
+    ir_printer.handle(kv.second);
+    std::cout << "\n";
+  }
+  std::cout << "--------------------\n\n";
+}
+
+void Fusion::printKernel() {
+  FusionGuard fg(this);
+  GPULower lower(this);
+  lower.printKernel(std::cout);
+}
+
 void Fusion::printMath() {
   FusionGuard fg(this);
-  IRMathPrinter op_exprs(std::cout);
-  op_exprs.handle(this);
+  for (auto expr : exprs(true))
+    std::cout << expr;
 }
 
 void Fusion::printTransforms() {
@@ -338,11 +483,28 @@
   return it->second;
 }
 
+bool Fusion::hasInput(const Val* val) const {
+  return std::find(inputs_.begin(), inputs_.end(), val) != inputs_.end();
+}
+
+bool Fusion::hasOutput(const Val* val) const {
+  return std::find(outputs_.begin(), outputs_.end(), val) != outputs_.end();
+}
+
+void Fusion::replaceInput(Val* replace, Val* with) {
+  std::replace(inputs_.begin(), inputs_.end(), replace, with);
+}
+
+void Fusion::replaceOutput(Val* replace, Val* with) {
+  std::replace(outputs_.begin(), outputs_.end(), replace, with);
+}
+
 StmtNameType Fusion::getValName(ValType vtype) {
-  if (val_type_name_map.find(vtype) != val_type_name_map.end())
-    return val_type_name_map[vtype]++;
+  if (val_type_name_map_.find(vtype) != val_type_name_map_.end())
+    return val_type_name_map_[vtype]++;
   return val_name_counter_++;
 }
+
 StmtNameType Fusion::getExprName() {
   return expr_name_counter_++;
 }
@@ -368,6 +530,26 @@
   return false;
 }
 
+bool Fusion::hasBlockReduction() {
+  for (auto expr : exprs(true))
+    for (auto out : expr->outputs())
+      if (out->getValType() == ValType::TensorView)
+        if (static_cast<TensorView*>(out)->hasBlockReduction())
+          return true;
+
+  return false;
+}
+
+bool Fusion::hasGridReduction() {
+  for (auto expr : exprs(true))
+    for (auto out : expr->outputs())
+      if (out->getValType() == ValType::TensorView)
+        if (static_cast<TensorView*>(out)->hasGridReduction())
+          return true;
+
+  return false;
+}
+
 } // namespace fuser
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h
index e0f240f..684dc91 100644
--- a/torch/csrc/jit/codegen/cuda/fusion.h
+++ b/torch/csrc/jit/codegen/cuda/fusion.h
@@ -48,16 +48,16 @@
  * us mechanisms for dependency analysis and DCE including safety checks.
  */
 
-struct Fusion;
-struct TensorView;
+class Fusion;
+class TensorView;
 
 namespace cuda {
-struct CudaKernel;
+class CudaKernel;
 }
 
 // Fusion Guard is our "context manager". It holds the actrive fusion and allows
 // it to be accessed anywhere through FusionGuard::getCurFusion().
-struct TORCH_CUDA_API FusionGuard {
+class TORCH_CUDA_API FusionGuard {
  public:
   Fusion* prev_fusion;
 
@@ -72,7 +72,7 @@
 
 // Expr sort will take a fusion and return a topologically sorted list of
 // expressions.
-struct ExprSort : public IterVisitor {
+class ExprSort : public IterVisitor {
  private:
   std::vector<Expr*> exprs;
 
@@ -82,10 +82,11 @@
   static std::vector<Expr*> getExprs(
       Fusion* fusion,
       bool from_outputs_only,
-      bool breadth_first);
+      bool breadth_first,
+      bool respect_compute_at);
 };
 
-struct InputsOf : public IterVisitor {
+class InputsOf : public IterVisitor {
  private:
   std::unordered_set<Val*> inputs;
 
@@ -101,21 +102,25 @@
  * duplicating all associated values and exprs. Fusion is considered to SSA,
  * though this could also change in the future if there is a good reason to do
  * so.
+ *
+ * The Fusion owns the whole IR graph (Vals and Exprs)
  */
+class TORCH_CUDA_API Fusion final {
+ public:
+  Fusion() = default;
 
-struct TORCH_CUDA_API Fusion : public IRInputOutput {
-  Fusion() {}
+  Fusion(const Fusion& other);
+  Fusion(Fusion&& other) noexcept;
 
-  // Not copyable
-  Fusion(const Fusion& other) = delete;
-  Fusion& operator=(const Fusion& other) = delete;
+  Fusion& operator=(const Fusion& other);
+  Fusion& operator=(Fusion&& other) noexcept;
 
-  Fusion(Fusion&& other) = delete;
-  Fusion& operator=(Fusion&& other) = delete;
-
-  // When destroyed clean up all IR associated with this fusion
   ~Fusion();
 
+  friend void swap(Fusion& a, Fusion& b) noexcept;
+
+  void clear() noexcept;
+
   // Break dependency chains associated with Expr, remove references to expr
   // delete expr.
   void removeExpr(Expr* expr);
@@ -153,7 +158,8 @@
    */
   std::vector<Expr*> exprs(
       bool from_outputs_only = false,
-      bool breadth_first = false);
+      bool breadth_first = false,
+      bool respect_compute_at = false);
 
   std::unordered_set<Val*> inputsOf(Val* val);
 
@@ -163,11 +169,16 @@
   // Print this fusion to cout.
   void print();
 
+  // Print value mapping
+  void printValuesMap();
+
   // Print Arith exprs used in outputs
   void printMath();
+
   // Print transformations used in fusion (can be very verbose)
   void printTransforms();
-
+  // Lower the fusion and print a kernel
+  void printKernel();
   // Register the Val with this fusion
   StmtNameType registerVal(Val* val);
 
@@ -203,22 +214,54 @@
   // Indicate to kernel to set itself up to generate random numbers
   bool hasRNG();
 
-  // Indicate to kernel to set itself up to generate random numbers
   bool hasReduction();
+  bool hasBlockReduction();
+  bool hasGridReduction();
+  size_t gridReductionTempBufferSize();
+
+  void setValuesMap(std::unordered_map<Val*, Val*> values_map) {
+    values_map_ = std::move(values_map);
+  }
+
+  Val* loweredVal(Val* value) const {
+    auto it = values_map_.find(value);
+    return it != values_map_.end() ? it->second : value;
+  }
+
+  const Val* loweredVal(const Val* value) const {
+    auto it = values_map_.find(const_cast<Val*>(value));
+    return it != values_map_.end() ? it->second : value;
+  }
+
+  const auto& inputs() const {
+    return inputs_;
+  }
+
+  const auto& outputs() const {
+    return outputs_;
+  }
+
+  bool hasInput(const Val* val) const;
+  bool hasOutput(const Val* val) const;
+
+  void replaceInput(Val* replace, Val* with);
+  void replaceOutput(Val* replace, Val* with);
 
  private:
-  // Sets of all Vals/Exprs registered with this fusion
-  std::unordered_set<Val*> val_set_;
-  std::deque<Val*> val_deque_;
-  std::unordered_set<Expr*> expr_set_;
-
   // Return an int that monotonically increases for each val/expr, some are
   // explicitly incremented by type.
   StmtNameType getValName(ValType vtype);
   StmtNameType getExprName();
 
+ private:
+  // Sets of all Vals/Exprs registered with this fusion
+  // (val_deque_ is not owning the objects)
+  std::unordered_set<Val*> val_set_;
+  std::deque<Val*> val_deque_;
+  std::unordered_set<Expr*> expr_set_;
+
   // map from valtype to individual name counters
-  std::unordered_map<ValType, StmtNameType, TypeHash> val_type_name_map = {
+  std::unordered_map<ValType, StmtNameType, TypeHash> val_type_name_map_ = {
       {ValType::TensorView, 0},
       {ValType::TensorDomain, 0},
       {ValType::IterDomain, 0},
@@ -231,6 +274,13 @@
   // Dependency tracking for Vals. Where did it come from? Where is it used?
   std::unordered_map<Val*, Expr*> origin_;
   std::unordered_map<Val*, std::unordered_set<Expr*>> uses_;
+
+  // Map a subset of values to the lowered equivalent (ex. sizes)
+  std::unordered_map<Val*, Val*> values_map_;
+
+  // Fusion inputs and outputs
+  std::vector<Val*> inputs_;
+  std::vector<Val*> outputs_;
 };
 
 } // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp
index ce08528..898df6a 100644
--- a/torch/csrc/jit/codegen/cuda/index_compute.cpp
+++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp
@@ -15,9 +15,8 @@
 
   auto outer_it = index_map_.find(outer_id);
   auto inner_it = index_map_.find(inner_id);
-  TORCH_INTERNAL_ASSERT(
-      outer_it != index_map_.end() && inner_it != index_map_.end(),
-      "Error in index compute, did not compute a necessary intermediate value.");
+  if (outer_it == index_map_.end() || inner_it == index_map_.end())
+    return;
 
   auto outer_ind = outer_it->second;
   auto inner_ind = inner_it->second;
@@ -32,9 +31,8 @@
   auto inner_id = merge->inner();
 
   auto out_it = index_map_.find(out_id);
-  TORCH_INTERNAL_ASSERT(
-      out_it != index_map_.end(),
-      "Error in index compute, did not compute a necessary intermediate value.");
+  if (out_it == index_map_.end())
+    return;
 
   auto out_ind = out_it->second;
 
@@ -58,21 +56,21 @@
   BackwardVisitor::handle(e);
 }
 
-IndexCompute::IndexCompute(TensorDomain* td, const std::vector<Val*>& indices) {
+IndexCompute::IndexCompute(
+    const TensorDomain* td,
+    const std::vector<Val*>& indices) {
   if (td->nDims() == 0 || indices.empty()) {
     indices_.push_back(new Int(0));
     return;
   }
 
-  bool exclude_reduction = td->nDims() > indices.size();
+  const bool exclude_reduction = td->nDims() > indices.size();
 
   TORCH_INTERNAL_ASSERT(
       td->noReductions().size() == indices.size() ||
           td->nDims() == indices.size(),
       "For IndexCompute the number of axes should match the number of dimensions in the TensorDomain.");
 
-  TORCH_INTERNAL_ASSERT(!td->hasRFactor(), "Not implemented yet.");
-
   {
     size_t i = 0;
     for (auto id : td->domain()) {
@@ -82,7 +80,7 @@
     }
   }
 
-  std::vector<Val*> domain_vals(td->domain().begin(), td->domain().end());
+  const std::vector<Val*> domain_vals(td->domain().begin(), td->domain().end());
 
   // Run the split/merge operations backwards. This will modify the index_map_
   // so it can be used to index the root TensorDomain. Each entry in the root
@@ -92,7 +90,6 @@
   // map at the rfactor IterDomains.
   traverseFrom(indices[0]->fusion(), domain_vals, false);
 
-  std::vector<Val*> inds;
   for (auto id : td->rootDomain()) {
     if (exclude_reduction && id->isReduction())
       continue;
@@ -105,66 +102,60 @@
 }
 
 std::vector<Val*> IndexCompute::get(
-    TensorDomain* td,
+    const TensorDomain* td,
     const std::vector<Val*>& _indices) {
   IndexCompute ic(td, _indices);
   return ic.indices_;
 }
 
 TensorIndex* Index::getGlobalProducerIndex(
-    TensorView* producer,
-    TensorView* consumer,
+    const TensorView* producer,
+    const TensorView* consumer,
     const std::vector<ForLoop*>& loops) {
-  // This replay will ignore reduction dimensions on the producer
-  auto pind =
-      TransformReplay::replayPasC(producer->domain(), consumer->domain(), -1);
+  // Grab indices from the loops
+  std::vector<Val*> indices(loops.size());
+  std::transform(loops.begin(), loops.end(), indices.begin(), [](ForLoop* fl) {
+    return fl->index();
+  });
 
-  TORCH_INTERNAL_ASSERT(
-      loops.size() == consumer->nDims(),
-      "Dimensionality error in code generator while computing tensor indexes.");
+  // What would the consumer indices be if it was global, keeping in mind
+  // reduction axes:
+  const std::vector<Val*> c_inds =
+      IndexCompute::get(consumer->domain(), indices);
 
-  std::vector<ForLoop*> loops_adjusted;
-  size_t it_c = 0, it_p = 0;
-  while (it_c < consumer->nDims() && it_p < pind->noReductions().size()) {
-    if (consumer->axis(it_c)->isBroadcast() &&
-        !pind->noReductions()[it_p]->isBroadcast()) {
-      it_c++;
-    } else {
-      loops_adjusted.push_back(loops[it_c]);
-      it_c++;
-      it_p++;
+  // Computed consumer indices should have everything we need for the producer
+  std::vector<Val*> p_inds;
+  auto p_root = TensorDomain::noReductions(producer->getRootDomain());
+  // Number of root dims that are broadcasted
+  size_t bcast_dims = 0;
+  {
+    auto c_root = consumer->getRootDomain();
+    size_t it_c = 0, it_p = 0;
+    while (it_c < c_root.size() && it_p < p_root.size()) {
+      const bool is_bcast = p_root[it_p]->isBroadcast();
+      if (c_root[it_c]->isBroadcast() && !is_bcast) {
+        it_c++;
+      } else {
+        if (!is_bcast) {
+          p_inds.push_back(c_inds[it_c]);
+        } else {
+          bcast_dims++;
+        }
+        it_c++;
+        it_p++;
+      }
     }
   }
-
   TORCH_INTERNAL_ASSERT(
-      loops_adjusted.size() == pind->noReductions().size(),
-      "Dimensionality error in code generator while computing tensor indexes.");
-
-  std::vector<Val*> indices(loops_adjusted.size());
-  std::transform(
-      loops_adjusted.begin(),
-      loops_adjusted.end(),
-      indices.begin(),
-      [](ForLoop* fl) { return fl->index(); });
-  std::vector<Val*> computed_inds = IndexCompute::get(pind, indices);
-
-  auto root_domain = producer->getRootDomain();
-
-  TORCH_INTERNAL_ASSERT(
-      computed_inds.size() == root_domain.size(),
-      "Dimensionality error in code generator while computing indexing.");
-
-  for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) {
-    if (root_domain[i]->isReduction() || root_domain[i]->isBroadcast())
-      computed_inds.erase(computed_inds.begin() + i);
-  }
+      p_inds.size() == p_root.size() - bcast_dims,
+      "Dimensionality error in code generator while computing tensor indices.");
 
   std::vector<Val*> strided_inds;
-  for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) {
+  for (size_t i = 0; i < p_inds.size(); i++) {
     std::stringstream ss;
     ss << "T" << producer->name() << ".stride[" << i << "]";
     strided_inds.push_back(
-        mul(computed_inds[i], new NamedScalar(ss.str(), DataType::Int)));
+        mul(p_inds[i], new NamedScalar(ss.str(), DataType::Int)));
   }
 
   // Probably shouldn't ever hit this
@@ -176,8 +167,8 @@
 
 // Producer index for either shared or local memory
 TensorIndex* Index::getProducerIndex_impl(
-    TensorView* producer,
-    TensorView* consumer,
+    const TensorView* producer,
+    const TensorView* consumer,
     const std::vector<ForLoop*>& loops) {
   TORCH_INTERNAL_ASSERT(
       loops.size() == consumer->nDims(),
@@ -222,7 +213,7 @@
   std::vector<Val*> used_inds;
   std::vector<IterDomain*> used_ranges;
   bool unrolled = false;
-  for (decltype(loops_adjusted.size()) i{0}; i < loops_adjusted.size(); i++) {
+  for (size_t i = 0; i < loops_adjusted.size(); i++) {
     if (ranges[i]->parallel_method() == ParallelType::Unroll)
       unrolled = true;
     if (!unrolled && producer->hasComputeAt() &&
@@ -233,16 +224,16 @@
       continue;
     if (producer->getMemoryType() == MemoryType::Local && ranges[i]->isThread())
       continue;
-    if (ranges[i]->isBroadcast())
+    if (producer->domain()->noReductions()[i]->isBroadcast())
       continue;
 
     used_inds.push_back(indices[i]);
     used_ranges.push_back(ranges[i]);
   }
 
-  for (decltype(used_inds.size()) i{0}; i < used_inds.size(); i++) {
+  for (size_t i = 0; i < used_inds.size(); i++) {
     Val* ind = used_inds[i];
-    for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++)
+    for (size_t j = i + 1; j < used_ranges.size(); j++)
       ind = mul(ind, used_ranges[j]->extent());
     used_inds[i] = ind;
   }
@@ -253,7 +244,7 @@
 }
 
 TensorIndex* Index::getGlobalConsumerIndex(
-    TensorView* consumer,
+    const TensorView* consumer,
     const std::vector<ForLoop*>& loops) {
   // If we're initializing a reduction buffer, we won't have the reduction
   // loops. If we're actually performing the reduction, we will.
@@ -273,7 +264,7 @@
       "Dimensionality error in code generator while computing indexing.");
 
   if (computed_inds.size() == root_dom.size())
-    for (decltype(root_dom.size()) i{0}; i < root_dom.size(); i++) {
+    for (size_t i = 0; i < root_dom.size(); i++) {
       // Do this backwards so erase offset will be right
       auto axis = root_dom.size() - i - 1;
       if (root_dom[axis]->isReduction() || root_dom[i]->isBroadcast())
@@ -281,7 +272,7 @@
     }
 
   std::vector<Val*> strided_inds;
-  for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) {
+  for (size_t i = 0; i < computed_inds.size(); i++) {
     std::stringstream ss;
     ss << "T" << consumer->name() << ".stride[" << i << "]";
     strided_inds.push_back(
@@ -297,7 +288,7 @@
 
 // Consumer index for either shared or local memory
 TensorIndex* Index::getConsumerIndex_impl(
-    TensorView* consumer,
+    const TensorView* consumer,
     const std::vector<ForLoop*>& loops) {
   // If we're initializing a reduction buffer, we won't have the reduction
   // loops. If we're actually performing the reduction, we will.
@@ -335,29 +326,40 @@
   std::vector<Val*> used_inds;
   std::vector<IterDomain*> used_ranges;
   bool unrolled = false;
-  for (decltype(loops.size()) i{0}; i < loops.size(); i++) {
-    if (have_reduction_iters && consumer->axis(i)->isReduction())
-      continue;
-    if (ranges[i]->parallel_method() == ParallelType::Unroll)
-      unrolled = true;
-    if (!unrolled && consumer->hasComputeAt() &&
-        i < consumer->getThisComputeAtAxis())
-      continue;
-    if (consumer->getMemoryType() == MemoryType::Shared &&
-        ranges[i]->isBlockDim())
-      continue;
-    if (consumer->getMemoryType() == MemoryType::Local && ranges[i]->isThread())
-      continue;
-    if (ranges[i]->isBroadcast())
-      continue;
+  {
+    size_t c_i = 0, l_i = 0;
+    while (c_i < consumer->nDims() && l_i < loops.size()) {
+      if (consumer->axis(c_i)->isReduction()) {
+        c_i++;
+        if (have_reduction_iters)
+          l_i++;
+        continue;
+      }
+      if (ranges[l_i]->parallel_method() == ParallelType::Unroll)
+        unrolled = true;
 
-    used_inds.push_back(indices[i]);
-    used_ranges.push_back(ranges[i]);
+      if ((!unrolled && consumer->hasComputeAt() &&
+           c_i < consumer->getThisComputeAtAxis()) ||
+          (consumer->getMemoryType() == MemoryType::Shared &&
+           ranges[l_i]->isBlockDim()) ||
+          (consumer->getMemoryType() == MemoryType::Local &&
+           ranges[l_i]->isThread()) ||
+          (consumer->axis(c_i)->isBroadcast())) {
+        c_i++;
+        l_i++;
+        continue;
+      }
+
+      used_inds.push_back(indices[l_i]);
+      used_ranges.push_back(ranges[l_i]);
+      l_i++;
+      c_i++;
+    }
   }
 
-  for (decltype(used_inds.size()) i{0}; i < used_inds.size(); i++) {
+  for (size_t i = 0; i < used_inds.size(); i++) {
     Val* ind = used_inds[i];
-    for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++)
+    for (size_t j = i + 1; j < used_ranges.size(); j++)
       ind = mul(ind, used_ranges[j]->extent());
     used_inds[i] = ind;
   }
@@ -370,13 +372,17 @@
 
 // Producer is the inputs of an expression
 TensorIndex* Index::getProducerIndex(
-    TensorView* producer,
-    TensorView* consumer,
+    const TensorView* producer,
+    const TensorView* consumer,
     const std::vector<ForLoop*>& loops) {
   TORCH_INTERNAL_ASSERT(
       loops.size() == consumer->nDims() ||
       loops.size() == consumer->domain()->noReductions().size());
 
+  if (producer->domain()->noReductions().size() == 0) {
+    return new TensorIndex(producer, {});
+  }
+
   if (producer->getMemoryType() == MemoryType::Global)
     return getGlobalProducerIndex(producer, consumer, loops);
   return getProducerIndex_impl(producer, consumer, loops);
@@ -384,12 +390,16 @@
 
 // Consumer is the output of an expression
 TensorIndex* Index::getConsumerIndex(
-    TensorView* consumer,
+    const TensorView* consumer,
     const std::vector<ForLoop*>& loops) {
   TORCH_INTERNAL_ASSERT(
       loops.size() == consumer->nDims() ||
       loops.size() == consumer->domain()->noReductions().size());
 
+  if (consumer->domain()->noReductions().size() == 0) {
+    return new TensorIndex(consumer, {});
+  }
+
   if (consumer->getMemoryType() == MemoryType::Global)
     return getGlobalConsumerIndex(consumer, loops);
   return getConsumerIndex_impl(consumer, loops);
diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h
index 09f1b8e..c677115 100644
--- a/torch/csrc/jit/codegen/cuda/index_compute.h
+++ b/torch/csrc/jit/codegen/cuda/index_compute.h
@@ -55,7 +55,7 @@
 namespace jit {
 namespace fuser {
 
-struct IndexCompute : public BackwardVisitor {
+class IndexCompute : public BackwardVisitor {
  private:
   using BackwardVisitor::handle;
   void handle(Split*) override;
@@ -65,54 +65,54 @@
   // Otherwise warning on runBackward as it hides an overloaded virtual
   // using TransformIter::runBackward;
 
-  IndexCompute(TensorDomain* td, const std::vector<Val*>& _indices);
+  IndexCompute(const TensorDomain* td, const std::vector<Val*>& _indices);
   std::unordered_map<IterDomain*, Val*> index_map_;
   std::vector<Val*> indices_;
 
  public:
   static std::vector<Val*> get(
-      TensorDomain* td,
+      const TensorDomain* td,
       const std::vector<Val*>& _indices);
 };
 
 // Simple interface for IndexCompute
-struct Index {
+class Index {
  private:
   // Producer indexing if it's in shared or local memory
   static TensorIndex* getProducerIndex_impl(
-      TensorView* producer,
-      TensorView* consumer,
+      const TensorView* producer,
+      const TensorView* consumer,
       const std::vector<ForLoop*>& loops);
 
   // Consumer indexing if it's in shared or local memory
   static TensorIndex* getConsumerIndex_impl(
-      TensorView* consumer,
+      const TensorView* consumer,
       const std::vector<ForLoop*>& loops);
 
- public:
   // Producer if it's in global memory
   static TensorIndex* getGlobalProducerIndex(
-      TensorView* producer,
-      TensorView* consumer,
+      const TensorView* producer,
+      const TensorView* consumer,
       const std::vector<ForLoop*>& loops);
 
   // Consumer indexing if it's in global memory
   static TensorIndex* getGlobalConsumerIndex(
-      TensorView* consumer,
+      const TensorView* consumer,
       const std::vector<ForLoop*>& loops);
 
+ public:
   // Indexing functions
   // Consumer = Producer
   // i.e. T0 = T1... -> T0 is the consumer, T1 is the producer
   // Producer indexing dispatch
   static TensorIndex* getProducerIndex(
-      TensorView* producer,
-      TensorView* consumer,
+      const TensorView* producer,
+      const TensorView* consumer,
       const std::vector<ForLoop*>& loops);
 
   // Consumer index dispatch
   static TensorIndex* getConsumerIndex(
-      TensorView* consumer,
+      const TensorView* consumer,
       const std::vector<ForLoop*>& loops);
 
   // Will run inds through back prop index computation for tv
diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
index 3027784..c7f5d37 100644
--- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
+++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
@@ -1,10 +1,10 @@
-#include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
-
-#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
-#include <torch/csrc/jit/codegen/cuda/mutator.h>
 
 #include <torch/csrc/jit/codegen/cuda/dispatch.h>
+#include <torch/csrc/jit/codegen/cuda/fusion.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
+#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
+#include <torch/csrc/jit/codegen/cuda/mutator.h>
 
 #include <torch/csrc/jit/ir/ir.h>
 
@@ -19,6 +19,12 @@
 namespace jit {
 namespace fuser {
 
+Statement::Statement(const Statement* src, IrCloner* ir_cloner) {
+  ir_cloner->registerClone(src, this);
+  name_ = src->name_;
+  fusion_ = ir_cloner->fusion();
+}
+
 Val* Statement::asVal() {
   TORCH_INTERNAL_ASSERT(isVal(), "Cannot cast to Val as this is not a Val.");
   return static_cast<Val*>(this);
@@ -29,6 +35,12 @@
   return static_cast<Expr*>(this);
 }
 
+void Statement::print() const {
+  IRPrinter ir_printer(std::cout);
+  ir_printer.handle(this);
+  std::cout << std::endl;
+}
+
 // When we create a Val we immediately register them with the active fusion.
 Val::Val(ValType _vtype, DataType _dtype, bool register_val)
     : vtype_{_vtype}, dtype_{_dtype} {
@@ -40,41 +52,45 @@
     this->name_ = this->fusion_->registerVal(this);
 }
 
+Val::Val(const Val* src, IrCloner* ir_cloner)
+    : Statement(src, ir_cloner), vtype_(src->vtype_), dtype_(src->dtype_) {}
+
 // Traverse origin of all values involved in constructing the provided val.
 // Check if all values involved are constant values, meaning the provided
 // val is also a constant value.
 namespace {
 
-struct ConstCheck : OptOutConstDispatch {
+class ConstCheck : OptOutConstDispatch {
  private:
-  bool is_const_ = false;
+  bool is_const_ = true;
 
-  void handle(const Bool* const b) override {
-    is_const_ = b->isConst();
+  void handle(const Bool* b) override {
+    is_const_ = is_const_ && b->isConst();
   }
 
-  void handle(const Float* const f) override {
-    is_const_ = f->isConst();
+  void handle(const Float* f) override {
+    is_const_ = is_const_ && f->isConst();
   }
 
-  void handle(const Half* const h) override {
-    is_const_ = h->isConst();
+  void handle(const Half* h) override {
+    is_const_ = is_const_ && h->isConst();
   }
 
-  void handle(const Int* const i) override {
-    is_const_ = i->isConst();
+  void handle(const Int* i) override {
+    is_const_ = is_const_ && i->isConst();
   }
 
-  void handle(const Expr* const expr) override {
+  void handle(const NamedScalar* ns) override {
+    is_const_ = is_const_ && false;
+  }
+
+  void handle(const Expr* expr) override {
     for (auto inp : expr->inputs()) {
       OptOutConstDispatch::handle(inp);
     }
   }
 
-  void handle(const NamedScalar* const ns) override {
-    is_const_ = false;
-  }
-  void handle(const Val* const val) override {
+  void handle(const Val* val) override {
     const Expr* orig = FusionGuard::getCurFusion()->origin(val);
     if (orig != nullptr)
       handle(orig);
@@ -83,7 +99,7 @@
   }
 
  public:
-  static bool isConst(const Val* const val) {
+  static bool isConst(const Val* val) {
     ConstCheck cc;
     cc.handle(val);
     return cc.is_const_;
@@ -123,6 +139,9 @@
   return (fusion_->origin(this));
 }
 
+Scope::Scope(const Scope* src, IrCloner* ir_cloner)
+    : exprs_(ir_cloner->clone(src->exprs_)) {}
+
 void Scope::insert_before(Expr* ref, Expr* expr) {
   auto it = exprs_.begin();
   while (it != exprs_.end()) {
@@ -176,76 +195,6 @@
   this->exprs_ = std::vector<Expr*>();
 }
 
-bool IRInputOutput::hasInput(const Val* const input) const {
-  for (auto val : inputs_)
-    if (val == input)
-      return true;
-  return false;
-}
-
-bool IRInputOutput::hasOutput(const Val* const output) const {
-  for (auto val : outputs_)
-    if (val == output)
-      return true;
-  return false;
-}
-
-void IRInputOutput::replaceInput(Val* replace, Val* with) {
-  bool changed = false;
-  for (decltype(inputs_.size()) i{0}; i < inputs_.size(); i++) {
-    if (inputs_[i] == replace) {
-      inputs_[i] = with;
-      changed = true;
-      break;
-    }
-  }
-  TORCH_INTERNAL_ASSERT(
-      changed,
-      "Error detected when trying to replace input ",
-      replace,
-      " with ",
-      with,
-      " .");
-}
-
-void IRInputOutput::replaceOutput(Val* replace, Val* with) {
-  bool changed = false;
-  for (decltype(outputs_.size()) i{0}; i < outputs_.size(); i++) {
-    if (outputs_[i] == replace) {
-      outputs_[i] = with;
-      changed = true;
-      break;
-    }
-  }
-  TORCH_INTERNAL_ASSERT(
-      changed,
-      "Error detected when trying to replace output ",
-      replace,
-      " with ",
-      with,
-      " .");
-}
-
-void IRInputOutput::removeInput(Val* val) {
-  auto it = inputs_.begin();
-  for (; it != inputs_.end(); ++it) {
-    if ((*it) == val)
-      break;
-  }
-  TORCH_INTERNAL_ASSERT(it != inputs_.end());
-  inputs_.erase(it);
-}
-
-void IRInputOutput::removeOutput(Val* val) {
-  auto it = outputs_.begin();
-  for (; it != outputs_.end(); ++it) {
-    if ((*it) == val)
-      break;
-  }
-  TORCH_INTERNAL_ASSERT(it != outputs_.end());
-  outputs_.erase(it);
-}
-
 // We don't register with the active fusion in Expr as this needs to be done
 // after inputs and outputs are registered with the Expr
 Expr::Expr(ExprType _type) : type_{_type} {
@@ -255,6 +204,25 @@
   this->fusion_ = fusion;
 }
 
+Expr::Expr(const Expr* src, IrCloner* ir_cloner)
+    : Statement(src, ir_cloner),
+      type_(src->type_),
+      inputs_(ir_cloner->clone(src->inputs_)),
+      outputs_(ir_cloner->clone(src->outputs_)) {}
+
+bool Expr::sameAs(const Expr* const other) const {
+  if (getExprType() != other->getExprType())
+    return false;
+  if (inputs().size() != other->inputs().size() ||
+      outputs().size() != other->outputs().size())
+    return false;
+  for (size_t i = 0; i < inputs().size(); i++) {
+    if (!input(i)->sameAs(other->input(i)))
+      return false;
+  }
+  return true;
+}
+
 } // namespace fuser
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h
index cc2507a..c0564b2 100644
--- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h
+++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h
@@ -23,7 +23,7 @@
 /*
  * This file defines the base IR structure. Any IR node in this system will
  * inherit from one of the following classes: Statement, Expr, Val,
- * IRInputOutput IR is any information that the code generation stack may need
+ * IrInputOutput IR is any information that the code generation stack may need
  * for analysis. By analysis we're refering to anything done in response to a
  * user facing call of this stack. This could be careful tracking of user calls,
  * and any transformation including optimizing transformations, user declared
@@ -38,13 +38,14 @@
 constexpr StmtNameType UNINITIALIZED_STMTNAMETYPE =
     std::numeric_limits<unsigned int>::max();
 
-struct Fusion;
-struct FusionGuard;
-struct Expr;
-struct Val;
-struct UnaryOp;
-struct BinaryOp;
-struct IterDomain;
+class Fusion;
+class FusionGuard;
+class Expr;
+class Val;
+class UnaryOp;
+class BinaryOp;
+class IterDomain;
+class IrCloner;
 
 /*
  * Statement is the highest level node representation. Everything that is
@@ -57,7 +58,15 @@
  * Basically beinng able to succienctly traverse down the inhereitance stack of
  * a Statment at runtime. This is currently implemented in dispatch.h
  */
-struct TORCH_CUDA_API Statement {
+class TORCH_CUDA_API Statement {
+  friend void swap(Fusion&, Fusion&) noexcept;
+
+ public:
+  Statement() = default;
+
+  // Cloning constructor
+  Statement(const Statement* src, IrCloner* ir_cloner);
+
   virtual ~Statement() = default;
 
   // Dispatch functions, definitions in dispatch.cpp
@@ -71,21 +80,21 @@
   static Statement* mutatorDispatch(T mutator, Statement*);
 
   // Accessor functions to types. Vals always have a DataType, Exprs never do
-  virtual c10::optional<ValType> getValType() const noexcept {
+  virtual c10::optional<ValType> getValType() const {
     return c10::nullopt;
   }
   virtual c10::optional<DataType> getDataType() const {
     return c10::nullopt;
   }
-  virtual c10::optional<ExprType> getExprType() const noexcept {
+  virtual c10::optional<ExprType> getExprType() const {
     return c10::nullopt;
   }
 
   // Short cut to figure out if it is a value/expression
-  bool isVal() const noexcept {
+  bool isVal() const {
     return getValType() != c10::nullopt;
   }
-  bool isExpr() const noexcept {
+  bool isExpr() const {
     return getExprType() != c10::nullopt;
   }
 
@@ -119,12 +128,12 @@
   }
 
   // Return the fusion this statement belongs to
-  Fusion* fusion() const noexcept {
+  Fusion* fusion() const {
     return fusion_;
   }
 
   // Return the int that represents its name
-  StmtNameType name() const noexcept {
+  StmtNameType name() const {
     return name_;
   }
 
@@ -142,6 +151,8 @@
     return this == other;
   }
 
+  void print() const;
+
  protected:
   StmtNameType name_ = UNINITIALIZED_STMTNAMETYPE;
   Fusion* fusion_ = nullptr;
@@ -165,13 +176,16 @@
  *     - Accessor functions for members
  *     - Must call Val constructor, Val constructor registers with fusion
  *     - Implementation of bool sameAs(...)
+ *     - Must implement a "cloning" constructor, ex.
+ *        Int::Int(const Int* src, IrCloner* ir_cloner)
  * 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
  * 3) Default mutator function should be added to mutator.cpp
- * 4) Printing functions should be added to ir_iostream.h/.cpp
+ * 4a) Printing functions should be added to ir_iostream.h/.cpp
+ * 4b) Graphviz generation must be added to ir_graphviz.h/.cpp
  * 5) An enum value must be added to ValType in type.h
  * 6) A string entry must be added in val_type_string_map
  */
-struct TORCH_CUDA_API Val : public Statement {
+class TORCH_CUDA_API Val : public Statement {
  public:
   virtual ~Val() = default;
 
@@ -182,10 +196,13 @@
   // to throw, fusion's destructor will get called, but the pointer to this Val
   // will be invalid. When fusion tries to delete this value it will cause a seg
   // fault, instead of showing the thrown error.
-  Val(ValType _vtype,
+  explicit Val(
+      ValType _vtype,
       DataType _dtype = DataType::Null,
       bool register_val = true);
 
+  Val(const Val* src, IrCloner* ir_cloner);
+
   // TODO: Values are unique and not copyable
   Val(const Val& other) = delete;
   Val& operator=(const Val& other) = delete;
@@ -193,7 +210,7 @@
   Val(Val&& other) = delete;
   Val& operator=(Val&& other) = delete;
 
-  c10::optional<ValType> getValType() const noexcept override {
+  c10::optional<ValType> getValType() const override {
     return vtype_;
   }
 
@@ -244,13 +261,12 @@
   const DataType dtype_;
 };
 
-// TODO: We should use this for the following:
-//    Fusion
-//    IfThenElse
-//    ForLoop
-struct TORCH_CUDA_API Scope {
+class TORCH_CUDA_API Scope {
  public:
-  const std::vector<Expr*>& exprs() const noexcept {
+  Scope() = default;
+  Scope(const Scope* src, IrCloner* ir_cloner);
+
+  const std::vector<Expr*>& exprs() const {
     return exprs_;
   }
 
@@ -274,6 +290,14 @@
     return exprs_.size();
   }
 
+  auto& operator[](size_t i) {
+    return exprs_[i];
+  }
+
+  auto& operator[](size_t i) const {
+    return exprs_[i];
+  }
+
   // Insert expr before ref
   void insert_before(Expr* ref, Expr* expr);
 
@@ -293,72 +317,6 @@
 };
 
 /*
- * IRInputOutput is a function on Vals. Has inputs and outputs that are all
- * Vals. It is used for anything that connects values and therefore would be
- * used during dependency analysis. Typically classes that inherit from
- * IRInputOutput will do so by inheriting from Exprs. Expr's are expected for
- * most dependency based operations like IterVisitor, or DependencyCheck.
- *
- * Examples:
- *   binary operations on tensors, scalar values, or a combination, a thread all
- *   reduce, for loops
- */
-struct TORCH_CUDA_API IRInputOutput {
-  virtual ~IRInputOutput() = default;
-
-  // Returns if Val is an input or output of this IRInputOutput instance
-  bool hasInput(const Val* const input) const;
-  bool hasOutput(const Val* const output) const;
-
-  // Input/output accessors
-  void addInputAt(std::deque<Val*>::size_type pos, Val* input) {
-    inputs_.insert(inputs_.begin() + pos, input);
-  }
-
-  void addOutputAt(std::deque<Val*>::size_type pos, Val* output) {
-    outputs_.insert(outputs_.begin() + pos, output);
-  }
-
-  const std::deque<Val*>& inputs() const noexcept {
-    return inputs_;
-  }
-  const std::deque<Val*>& outputs() const noexcept {
-    return outputs_;
-  }
-
-  Val* input(std::deque<Val*>::size_type idx) const {
-    return inputs_[idx];
-  }
-  Val* output(std::deque<Val*>::size_type idx) const {
-    return outputs_[idx];
-  }
-
-  void addInput(Val* input) {
-    inputs_.push_back(input);
-  }
-  void addOutput(Val* output) {
-    outputs_.push_back(output);
-  }
-
-  void replaceInput(Val* replace, Val* with);
-  void replaceOutput(Val* replace, Val* with);
-
-  void removeInput(Val* val);
-  void removeOutput(Val* val);
-
-  std::deque<Val*>::size_type nInputs() const noexcept {
-    return inputs_.size();
-  }
-  std::deque<Val*>::size_type nOutputs() const noexcept {
-    return outputs_.size();
-  }
-
- protected:
-  std::deque<Val*> inputs_;
-  std::deque<Val*> outputs_;
-};
-
-/*
  * A Expr represents a "computation." These are functions that takes inputs
  * and produce outputs, inputs and outputs all being Vals. There are
  * specializations of BinaryOp which takes 2 inputs and produces 1 output, and
@@ -396,11 +354,12 @@
  * 6) An enum value must be added to ExprType in type.h 7) A string
  *  entry must be added in expr_type_string_map
  */
-struct TORCH_CUDA_API Expr : public Statement, IRInputOutput {
+class TORCH_CUDA_API Expr : public Statement {
  public:
-  virtual ~Expr() = default;
   Expr() = delete;
-  Expr(ExprType _type);
+  explicit Expr(ExprType _type);
+  Expr(const Expr* src, IrCloner* ir_cloner);
+  virtual ~Expr() = default;
 
   Expr(const Expr& other) = delete;
   Expr& operator=(const Expr& other) = delete;
@@ -408,24 +367,31 @@
   Expr(Expr&& other) = delete;
   Expr& operator=(Expr&& other) = delete;
 
-  c10::optional<ExprType> getExprType() const noexcept override {
-    return type_;
-  }
-  ExprType type() const noexcept {
+  c10::optional<ExprType> getExprType() const override {
     return type_;
   }
 
-  bool sameAs(const Expr* const other) const {
-    if (getExprType() != other->getExprType())
-      return false;
-    if (inputs().size() != other->inputs().size() ||
-        outputs().size() != other->outputs().size())
-      return false;
-    for (size_t i = 0; i < inputs().size(); i++) {
-      if (!input(i)->sameAs(other->input(i)))
-        return false;
-    }
-    return true;
+  ExprType type() const {
+    return type_;
+  }
+
+  bool sameAs(const Expr* const other) const;
+
+  // Input/output accessors
+  const auto& inputs() const {
+    return inputs_;
+  }
+
+  const auto& outputs() const {
+    return outputs_;
+  }
+
+  auto input(size_t index) const {
+    return inputs_[index];
+  }
+
+  auto output(size_t index) const {
+    return outputs_[index];
   }
 
   // Dispatch functions, definitions in dispatch.cpp
@@ -438,8 +404,19 @@
   template <typename T>
   static Statement* mutatorDispatch(T mutator, Expr*);
 
+ protected:
+  void addInput(Val* input) {
+    inputs_.push_back(input);
+  }
+
+  void addOutput(Val* output) {
+    outputs_.push_back(output);
+  }
+
  private:
-  ExprType type_;
+  ExprType type_ = ExprType::Invalid;
+  std::vector<Val*> inputs_;
+  std::vector<Val*> outputs_;
 };
 
 } // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp
new file mode 100644
index 0000000..49e557d
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp
@@ -0,0 +1,133 @@
+
+#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
+#include <torch/csrc/jit/codegen/cuda/fusion.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+Statement* IrCloner::clone(const Statement* statement) {
+  if (statement == nullptr) {
+    return nullptr;
+  }
+
+  // Have we already cloned this node?
+  const auto it = clones_map_.find(statement);
+  if (it != clones_map_.end()) {
+    return it->second;
+  } else {
+    // Clone the new node, saving/restoring this->clone_
+    // since the cloning can be reentrant
+    auto saved_clone = clone_;
+    handle(statement);
+    auto new_node = clone_;
+    clone_ = saved_clone;
+
+    // The base cloning constructor (Statement) should have
+    // registered the new node. Failure to do so indicates
+    // that something went horribly wrong.
+    TORCH_INTERNAL_ASSERT(new_node != nullptr);
+    TORCH_INTERNAL_ASSERT(clones_map_[statement] == new_node);
+
+    return new_node;
+  }
+}
+
+void IrCloner::registerClone(const Statement* src, Statement* clone) {
+  TORCH_CHECK(src != nullptr);
+  TORCH_CHECK(clone != nullptr);
+  TORCH_CHECK(clones_map_.insert({src, clone}).second);
+}
+
+void IrCloner::handle(const Statement* s) {
+  OptInConstDispatch::handle(s);
+}
+
+void IrCloner::handle(const Val* v) {
+  OptInConstDispatch::handle(v);
+}
+
+void IrCloner::handle(const Expr* e) {
+  OptInConstDispatch::handle(e);
+}
+
+void IrCloner::handle(const TensorDomain* td) {
+  clone_ = new TensorDomain(td, this);
+}
+
+void IrCloner::handle(const IterDomain* id) {
+  clone_ = new IterDomain(id, this);
+}
+
+void IrCloner::handle(const TensorIndex* ti) {
+  clone_ = new TensorIndex(ti, this);
+}
+
+void IrCloner::handle(const Bool* b) {
+  clone_ = new Bool(b, this);
+}
+
+void IrCloner::handle(const Float* f) {
+  clone_ = new Float(f, this);
+}
+
+void IrCloner::handle(const Half* h) {
+  clone_ = new Half(h, this);
+}
+
+void IrCloner::handle(const Int* i) {
+  clone_ = new Int(i, this);
+}
+
+void IrCloner::handle(const NamedScalar* named_scalar) {
+  clone_ = new NamedScalar(named_scalar, this);
+}
+
+void IrCloner::handle(const TensorView* tv) {
+  clone_ = new TensorView(tv, this);
+}
+
+void IrCloner::handle(const UnaryOp* op) {
+  clone_ = new UnaryOp(op, this);
+}
+
+void IrCloner::handle(const BinaryOp* op) {
+  clone_ = new BinaryOp(op, this);
+}
+
+void IrCloner::handle(const TernaryOp* op) {
+  clone_ = new TernaryOp(op, this);
+}
+
+void IrCloner::handle(const BroadcastOp* op) {
+  clone_ = new BroadcastOp(op, this);
+}
+
+void IrCloner::handle(const ReductionOp* op) {
+  clone_ = new ReductionOp(op, this);
+}
+
+void IrCloner::handle(const ForLoop* for_loop) {
+  clone_ = new ForLoop(for_loop, this);
+}
+
+void IrCloner::handle(const IfThenElse* if_then_else) {
+  clone_ = new IfThenElse(if_then_else, this);
+}
+
+void IrCloner::handle(const Allocate* allocate) {
+  clone_ = new Allocate(allocate, this);
+}
+
+void IrCloner::handle(const Split* split) {
+  clone_ = new Split(split, this);
+}
+
+void IrCloner::handle(const Merge* merge) {
+  clone_ = new Merge(merge, this);
+}
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h
new file mode 100644
index 0000000..4097121
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h
@@ -0,0 +1,91 @@
+
+#pragma once
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+
+#include <unordered_map>
+#include <vector>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+class Fusion;
+
+// Clones nodes from an exiting Fusion
+class TORCH_CUDA_API IrCloner : private OptInConstDispatch {
+  friend class Statement;
+
+ public:
+  explicit IrCloner(Fusion* new_fusion) : fusion_(new_fusion) {}
+
+  Statement* clone(const Statement* statement);
+
+  template <class T>
+  T* clone(const T* node) {
+    return node ? clone(node->template as<Statement>())->template as<T>()
+                : nullptr;
+  }
+
+  template <class T>
+  std::vector<T*> clone(const std::vector<T*>& container) {
+    std::vector<T*> copy;
+    for (auto p : container) {
+      copy.push_back(clone(p));
+    }
+    return copy;
+  }
+
+  Fusion* fusion() const {
+    return fusion_;
+  }
+
+ private:
+  void registerClone(const Statement* src, Statement* clone);
+
+  void handle(const Statement*) override;
+  void handle(const Val*) override;
+  void handle(const Expr*) override;
+
+  void handle(const TensorDomain*) override;
+  void handle(const TensorView*) override;
+  void handle(const IterDomain*) override;
+  void handle(const TensorIndex*) override;
+
+  void handle(const Bool*) override;
+  void handle(const Float*) override;
+  void handle(const Half*) override;
+  void handle(const Int*) override;
+  void handle(const NamedScalar*) override;
+
+  void handle(const UnaryOp*) override;
+  void handle(const BinaryOp*) override;
+  void handle(const TernaryOp*) override;
+  void handle(const BroadcastOp*) override;
+  void handle(const ReductionOp*) override;
+
+  void handle(const ForLoop*) override;
+  void handle(const IfThenElse*) override;
+  void handle(const Allocate*) override;
+
+  void handle(const Split*) override;
+  void handle(const Merge*) override;
+
+ private:
+  // The destination Fusion container
+  Fusion* fusion_ = nullptr;
+
+  // The dispatch interface doesn't allow returning values from
+  // individual `handle()` methods, so they are storing the
+  // result here
+  Statement* clone_ = nullptr;
+
+  // We keep track of the original -> clone map so we don't
+  // duplicate clones of the same object if referenced multiple times
+  std::unordered_map<const Statement*, Statement*> clones_map_;
+};
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp
index 3c3b259..e434b14 100644
--- a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp
+++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp
@@ -31,6 +31,17 @@
 
   ~IrNodeLabel() override = default;
 
+  void handle(const Bool* b) override {
+    if (b->isSymbolic()) {
+      label_ << "b" << b->name();
+    } else {
+      if (detail_level_ >= DetailLevel::Explicit) {
+        label_ << "b" << b->name() << "=";
+      }
+      label_ << *b->value();
+    }
+  }
+
   void handle(const Float* f) override {
     if (f->isSymbolic()) {
       label_ << "f" << f->name();
@@ -42,6 +53,17 @@
     }
   }
 
+  void handle(const Half* h) override {
+    if (h->isSymbolic()) {
+      label_ << "h" << h->name();
+    } else {
+      if (detail_level_ >= DetailLevel::Explicit) {
+        label_ << "h" << h->name() << "=";
+      }
+      label_ << *h->value();
+    }
+  }
+
   void handle(const Int* i) override {
     if (i->isSymbolic()) {
       label_ << "i" << i->name();
@@ -90,13 +112,11 @@
   }
 
   void handle(const Split* split) override {
-    label_ << "Split(IterDomain=" << split->in()
-           << ", factor=" << IrNodeLabel::gen(split->factor()) << ")";
+    label_ << "Split(factor=" << IrNodeLabel::gen(split->factor()) << ")";
   }
 
   void handle(const Merge* merge) override {
-    label_ << "Merge(IterDomainOuter=" << merge->outer()
-           << ", IterDomainInner=" << merge->inner() << ")";
+    label_ << "Merge";
   }
 
  private:
@@ -286,7 +306,7 @@
 
 void IrGraphGenerator::handle(const Statement* s) {
   OptInConstDispatch::handle(s);
-};
+}
 
 void IrGraphGenerator::handle(const Val* v) {
   if (!visited(v)) {
@@ -296,14 +316,14 @@
     }
     OptInConstDispatch::handle(v);
   }
-};
+}
 
 void IrGraphGenerator::handle(const Expr* e) {
   if (!visited(e)) {
     visited_.insert(e);
     OptInConstDispatch::handle(e);
   }
-};
+}
 
 void IrGraphGenerator::handle(const TensorDomain* td) {
   graph_def_ << "    " << getid(td) << " [label=\"TensorDomain\", "
@@ -341,10 +361,18 @@
   }
 }
 
+void IrGraphGenerator::handle(const Bool* b) {
+  printValue(b, IrNodeLabel::gen(b, detail_level_));
+}
+
 void IrGraphGenerator::handle(const Float* f) {
   printValue(f, IrNodeLabel::gen(f, detail_level_));
 }
 
+void IrGraphGenerator::handle(const Half* h) {
+  printValue(h, IrNodeLabel::gen(h, detail_level_));
+}
+
 void IrGraphGenerator::handle(const Int* i) {
   printValue(i, IrNodeLabel::gen(i, detail_level_));
 }
@@ -394,7 +422,7 @@
   label << uop->getUnaryOpType();
   printExpr(uop, label.str());
 
-  // UnaryOp inputs & outputs
+  // inputs & outputs
   addArc(uop->in(), uop);
   addArc(uop, uop->out());
 }
@@ -405,12 +433,43 @@
   label << bop->getBinaryOpType();
   printExpr(bop, label.str());
 
-  // BinaryOp inputs & outputs
+  // inputs & outputs
   addArc(bop->lhs(), bop);
   addArc(bop->rhs(), bop, "[color=blue]");
   addArc(bop, bop->out());
 }
 
+void IrGraphGenerator::handle(const TernaryOp* op) {
+  // node
+  std::stringstream label;
+  label << op->getTernaryOpType();
+  printExpr(op, label.str());
+
+  // inputs & outputs
+  addArc(op->in1(), op);
+  addArc(op->in2(), op, "[color=blue]");
+  addArc(op->in3(), op, "[color=brown]");
+  addArc(op, op->out());
+}
+
+void IrGraphGenerator::handle(const BroadcastOp* op) {
+  printExpr(op, "Broadcast");
+  addArc(op->in(), op);
+  addArc(op, op->out());
+}
+
+void IrGraphGenerator::handle(const ReductionOp* op) {
+  // node
+  std::stringstream label;
+  label << "Reduction(" << op->getReductionOpType() << ")";
+  printExpr(op, label.str());
+
+  // inputs & outputs
+  addArc(op->in(), op);
+  addArc(op->init(), op, "[color=blue]");
+  addArc(op, op->out());
+}
+
 void IrGraphGenerator::handle(const ForLoop* for_loop) {
   printExpr(for_loop, "ForLoop");
   addArc(for_loop->index(), for_loop);
diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.h b/torch/csrc/jit/codegen/cuda/ir_graphviz.h
index 386521d..c022ec0 100644
--- a/torch/csrc/jit/codegen/cuda/ir_graphviz.h
+++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.h
@@ -68,12 +68,17 @@
   void handle(const IterDomain*) override;
   void handle(const TensorIndex*) override;
 
+  void handle(const Bool*) override;
   void handle(const Float*) override;
+  void handle(const Half*) override;
   void handle(const Int*) override;
   void handle(const NamedScalar*) override;
 
   void handle(const UnaryOp*) override;
   void handle(const BinaryOp*) override;
+  void handle(const TernaryOp*) override;
+  void handle(const BroadcastOp*) override;
+  void handle(const ReductionOp*) override;
 
   void handle(const ForLoop*) override;
   void handle(const IfThenElse*) override;
diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
index 64be541..5451c75 100644
--- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
+++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
@@ -21,7 +21,8 @@
  * This value can be a symbolic value (defined after the kernel
  * is compiled) or a constant value (inlined into the kernel definition).
  */
-struct TORCH_CUDA_API Bool : public Val {
+class TORCH_CUDA_API Bool : public Val {
+ public:
   ~Bool() = default;
 
   Bool() : Val(ValType::Scalar, DataType::Bool), maybe_value_{c10::nullopt} {}
@@ -29,6 +30,8 @@
   Bool(bool _value)
       : Val(ValType::Scalar, DataType::Bool), maybe_value_{_value} {}
 
+  Bool(const Bool* src, IrCloner* ir_cloner);
+
   Bool(const Bool& other) = delete;
   Bool& operator=(const Bool& other) = delete;
 
@@ -41,7 +44,7 @@
   bool isConst() const {
     return maybe_value_.has_value();
   }
-  c10::optional<bool> value() const noexcept {
+  c10::optional<bool> value() const {
     return maybe_value_;
   }
 
@@ -56,7 +59,8 @@
  * Float32. This value can be a symbolic value (defined after the kernel
  * is compiled) or a constant value (inlined into the kernel definition).
  */
-struct TORCH_CUDA_API Float : public Val {
+class TORCH_CUDA_API Float : public Val {
+ public:
   using ScalarType = double;
 
   ~Float() = default;
@@ -66,6 +70,8 @@
   Float(ScalarType _value)
       : Val(ValType::Scalar, DataType::Float), maybe_value_{_value} {}
 
+  Float(const Float* src, IrCloner* ir_cloner);
+
   Float(const Float& other) = delete;
   Float& operator=(const Float& other) = delete;
 
@@ -78,7 +84,7 @@
   bool isConst() const {
     return maybe_value_.has_value();
   }
-  c10::optional<ScalarType> value() const noexcept {
+  c10::optional<ScalarType> value() const {
     return maybe_value_;
   }
 
@@ -93,7 +99,8 @@
  * This value can be a symbolic value (defined after the kernel
  * is compiled) or a constant value (inlined into the kernel definition).
  */
-struct TORCH_CUDA_API Half : public Val {
+class TORCH_CUDA_API Half : public Val {
+ public:
   ~Half() = default;
 
   Half() : Val(ValType::Scalar, DataType::Half), maybe_value_{c10::nullopt} {}
@@ -101,6 +108,8 @@
   Half(float _value)
       : Val(ValType::Scalar, DataType::Half), maybe_value_{_value} {}
 
+  Half(const Half* src, IrCloner* ir_cloner);
+
   Half(const Half& other) = delete;
   Half& operator=(const Half& other) = delete;
 
@@ -113,7 +122,7 @@
   bool isConst() const {
     return maybe_value_.has_value();
   }
-  c10::optional<float> value() const noexcept {
+  c10::optional<float> value() const {
     return maybe_value_;
   }
 
@@ -125,7 +134,8 @@
 
 // An Int64 value. If used for indexing it's set as size_t. Otherwise it's an
 // inlined literal in the kernel.
-struct TORCH_CUDA_API Int : public Val {
+class TORCH_CUDA_API Int : public Val {
+ public:
   using ScalarType = int64_t;
 
   ~Int() = default;
@@ -135,6 +145,8 @@
   Int(ScalarType _value)
       : Val(ValType::Scalar, DataType::Int), maybe_value_{_value} {}
 
+  Int(const Int* src, IrCloner* ir_cloner);
+
   Int(const Int& other) = delete;
   Int& operator=(const Int& other) = delete;
 
@@ -147,7 +159,7 @@
   bool isConst() const {
     return maybe_value_.has_value();
   }
-  c10::optional<ScalarType> value() const noexcept {
+  c10::optional<ScalarType> value() const {
     return maybe_value_;
   }
 
@@ -157,11 +169,13 @@
   const c10::optional<ScalarType> maybe_value_;
 };
 
-struct TransformReplay;
-struct TransformIter;
-struct OptOutMutator;
-struct LoopNestGenerator;
-struct GPULower;
+class ComputeAt;
+class TransformReplay;
+class TransformIter;
+class OptOutMutator;
+class LoopNestGenerator;
+class GPULower;
+
 /*
  * TensorView is our primitive Tensor Type used in code generation. It can be
  * thought of as representing physical memory, however, its dimensionality is
@@ -178,7 +192,8 @@
  * we iterate over the 3D TensorDomain [I, J, K], where K is the fastest
  * changing dimension.
  */
-struct TORCH_CUDA_API TensorView : public Val {
+class TORCH_CUDA_API TensorView : public Val {
+ public:
   ~TensorView() = default;
 
   TensorView(const TensorView& other) = delete;
@@ -194,11 +209,15 @@
   TensorView(const std::shared_ptr<Value>& jit_value)
       : TensorView(jit_value->type()->cast<c10::TensorType>()) {}
 
-  TensorDomain* domain() const noexcept {
+  TensorView(const TensorView* src, IrCloner* ir_cloner);
+
+  TensorDomain* domain() const {
     return domain_;
   }
 
   bool hasReduction() const;
+  bool hasBlockReduction() const;
+  bool hasGridReduction() const;
   bool hasBroadcast() const;
 
   // Is there an active computeAt TensorView/Axis
@@ -207,7 +226,7 @@
   }
 
   // Return the TensorView we're computing at
-  TensorView* getComputeAtView() const noexcept {
+  TensorView* getComputeAtView() const {
     return compute_at_view_;
   }
 
@@ -216,18 +235,20 @@
   IterDomain* axis(int pos) const;
 
   // Return compute at axis relative to this domain
-  unsigned int getThisComputeAtAxis() const noexcept {
+  unsigned int getThisComputeAtAxis() const {
     return this_compute_at_axis_;
   }
 
   // Return compute at axis relative to compute at view
-  unsigned int getRelativeComputeAtAxis() const noexcept {
+  unsigned int getRelativeComputeAtAxis() const {
     return relative_compute_at_axis_;
   }
 
   // Will check if an axis is inside computeAtAxis and will fetch the reference
   // to be used in code generation.
   std::pair<IterDomain*, TensorView*> getComputeAtAxis(int pos) {
+    TORCH_INTERNAL_ASSERT(
+        nDims() > 0, "Tried to access a computeAt axis in a 0-dim TensorView");
     if (!hasComputeAt() || getThisComputeAtAxis() <= (unsigned int)pos)
       return std::pair<IterDomain*, TensorView*>(axis(pos), this);
     return compute_at_view_->getComputeAtAxis(getComputeAtRelPos(pos));
@@ -259,26 +280,36 @@
   // Reorder axes according to old2new[old_pos] = new_pos
   TensorView* reorder(const std::unordered_map<int, int>& old2new);
 
-  /*
-   * WARNING: Does not return this TensorView, returns a new tensorview consumed
-   * to create this!! Take reduction axes out of this domain, and create a new
-   * domain. New domain will be used to create this domain. For example: TV1[I0,
-   * I1] = TV0[I0, R0, R1, I1] TV0->rfactor({1}) TV0 is transformed to ->
-   * TV0[I0, R1, I1] The TensorView returned is: TV2[I0, R0, I3, I1] The
-   * reduction will now beset as: TV1[I0, R1, I1] = TV2[I0, R0, I3, I1] TV0[I0,
-   * I1] = TV1[I0, R1, I1]
-   */
+  // WARNING: rFactor does not return this TensorView, ir returns a new
+  //  tensorview consumed by this!
+  //
+  // Take reduction axes out of this domain, and create a new
+  // domain. New domain will be used to create this domain.
+  //
+  // For example:
+  //  TV1[I0, R1, R2, I3] = TV0[I0, I1, I2, I3]
+  //
+  // After:
+  //  TV1->rfactor({1}), TV1 is transformed to -> TV1[I0, R2, I3]
+  //
+  // The TensorView returned is: TV2[I0, R1, I2, I3]
+  //
+  // The reduction will now beset as:
+  //  TV2[I0, R1, I2, I3] = TV0[I0, I1, I2, I3]
+  //  TV1[I0, R2, I3] = TV2[I0, R1, I2, I3]
+  //
   TensorView* rFactor(const std::vector<int>& axes);
 
-  MemoryType getMemoryType() const noexcept {
+  MemoryType getMemoryType() const {
     return memory_type_;
   }
 
   friend TORCH_CUDA_API TransformReplay;
-  // friend TORCH_CUDA_API TransformIter;
   friend TORCH_CUDA_API OptOutMutator;
-  friend TORCH_CUDA_API GPULower;
   friend TORCH_CUDA_API LoopNestGenerator;
+  friend ComputeAt;
+  friend void IrFixComputeAt(Fusion*);
+  friend void IrAdjustMemoryTypes(Fusion* fusion);
 
  protected:
   // Make an exact copy of this tensor (similar to clone()), however, also grabs
@@ -296,6 +327,10 @@
 
   void setComputeAt(TensorView* computeAtView, int axis);
 
+  // Set all computeAt members without checking any correctness. Useful for
+  // computeAt with outputs relative to eachother
+  void setComputeAt(TensorView* computeAtView, int thisPos, int relPos);
+
   void setMemoryType(MemoryType mt) {
     memory_type_ = mt;
     bool is_inp_or_out =
@@ -307,12 +342,6 @@
   }
 
  private:
-  // Transform this view like consumer, mark compute_at_(viw,axis)
-  void computeAt_impl(TensorView* consumer, int axis);
-
-  // Transform this view like producer, mark producer as compute_at_(this, axis)
-  void forwardComputeAt_impl(TensorView* producer, int axis);
-
   // Make a copy of the domain (used for Tensor based constructor), likely to be
   // removed soon.
   void copyDomain(const TensorDomain* td);
@@ -321,7 +350,8 @@
   int getComputeAtRelPos(int pos);
   void setThisComputeAtAxis();
 
-  TensorDomain* domain_;
+ private:
+  TensorDomain* domain_ = nullptr;
   TensorView* compute_at_view_ = nullptr;
   // compute at axis in compute at view
   unsigned int relative_compute_at_axis_ = 0;
diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
index f77f165..9420227 100644
--- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
+++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
@@ -5,7 +5,6 @@
 #include <torch/csrc/jit/codegen/cuda/fusion.h>
 #include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
-#include <torch/csrc/jit/codegen/cuda/tensor_meta.h>
 
 /*
  * Nodes in here should generally not be used by users. They should be behind
@@ -30,24 +29,27 @@
  *   3) Reduction across a dimension i.e. val.sum(axis=2)
  *   4) split/merge
  */
-struct TORCH_CUDA_API UnaryOp : public Expr {
+class TORCH_CUDA_API UnaryOp : public Expr {
+ public:
   ~UnaryOp() = default;
   UnaryOp(UnaryOpType _type, Val* _out, Val* _in);
 
+  UnaryOp(const UnaryOp* src, IrCloner* ir_cloner);
+
   UnaryOp(const UnaryOp& other) = delete;
   UnaryOp& operator=(const UnaryOp& other) = delete;
 
   UnaryOp(UnaryOp&& other) = delete;
   UnaryOp& operator=(UnaryOp&& other) = delete;
 
-  Val* out() const noexcept {
+  Val* out() const {
     return out_;
   }
-  Val* in() const noexcept {
+  Val* in() const {
     return in_;
   }
 
-  UnaryOpType getUnaryOpType() const noexcept {
+  UnaryOpType getUnaryOpType() const {
     return unary_op_type_;
   }
 
@@ -55,8 +57,8 @@
 
  private:
   const UnaryOpType unary_op_type_;
-  Val* const out_;
-  Val* const in_;
+  Val* const out_ = nullptr;
+  Val* const in_ = nullptr;
 };
 
 /*
@@ -65,27 +67,30 @@
  *  1) Add/mul/div/mod/sub (A * B)
  *  2) LT (A < B)
  */
-struct TORCH_CUDA_API BinaryOp : public Expr {
+class TORCH_CUDA_API BinaryOp : public Expr {
+ public:
   ~BinaryOp() = default;
   BinaryOp(BinaryOpType _type, Val* _out, Val* _lhs, Val* _rhs);
 
+  BinaryOp(const BinaryOp* src, IrCloner* ir_cloner);
+
   BinaryOp(const BinaryOp& other) = delete;
   BinaryOp& operator=(const BinaryOp& other) = delete;
 
   BinaryOp(BinaryOp&& other) = delete;
   BinaryOp& operator=(BinaryOp&& other) = delete;
 
-  Val* out() const noexcept {
+  Val* out() const {
     return out_;
   }
-  Val* lhs() const noexcept {
+  Val* lhs() const {
     return lhs_;
   }
-  Val* rhs() const noexcept {
+  Val* rhs() const {
     return rhs_;
   }
 
-  BinaryOpType getBinaryOpType() const noexcept {
+  BinaryOpType getBinaryOpType() const {
     return binary_op_type_;
   }
 
@@ -93,37 +98,40 @@
 
  private:
   const BinaryOpType binary_op_type_;
-  Val* const out_;
-  Val* const lhs_;
-  Val* const rhs_;
+  Val* const out_ = nullptr;
+  Val* const lhs_ = nullptr;
+  Val* const rhs_ = nullptr;
 };
 
 /*
  * Broadcast _in to match _out. broadcast_dims are relative to out. Where
  * broadcast_dims.size() + _in->nDims() == _out->nDims().
  */
-struct TORCH_CUDA_API BroadcastOp : public Expr {
+class TORCH_CUDA_API BroadcastOp : public Expr {
+ public:
   ~BroadcastOp() = default;
   BroadcastOp(Val* _out, Val* _in);
 
+  BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner);
+
   BroadcastOp(const BroadcastOp& other) = delete;
   BroadcastOp& operator=(const BroadcastOp& other) = delete;
 
   BroadcastOp(BroadcastOp&& other) = delete;
   BroadcastOp& operator=(BroadcastOp&& other) = delete;
 
-  Val* out() const noexcept {
+  Val* out() const {
     return out_;
   }
-  Val* in() const noexcept {
+  Val* in() const {
     return in_;
   }
 
   bool sameAs(const BroadcastOp* const other) const;
 
  private:
-  Val* const out_;
-  Val* const in_;
+  Val* const out_ = nullptr;
+  Val* const in_ = nullptr;
 };
 
 /*
@@ -133,64 +141,75 @@
  * tensor. The output tensors size will be the size of all
  * non-reduction/non-broadcast dimensions.
  */
-struct TORCH_CUDA_API ReductionOp : public Expr {
+class TORCH_CUDA_API ReductionOp : public Expr {
+ public:
   ~ReductionOp() = default;
   ReductionOp(BinaryOpType _reduction_op_type, Val* _init, Val* _out, Val* _in);
 
+  ReductionOp(const ReductionOp* src, IrCloner* ir_cloner);
+
   ReductionOp(const ReductionOp& other) = delete;
   ReductionOp& operator=(const ReductionOp& other) = delete;
 
   ReductionOp(ReductionOp&& other) = delete;
   ReductionOp& operator=(ReductionOp&& other) = delete;
 
-  Val* out() const noexcept {
+  Val* out() const {
     return out_;
   }
-  Val* in() const noexcept {
+  Val* in() const {
     return in_;
   }
-  Val* init() const noexcept {
+  Val* init() const {
     return init_;
   }
 
-  BinaryOpType getReductionOpType() const noexcept {
+  BinaryOpType getReductionOpType() const {
     return reduction_op_type_;
   }
 
   bool sameAs(const ReductionOp* const other) const;
 
+  std::vector<IterDomain*> getReductionDomains() const;
+
+  std::unordered_map<ParallelType, IterDomain*> getParallelReductionDomains()
+      const;
+
  private:
   const BinaryOpType reduction_op_type_;
-  Val* const init_;
-  Val* const out_;
-  Val* const in_;
+  Val* const init_ = nullptr;
+  Val* const out_ = nullptr;
+  Val* const in_ = nullptr;
 };
 
-struct TORCH_CUDA_API TernaryOp : public Expr {
+class TORCH_CUDA_API TernaryOp : public Expr {
+ public:
   ~TernaryOp() = default;
   TernaryOp(TernaryOpType _type, Val* _out, Val* _in1, Val* _in2, Val* _in3);
 
+  TernaryOp(const TernaryOp* src, IrCloner* ir_cloner);
+
   TernaryOp(const TernaryOp& other) = delete;
   TernaryOp& operator=(const TernaryOp& other) = delete;
 
   TernaryOp(TernaryOp&& other) = delete;
   TernaryOp& operator=(TernaryOp&& other) = delete;
 
-  Val* out() const noexcept {
+  Val* out() const {
     return out_;
   }
 
-  Val* in1() const noexcept {
+  Val* in1() const {
     return in1_;
   }
-  Val* in2() const noexcept {
+  Val* in2() const {
     return in2_;
   }
-  Val* in3() const noexcept {
+  Val* in3() const {
     return in3_;
   }
 
-  TernaryOpType getTernaryOpType() const noexcept {
+  TernaryOpType getTernaryOpType() const {
     return ternary_op_type_;
   }
 
@@ -198,10 +217,10 @@
 
  private:
   const TernaryOpType ternary_op_type_;
-  Val* const out_;
-  Val* const in1_;
-  Val* const in2_;
-  Val* const in3_;
+  Val* const out_ = nullptr;
+  Val* const in1_ = nullptr;
+  Val* const in2_ = nullptr;
+  Val* const in3_ = nullptr;
 };
 
 /*
@@ -210,7 +229,8 @@
  * IterDomains to form an ND iterable. We directly set parallization strategies
  * on IterDomains.
  */
-struct TORCH_CUDA_API IterDomain : public Val {
+class TORCH_CUDA_API IterDomain : public Val {
+ public:
   ~IterDomain() = default;
 
   IterDomain() = delete;
@@ -223,6 +243,8 @@
       bool _rfactor_domain = false,
       bool _broadcast_domain = false);
 
+  IterDomain(const IterDomain* src, IrCloner* ir_cloner);
+
   bool sameAs(const IterDomain* const other) const;
 
   // Returns a new IterDomain matching properties of this
@@ -241,15 +263,15 @@
       IterDomain* in,
       unsigned int factor);
 
-  bool isReduction() const noexcept {
+  bool isReduction() const {
     return is_reduction_domain_;
   }
 
-  bool isRFactorProduct() const noexcept {
+  bool isRFactorProduct() const {
     return is_rfactor_domain_;
   }
 
-  bool isBroadcast() const noexcept {
+  bool isBroadcast() const {
     return is_broadcast_domain_;
   }
 
@@ -280,10 +302,6 @@
 
   void parallelize(ParallelType t) {
     parallel_method_ = t;
-    if (isBlockDim())
-      TORCH_CHECK(
-          !isReduction(),
-          "Cannot parallelize reductions across a block dimension.");
 
     // Currently a limitation as we allocate shared memory as static (not based
     // off a dynamic size.)
@@ -307,11 +325,11 @@
           " .");
   }
 
-  ParallelType parallel_method() const noexcept {
+  ParallelType parallel_method() const {
     return parallel_method_;
   }
 
-  Val* start() const noexcept {
+  Val* start() const {
     return start_;
   }
   Val* extent() const;
@@ -326,13 +344,14 @@
   IterDomain& operator=(IterDomain&& other) = delete;
 
  private:
-  Val* const start_;
-  Val* const extent_;
+  Val* const start_ = nullptr;
+  Val* const extent_ = nullptr;
   ParallelType parallel_method_ = ParallelType::Serial;
-  bool is_reduction_domain_;
-  bool is_rfactor_domain_;
-  bool is_broadcast_domain_;
+  bool is_reduction_domain_ = false;
+  bool is_rfactor_domain_ = false;
+  bool is_broadcast_domain_ = false;
 };
+
 /*
  * TensorDomain holds a vector of IterDomains. It holds an IterDomain for every
  * logical axis in its associated tensor. TensorDomain does not directly hold
@@ -346,7 +365,9 @@
  * operations that take in a TensorDomain, applies a transformation and outputs
  * a tensor domain.
  */
-struct TORCH_CUDA_API TensorDomain : public Val {
+class TORCH_CUDA_API TensorDomain : public Val {
+ public:
+  TensorDomain() = delete;
   ~TensorDomain() = default;
 
   TensorDomain(const TensorDomain& other) = delete;
@@ -355,7 +376,7 @@
   TensorDomain(TensorDomain&& other) = delete;
   TensorDomain& operator=(TensorDomain&& other) = delete;
 
-  TensorDomain(std::vector<IterDomain*> _domain);
+  explicit TensorDomain(std::vector<IterDomain*> _domain);
 
   TensorDomain(
       std::vector<IterDomain*> _root_domain,
@@ -366,6 +387,8 @@
       std::vector<IterDomain*> _rfactor_domain,
       std::vector<IterDomain*> _domain);
 
+  TensorDomain(const TensorDomain* src, IrCloner* ir_cloner);
+
   std::vector<IterDomain*>::size_type nDims() const {
     return domain_.size();
   }
@@ -376,27 +399,29 @@
       const std::vector<IterDomain*>& lhs,
       const std::vector<IterDomain*>& rhs);
 
-  const std::vector<IterDomain*>& domain() const noexcept {
+  const std::vector<IterDomain*>& domain() const {
     return domain_;
   }
 
   bool hasReduction() const;
+  bool hasBlockReduction() const;
+  bool hasGridReduction() const;
   bool hasBroadcast() const;
   bool hasRFactor() const;
 
-  const std::vector<IterDomain*>& noReductions() const noexcept {
+  const std::vector<IterDomain*>& noReductions() const {
     return no_reduction_domain_;
   }
 
-  const std::vector<IterDomain*>& noBroadcasts() const noexcept {
+  const std::vector<IterDomain*>& noBroadcasts() const {
     return no_bcast_domain_;
   }
 
-  const std::vector<IterDomain*>& rootDomain() const noexcept {
+  const std::vector<IterDomain*>& rootDomain() const {
     return root_domain_;
   };
 
-  const std::vector<IterDomain*>& rfactorDomain() const noexcept {
+  const std::vector<IterDomain*>& rfactorDomain() const {
     return rfactor_domain_;
   };
 
@@ -447,7 +472,8 @@
  * Representation a split on an IterDomain by "factor"
  * TODO: Implement split by nparts
  */
-struct TORCH_CUDA_API Split : public Expr {
+class TORCH_CUDA_API Split : public Expr {
+ public:
   ~Split() = default;
 
   Split(const Split& other) = delete;
@@ -458,25 +484,27 @@
 
   Split(IterDomain* _outer, IterDomain* _inner, IterDomain* _in, Int* _factor);
 
-  IterDomain* outer() const noexcept {
+  Split(const Split* src, IrCloner* ir_cloner);
+
+  IterDomain* outer() const {
     return outer_;
   }
-  IterDomain* inner() const noexcept {
+  IterDomain* inner() const {
     return inner_;
   }
-  IterDomain* in() const noexcept {
+  IterDomain* in() const {
     return in_;
   }
-  Int* factor() const noexcept {
+  Int* factor() const {
     return factor_;
   }
   bool sameAs(const Split* const other) const;
 
  private:
-  IterDomain* const outer_;
-  IterDomain* const inner_;
-  IterDomain* const in_;
-  Int* const factor_;
+  IterDomain* const outer_ = nullptr;
+  IterDomain* const inner_ = nullptr;
+  IterDomain* const in_ = nullptr;
+  Int* const factor_ = nullptr;
 };
 
 /*
@@ -486,32 +514,35 @@
  * if there is one.
  * TODO: Should this be a unary op type?
  */
-struct TORCH_CUDA_API Merge : public Expr {
+class TORCH_CUDA_API Merge : public Expr {
+ public:
   ~Merge() = default;
   Merge(IterDomain* _out, IterDomain* _outer, IterDomain* _inner);
 
+  Merge(const Merge* src, IrCloner* ir_cloner);
+
   Merge(const Merge& other) = delete;
   Merge& operator=(const Merge& other) = delete;
 
   Merge(Merge&& other) = delete;
   Merge& operator=(Merge&& other) = delete;
 
-  IterDomain* out() const noexcept {
+  IterDomain* out() const {
     return out_;
   }
-  IterDomain* outer() const noexcept {
+  IterDomain* outer() const {
     return outer_;
   }
-  IterDomain* inner() const noexcept {
+  IterDomain* inner() const {
     return inner_;
   }
 
   bool sameAs(const Merge* const other) const;
 
  private:
-  IterDomain* const out_;
-  IterDomain* const outer_;
-  IterDomain* const inner_;
+  IterDomain* const out_ = nullptr;
+  IterDomain* const outer_ = nullptr;
+  IterDomain* const inner_ = nullptr;
 };
 
 /*
@@ -523,7 +554,8 @@
  * TODO: Change implmentation of Exprs contained in the scope to be more similar
  * to Fusion where we can do proper dependency analysis.
  */
-struct TORCH_CUDA_API ForLoop : public Expr {
+class TORCH_CUDA_API ForLoop : public Expr {
+ public:
   ~ForLoop() = default;
   ForLoop(
       Val* _index,
@@ -531,38 +563,40 @@
       const std::vector<Expr*>& _body = {},
       Expr* parent_scope = nullptr);
 
+  ForLoop(const ForLoop* src, IrCloner* ir_cloner);
+
   ForLoop(const ForLoop& other) = delete;
   ForLoop& operator=(const ForLoop& other) = delete;
 
   ForLoop(ForLoop&& other) = delete;
   ForLoop& operator=(ForLoop&& other) = delete;
 
-  Val* index() const noexcept {
+  Val* index() const {
     return index_;
   }
 
-  IterDomain* iter_domain() const noexcept {
+  IterDomain* iter_domain() const {
     return iter_domain_;
   }
 
-  Scope& body() noexcept {
+  Scope& body() {
     return body_;
   }
 
-  const Scope& constBody() const noexcept {
+  const Scope& constBody() const {
     return body_;
   }
 
   bool sameAs(const ForLoop* other) const;
-  Expr* parentScope() const noexcept {
+  Expr* parentScope() const {
     return parent_scope_;
   }
 
  private:
-  Val* const index_;
+  Val* const index_ = nullptr;
   IterDomain* const iter_domain_;
   Scope body_;
-  Expr* parent_scope_;
+  Expr* parent_scope_ = nullptr;
 };
 
 /*
@@ -574,7 +608,8 @@
  * TODO: Change implmentation of Exprs contained in the scope to be more similar
  * to Fusion where we can do proper dependency analysis.
  */
-struct TORCH_CUDA_API IfThenElse : public Expr {
+class TORCH_CUDA_API IfThenElse : public Expr {
+ public:
   ~IfThenElse() = default;
   IfThenElse(
       Bool* _cond,
@@ -582,47 +617,49 @@
       const std::vector<Expr*>& _else_body = {},
       Expr* _parent_scope = nullptr);
 
+  IfThenElse(const IfThenElse* src, IrCloner* ir_cloner);
+
   IfThenElse(const IfThenElse& other) = delete;
   IfThenElse& operator=(const IfThenElse& other) = delete;
 
   IfThenElse(IfThenElse&& other) = delete;
   IfThenElse& operator=(IfThenElse&& other) = delete;
 
-  Bool* cond() const noexcept {
+  Bool* cond() const {
     return cond_;
   }
 
-  const Scope& constBody() const noexcept {
+  const Scope& constBody() const {
     return body_;
   }
 
-  const Scope& constElseBody() const noexcept {
+  const Scope& constElseBody() const {
     return else_body_;
   }
 
-  Scope& body() noexcept {
+  Scope& body() {
     return body_;
   }
 
-  Scope& elseBody() noexcept {
+  Scope& elseBody() {
     return else_body_;
   }
 
-  bool hasElse() const noexcept {
+  bool hasElse() const {
     return !else_body_.empty();
   }
 
   bool sameAs(const IfThenElse* other) const;
 
-  Expr* parentScope() const noexcept {
+  Expr* parentScope() const {
     return parent_scope_;
   }
 
  private:
-  Bool* const cond_;
+  Bool* const cond_ = nullptr;
   Scope body_;
   Scope else_body_;
-  Expr* parent_scope_;
+  Expr* parent_scope_ = nullptr;
 };
 
 /*
@@ -630,7 +667,8 @@
  * TensorView. It is not the flattened index, which needs to be computed using
  * stride information.
  */
-struct TORCH_CUDA_API TensorIndex : public Val {
+class TORCH_CUDA_API TensorIndex : public Val {
+ public:
   ~TensorIndex() = default;
 
   TensorIndex(const TensorIndex& other) = delete;
@@ -655,6 +693,8 @@
         "Cannot index with a value other than an int.");
   }
 
+  TensorIndex(const TensorIndex* src, IrCloner* ir_cloner);
+
   std::vector<Val*>::size_type nDims() const {
     return indices_.size();
   }
@@ -663,18 +703,18 @@
   // uint.
   Val* index(int i) const;
 
-  const std::vector<Val*>& indices() const noexcept {
+  const std::vector<Val*>& indices() const {
     return indices_;
   }
 
-  const TensorView* view() const noexcept {
+  const TensorView* view() const {
     return view_;
   }
 
   bool sameAs(const TensorIndex* const other) const;
 
  private:
-  const TensorView* view_;
+  const TensorView* view_ = nullptr;
   std::vector<Val*> indices_;
 };
 
@@ -687,7 +727,8 @@
  * TODO: The components of Allocate like Type and Name could be separated from
  * the the assocated TensorView.  Perhaps that is more appropriate?
  */
-struct TORCH_CUDA_API Allocate : public Expr {
+class TORCH_CUDA_API Allocate : public Expr {
+ public:
   ~Allocate() = default;
 
   Allocate(const Allocate& other) = delete;
@@ -698,19 +739,21 @@
 
   Allocate(Val* _tv, Val* size);
 
+  Allocate(const Allocate* src, IrCloner* ir_cloner);
+
   DataType buf_type() const;
-  Val* extent() const noexcept {
+  Val* extent() const {
     return extent_;
   }
-  Val* buffer() const noexcept {
+  Val* buffer() const {
     return buffer_;
   }
 
   bool sameAs(const Allocate* other) const;
 
  private:
-  Val* buffer_;
-  Val* extent_;
+  Val* buffer_ = nullptr;
+  Val* extent_ = nullptr;
 };
 
 /*
@@ -720,20 +763,23 @@
  * - blockDim.z
  * - T3.stride[2]
  */
-struct TORCH_CUDA_API NamedScalar : public Val {
+class TORCH_CUDA_API NamedScalar : public Val {
+ public:
   ~NamedScalar() = default;
   NamedScalar() = delete;
 
   NamedScalar(std::string _name, DataType dtype)
       : Val(ValType::NamedScalar, dtype), name_(_name) {}
 
+  NamedScalar(const NamedScalar* src, IrCloner* ir_cloner);
+
   NamedScalar(const NamedScalar& other) = delete;
   NamedScalar& operator=(const NamedScalar& other) = delete;
 
   NamedScalar(NamedScalar&& other) = delete;
   NamedScalar& operator=(NamedScalar&& other) = delete;
 
-  const std::string& name() const noexcept {
+  const std::string& name() const {
     return name_;
   }
 
diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp
index 6a9146f..7923218 100644
--- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp
+++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp
@@ -8,30 +8,49 @@
 namespace jit {
 namespace fuser {
 
-namespace {
 // Make sure we can inline something, before we attempt to.
-void check_inlineable(const IRInputOutput* const irio) {
-  for (auto inp : irio->inputs())
+static void checkInlineable(const Expr* expr) {
+  for (auto input : expr->inputs()) {
     TORCH_CHECK(
-        inp->isScalar(),
+        input->isScalar(),
         "Printing inline computations involving values other than scalars is not currently supported.");
+  }
   TORCH_CHECK(
-      irio->nOutputs() == 1,
+      expr->outputs().size() == 1,
       "Cannot print inline computations if there's more than one output.");
   TORCH_CHECK(
-      irio->output(0)->isScalar(),
+      expr->output(0)->isScalar(),
       "Printing inline computations involving values other than scalars is not currently supported.");
 }
-} // namespace
+
+void IRPrinter::handle(const Statement* s) {
+  OptInConstDispatch::handle(s);
+}
+
+void IRPrinter::handle(const Val* v) {
+  if (follow_val_map) {
+    // Follow a single maping (permutation chains are not expected)
+    v = FusionGuard::getCurFusion()->loweredVal(v);
+    TORCH_INTERNAL_ASSERT(v == FusionGuard::getCurFusion()->loweredVal(v));
+  }
+  OptInConstDispatch::handle(v);
+}
+
+void IRPrinter::handle(const Expr* e) {
+  OptInConstDispatch::handle(e);
+}
 
 void IRPrinter::printHeader(Fusion* fusion, const std::string& kernel_name_) {
   os << "__global__ void " << kernel_name_ << "(";
 
-  std::deque<Val*> vals;
-  for (decltype(fusion->nInputs()) i{0}; i < fusion->nInputs(); i++)
-    vals.push_back(fusion->input(i));
-  for (decltype(fusion->nOutputs()) i{0}; i < fusion->nOutputs(); i++)
-    vals.push_back(fusion->output(i));
+  std::vector<Val*> vals;
+
+  for (auto val : fusion->inputs()) {
+    vals.push_back(val);
+  }
+  for (auto val : fusion->outputs()) {
+    vals.push_back(val);
+  }
 
   for (Val* val : vals) {
     switch (val->getValType().value()) {
@@ -57,6 +76,11 @@
 
   if (fusion->hasRNG())
     os << ", unsigned long long seed, unsigned long long offset";
+
+  if (fusion->hasGridReduction()) {
+    os << ", void* work_buf, unsigned* sync_flags";
+  }
+
   os << "){\n";
   indent_size++;
   if (fusion->hasRNG()) {
@@ -65,6 +89,12 @@
     indent();
     os << "Philox rnd(seed, idx, offset);\n";
   }
+  if (fusion->hasBlockReduction() || fusion->hasGridReduction()) {
+    indent();
+    // TODO: Dynamic sizing possible? blockReduce originally used 1024
+    // values of a given type
+    os << "__shared__ float shared_mem[1024];\n";
+  }
 }
 
 void IRPrinter::handle(Fusion* fusion) {
@@ -74,9 +104,13 @@
   }
 }
 
-void IRPrinter::handle(const TensorDomain* const td) {
+void IRPrinter::handle(const TensorDomain* td) {
+  if (td->nDims() == 0) {
+    os << "[ 0 ]";
+    return;
+  }
   os << "[ ";
-  for (std::vector<const IterDomain*>::size_type i = 0; i < td->nDims(); i++) {
+  for (size_t i = 0; i < td->nDims(); i++) {
     handle(td->axis(i));
     if (i != td->nDims() - 1)
       os << ", ";
@@ -84,7 +118,7 @@
   os << " ]";
 }
 
-void IRPrinter::handle(const TensorView* const tv) {
+void IRPrinter::handle(const TensorView* tv) {
   os << "T" << tv->name();
   handle(tv->domain());
 
@@ -95,7 +129,7 @@
   }
 }
 
-void IRPrinter::handle(const IterDomain* const id) {
+void IRPrinter::handle(const IterDomain* id) {
   if (id->isReduction())
     os << "r";
   else if (id->isBroadcast())
@@ -127,9 +161,14 @@
     os << "rf";
 }
 
-void IRPrinter::handle(const TensorIndex* const ti) {
-  os << "T" << ti->view()->name() << "[ ";
+void IRPrinter::handle(const TensorIndex* ti) {
+  os << "T" << ti->view()->name();
+  if (ti->nDims() == 0) {
+    os << "[ 0 ]";
+    return;
+  }
 
+  os << "[ ";
   bool first = true;
   for (auto* ind : ti->indices()) {
     if (!first)
@@ -140,7 +179,7 @@
   os << " ]";
 }
 
-void IRPrinter::handle(const Bool* const b) {
+void IRPrinter::handle(const Bool* b) {
   if (print_inline_ && FusionGuard::getCurFusion()->origin(b) != nullptr) {
     os << "( ";
     handle(FusionGuard::getCurFusion()->origin(b));
@@ -155,7 +194,7 @@
   }
 }
 
-void IRPrinter::handle(const Float* const f) {
+void IRPrinter::handle(const Float* f) {
   if (print_inline_ && FusionGuard::getCurFusion()->origin(f) != nullptr) {
     os << "( ";
     handle(FusionGuard::getCurFusion()->origin(f));
@@ -173,7 +212,7 @@
   }
 }
 
-void IRPrinter::handle(const Half* const h) {
+void IRPrinter::handle(const Half* h) {
   if (print_inline_ && FusionGuard::getCurFusion()->origin(h) != nullptr) {
     os << "( ";
     handle(FusionGuard::getCurFusion()->origin(h));
@@ -188,12 +227,19 @@
   }
 }
 
-void IRPrinter::handle(const Int* const i) {
-  if (print_inline_ && FusionGuard::getCurFusion()->origin(i) != nullptr) {
-    os << "( ";
-    handle(FusionGuard::getCurFusion()->origin(i));
-    os << " )";
-    return;
+void IRPrinter::handle(const Int* i) {
+  // Make sure we didn't bypass the value mapping
+  // (for example calling IRPrinter::handle() with a Int*)
+  TORCH_CHECK(
+      !follow_val_map || i == FusionGuard::getCurFusion()->loweredVal(i));
+
+  if (print_inline_) {
+    if (auto def = FusionGuard::getCurFusion()->origin(i)) {
+      os << "( ";
+      handle(def);
+      os << " )";
+      return;
+    }
   }
 
   if (i->isSymbolic()) {
@@ -203,27 +249,21 @@
   }
 }
 
-void IRPrinter::handle(const NamedScalar* const i) {
+void IRPrinter::handle(const NamedScalar* i) {
   os << i->name();
 }
 
-namespace {
-
-bool isTV(const Val* const val) {
-  return (
-      val->getValType().value() == ValType::TensorView ||
-      val->getValType().value() == ValType::TensorIndex);
+static bool isTV(const Val* val) {
+  return val->getValType().value() == ValType::TensorView ||
+      val->getValType().value() == ValType::TensorIndex;
 }
 
 // Check if we're a TensorView op that we can generate code for.
-bool isTVOp(const Expr* const expr) {
-  if (expr->nOutputs() == 1 && isTV(expr->output(0)))
-    return true;
-  return false;
+static bool isTVOp(const Expr* expr) {
+  return expr->outputs().size() == 1 && isTV(expr->outputs().front());
 }
-} // namespace
 
-void IRPrinter::handle(const UnaryOp* const uop) {
+void IRPrinter::handle(const UnaryOp* uop) {
   bool istvop = isTVOp(uop);
   if (!print_inline_) {
     indent();
@@ -235,7 +275,7 @@
     }
     os << " = ";
   } else {
-    check_inlineable(uop);
+    checkInlineable(uop);
   }
 
   if (auto inline_uop = inline_op_str(uop->getUnaryOpType())) {
@@ -269,7 +309,7 @@
     os << ";\n";
 }
 
-void IRPrinter::handle(const BinaryOp* const bop) {
+void IRPrinter::handle(const BinaryOp* bop) {
   bool istvop = isTVOp(bop);
   if (!print_inline_) {
     indent();
@@ -284,7 +324,7 @@
 
     os << " = ";
   } else {
-    check_inlineable(bop);
+    checkInlineable(bop);
   }
 
   if (auto inline_bop = inline_op_str(bop->getBinaryOpType())) {
@@ -314,7 +354,7 @@
     os << ";\n";
 }
 
-void IRPrinter::handle(const TernaryOp* const top) {
+void IRPrinter::handle(const TernaryOp* top) {
   bool istvop = isTVOp(top);
   if (!print_inline_) {
     indent();
@@ -329,7 +369,7 @@
 
     os << " = ";
   } else {
-    check_inlineable(top);
+    checkInlineable(top);
   }
 
   os << top->getTernaryOpType() << "(";
@@ -355,7 +395,7 @@
     os << ";\n";
 }
 
-void IRPrinter::handle(const ReductionOp* const rop) {
+void IRPrinter::handle(const ReductionOp* rop) {
   // Check if we've lowered yet.
 
   bool lowered = rop->out()->getValType() == ValType::TensorIndex;
@@ -367,52 +407,77 @@
     return;
   }
 
-  TensorIndex* out = static_cast<TensorIndex*>(rop->out());
+  auto out = rop->out()->as<TensorIndex>();
   auto vec_domain = out->view()->domain()->domain();
 
-  IterDomain *tidx = nullptr, *tidy = nullptr, *tidz = nullptr;
-  bool is_thread_reduce = false;
-  for (auto id : vec_domain) {
-    if (id->isThreadDim() && id->isReduction()) {
-      switch (id->parallel_method()) {
-        case (ParallelType::TIDz):
-          tidz = id;
-          break;
-        case (ParallelType::TIDy):
-          tidy = id;
-          break;
-        case (ParallelType::TIDx):
-          tidx = id;
-          break;
-        default:
-          TORCH_INTERNAL_ASSERT(
-              false, "Did not recognize parallel type for reduction.");
-      }
-      is_thread_reduce = true;
-    }
-  }
+  bool has_block_reduce = out->view()->hasBlockReduction();
+  bool has_grid_reduce = out->view()->hasGridReduction();
 
-  if (!is_thread_reduce) {
+  if (!has_block_reduce && !has_grid_reduce) {
     handle(new BinaryOp(rop->getReductionOpType(), out, out, rop->in()));
     return;
   }
+
+  auto par_domains = rop->getParallelReductionDomains();
+  bool tidx = par_domains.find(ParallelType::TIDx) != par_domains.end();
+  bool tidy = par_domains.find(ParallelType::TIDy) != par_domains.end();
+  bool tidz = par_domains.find(ParallelType::TIDz) != par_domains.end();
+  bool bidx = par_domains.find(ParallelType::BIDx) != par_domains.end();
+  bool bidy = par_domains.find(ParallelType::BIDy) != par_domains.end();
+  bool bidz = par_domains.find(ParallelType::BIDz) != par_domains.end();
+
   auto d_type = rop->out()->getDataType().value();
   auto op_type = rop->getReductionOpType();
-  indent();
-  // Thread all reduce.
-  os << "blockReduce< " << (tidx != nullptr ? "true" : "false") << ", "
-     << (tidy != nullptr ? "true" : "false") << ", "
-     << (tidz != nullptr ? "true" : "false") << " >"
-     << " ( ";
-  handle(rop->out());
-  os << ", ";
-  handle(rop->in());
-  os << ", ";
-  os << "reduction_" << op_type << "_" << d_type;
-  os << ");\n";
+  const std::string block_result = "block_result";
+  if (has_block_reduce) {
+    if (has_grid_reduce) {
+      indent();
+      os << d_type << " " << block_result << ";\n";
+    }
+    indent();
+    // Thread all reduce.
+    os << "blockReduce< " << (tidx ? "true" : "false") << ", "
+       << (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false") << " >"
+       << " ( ";
+    if (has_grid_reduce) {
+      os << block_result;
+    } else {
+      handle(rop->out());
+    }
+    os << ", ";
+    handle(rop->in());
+    os << ", ";
+    os << "reduction_" << op_type << "_" << d_type;
+    os << ", threadIdx, blockDim";
+    os << ", reinterpret_cast<" << d_type << "*>(shared_mem)";
+    os << ");\n";
+  }
+  if (has_grid_reduce) {
+    indent();
+    // Since block-level reduction is already done, those dimensions
+    // with tidx/y/z being true do not participate in the grid reduction.
+    os << "reduction::gridReduce< " << (bidx ? "true" : "false") << ", "
+       << (bidy ? "true" : "false") << ", " << (bidz ? "true" : "false") << ", "
+       << (!tidx ? "true" : "false") << ", " << (!tidy ? "true" : "false")
+       << ", " << (!tidz ? "true" : "false") << " >"
+       << " ( ";
+    handle(rop->out());
+    os << ", ";
+    if (has_block_reduce) {
+      os << block_result;
+    } else {
+      handle(rop->in());
+    }
+    os << ", ";
+    os << "reduction_" << op_type << "_" << d_type;
+    os << ", static_cast<" << d_type << "*>(work_buf)";
+    os << ", sync_flags";
+    os << ", reinterpret_cast<" << d_type << "*>(shared_mem)";
+    os << ");\n";
+  }
 }
 
-void IRPrinter::handle(const BroadcastOp* const bop) {
+void IRPrinter::handle(const BroadcastOp* bop) {
   indent();
   handle(bop->out());
   os << "\n";
@@ -424,8 +489,8 @@
   os << ";\n";
 }
 
-void IRPrinter::handle(const ForLoop* const fl) {
-  if (fl->iter_domain()->isThread()) {
+void IRPrinter::handle(const ForLoop* fl) {
+  if (fl->iter_domain()->isThread() || fl->iter_domain()->isBroadcast()) {
     for (auto& expr : fl->constBody().exprs())
       handle(expr);
     return;
@@ -452,7 +517,7 @@
   os << "}\n";
 }
 
-void IRPrinter::handle(const IfThenElse* const ite) {
+void IRPrinter::handle(const IfThenElse* ite) {
   indent();
 
   // IF
@@ -480,7 +545,7 @@
   os << "}\n";
 }
 
-void IRPrinter::handle(const Allocate* const a) {
+void IRPrinter::handle(const Allocate* a) {
   indent();
   os << a->buf_type();
   if (a->buffer()->getValType() == ValType::TensorView) {
@@ -501,7 +566,7 @@
   }
 }
 
-void IRPrinter::handle(const Split* const s) {
+void IRPrinter::handle(const Split* s) {
   os << "Split: ";
   handle(s->in());
   os << " by factor " << s->factor() << " -> ";
@@ -511,20 +576,20 @@
   os << "\n";
 }
 
-void IRPrinter::handle(const Merge* const m) {
+void IRPrinter::handle(const Merge* m) {
   os << "Merge: ";
   handle(m->outer());
   os << " and ";
   handle(m->inner());
   os << " -> ";
   handle(m->out());
-  --indent_size;
   os << "\n";
 }
 
 namespace {
 
-struct ReductionOps : OptOutDispatch {
+class ReductionOps : OptOutDispatch {
+ public:
   std::set<std::pair<BinaryOpType, DataType>> rops;
   void handle(ReductionOp* rop) override {
     rops.emplace(std::pair<BinaryOpType, DataType>{
@@ -541,6 +606,7 @@
     return ROPs.rops;
   }
 };
+
 } // namespace
 
 void IRPrinter::printReductionOps(Fusion* fusion) {
@@ -566,7 +632,6 @@
     const std::vector<Expr*>& exprs,
     const std::string& kernel_name) {
   Fusion* fusion = FusionGuard::getCurFusion();
-
   printReductionOps(fusion);
   printHeader(fusion, kernel_name);
   for (auto* expr : exprs) {
@@ -575,7 +640,7 @@
   os << "}\n";
 }
 
-std::ostream& operator<<(std::ostream& os, const Statement* const stmt) {
+std::ostream& operator<<(std::ostream& os, const Statement* stmt) {
   IRPrinter p(os);
   p.handle(stmt);
   return os;
diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h
index 5034dbf..1857648 100644
--- a/torch/csrc/jit/codegen/cuda/ir_iostream.h
+++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h
@@ -10,37 +10,35 @@
 namespace jit {
 namespace fuser {
 
-struct Fusion;
+class Fusion;
 
-struct Statement;
+class Statement;
 
-struct Val;
-struct Expr;
+class Val;
+class Expr;
 
-struct UnaryOp;
-struct BinaryOp;
-struct TernaryOp;
-struct ReductionOp;
-struct BroadcastOp;
+class UnaryOp;
+class BinaryOp;
+class TernaryOp;
+class ReductionOp;
+class BroadcastOp;
 
-struct ForLoop;
-struct IfThenElse;
+class ForLoop;
+class IfThenElse;
 
-struct TensorDomain;
-struct TensorView;
-struct IterDomain;
-struct TensorIndex;
+class TensorDomain;
+class TensorView;
+class IterDomain;
+class TensorIndex;
 
-struct TensorContiguity;
+class Split;
+class Merge;
 
-struct Split;
-struct Merge;
-
-struct Bool;
-struct Float;
-struct Half;
-struct Int;
-struct Add;
+class Bool;
+class Float;
+class Half;
+class Int;
+class Add;
 
 /*
  * Define pretty printing functions for all nodes. handle is used so we can take
@@ -50,7 +48,7 @@
  * stream operator <<.
  */
 
-struct TORCH_CUDA_API IRPrinter : public OptInConstDispatch {
+class TORCH_CUDA_API IRPrinter : public OptInConstDispatch {
  public:
   std::ostream& os;
   bool print_inline_ = false;
@@ -58,6 +56,9 @@
   // Track the indentation size for pretty printing
   int indent_size = 0;
 
+  // Handle value mapping
+  bool follow_val_map = true;
+
   // Indent the generated code
   void indent() {
     for (int i = 0; i < indent_size; i++)
@@ -72,54 +73,48 @@
 
   IRPrinter(std::ostream& _os) : os(_os) {}
 
-  virtual void handle(Fusion* const f);
+  virtual void handle(Fusion* f);
 
   // handle calls some non const fusion ops,
   // eventhough fusion should remain unchanged.
   // Need to look into this.
-  virtual void handle(const Fusion* const f) {
+  virtual void handle(const Fusion* f) {
     handle(const_cast<Fusion*>(f));
   }
+
   virtual void handle(Fusion& f) {
     handle(&f);
   }
 
-  virtual void handle(const Statement* const s) {
-    OptInConstDispatch::handle(s);
-  };
+  void handle(const Statement* s) override;
+  void handle(const Val* v) override;
+  void handle(const Expr* e) override;
 
-  virtual void handle(const Val* const v) {
-    OptInConstDispatch::handle(v);
-  };
-  virtual void handle(const Expr* const e) {
-    OptInConstDispatch::handle(e);
-  };
+  void handle(const TensorDomain*) override;
+  void handle(const TensorView*) override;
+  void handle(const IterDomain*) override;
+  void handle(const TensorIndex*) override;
 
-  virtual void handle(const TensorDomain* const) override;
-  virtual void handle(const TensorView* const) override;
-  virtual void handle(const IterDomain* const) override;
-  virtual void handle(const TensorIndex* const) override;
+  void handle(const Bool*) override;
+  void handle(const Float*) override;
+  void handle(const Half*) override;
+  void handle(const Int*) override;
+  void handle(const NamedScalar*) override;
 
-  virtual void handle(const Bool* const) override;
-  virtual void handle(const Float* const) override;
-  virtual void handle(const Half* const) override;
-  virtual void handle(const Int* const) override;
-  virtual void handle(const NamedScalar* const) override;
+  void handle(const UnaryOp*) override;
+  void handle(const BinaryOp*) override;
+  void handle(const TernaryOp*) override;
+  void handle(const ReductionOp*) override;
+  void handle(const BroadcastOp*) override;
 
-  virtual void handle(const UnaryOp* const) override;
-  virtual void handle(const BinaryOp* const) override;
-  virtual void handle(const TernaryOp* const) override;
-  virtual void handle(const ReductionOp* const) override;
-  virtual void handle(const BroadcastOp* const) override;
+  void handle(const ForLoop*) override;
+  void handle(const IfThenElse*) override;
+  void handle(const Allocate*) override;
 
-  virtual void handle(const ForLoop* const) override;
-  virtual void handle(const IfThenElse* const) override;
-  virtual void handle(const Allocate* const) override;
+  void handle(const Split*) override;
+  void handle(const Merge*) override;
 
-  virtual void handle(const Split* const) override;
-  virtual void handle(const Merge* const) override;
-
-  void print_inline(const Statement* const stmt) {
+  void print_inline(const Statement* stmt) {
     bool prev = print_inline_;
     print_inline_ = true;
     handle(stmt);
@@ -135,7 +130,7 @@
 
 TORCH_CUDA_API std::ostream& operator<<(
     std::ostream& os,
-    const Statement* const stmt);
+    const Statement* stmt);
 TORCH_CUDA_API std::ostream& operator<<(std::ostream& os, Fusion* f);
 TORCH_CUDA_API std::ostream& operator<<(std::ostream& os, Fusion& f);
 
diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp
index 03ff9c4..1ad4f30 100644
--- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp
+++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp
@@ -1,4 +1,5 @@
 #include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
 #include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
 #include <torch/csrc/jit/codegen/cuda/transform_iter.h>
@@ -13,36 +14,8 @@
 namespace fuser {
 
 namespace {
-struct ScalarCheck : OptInDispatch {
-  Val* v1_;
-  Val* v2_;
-  bool same = false;
 
-  void handle(Bool* b) override {
-    same = static_cast<Bool*>(v1_)->sameAs(static_cast<Bool*>(v2_));
-  }
-
-  void handle(Float* f) override {
-    same = static_cast<Float*>(v1_)->sameAs(static_cast<Float*>(v2_));
-  }
-
-  void handle(Half* h) override {
-    same = static_cast<Half*>(v1_)->sameAs(static_cast<Half*>(v2_));
-  }
-
-  void handle(Int* i) override {
-    same = static_cast<Int*>(v1_)->sameAs(static_cast<Int*>(v2_));
-  }
-
-  void handle(NamedScalar* ns) override {
-    same =
-        static_cast<NamedScalar*>(v1_)->sameAs(static_cast<NamedScalar*>(v2_));
-  }
-
-  ScalarCheck(Val* _v1, Val* _v2) : v1_(_v1), v2_(_v2) {
-    OptInDispatch::handle(v1_);
-  }
-
+class ScalarCheck : OptInDispatch {
  public:
   static bool sameAs(Val* v1, Val* v2) {
     if (v1 == v2)
@@ -55,29 +28,73 @@
       return false;
 
     ScalarCheck sc(v1, v2);
-    return sc.same;
+    return sc.same_;
   }
+
+ private:
+  void handle(Bool* b) override {
+    same_ = static_cast<Bool*>(v1_)->sameAs(static_cast<Bool*>(v2_));
+  }
+
+  void handle(Float* f) override {
+    same_ = static_cast<Float*>(v1_)->sameAs(static_cast<Float*>(v2_));
+  }
+
+  void handle(Half* h) override {
+    same_ = static_cast<Half*>(v1_)->sameAs(static_cast<Half*>(v2_));
+  }
+
+  void handle(Int* i) override {
+    same_ = static_cast<Int*>(v1_)->sameAs(static_cast<Int*>(v2_));
+  }
+
+  void handle(NamedScalar* ns) override {
+    same_ =
+        static_cast<NamedScalar*>(v1_)->sameAs(static_cast<NamedScalar*>(v2_));
+  }
+
+  ScalarCheck(Val* _v1, Val* _v2) : v1_(_v1), v2_(_v2) {
+    OptInDispatch::handle(v1_);
+  }
+
+ private:
+  Val* v1_ = nullptr;
+  Val* v2_ = nullptr;
+  bool same_ = false;
 };
+
 } // namespace
 
+Bool::Bool(const Bool* src, IrCloner* ir_cloner)
+    : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
+
 bool Bool::sameAs(const Bool* const other) const {
   if (isConst() && other->isConst())
     return *value() == *(other->value());
   return this == other;
 }
 
+Float::Float(const Float* src, IrCloner* ir_cloner)
+    : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
+
 bool Float::sameAs(const Float* const other) const {
   if (isConst() && other->isConst())
     return *value() == *(other->value());
   return this == other;
 }
 
+Half::Half(const Half* src, IrCloner* ir_cloner)
+    : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
+
 bool Half::sameAs(const Half* const other) const {
   if (isConst() && other->isConst())
     return *value() == *(other->value());
   return this == other;
 }
 
+Int::Int(const Int* src, IrCloner* ir_cloner)
+    : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
+
 bool Int::sameAs(const Int* const other) const {
   if (isConst() && other->isConst())
     return *value() == *(other->value());
@@ -91,6 +108,12 @@
   this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
+UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner)
+    : Expr(src, ir_cloner),
+      unary_op_type_(src->unary_op_type_),
+      out_(ir_cloner->clone(src->out_)),
+      in_(ir_cloner->clone(src->in_)) {}
+
 bool UnaryOp::sameAs(const UnaryOp* const other) const {
   if (this->type() != other->type())
     return false;
@@ -109,6 +132,13 @@
   this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
+BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner)
+    : Expr(src, ir_cloner),
+      binary_op_type_(src->binary_op_type_),
+      out_(ir_cloner->clone(src->out_)),
+      lhs_(ir_cloner->clone(src->lhs_)),
+      rhs_(ir_cloner->clone(src->rhs_)) {}
+
 bool BinaryOp::sameAs(const BinaryOp* other) const {
   if (getBinaryOpType() != other->getBinaryOpType())
     return false;
@@ -136,6 +166,14 @@
   this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
+TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner)
+    : Expr(src, ir_cloner),
+      ternary_op_type_(src->ternary_op_type_),
+      out_(ir_cloner->clone(src->out_)),
+      in1_(ir_cloner->clone(src->in1_)),
+      in2_(ir_cloner->clone(src->in2_)),
+      in3_(ir_cloner->clone(src->in3_)) {}
+
 bool TernaryOp::sameAs(const TernaryOp* other) const {
   if (getTernaryOpType() != other->getTernaryOpType())
     return false;
@@ -160,7 +198,7 @@
         ndims++;
 
     TORCH_INTERNAL_ASSERT(
-        ndims == (int)static_cast<TensorView*>(in_)->nDims(),
+        ndims == (int)in_->as<TensorView>()->domain()->noReductions().size(),
         "Invalid broadcast op. Non-broadcasted dims don't match from input to output.");
   } else {
     TORCH_INTERNAL_ASSERT(
@@ -177,6 +215,11 @@
   this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
+BroadcastOp::BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner)
+    : Expr(src, ir_cloner),
+      out_(ir_cloner->clone(src->out_)),
+      in_(ir_cloner->clone(src->in_)) {}
+
 bool BroadcastOp::sameAs(const BroadcastOp* const other) const {
   return other->in() == in() && other->out() == out();
 }
@@ -199,6 +242,13 @@
   this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
+ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner)
+    : Expr(src, ir_cloner),
+      reduction_op_type_(src->reduction_op_type_),
+      init_(ir_cloner->clone(src->init_)),
+      out_(ir_cloner->clone(src->out_)),
+      in_(ir_cloner->clone(src->in_)) {}
+
 bool ReductionOp::sameAs(const ReductionOp* other) const {
   return (
       this->in()->sameAs(other->in()) &&
@@ -206,6 +256,37 @@
       this->init()->sameAs(other->init()));
 }
 
+std::vector<IterDomain*> ReductionOp::getReductionDomains() const {
+  const Val* out_val = out();
+  TORCH_INTERNAL_ASSERT(
+      out_val->getValType() == ValType::TensorView ||
+          out_val->getValType() == ValType::TensorIndex,
+      "Output of reduction must be TensorView or TensorIndex");
+  // out is a TensorIndex after lowering
+  if (out_val->getValType() == ValType::TensorIndex) {
+    out_val = static_cast<const TensorIndex*>(out_val)->view();
+  }
+  auto vec_domain = out_val->as<TensorView>()->domain()->domain();
+  vec_domain.erase(
+      std::remove_if(
+          vec_domain.begin(),
+          vec_domain.end(),
+          [](IterDomain* id) { return !id->isReduction(); }),
+      vec_domain.end());
+  return vec_domain;
+}
+
+std::unordered_map<ParallelType, IterDomain*> ReductionOp::
+    getParallelReductionDomains() const {
+  std::unordered_map<ParallelType, IterDomain*> parallel_domains;
+  for (auto d : getReductionDomains()) {
+    if (d->isThread()) {
+      parallel_domains.insert(std::make_pair(d->parallel_method(), d));
+    }
+  }
+  return parallel_domains;
+}
+
 IterDomain::IterDomain(
     Val* _start,
     Val* _extent,
@@ -240,6 +321,15 @@
   this->name_ = fusion_->registerVal(this);
 }
 
+IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner)
+    : Val(src, ir_cloner),
+      start_(ir_cloner->clone(src->start_)),
+      extent_(ir_cloner->clone(src->extent_)),
+      parallel_method_(src->parallel_method_),
+      is_reduction_domain_(src->is_reduction_domain_),
+      is_rfactor_domain_(src->is_rfactor_domain_),
+      is_broadcast_domain_(src->is_broadcast_domain_) {}
+
 bool IterDomain::sameAs(const IterDomain* const other) const {
   bool is_same = isReduction() == other->isReduction() &&
       parallel_method() == other->parallel_method();
@@ -393,6 +483,14 @@
   this->name_ = fusion_->registerVal(this);
 }
 
+TensorDomain::TensorDomain(const TensorDomain* src, IrCloner* ir_cloner)
+    : Val(src, ir_cloner),
+      root_domain_(ir_cloner->clone(src->root_domain_)),
+      domain_(ir_cloner->clone(src->domain_)),
+      no_bcast_domain_(ir_cloner->clone(src->no_bcast_domain_)),
+      no_reduction_domain_(ir_cloner->clone(src->no_reduction_domain_)),
+      rfactor_domain_(ir_cloner->clone(src->rfactor_domain_)) {}
+
 bool TensorDomain::sameAs(const TensorDomain* const other) const {
   if (nDims() != other->nDims())
     return false;
@@ -433,6 +531,18 @@
   return no_reduction_domain_.size() != domain_.size();
 }
 
+bool TensorDomain::hasBlockReduction() const {
+  return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
+    return id->isReduction() && id->isThreadDim();
+  });
+}
+
+bool TensorDomain::hasGridReduction() const {
+  return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
+    return id->isReduction() && id->isBlockDim();
+  });
+}
+
 bool TensorDomain::hasBroadcast() const {
   return no_bcast_domain_.size() != domain_.size();
 }
@@ -444,6 +554,8 @@
 // i here is int, as we want to accept negative value and ::size_type can be a
 // uint.
 IterDomain* TensorDomain::axis(int i) const {
+  TORCH_INTERNAL_ASSERT(
+      nDims() > 0, "Tried to access an axis in a 0-dim domain");
   if (i < 0)
     i += nDims();
   TORCH_CHECK(
@@ -456,6 +568,7 @@
 }
 
 size_t TensorDomain::posOf(IterDomain* id) const {
+  TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to find an axis in a 0-dim domain");
   size_t i = 0;
   while (i < domain_.size()) {
     if (domain_[i] == id)
@@ -468,6 +581,7 @@
 // Split "axis" into 2 axes where the inner axes is size of "factor"
 // and outer axis is size axis.extent() / factor
 void TensorDomain::split(int axis_, unsigned int factor) {
+  TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim domain");
   if (axis_ < 0)
     axis_ += nDims();
 
@@ -485,6 +599,7 @@
 
 // Merge "axis" and "axis+1" into 1 dimension
 void TensorDomain::merge(int axis_o, int axis_i) {
+  TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim domain");
   if (axis_o < 0)
     axis_o += nDims();
 
@@ -519,6 +634,9 @@
 
 // Reorder axes according to map[old_pos] = new_pos
 void TensorDomain::reorder(const std::unordered_map<int, int>& old2new_) {
+  TORCH_INTERNAL_ASSERT(
+      !(nDims() == 0 && old2new_.size() > 0),
+      "Tried to reorder a 0-dim domain");
   domain_ = orderedAs(domain_, old2new_);
   resetDomains();
 }
@@ -526,6 +644,10 @@
 std::vector<IterDomain*> TensorDomain::orderedAs(
     const std::vector<IterDomain*>& dom,
     const std::unordered_map<int, int>& old2new_) {
+  TORCH_INTERNAL_ASSERT(
+      !(dom.size() == 0 && old2new_.size() > 0),
+      "Tried to reorder a 0-dim domain");
+
   // Eventhough these checks are already in TensorView, we want to redo them as
   // we can enter this function from other places, not through TensorView
 
@@ -678,6 +800,8 @@
 // pair is in order where second is the consumer of first
 std::pair<TensorDomain*, TensorDomain*> TensorDomain::rFactor(
     const std::vector<int>& axes_) {
+  TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim domain");
+
   std::vector<int> axes(axes_.size());
 
   auto ndims = nDims();
@@ -736,6 +860,13 @@
   this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
+Split::Split(const Split* src, IrCloner* ir_cloner)
+    : Expr(src, ir_cloner),
+      outer_(ir_cloner->clone(src->outer_)),
+      inner_(ir_cloner->clone(src->inner_)),
+      in_(ir_cloner->clone(src->in_)),
+      factor_(ir_cloner->clone(src->factor_)) {}
+
 bool Split::sameAs(const Split* const other) const {
   return (
       outer()->sameAs(other->outer()) && inner()->sameAs(other->inner()) &&
@@ -750,6 +881,12 @@
   this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
+Merge::Merge(const Merge* src, IrCloner* ir_cloner)
+    : Expr(src, ir_cloner),
+      out_(ir_cloner->clone(src->out_)),
+      outer_(ir_cloner->clone(src->outer_)),
+      inner_(ir_cloner->clone(src->inner_)) {}
+
 bool Merge::sameAs(const Merge* const other) const {
   return (
       out()->sameAs(other->out()) && outer()->sameAs(other->outer()) &&
@@ -775,6 +912,13 @@
     body().push_back(expr);
 }
 
+ForLoop::ForLoop(const ForLoop* src, IrCloner* ir_cloner)
+    : Expr(src, ir_cloner),
+      index_(ir_cloner->clone(src->index_)),
+      iter_domain_(ir_cloner->clone(src->iter_domain_)),
+      body_(&src->body_, ir_cloner),
+      parent_scope_(ir_cloner->clone(src->parent_scope_)) {}
+
 bool ForLoop::sameAs(const ForLoop* other) const {
   if (this->iter_domain() != other->iter_domain())
     return false;
@@ -798,6 +942,13 @@
     else_body_.push_back(expr);
 }
 
+IfThenElse::IfThenElse(const IfThenElse* src, IrCloner* ir_cloner)
+    : Expr(src, ir_cloner),
+      cond_(src->cond_),
+      body_(&src->body_, ir_cloner),
+      else_body_(&src->else_body_, ir_cloner),
+      parent_scope_(ir_cloner->clone(src->parent_scope_)) {}
+
 bool IfThenElse::sameAs(const IfThenElse* other) const {
   if (!(this->cond()->sameAs(other->cond()) &&
         this->constBody().sameAs(other->constBody()) &&
@@ -806,6 +957,11 @@
   return true;
 }
 
+TensorIndex::TensorIndex(const TensorIndex* src, IrCloner* ir_cloner)
+    : Val(src, ir_cloner),
+      view_(ir_cloner->clone(src->view_)),
+      indices_(ir_cloner->clone(src->indices_)) {}
+
 bool TensorIndex::sameAs(const TensorIndex* const other) const {
   if (nDims() != other->nDims())
     return false;
@@ -821,6 +977,8 @@
 }
 
 Val* TensorIndex::index(int i) const {
+  TORCH_INTERNAL_ASSERT(
+      nDims() > 0, "Tried to get an index of a 0-dim TensorIndex");
   if (i < 0)
     i += nDims();
   assert(i >= 0 && i < nDims());
@@ -846,6 +1004,11 @@
   this->name_ = FusionGuard::getCurFusion()->registerExpr(this);
 }
 
+Allocate::Allocate(const Allocate* src, IrCloner* ir_cloner)
+    : Expr(src, ir_cloner),
+      buffer_(ir_cloner->clone(src->buffer_)),
+      extent_(ir_cloner->clone(src->extent_)) {}
+
 DataType Allocate::buf_type() const {
   return buffer_->getDataType().value();
 }
@@ -861,6 +1024,9 @@
   return true;
 }
 
+NamedScalar::NamedScalar(const NamedScalar* src, IrCloner* ir_cloner)
+    : Val(src, ir_cloner), name_(src->name_) {}
+
 } // namespace fuser
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp
index b86501d..e7df6d0 100644
--- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp
+++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp
@@ -9,11 +9,13 @@
 
 /* ITER VISITOR */
 
-std::vector<Statement*> IterVisitor::next(Statement* statement) {
+std::vector<Statement*> IterVisitor::next(
+    Statement* statement,
+    bool respect_compute_at) {
   if (statement->isVal())
     return next(static_cast<Val*>(statement));
   else if (statement->isExpr())
-    return next(static_cast<Expr*>(statement));
+    return next(static_cast<Expr*>(statement), respect_compute_at);
   else
     TORCH_INTERNAL_ASSERT(
         false, "IterVisitor could not detect type in next_dispatch.");
@@ -26,13 +28,44 @@
   return {};
 }
 
-std::vector<Statement*> IterVisitor::next(Expr* expr) {
+std::vector<Statement*> IterVisitor::next(Expr* expr, bool respect_compute_at) {
   FusionGuard::getCurFusion()->assertInFusion(expr, "Cannot traverse expr, ");
-  return {expr->inputs().begin(), expr->inputs().end()};
+  std::vector<Statement*> next_stmts{expr->inputs().begin(),
+                                     expr->inputs().end()};
+  if (respect_compute_at) {
+    TORCH_INTERNAL_ASSERT(
+        expr->outputs().size() == 1,
+        "Expressions with multiple outputs are not supported");
+    if (expr->output(0)->getValType().value() == ValType::TensorView) {
+      auto out = expr->output(0)->as<const TensorView>();
+      // Move input TVs that are computed at this expression backward
+      // so that they are visited later. If multiple inputs are
+      // computed at, move TVs that are computed at an inner loop nest
+      // further backward.
+      std::stable_sort(
+          next_stmts.begin(),
+          next_stmts.end(),
+          [out](const Statement* stmt0, const Statement* stmt1) {
+            std::array<const Statement*, 2> inputs{stmt0, stmt1};
+            std::array<int, 2> compute_at_axes{-1, -1};
+            for (int i = 0; i < 2; ++i) {
+              if (inputs[i]->getValType().value() == ValType::TensorView) {
+                auto tv = inputs[i]->as<TensorView>();
+                if (tv->getComputeAtView() == out) {
+                  compute_at_axes[i] = tv->getRelativeComputeAtAxis();
+                }
+              }
+            }
+            return compute_at_axes[0] < compute_at_axes[1];
+          });
+    }
+  }
+  return next_stmts;
 }
 
-// Remove any stmt in stmts that is in visited
 namespace {
+
+// Remove any stmt in stmts that is in visited
 void remove_visited(
     std::vector<Statement*>& stmts,
     const std::unordered_set<Statement*>& visited) {
@@ -47,50 +80,52 @@
     to_erase.pop_back();
   }
 }
+
 } // namespace
 
 void IterVisitor::traverseFrom(
     Fusion* const fusion,
     const std::vector<Val*>& from,
-    bool traverseAllPaths) {
+    bool traverseAllPaths,
+    bool respectComputeAt) {
   FusionGuard fg(fusion);
   std::unordered_set<Statement*> visited;
   stmt_stack.clear();
   stmt_stack.emplace_back(from.rbegin(), from.rend());
+  // true when returning to a node after vistiting all its input
+  // nodes. Nodes are only visited when this is true.
+  bool all_inputs_visited = false;
 
   while (!stmt_stack.empty()) {
-    auto next_stmts = next(stmt_stack.back().back());
-
-    // Remove statements we already visited if we're not traversing all paths
-    if (!traverseAllPaths)
-      remove_visited(next_stmts, visited);
-
-    // Traverse down until we get to a leaf
-    while (!next_stmts.empty()) {
-      stmt_stack.emplace_back(next_stmts.rbegin(), next_stmts.rend());
-      next_stmts = next(stmt_stack.back().back());
-      // Remove statements we already visited if we're not traversing all paths
-      if (!traverseAllPaths)
-        remove_visited(next_stmts, visited);
-    }
-
-    // Traverse back up
-    // Mark visited
-    visited.emplace(stmt_stack.back().back());
-    // Handle
-    handle(stmt_stack.back().back());
-    // Remove
-    stmt_stack.back().pop_back();
-
-    while (!stmt_stack.empty() && stmt_stack.back().empty()) {
+    auto& current_inputs = stmt_stack.back();
+    // When current_inputs is empty, all the input nodes have been
+    // visited. Return to the output node by popping the stack. Record
+    // all inputs are visited.
+    if (current_inputs.empty()) {
       stmt_stack.pop_back();
-      if (!stmt_stack.empty()) {
-        // Mark visited
-        visited.emplace(stmt_stack.back().back());
-        // Handle
-        handle(stmt_stack.back().back());
-        // Remove
-        stmt_stack.back().pop_back();
+      all_inputs_visited = true;
+      continue;
+    }
+    const auto& stmt = current_inputs.back();
+    // Visit stmt when all_inputs_visited is true.
+    if (all_inputs_visited) {
+      // Mark visited
+      visited.insert(stmt);
+      // Handle
+      handle(stmt);
+      current_inputs.pop_back();
+      all_inputs_visited = false;
+    } else {
+      // Visit input nodes.
+      auto next_stmts = next(stmt, respectComputeAt);
+      if (!traverseAllPaths) {
+        remove_visited(next_stmts, visited);
+      }
+      if (next_stmts.empty()) {
+        all_inputs_visited = true;
+      } else {
+        stmt_stack.emplace_back(next_stmts.rbegin(), next_stmts.rend());
+        all_inputs_visited = false;
       }
     }
   }
@@ -100,7 +135,8 @@
     Fusion* const fusion,
     bool from_outputs_only,
     bool breadth_first,
-    bool traverse_all_paths) {
+    bool traverse_all_paths,
+    bool respect_compute_at) {
   FusionGuard fg(fusion);
   if (breadth_first)
     TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
@@ -109,7 +145,8 @@
     auto term_outs = IterVisitor::getTerminatingOutputs(fusion);
     std::vector<Val*> term_val_outs(term_outs.begin(), term_outs.end());
     if (!term_val_outs.empty())
-      traverseFrom(fusion, term_val_outs, traverse_all_paths);
+      traverseFrom(
+          fusion, term_val_outs, traverse_all_paths, respect_compute_at);
     return;
   }
 
@@ -120,28 +157,31 @@
       leaves.push_back(val);
 
   if (!leaves.empty())
-    traverseFrom(fusion, leaves, traverse_all_paths);
+    traverseFrom(fusion, leaves, traverse_all_paths, respect_compute_at);
 }
 
 void IterVisitor::traverse(
     Fusion* const fusion,
     bool from_outputs_only,
-    bool breadth_first) {
-  traverse_(fusion, from_outputs_only, breadth_first, false);
+    bool breadth_first,
+    bool respect_compute_at) {
+  traverse_(
+      fusion, from_outputs_only, breadth_first, false, respect_compute_at);
 }
 
 void IterVisitor::traverseAllPaths(
     Fusion* const fusion,
     bool from_outputs_only,
-    bool breadth_first) {
-  traverse_(fusion, from_outputs_only, breadth_first, true);
+    bool breadth_first,
+    bool respect_compute_at) {
+  traverse_(fusion, from_outputs_only, breadth_first, true, respect_compute_at);
 }
 
 namespace {
 
 // Expr sort will take a fusion and return a topologically sorted list of
 // expressions.
-struct Exprs : public IterVisitor {
+class Exprs : public IterVisitor {
  private:
   std::vector<Expr*> exprs;
 
@@ -161,7 +201,7 @@
 
 // Expr sort will take a fusion and return a topologically sorted list of
 // expressions.
-struct Inputs : public IterVisitor {
+class Inputs : public IterVisitor {
  private:
   std::unordered_set<Val*> inputs;
 
@@ -179,6 +219,7 @@
     return inps.inputs;
   }
 };
+
 } // namespace
 
 std::unordered_set<Val*> IterVisitor::getTerminatingOutputs(
@@ -186,10 +227,12 @@
   FusionGuard fg(fusion);
 
   std::unordered_set<Val*> used_vals;
-  for (auto expr : Exprs::getExprs(
-           fusion,
-           std::vector<Val*>(
-               fusion->outputs().begin(), fusion->outputs().end()))) {
+
+  const auto exprs = Exprs::getExprs(
+      fusion,
+      std::vector<Val*>(fusion->outputs().begin(), fusion->outputs().end()));
+
+  for (auto expr : exprs) {
     for (auto inp : expr->inputs())
       used_vals.emplace(inp);
   }
@@ -209,7 +252,7 @@
 
 namespace {
 
-struct AllVals : public IterVisitor {
+class AllVals : public IterVisitor {
  private:
   std::unordered_set<Val*> vals;
 
@@ -342,8 +385,45 @@
 /* DEPENDENCY CHECKING */
 
 namespace {
+
+// Looks for and returns all values in between dependencies and vals, including
+// them.
+struct Dependencies : public IterVisitor {
+  std::unordered_set<Val*> dependencies_;
+  std::unordered_set<Val*> vals;
+
+  std::vector<Statement*> next(Val* v) override {
+    if (dependencies_.find(v) != dependencies_.end())
+      return std::vector<Statement*>();
+    return IterVisitor::next(v);
+  }
+
+  void handle(Val* val) override {
+    vals.emplace(val);
+  }
+
+  Dependencies(
+      std::unordered_set<Val*> _dependencies,
+      const std::vector<Val*>& of)
+      : dependencies_(std::move(_dependencies)) {
+    traverseFrom(of[0]->fusion(), of, false);
+  };
+
+ public:
+  static std::unordered_set<Val*> getAllVals(
+      const std::unordered_set<Val*>& dependencies,
+      const std::vector<Val*>& of) {
+    if (of.empty())
+      return std::unordered_set<Val*>();
+
+    Dependencies deps(dependencies, of);
+    return deps.vals;
+  }
+};
+
 // Looks for and returns
-struct DependencyChains : public IterVisitor {
+class DependencyChains : public IterVisitor {
+ public:
   std::deque<std::deque<Val*>> dep_chains;
   bool is_dependency = false;
   std::unordered_set<Val*> dependencies_;
@@ -394,6 +474,7 @@
     return dp.dep_chains[0];
   }
 
+  // I don't think this is actually hooked up, but leaving for now.
   static std::deque<std::deque<Val*>> getDependencyChains(
       Val* dependency,
       Val* of) {
@@ -403,14 +484,14 @@
     return dp.dep_chains;
   }
 
-  static std::deque<std::deque<Val*>> getDependencyChainsTo(Val* dependency) {
+  static std::deque<std::deque<Val*>> getAllUseChains(Val* dependency) {
     DependencyChains dp(dependency, true);
     if (dp.dep_chains.empty())
       return std::deque<std::deque<Val*>>();
     return dp.dep_chains;
   }
 
-  static std::deque<std::deque<Val*>> getDependencyChainsTo(
+  static std::deque<std::deque<Val*>> getAllUseChains(
       const std::unordered_set<Val*>& dependencies) {
     DependencyChains dp(dependencies, true);
     if (dp.dep_chains.empty())
@@ -437,9 +518,14 @@
   return DependencyChains::getDependencyChains(dependency, of);
 }
 
-std::deque<std::deque<Val*>> DependencyCheck::getAllDependencyChainsTo(
-    Val* dependency) {
-  return DependencyChains::getDependencyChainsTo(dependency);
+std::deque<std::deque<Val*>> DependencyCheck::getAllUseChains(Val* producer) {
+  return DependencyChains::getAllUseChains(producer);
+}
+
+std::unordered_set<Val*> DependencyCheck::getAllValsBetween(
+    const std::unordered_set<Val*>& dependencies,
+    const std::vector<Val*>& of) {
+  return Dependencies::getAllVals(dependencies, of);
 }
 
 } // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h
index 1dcb6bd..117c5f1 100644
--- a/torch/csrc/jit/codegen/cuda/iter_visitor.h
+++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h
@@ -12,11 +12,11 @@
 namespace jit {
 namespace fuser {
 
-struct Statement;
-struct Val;
-struct Expr;
+class Statement;
+class Val;
+class Expr;
 
-struct Fusion;
+class Fusion;
 
 enum class ValType;
 
@@ -32,7 +32,8 @@
  * TODO: We may want to have ordering of outputs to inputs. I'm not sure why we
  * would want this, but seems like it would be a reasonable request.
  */
-struct TORCH_CUDA_API IterVisitor : public OptOutDispatch {
+class TORCH_CUDA_API IterVisitor : public OptOutDispatch {
+ public:
   virtual ~IterVisitor() = default;
 
   IterVisitor() = default;
@@ -48,23 +49,25 @@
   // These functions will start at outputs and propagate up through the DAG
   // to inputs based on depth first traversal. Next could be called on a node
   // multiple times.
-  virtual std::vector<Statement*> next(Statement* stmt);
-  virtual std::vector<Statement*> next(Expr* expr);
+  virtual std::vector<Statement*> next(
+      Statement* stmt,
+      bool respect_compute_at);
+  virtual std::vector<Statement*> next(Expr* expr, bool respect_compute_at);
   virtual std::vector<Statement*> next(Val* v);
 
   // This handle functions is called on every Statement* in topological order,
   // starting from outputs to inputs.
-  virtual void handle(Statement* s) override {
+  void handle(Statement* s) override {
     OptOutDispatch::handle(s);
   }
   // This handle functions is called on every Expr* in topological order,
   // starting from outputs to inputs.
-  virtual void handle(Expr* e) override {
+  void handle(Expr* e) override {
     OptOutDispatch::handle(e);
   }
   // This handle functions is called on every Val* in topological order,
   // starting from outputs to inputs.
-  virtual void handle(Val* v) override {
+  void handle(Val* v) override {
     OptOutDispatch::handle(v);
   }
 
@@ -79,7 +82,8 @@
       Fusion* const fusion,
       bool from_outputs_only = false,
       bool breadth_first = false,
-      bool traverse_all_paths = false);
+      bool traverse_all_paths = false,
+      bool respect_compute_at = false);
 
  public:
   // Starts at nodes provided in from, traverses from these nodes to inputs.
@@ -90,23 +94,28 @@
   void traverseFrom(
       Fusion* const fusion,
       const std::vector<Val*>& from,
-      bool traverseAllPaths = false);
+      bool traverseAllPaths = false,
+      bool respectComputeAt = false);
 
   // from_outputs_only = true start from outputs registered with fusion,
   // from_outputs_only = false start from all leaf nodes,
   // bool breadth_first = true is not implemented yet
+  // respect_compute_at = true traverse computeAt input exprs later
   void traverse(
       Fusion* const fusion,
       bool from_outputs_only = false,
-      bool breadth_first = false);
+      bool breadth_first = false,
+      bool respect_compute_at = false);
 
   // from_outputs_only = true start from outputs registered with fusion,
   // from_outputs_only = false start from all leaf nodes,
   // bool breadth_first = true is not implemented yet
+  // respect_compute_at = true traverse computeAt input exprs later
   void traverseAllPaths(
       Fusion* const fusion,
       bool from_outputs_only = false,
-      bool breadth_first = false);
+      bool breadth_first = false,
+      bool respect_compute_at = false);
 
   static std::unordered_set<Val*> getTerminatingOutputs(Fusion* const);
 
@@ -128,7 +137,8 @@
  * outputs to guarentee that we will traverse all outputs of all exprs during
  * the backward traversal.
  */
-struct TORCH_CUDA_API BackwardVisitor : public OptOutDispatch {
+class TORCH_CUDA_API BackwardVisitor : public OptOutDispatch {
+ public:
   virtual ~BackwardVisitor() = default;
 
   BackwardVisitor() = default;
@@ -187,7 +197,7 @@
       bool traverseAllPaths = false);
 };
 
-struct TORCH_CUDA_API DependencyCheck {
+class TORCH_CUDA_API DependencyCheck {
  public:
   // Returns if "dependency" is a dependency of "of".
   static bool isDependencyOf(Val* dependency, Val* of);
@@ -206,7 +216,12 @@
   // Finds all Val* paths from all leaf nodes to "dependency". Returns those
   // paths. deque[i].back() are leaf nodes, and deque[i][0] is "dependency".
   // Returns an empty deque if there are no uses of dependency found.
-  static std::deque<std::deque<Val*>> getAllDependencyChainsTo(Val* dependency);
+  static std::deque<std::deque<Val*>> getAllUseChains(Val* dependency);
+
+  // Grab all values that exist between and including provided vals
+  static std::unordered_set<Val*> getAllValsBetween(
+      const std::unordered_set<Val*>& dependencies,
+      const std::vector<Val*>& of);
 };
 
 } // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp
index 567cefa..6d7a091 100644
--- a/torch/csrc/jit/codegen/cuda/kernel.cpp
+++ b/torch/csrc/jit/codegen/cuda/kernel.cpp
@@ -9,6 +9,7 @@
 #include <torch/csrc/jit/codegen/cuda/kernel_arg.h>
 #include <torch/csrc/jit/codegen/cuda/kernel_resource_strings.h>
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
+#include <torch/csrc/jit/codegen/cuda/parser.h>
 
 #include <torch/csrc/jit/resource_guard.h>
 #include <fstream>
@@ -19,10 +20,11 @@
 namespace fuser {
 namespace cuda {
 
-constexpr auto CG_NAMESPACE = "CudaCodeGen";
-constexpr auto KERNEL_NAME = "kernel";
+constexpr auto kCgNamespace = "CudaCodeGen";
+constexpr auto kKernelName = "kernel";
 
 namespace {
+
 // See NOTE [ USE OF NVRTC AND DRIVER API ]
 static const at::cuda::NVRTC& nvrtc() {
   return at::globalContext().getNVRTC();
@@ -147,24 +149,25 @@
 
 std::pair<std::string, std::string> codeGeneration(Fusion* fusion) {
   std::stringstream str_stream;
-  str_stream << "namespace " << CG_NAMESPACE << " {\n"
+  str_stream << "namespace " << kCgNamespace << " {\n"
              << code_template_tensor_struct << "\n"
              << code_fp16_support << "\n"
              << code_random_number_gen << "\n"
              << code_helper_funcs << "\n"
-             << code_template_block_reduction << "\n";
+             << code_template_block_reduction << "\n"
+             << code_template_grid_reduction << "\n";
   std::stringstream cdg;
   GPULower gpulw(fusion);
-  gpulw.printKernel(str_stream, KERNEL_NAME);
+  gpulw.printKernel(str_stream, kKernelName);
   str_stream << "\n} // namespace";
 
-  std::string func_name = std::string(CG_NAMESPACE) + "::" + KERNEL_NAME;
+  std::string func_name = std::string(kCgNamespace) + "::" + kKernelName;
   return std::make_pair(func_name, str_stream.str());
-};
+}
 
 bool validateKernelArgTensor(
     const at::Tensor& arg,
-    const Val* const param,
+    const Val* param,
     int device_index,
     std::stringstream& msg) {
   // Arg is a tensor. Param must be a tensor too.
@@ -219,7 +222,7 @@
 
 bool validateKernelArgScalar(
     const c10::TypePtr& arg_type,
-    const Val* const param,
+    const Val* param,
     std::stringstream& msg) {
   if (!param->isScalar()) {
     msg << "Argument is a scalar, but the parameter is not.";
@@ -249,7 +252,7 @@
 
 bool validateKernelArg(
     const c10::IValue& arg,
-    const Val* const param,
+    const Val* param,
     int device_index,
     std::stringstream& msg) {
   if (arg.type()->kind() != c10::TypeKind::TensorType) {
@@ -271,7 +274,7 @@
       "Wrong number of kernel inputs.");
   for (size_t i = 0; i < inputs.size(); ++i) {
     const IValue& arg = inputs[i];
-    const Val* const param = entry.fusion_->inputs()[i];
+    const Val* param = entry.fusion_->inputs()[i];
     std::stringstream msg;
     TORCH_INTERNAL_ASSERT(
         validateKernelArg(arg, param, entry.device_, msg),
@@ -290,7 +293,7 @@
       "Wrong number of kernel outputs.");
   for (size_t i = 0; i < outputs.size(); ++i) {
     const at::Tensor& arg = outputs[i];
-    const Val* const param = entry.fusion_->outputs()[i];
+    const Val* param = entry.fusion_->outputs()[i];
     std::stringstream msg;
     TORCH_INTERNAL_ASSERT(
         validateKernelArgTensor(arg, param, entry.device_, msg),
@@ -300,6 +303,73 @@
         msg.str());
   }
 }
+
+size_t size(const dim3& d) {
+  return (size_t)d.x * (size_t)d.y * (size_t)d.z;
+}
+
+dim3 dimensionOfReductionBlock(
+    const dim3& block_dim,
+    bool x_thread,
+    bool y_thread,
+    bool z_thread) {
+  return dim3{x_thread ? block_dim.x : 1,
+              y_thread ? block_dim.y : 1,
+              z_thread ? block_dim.z : 1};
+}
+
+int sizeOfReductionBlock(
+    const dim3& block_dim,
+    bool x_thread,
+    bool y_thread,
+    bool z_thread) {
+  return size(
+      dimensionOfReductionBlock(block_dim, x_thread, y_thread, z_thread));
+}
+
+// Returns the total number of reduction segments.
+size_t numberOfReductionSegments(
+    const dim3& grid_dim,
+    bool x_block,
+    bool y_block,
+    bool z_block) {
+  return (x_block ? 1 : grid_dim.x) * (y_block ? 1 : grid_dim.y) *
+      (z_block ? 1 : grid_dim.z);
+}
+
+std::array<size_t, 2> gridReductionTempBufferSizes(CudaKernel* entry) {
+  size_t buffer_size = 0;
+  size_t sync_flag_size = 0;
+  for (auto expr : entry->fusion_->exprs(true)) {
+    if (expr->getExprType() != ExprType::ReductionOp)
+      continue;
+    ReductionOp* rop = static_cast<ReductionOp*>(expr);
+    auto domains = rop->getParallelReductionDomains();
+    bool x_block = domains.find(ParallelType::BIDx) != domains.end();
+    bool y_block = domains.find(ParallelType::BIDy) != domains.end();
+    bool z_block = domains.find(ParallelType::BIDz) != domains.end();
+    // No buffer needed unless it's a grid reduction
+    if (!x_block && !y_block && !z_block)
+      continue;
+    // Assumption here is that reduction along the block-parallel
+    // domains is done prior to this grid reduction, so those domains
+    // do not need to participate in the grid reductions
+    bool x_thread = domains.find(ParallelType::TIDx) == domains.end();
+    bool y_thread = domains.find(ParallelType::TIDy) == domains.end();
+    bool z_thread = domains.find(ParallelType::TIDz) == domains.end();
+    auto rb_size =
+        sizeOfReductionBlock(entry->block_, x_thread, y_thread, z_thread);
+    auto num_blocks = size(entry->grid_);
+    auto element_size = dataTypeSize(*(rop->out()->getDataType()));
+    auto required_temp_buffer_size = num_blocks * rb_size * element_size;
+    buffer_size = std::max(buffer_size, required_temp_buffer_size);
+    auto flag_size = sizeof(unsigned) *
+        numberOfReductionSegments(entry->grid_, x_block, y_block, z_block);
+    sync_flag_size = std::max(sync_flag_size, flag_size);
+  }
+  return {{buffer_size, sync_flag_size}};
+}
+
 } // namespace
 
 bool NaivePWKernelArgsReq::matchKernelSize(const at::ArrayRef<IValue> inputs) {
@@ -327,6 +397,16 @@
   std::tie(func_name, code) = codeGeneration(entry->fusion_.get());
 
   static int32_t compiled_kernel_id = 0;
+  // We increment the id here instead of at the end of the function to avoid
+  // error during jit-compilation that would make debug message confusing.
+  compiled_kernel_id++;
+  const char* debug_env = getenv("PYTORCH_CUDA_FUSER_DEBUG");
+  if (debug_env && atoi(debug_env)) {
+    std::cout << "\n==== codegen output for kernel: " << compiled_kernel_id
+              << " ====" << std::endl
+              << code << std::endl
+              << "====================================" << std::endl;
+  }
 
   // vvv NVRTC COMPILATION vvv
 
@@ -452,7 +532,8 @@
 void runKernel(
     CudaKernel* entry,
     const at::ArrayRef<IValue> inputs,
-    std::vector<at::Tensor> outputs) {
+    const std::vector<at::Tensor>& outputs,
+    const std::vector<int64_t>& broadcasted_shape) {
   validateKernelArgs(*entry, inputs, outputs);
 
   const auto prior_device = at::cuda::current_device();
@@ -461,10 +542,32 @@
 
   // TODO: Proper API to establish reasonable launch configurations;
   // Naive launch config;
-  size_t numel = outputs[0].numel();
+  const size_t numel = outputs[0].numel();
 
-  // TODO: we can't randomly clap down this until we got striding.
-  const auto nBlocks = ceilDiv(numel, 128 * entry->unroll_factor_);
+  int blocks = 1;
+  int thread_x = 1;
+  int thread_y = 1;
+  if (!entry->reduction_axes_.empty()) {
+    // TODO: MAJOR HACK! Expr evaluation makes launch configuration much easier
+    blocks = numel;
+    // Translated to `fcd_reduction`
+    if (entry->reduction_axes_.back() ==
+        outputs[0].dim() + ((int)entry->reduction_axes_.size()) - 1) {
+      thread_x = kFcdReductionThreadX;
+      thread_y = 1;
+    } else {
+      thread_x = kNonFcdReductionThreadX;
+      thread_y = kNonFcdReductionThreadY;
+    }
+  } else {
+    // TODO: we can't randomly clap down this until we got striding.
+    blocks = ceilDiv(numel, kPwThreadX * entry->unroll_factor_);
+    thread_x = kPwThreadX;
+    thread_y = 1;
+  }
+  const auto nBlocks = blocks;
+  const auto nThreadx = thread_x;
+  const auto nThready = thread_y;
 
   KernelArgumentHolder kernel_args;
 
@@ -473,7 +576,7 @@
   // from I/O expected by the generated CUDA kernel.
   for (auto& input : inputs) {
     if (input.isTensor()) {
-      kernel_args.push(input.toTensor(), outputs[0].sizes());
+      kernel_args.push(input.toTensor(), broadcasted_shape);
     } else {
       kernel_args.push(input);
     }
@@ -505,8 +608,8 @@
       nBlocks,
       1,
       1,
-      128,
-      1,
+      nThreadx,
+      nThready,
       1,
       0,
       stream,
@@ -522,7 +625,7 @@
 void runTestKernel(
     CudaKernel* entry,
     const at::ArrayRef<IValue> inputs,
-    std::vector<at::Tensor> outputs) {
+    const std::vector<at::Tensor>& outputs) {
   validateKernelArgs(*entry, inputs, outputs);
 
   const auto prior_device = at::cuda::current_device();
@@ -540,9 +643,6 @@
   KernelArgumentHolder kernel_args;
 
   auto exprs = entry->fusion_->exprs(true);
-  bool has_reduction = std::any_of(exprs.begin(), exprs.end(), [](Expr* expr) {
-    return expr->getExprType() == ExprType::ReductionOp;
-  });
 
   // Naive I/O setup, I'm ignoring all the potential transformation (i.e. I/O
   // allocated here from the subgraph could be, and very likely are, different
@@ -555,11 +655,7 @@
       TORCH_INTERNAL_ASSERT(
           !entry->fusion_->outputs().empty(),
           "No output found for this kernel, aborting.");
-      if (has_reduction) {
-        kernel_args.push(input.toTensor());
-      } else {
-        kernel_args.push(input.toTensor(), outputs[0].sizes());
-      }
+      kernel_args.push(input.toTensor());
     } else {
       kernel_args.push(input);
     }
@@ -585,6 +681,22 @@
     kernel_args.push(philox_engine_inputs.second);
   }
 
+  // When the kernel has global reductions, the kernel needs two
+  // additional temporary buffers, one for intermediate results and
+  // another for synchronization among thread blocks.
+  if (entry->fusion_->hasGridReduction()) {
+    auto temp_buf_type = at::kFloat;
+    auto temp_buf_sizes = gridReductionTempBufferSizes(entry);
+    auto options =
+        at::TensorOptions().dtype(temp_buf_type).device(at::kCUDA, 0);
+    at::Tensor reduction_work_buffer = at::empty(
+        {(long)(temp_buf_sizes[0] / c10::elementSize(temp_buf_type))}, options);
+    kernel_args.push(reduction_work_buffer);
+    at::Tensor sync_flags = at::zeros(
+        {(long)(temp_buf_sizes[1] / c10::elementSize(temp_buf_type))}, options);
+    kernel_args.push(sync_flags);
+  }
+
   // launch kernel;
   AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
       entry->function_,
diff --git a/torch/csrc/jit/codegen/cuda/kernel.h b/torch/csrc/jit/codegen/cuda/kernel.h
index 2b37938..67050d7 100644
--- a/torch/csrc/jit/codegen/cuda/kernel.h
+++ b/torch/csrc/jit/codegen/cuda/kernel.h
@@ -40,7 +40,7 @@
   std::vector<int> dims_;
 };
 
-struct CudaKernel {
+class CudaKernel {
  public:
   CudaKernel() {
     fusion_ = std::make_unique<Fusion>();
@@ -59,6 +59,8 @@
   CUfunction function_;
   int max_blocks_;
   int unroll_factor_ = 1;
+  // mark reduction axes;
+  std::vector<int> reduction_axes_;
 
   // WARNING:
   // Block and Grid dimension setting is here for testing purposes only
@@ -89,13 +91,14 @@
 TORCH_CUDA_API void runKernel(
     CudaKernel* entry,
     const at::ArrayRef<c10::IValue> inputs,
-    std::vector<at::Tensor> outputs);
+    const std::vector<at::Tensor>& outputs,
+    const std::vector<int64_t>& broadcasted_shape);
 
 // Facility API to run kernel in tests.
 TORCH_CUDA_API void runTestKernel(
     CudaKernel* entry,
     const at::ArrayRef<c10::IValue> inputs,
-    std::vector<at::Tensor> outputs);
+    const std::vector<at::Tensor>& outputs);
 
 } // namespace cuda
 } // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/kernel_arg.h b/torch/csrc/jit/codegen/cuda/kernel_arg.h
index 0541647..984afb0 100644
--- a/torch/csrc/jit/codegen/cuda/kernel_arg.h
+++ b/torch/csrc/jit/codegen/cuda/kernel_arg.h
@@ -20,6 +20,30 @@
   constexpr int nDims() {
     return N;
   }
+  void setSize(int i, int64_t s) {
+    size[i] = s;
+  }
+  void setStride(int i, int64_t s) {
+    stride[i] = s;
+  }
+};
+
+template <typename T>
+struct TensorArgCodegen<T, 0> {
+  T& operator[](int64_t ind) {
+    return data[ind];
+  };
+
+  T* data;
+  constexpr int nDims() {
+    return 0;
+  }
+  void setSize(int, int64_t) {
+    TORCH_INTERNAL_ASSERT(false, "Tried to set size of a 0-dim tensor");
+  }
+  void setStride(int, int64_t) {
+    TORCH_INTERNAL_ASSERT(false, "Tried to set stride of a 0-dim tensor");
+  }
 };
 
 struct ArgAbstract {
@@ -64,10 +88,10 @@
   TENSOR_TYPE instance_;
 
   void setSize(int i, int64_t size) override {
-    instance_.size[i] = size;
+    instance_.setSize(i, size);
   }
   void setStride(int i, int64_t stride) override {
-    instance_.stride[i] = stride;
+    instance_.setStride(i, stride);
   }
   void setPointer(void* ptr) override {
     instance_.data = static_cast<decltype(TENSOR_TYPE::data)>(ptr);
@@ -81,6 +105,8 @@
 template <typename T>
 TensorArgAbstract* getTensorArg(int nDims) {
   switch (nDims) {
+    case (0):
+      return new TensorArg<TensorArgCodegen<T, 0>>();
     case (1):
       return new TensorArg<TensorArgCodegen<T, 1>>();
     case (2):
diff --git a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h
index bfba2ad..0427c8f 100644
--- a/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h
+++ b/torch/csrc/jit/codegen/cuda/kernel_resource_strings.h
@@ -20,6 +20,17 @@
   int64_t size[N];
   int64_t stride[N];
 };
+
+// Specialization for 0-dim case as it does not need size and stride arrays.
+// They will be an error as well since zero-length arrays are not allowed.
+template<typename T>
+struct Tensor<T, 0> {
+  T& operator[](int64_t) {
+    return *data;
+  };
+
+  T* data;
+};
 )";
 
 // Code support for FP16 __half type and intrinsics
@@ -184,25 +195,22 @@
 // may actually be slower.
 template<bool X_REDUCE, bool Y_REDUCE, bool Z_REDUCE, typename T, typename Func>
 __inline__ __device__
-void blockReduce(T& out, const T inp_val, Func reduction_op) {
-
-  // Use worst case for memory.
-  __shared__ T shared_mem[1024];
+void blockReduce(T& out, const T inp_val, Func reduction_op, const dim3& thread_idx, const dim3& block_dim, T* shared_mem) {
 
   unsigned int reduction_size 
-    = (X_REDUCE ? blockDim.x : 1) 
-    * (Y_REDUCE ? blockDim.y : 1)
-    * (Z_REDUCE ? blockDim.z : 1);
+    = (X_REDUCE ? block_dim.x : 1)
+    * (Y_REDUCE ? block_dim.y : 1)
+    * (Z_REDUCE ? block_dim.z : 1);
 
   // If this thread will output a final result
   bool should_write = true;
 
   if (X_REDUCE)
-    should_write = should_write && threadIdx.x == 0;
+    should_write = should_write && thread_idx.x == 0;
   if (Y_REDUCE)
-    should_write = should_write && threadIdx.y == 0;
+    should_write = should_write && thread_idx.y == 0;
   if (Z_REDUCE)
-    should_write = should_write && threadIdx.z == 0;
+    should_write = should_write && thread_idx.z == 0;
 
   unsigned int reduction_stride;
   unsigned int reduction_tid;
@@ -212,23 +220,20 @@
     // Transpose Z and Y in the shared memory so Z and X dims are contiguous in smem
     reduction_stride = 1;
     linear_tid = threadIdx.y * blockDim.z * blockDim.x + threadIdx.z * blockDim.x + threadIdx.x;
-    reduction_tid
-    = threadIdx.y * blockDim.z * blockDim.x
-    + threadIdx.z              * blockDim.x
-    + threadIdx.x;
+    reduction_tid = threadIdx.z * blockDim.x + threadIdx.x;
   } else {
     // Normal reduction in order
     reduction_stride 
     = (X_REDUCE ? 1 
-    : (Y_REDUCE ? blockDim.x
-    : (Z_REDUCE ? blockDim.x * blockDim.y : 0)));
+    : (Y_REDUCE ? block_dim.x
+    : (Z_REDUCE ? block_dim.x * block_dim.y : 0)));
 
-    linear_tid = threadIdx.z * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x;
+    linear_tid = thread_idx.z * block_dim.y * block_dim.x + thread_idx.y * block_dim.x + thread_idx.x;
 
     reduction_tid
-    = ( Z_REDUCE ? threadIdx.z : 0 ) * ( Y_REDUCE ? blockDim.y : 1 ) * ( X_REDUCE ? blockDim.x : 1 )
-    + ( Y_REDUCE ? threadIdx.y : 0 )                                 * ( X_REDUCE ? blockDim.x : 1 )
-    + ( X_REDUCE ? threadIdx.x : 0 );
+    = ( Z_REDUCE ? thread_idx.z : 0 ) * ( Y_REDUCE ? block_dim.y : 1 ) * ( X_REDUCE ? block_dim.x : 1 )
+    + ( Y_REDUCE ? thread_idx.y : 0 )                                 * ( X_REDUCE ? block_dim.x : 1 )
+    + ( X_REDUCE ? thread_idx.x : 0 );
   }
 
   assert( reduction_stride != 0 );
@@ -257,7 +262,315 @@
 }
 )";
 
+/**
+  Inter-block reduction.
+
+  Function gridReduce performs point-wise reductions of scalars across thread
+  blocks. Thread blocks are disjointly partitioned into groups of thread blocks,
+  "reduction segments," that are collectively defined by boolean template
+  parameters, X_BLOCK, Y_BLOCK and Z_BLOCK. Each of X/Y/Z_BLOCK determines
+  whether thread blocks along the dimension should be grouped into the same
+  reduction segment. Cross-block reducitons are independently done within each
+  segment and generates distinctive results per segment. For instance, if all of
+  X/Y/Z_BLOCK are true, reductions will be done across all thread blocks since
+  there will be just a single segment consisting of all thread blocks. If none
+  of them are true, each thread block will become a segment by itself, so no
+  reduction will be performed.
+
+  The input scalars to reduce within each segment are a certain subset of
+  thread-private scalars provided as part of the gridReduce function parameters.
+  Boolean template parameters, X_THREAD, Y_THREAD and Z_THREAD, determine which
+  subset of the scalars should be used for inter-block reductions. Specifically,
+  all the input scalars of threads along each dimension will be used when
+  X/Y/Z_THREAD are true. Otherwise, only the value held at offset 0 of each
+  dimension will be used. Thus, for example, if all of X/Y/Z_THREAD are true,
+  the scalars of all threads in each block will participate in inter-block
+  reductions. If all of them are false, only one scalar of the thread at
+  threadIdx.x == threadIdx.y == threadIdx.z == 0 will be used. In the code
+  below, we call the subset of threads a "reduction block."
+
+  Inter-block reductions perform point-wise reductions of scalars of reduction
+  blocks within each reduction segment. More specifically, let rb be a reduction
+  block and rs be a reduction segment. Let IN(thread_idx, block_idx) denote the
+  input scalar of thread at thread_idx and block_idx. The result of each
+  reduction segment, OUT(thread_idx, block_idx_out), is defined only for each
+  thread_idx in thread block block_idx_out in the segment as follows:
+
+    OUT(thread_idx, block_idx_out) = Reduction of IN(thread_idx, block_idx) for
+  all block_idx in a reduction segment
+
+  OUT is not given for all threads that are not in block_idx_out and the
+  reduction block.
+
+  See also the function comment of gridReduce.
+*/
+static auto code_template_grid_reduction = R"(
+namespace reduction {
+
+// Utility functions
+__host__ __device__ __forceinline__ size_t size(const dim3& d) {
+  return (size_t)d.x * (size_t)d.y * (size_t)d.z;
+}
+
+__host__ __device__ __forceinline__ int isize(const dim3& d) {
+  return d.x * d.y * d.z;
+}
+
+__host__ __device__ __forceinline__ size_t offset(const dim3& pos, const dim3& dim) {
+  return (size_t)pos.x + (size_t)pos.y * (size_t)dim.x +
+      (size_t)pos.z * (size_t)dim.x * (size_t)dim.y;
+}
+
+__host__ __device__ __forceinline__ size_t ioffset(const dim3& pos, const dim3& dim) {
+  return pos.x + pos.y * dim.x + pos.z * dim.x * dim.y;
+}
+
+// Returns dim3 of each reduction segment.
+template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK>
+__host__ __device__ dim3 dimension_of_reduction_segment(const dim3& grid_dim) {
+  return dim3{X_BLOCK ? grid_dim.x : 1,
+        Y_BLOCK ? grid_dim.y : 1,
+        Z_BLOCK ? grid_dim.z : 1};
+}
+
+// Returns the number of blocks in each reduction segment.
+template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK>
+__host__ __device__ size_t size_of_reduction_segment(const dim3& grid_dim) {
+  return size(dimension_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(grid_dim));
+}
+
+// Returns the total number of reduction segments.
+template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK>
+__host__ __device__ size_t number_of_reduction_segments(const dim3& grid_dim) {
+  return (X_BLOCK ? 1: grid_dim.x) *
+      (Y_BLOCK ? 1 : grid_dim.y) *
+      (Z_BLOCK ? 1 : grid_dim.z);
+}
+
+// Returns the 1-D index of the segment of thread block of block_idx.
+template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK>
+__host__ __device__ size_t index_of_reduction_segment(const dim3& block_idx,
+                                                      const dim3& grid_dim) {
+  size_t seg_idx = 0;
+  if (!Z_BLOCK)
+    seg_idx += block_idx.z;
+  if (!Y_BLOCK)
+    seg_idx = seg_idx * grid_dim.y + block_idx.y;
+  if (!X_BLOCK)
+    seg_idx = seg_idx * grid_dim.x + block_idx.x;
+  return seg_idx;
+}
+
+// Returns the offset of thread block in its reduction segment.
+template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK>
+__host__ __device__ size_t offset_in_reduction_segment(const dim3& block_idx,
+                                                       const dim3& grid_dim) {
+  size_t offset = 0;
+  if (Z_BLOCK)
+    offset = offset * grid_dim.z + block_idx.z;
+  if (Y_BLOCK)
+    offset = offset * grid_dim.y + block_idx.y;
+  if (X_BLOCK)
+    offset = offset * grid_dim.x + block_idx.x;
+  return offset;
+}
+
+// Returns dim3 of each reduction block.
+template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD>
+__host__ __device__ dim3 dimension_of_reduction_block(const dim3& block_dim) {
+  return dim3{X_THREAD ? block_dim.x : 1,
+        Y_THREAD ? block_dim.y : 1,
+        Z_THREAD ? block_dim.z : 1};
+}
+
+// Returns the number of threads of each reduction block.
+template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD>
+__host__ __device__ int size_of_reduction_block(const dim3& block_dim) {
+  return isize(dimension_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(block_dim));
+}
+
+// Returns the linear offset of a thread in a reduction block.
+template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD>
+__host__ __device__ int offset_in_reduction_block(const dim3& thread_idx,
+                                                  const dim3& block_dim) {
+  int offset = 0;
+  if (Z_THREAD)
+    offset += thread_idx.z;
+  if (Y_THREAD)
+    offset = offset * block_dim.y + thread_idx.y;
+  if (X_THREAD)
+    offset = offset * block_dim.x + thread_idx.x;
+  return offset;
+}
+
+/** Reduces all the reduction blocks in each reduction segment.
+
+  This is only used by one thread block per reduction segment. The input
+  reduction blocks of the segment are stored in an intermediate buffer pointed
+  by parameter in. Template parameters X/Y/Z_THREAD denote how the reduction
+  block is formed.
+
+  The size of a reduction block is by definition smaller or equal to the size of
+  a thread block. We use the remaining threads to parallelize reductions across
+  reduction blocks. For example, when X/Y/Z_THREAD = {true, false, false}, we
+  use blockDim.y*blockDim.z threads for each output value. This is done first by
+  loading the input values in parallel and then by reducing across threads of
+  dimensions whose XYZ_THREAD are false.
+
+  Note that what is done here after the loading from global memory is similar to
+  what the existing blockReduce function does. The main difference is that the
+  logical block to reduce is a 2D domain where the leading dimension is the size
+  of a reduction block and the second dimension is the remaining factor in each
+  thread block. For example, when X/Y/Z_THREAD = {false, true, false}, the
+  threads are arranged as (blockDim.y, blockDim.x*blockDim.z). We do not reduce
+  along the first dimension but only the second dimension. So, it is possible to
+  reuse the existing blockReduce with dim3{blockDim.y, blockDim.x*blockDim.z}
+  instead of blockDim and with X_THREAD and Y_THREAD being false and true,
+  respectively. Also, it still need to shuffle the final output values to their
+  actual corresponding threads. In the case of when X/Y/Z_THREAD = {false, true,
+  false}, after the intra-block reduction, the final results will still be held
+  by the first blockDim.y threads, which need to be transferred to threads at
+  threadIdx.x == 0 and threadIdx.z == 0.
+*/
+template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD,
+          typename T, typename Func>
+__device__ void gridReduceLastBlock(T& out, const T *in, const size_t in_size,
+                                    Func reduction_op, T* shared_buf) {
+  const int tid = ioffset(threadIdx, blockDim);
+  const int block_size = isize(blockDim);
+  const int rblock_size = size_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(blockDim);
+
+  T inp = 0;
+  if (tid < in_size) {
+    inp = in[tid];
+  }
+  for (size_t i = tid + block_size; i < in_size; i += block_size) {
+    reduction_op(inp, in[i]);
+  }
+
+  const auto should_write = (X_THREAD || threadIdx.x == 0) &&
+      (Y_THREAD || threadIdx.y == 0) &&
+      (Z_THREAD || threadIdx.z == 0);
+
+  auto rem_size = block_size / rblock_size;
+
+  if (rem_size > 1) {
+    const int rblock_offset = tid % rblock_size;
+    const int rblock_idx = tid / rblock_size;
+    blockReduce<false, true, false>(
+        inp, inp, reduction_op,
+        dim3{(unsigned)rblock_offset, (unsigned)rblock_idx, 0},
+        dim3{(unsigned)rblock_size, (unsigned)rem_size},
+        shared_buf);
+    __syncthreads();
+    if (tid < rblock_size) {
+      shared_buf[tid] = inp;
+    }
+    __syncthreads();
+    if (should_write) {
+      inp = shared_buf[offset_in_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(
+          threadIdx, blockDim)];
+    }
+  }
+
+  if (should_write) {
+    out = inp;
+  }
+}
+
+__device__ unsigned atomic_inc(unsigned* sync_flag, unsigned max_val) {
+  return atomicInc(sync_flag, max_val - 1);
+}
+
+/** Reduces per-thread values across thread blocks.
+
+Function parameters:
+- out: Per-thread output location
+- inp_val: Per-thread input value
+- reduction_op: Scalar reduction function
+- work_buf: Temporary buffer for cross-block reductions
+- sync_flags: A vector of integers for synchronizations
+- shared_buf: Shared memory buffer for intra-block reduction
+
+Template parameters:
+- X/Y/Z_BLOCK: When true, reduces across thread blocks along the X/Y/Z
+  dimensions
+- X/Y/Z_THREAD: When true, all threads along the X/Y/Z dimensions participate in
+  the cross-block reduction. Otherwise, only threads at offset 0 do.
+- T: Scalar data type of input/output data
+- Func: Type of scalara reduction function
+
+Template parameters X/Y/Z_BLOCK define a group of thread blocks that are reduced together. We call
+it a reduction segment. Some examples are:
+
+Case 1: X/Y/Z_BLOCK == true/true/true -> There is only one segment, which includes all
+  thread blocks. It is effecively the same as the grid.
+Case 2: X/Y/Z_BLOCK == false/false/false -> Each thread block comprises an individual
+  segment by itself.
+Case 3: X/Y/Z_BLOCK == true/false/false -> Each segment contains thread blocks that have
+  the same blockDim.x. There will be blockDim.y*blockDim.z such segments.
+
+X/Y/Z_THREAD defines a sub region of a thread block that should be reduced with
+the sub regions of other thread blocks. We call it a reduction block. E.g.,
+
+Case 1: X/Y/Z_THREAD == false/false/false -> Only thread 0 participates in the
+  cross-block reductions. The reduction block is 1x1x1 with thread 0.
+Case 2: X/Y/Z_THREAD == true/true/true-> All threads in a thread block participate in
+  the cross-block reductions. The reduction block in this case is equivalent to
+  the thread block.
+
+After the function completes, only one thread block per reduction segment gets
+valid reduction results. There is no guarantee which particular block gets the
+final results.
+*/
+template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK,
+          bool X_THREAD, bool Y_THREAD, bool Z_THREAD,
+          typename T, typename Func>
+__device__ void gridReduce(T& out, T inp_val, Func reduction_op,
+                           volatile T* work_buf,
+                           unsigned* sync_flags,
+                           T* shared_buf) {
+  const auto seg_size =
+      size_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim);
+  const auto seg_idx =
+      index_of_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
+  const auto rblock_size =
+      size_of_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(blockDim);
+
+  // advance to the offset for this segment
+  work_buf += seg_idx * seg_size * rblock_size;
+
+  if ((X_THREAD || threadIdx.x == 0) &&
+      (Y_THREAD || threadIdx.y == 0) &&
+      (Z_THREAD || threadIdx.z == 0)) {
+    auto rblock_offset =
+        offset_in_reduction_segment<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
+    auto thread_offset =
+        offset_in_reduction_block<X_THREAD, Y_THREAD, Z_THREAD>(threadIdx, blockDim);
+    auto work_buf_offset = rblock_size * rblock_offset + thread_offset;
+    work_buf[work_buf_offset] = inp_val;
+  }
+  __syncthreads();
+
+  __shared__ bool last_block;
+  if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
+    __threadfence();
+    auto old = atomic_inc(&sync_flags[seg_idx], seg_size);
+    last_block = old == seg_size - 1;
+  }
+  __syncthreads();
+
+  if (last_block) {
+    // final reduction
+    gridReduceLastBlock<X_THREAD, Y_THREAD, Z_THREAD>(
+        out, (T*)work_buf, seg_size * rblock_size,
+        reduction_op, shared_buf);
+  }
+}
+} // namespace reduction
+)";
+
 } // namespace cuda
 } // namespace fuser
 } // namespace jit
-} // namespace torch
\ No newline at end of file
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp
index ef0ff56..913c39e 100644
--- a/torch/csrc/jit/codegen/cuda/lower2device.cpp
+++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp
@@ -1,12 +1,11 @@
-#include <torch/csrc/jit/codegen/cuda/arith.h>
 #include <torch/csrc/jit/codegen/cuda/fusion.h>
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
+#include <torch/csrc/jit/codegen/cuda/lower_index.h>
 #include <torch/csrc/jit/codegen/cuda/lower_loops.h>
+#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
+#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
 #include <torch/csrc/jit/codegen/cuda/lower_utils.h>
-#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
-#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
-#include <torch/csrc/jit/codegen/cuda/type.h>
+#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
 
 #include <torch/csrc/jit/codegen/cuda/lower2device.h>
 
@@ -14,418 +13,34 @@
 namespace jit {
 namespace fuser {
 
-void GPULower::pushBack(Expr* expr) {
-  if (active_scope == nullptr)
-    lowered_exprs.push_back(expr);
-  else
-    scope_utils::pushBack(active_scope, expr);
-}
-
-Statement* GPULower::mutate(Expr* expr) {
-  Statement* mutated_stmt = OptOutMutator::mutate(expr);
-  TORCH_INTERNAL_ASSERT(
-      mutated_stmt->isExpr(),
-      "Tried to generate a kernel but hit a non expression during lowering: ",
-      mutated_stmt);
-  return mutated_stmt;
-}
-
-Statement* GPULower::mutate(IfThenElse* ite) {
-  Expr* prev_scope = active_scope;
-  active_scope = ite;
-  std::vector<Expr*> mutated_exprs;
-  bool is_mutated = false;
-  for (auto expr : ite->body().exprs()) {
-    Statement* mutated_stmt = mutate(expr);
-    Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
-    mutated_exprs.push_back(mutated_expr);
-    is_mutated = is_mutated | (mutated_expr != expr);
-  }
-
-  std::vector<Expr*> mutated_else_exprs;
-  for (auto expr : ite->elseBody().exprs()) {
-    Statement* mutated_stmt = mutate(expr);
-    Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
-    mutated_else_exprs.push_back(mutated_expr);
-    is_mutated = is_mutated | (mutated_expr != expr);
-  }
-
-  if (is_mutated) {
-    ite->body().clear();
-    for (auto expr : mutated_exprs)
-      ite->body().push_back(expr);
-    ite->elseBody().clear();
-    for (auto expr : mutated_else_exprs)
-      ite->elseBody().push_back(expr);
-  }
-
-  active_scope = prev_scope;
-
-  if (is_mutated) {
-    auto new_ite = new IfThenElse(
-        ite->cond(), mutated_exprs, mutated_else_exprs, ite->parentScope());
-    return new_ite;
-  }
-
-  return ite;
-}
-
-Statement* GPULower::mutate(ForLoop* fl) {
-  Expr* prev_scope = active_scope;
-  active_scope = fl;
-  std::vector<Expr*> mutated_exprs;
-  bool is_mutated = false;
-  for (auto expr : fl->body().exprs()) {
-    Statement* mutated_stmt = mutate(expr);
-    Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
-    mutated_exprs.push_back(mutated_expr);
-    is_mutated = is_mutated | (mutated_expr != expr);
-  }
-
-  active_scope = prev_scope;
-  if (is_mutated) {
-    auto newFL = new ForLoop(
-        fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope());
-    return newFL;
-  }
-
-  return fl;
-}
-
-Statement* GPULower::mutate(UnaryOp* uop) {
-  if (!ir_utils::isTVOp(uop))
-    return OptOutMutator::mutate(uop);
-
-  TensorIndex* out = Index::getConsumerIndex(
-      ir_utils::asTV(uop->out()), scope_utils::getLoops(active_scope));
-  Val* in = uop->in();
-  if (ir_utils::isTV(in))
-    in = Index::getProducerIndex(
-        ir_utils::asTV(in),
-        ir_utils::asTV(uop->out()),
-        scope_utils::getLoops(active_scope));
-  Expr* new_op = new UnaryOp(uop->getUnaryOpType(), out, in);
-
-  return new_op;
-}
-
-Statement* GPULower::mutate(BinaryOp* bop) {
-  if (!ir_utils::isTVOp(bop))
-    return OptOutMutator::mutate(bop);
-
-  TensorIndex* out = Index::getConsumerIndex(
-      ir_utils::asTV(bop->out()), scope_utils::getLoops(active_scope));
-
-  Val* lhs = bop->lhs();
-  Val* rhs = bop->rhs();
-
-  if (ir_utils::isTV(lhs))
-    lhs = Index::getProducerIndex(
-        ir_utils::asTV(lhs),
-        ir_utils::asTV(bop->out()),
-        scope_utils::getLoops(active_scope));
-
-  if (ir_utils::isTV(rhs))
-    rhs = Index::getProducerIndex(
-        ir_utils::asTV(rhs),
-        ir_utils::asTV(bop->out()),
-        scope_utils::getLoops(active_scope));
-
-  Expr* new_op = new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs);
-
-  return new_op;
-}
-
-Statement* GPULower::mutate(TernaryOp* top) {
-  if (!ir_utils::isTVOp(top))
-    return OptOutMutator::mutate(top);
-
-  TensorIndex* out = Index::getConsumerIndex(
-      ir_utils::asTV(top->out()), scope_utils::getLoops(active_scope));
-  Val* in1 = top->in1();
-  Val* in2 = top->in2();
-  Val* in3 = top->in3();
-
-  if (ir_utils::isTV(in1))
-    in1 = Index::getProducerIndex(
-        ir_utils::asTV(in1),
-        ir_utils::asTV(top->out()),
-        scope_utils::getLoops(active_scope));
-
-  if (ir_utils::isTV(in2))
-    in2 = Index::getProducerIndex(
-        ir_utils::asTV(in2),
-        ir_utils::asTV(top->out()),
-        scope_utils::getLoops(active_scope));
-
-  if (ir_utils::isTV(in3))
-    in3 = Index::getProducerIndex(
-        ir_utils::asTV(in3),
-        ir_utils::asTV(top->out()),
-        scope_utils::getLoops(active_scope));
-
-  Expr* new_op = new TernaryOp(top->getTernaryOpType(), out, in1, in2, in3);
-
-  return new_op;
-}
-
-Statement* GPULower::mutate(ReductionOp* rop) {
-  TORCH_INTERNAL_ASSERT(
-      ir_utils::isTVOp(rop),
-      "Cannot have a reduction operation on something other than a tensor view.");
-  auto loops = scope_utils::getLoops(active_scope);
-  TORCH_INTERNAL_ASSERT(
-      std::none_of(
-          loops.begin(),
-          loops.end(),
-          [](ForLoop* fl) {
-            return fl->iter_domain()->isBlockDim() &&
-                fl->iter_domain()->isReduction();
-          }),
-      "Reduction on block axes not yet supported.");
-
-  bool is_thread_reduce =
-      std::any_of(loops.begin(), loops.end(), [](ForLoop* fl) {
-        return fl->iter_domain()->isThreadDim() &&
-            fl->iter_domain()->isReduction();
-      });
-
-  TensorIndex* out = Index::getConsumerIndex(ir_utils::asTV(rop->out()), loops);
-
-  Val* in = rop->in();
-  if (ir_utils::isTV(in))
-    in = Index::getProducerIndex(
-        ir_utils::asTV(in),
-        ir_utils::asTV(rop->out()),
-        scope_utils::getLoops(active_scope));
-
-  if (is_thread_reduce)
-    return new ReductionOp(rop->getReductionOpType(), rop->init(), out, in);
-
-  Expr* new_op = new BinaryOp(rop->getReductionOpType(), out, out, in);
-
-  return new_op;
-}
-
-Statement* GPULower::mutate(BroadcastOp* bop) {
-  if (!ir_utils::isTVOp(bop))
-    return OptOutMutator::mutate(bop);
-
-  TensorIndex* out = Index::getConsumerIndex(
-      ir_utils::asTV(bop->out()), scope_utils::getLoops(active_scope));
-  Val* in = bop->in();
-  if (ir_utils::isTV(in))
-    in = Index::getProducerIndex(
-        ir_utils::asTV(in),
-        ir_utils::asTV(bop->out()),
-        scope_utils::getLoops(active_scope));
-  Expr* new_op = new BroadcastOp(out, in);
-
-  return new_op;
-}
-
-// TensorViews are all based on symbolic sizes. When we first initialize them we
-// don't know if they're inputs or outputs which would mean that they have
-// runtime shapes. Intermediate tensors (those not going to global memory) do
-// not have this information. Since we need to have the correct information in
-// the kernel being fetched for shapes, we want to replace input and output
-// tensors to reference the runtime structure containing sizes.
-void GPULower::replaceSizes() {
-  Fusion* fusion = FusionGuard::getCurFusion();
-  // Sizes of inputs/outputs -> T.size[...]
-  std::unordered_map<Val*, Val*> size_map;
-
-  // Grab inputs and outputs
-  std::vector<TensorView*> orig_inp_out;
-  std::vector<TensorView*> all_tvs;
-
-  for (auto* val : fusion->inputs())
-    if (ir_utils::isTV(val))
-      orig_inp_out.push_back(ir_utils::asTV(val));
-
-  for (auto* val : fusion->outputs())
-    if (ir_utils::isTV(val))
-      orig_inp_out.push_back(ir_utils::asTV(val));
-
-  for (auto* val : fusion->deterministic_vals()) {
-    if (ir_utils::isTV(val)) {
-      all_tvs.push_back(ir_utils::asTV(val));
-    }
-  }
-
-  // Run through inputs and outputs first. Since we're replacing full
-  // tensorviews their names are going to change. We need  the new referenc
-  // name for the inputs/outputs. This way we won't reference the wrong tensor
-  // view. For example T0 may be translated to T9. We don't want our new
-  // variable to be T0->size[...] we need it to be T9->size[...]
-  //
-  // This could be done in a better way but changing split/merge to be a
-  // TensorDomain focused operation, then we could simply do this process on
-  // domains, instead of tensorviews. This would have the benefit that the
-  // TensorView wouldn't change, so users pointers will remain valid. The other
-  // option which seems less elegant but would also work is build up the domain
-  // on the new tensor, and then simply replace it into the original one.
-  for (TensorView* tv : orig_inp_out) {
-    // Replace the domain with one based on Ti.size[j]
-    std::vector<IterDomain*> new_domain_iters;
-    const std::vector<IterDomain*>& root_td = tv->getRootDomain();
-
-    for (decltype(root_td.size()) i{0}; i < root_td.size(); i++) {
-      // Output sizes could have reduction axes, which isn't what gets output.
-      if (root_td[i]->isReduction())
-        continue;
-
-      Val* orig_size = root_td[i]->extent();
-
-      std::stringstream ss;
-      ss << "T" << tv->name() << ".size[" << i << "]";
-      Val* new_size =
-          new NamedScalar(ss.str(), orig_size->getDataType().value());
-      if (!orig_size->sameAs(new_size) ||
-          size_map.find(orig_size) == size_map.end())
-        size_map[orig_size] = new_size;
-    }
-  }
-
-  // If we already lowered all inputs/outputs we can just return.
-  if (size_map.size() == 0)
-    return;
-
-  // Set domains to be based on symbolic sizes (i.e. Ti.size[...])
-  for (TensorView* tv : all_tvs) {
-    std::vector<IterDomain*> new_domain_iters;
-    const std::vector<IterDomain*>& root_td = tv->getRootDomain();
-
-    for (decltype(root_td.size()) i{0}; i < root_td.size(); i++) {
-      Val* new_size = root_td[i]->extent();
-      if (size_map.find(new_size) != size_map.end())
-        new_size = size_map[new_size];
-
-      new_domain_iters.push_back(new IterDomain(
-          root_td[i]->start(),
-          new_size,
-          root_td[i]->parallel_method(),
-          root_td[i]->isReduction(),
-          root_td[i]->isRFactorProduct(),
-          root_td[i]->isBroadcast()));
-    }
-
-    TensorDomain* old_domain = tv->domain();
-    TensorDomain* new_domain = new TensorDomain(new_domain_iters);
-
-    // We should just be able to replace sizes in place, but mutator is setup to
-    // do that as it set up to replace vals in Exprs, but
-    // IterDomain/TensorDomain are vals.
-
-    new_domain = TransformReplay::fullSelfReplay(new_domain, old_domain);
-
-    TORCH_INTERNAL_ASSERT(
-        old_domain->nDims() == new_domain->nDims(),
-        "Tried to set symbolic sizes through the kernel, but hit a snag, Replayed domain should be the same size as the target domain, but got ",
-        new_domain->nDims(),
-        " and ",
-        old_domain->nDims());
-    // Parallelize all iter domains
-    for (decltype(new_domain->nDims()) i{0}; i < new_domain->nDims(); i++)
-      new_domain->axis(i)->parallelize(old_domain->axis(i)->parallel_method());
-
-    tv->setDomain(new_domain);
-  }
-
-  // Adjust memory types to make sure they are valid
-  for (TensorView* tv : all_tvs) {
-    if (fusion->hasInput(tv) || fusion->hasOutput(tv)) {
-      tv->setMemoryType(MemoryType::Global);
-    } else {
-      if (tv->getMemoryType() == MemoryType::Global)
-        tv->setMemoryType(MemoryType::Local);
-    }
-  }
-}
-
-namespace {
-
-// Some pre-compilation checks
-void validate(Fusion* fusion) {
-  FusionGuard fg(fusion);
-  fusion->validateInputs();
-  for (Val* val : fusion->vals()) {
-    if (ir_utils::isTV(val)) {
-      TensorView* tv = ir_utils::asTV(val);
-      for (decltype(tv->nDims()) i{0}; i < tv->nDims(); i++) {
-        IterDomain* id = tv->getComputeAtAxis(i).first;
-
-        if (id->isBlockDim())
-          TORCH_CHECK(
-              !id->isReduction(),
-              "Parallelization across blocks on reduction axes not support at the moment but found on, ",
-              tv,
-              ".");
-      }
-    } // if ir_utils::isTV
-  } // for(Val* val : fusion->vals())
-} // validate
-
-} // namespace
-
-// Remove circular computeAt references
-void GPULower::fixComputeAt(Fusion* fusion) {
-  FusionGuard fg(fusion);
-
-  std::vector<Expr*> exprs = fusion->exprs(true);
-  std::set<TensorView*> visited;
-  for (auto it = exprs.rbegin(); it != exprs.rend(); it++) {
-    Expr* expr = *it;
-    if (!ir_utils::isTVOp(expr))
-      continue;
-
-    TensorView* tv = ir_utils::asTV(expr->output(0));
-    TensorView* ctv = tv->getComputeAtView();
-
-    if (ctv != nullptr && visited.find(ctv) == visited.end()) {
-      ctv->setComputeAt(tv, (int)tv->getThisComputeAtAxis());
-      tv->clearComputeAt();
-    }
-    visited.emplace(tv);
-  }
-}
-
 // Traverse through the fusion and print CUDA code associated with it
 std::vector<Expr*> GPULower::getLoweredExprs() {
   FusionGuard fg(fusion_);
 
-  // Compute at can have some circular references. Before we can call any tv
-  // with tv->getComputeAtAxis(i) we need to break those circular dependencies.
-  fixComputeAt(fusion_);
+  // Validate and make some minor modifications in preparation to generate code.
+  PrepareForLowering(fusion_);
 
-  validate(fusion_);
-  replaceSizes();
-  auto loop_nests =
-      LoopNestGenerator::getLoopNest(fusion_, fusion_->exprs(true));
+  auto preds = ThreadPredicates::compute(fusion_);
 
-  auto unrolled_loops = UnrollPass::runPass(fusion_, loop_nests);
+  // Run our passes keeping the lowered expressions and forwarding them.
+  auto loop_nests = LoopNestGenerator::getLoopNest(
+      fusion_, fusion_->exprs(true, false, true), preds);
 
-  // Run through loop nests and further lower the expressions
-  for (auto* expr : unrolled_loops) {
-    Statement* mutated_stmt = mutate(expr);
-    TORCH_INTERNAL_ASSERT(
-        mutated_stmt->isExpr(),
-        "Tried to generate a kernel but hit a non expression during lowering: ",
-        mutated_stmt);
-    lowered_exprs.push_back(static_cast<Expr*>(mutated_stmt));
-  }
+  auto unrolled_loops = UnrollPass::runPass(fusion_, loop_nests, preds);
 
-  return lowered_exprs;
+  auto indexed_loops = IndexLowering::getIndexedExprs(fusion_, unrolled_loops);
+
+  return indexed_loops;
 }
 
 std::ostream& GPULower::printKernel(
     std::ostream& os,
     const std::string& kernel_name) {
   FusionGuard fg(fusion_);
-  getLoweredExprs();
+  auto exprs = getLoweredExprs();
 
   IRPrinter irp(os);
-  irp.printKernel(lowered_exprs, kernel_name);
+  irp.printKernel(exprs, kernel_name);
   return os;
 }
 
diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h
index 80607de..57eadbf 100644
--- a/torch/csrc/jit/codegen/cuda/lower2device.h
+++ b/torch/csrc/jit/codegen/cuda/lower2device.h
@@ -4,53 +4,13 @@
 
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
 
-#include <map>
 #include <ostream>
-#include <stack>
 
 namespace torch {
 namespace jit {
 namespace fuser {
 
-// TODO: Change lowering so it can be called multiple times. It would be good to
-// keep user references intact so they can lower it as they describe the kernel.
-// Right now we can only lower once.
-
-struct TORCH_CUDA_API GPULower : public OptOutMutator {
- private:
-  Fusion* const fusion_;
-  std::vector<Expr*> lowered_exprs;
-  Expr* active_scope = nullptr;
-
-  // Wrap pushBack in lower_utils if active_scope is null we want it to go
-  // straight to lower_exprs
-  void pushBack(Expr*);
-
-  // Custom dispatch for Expr, want to find out of it's a TV op
-  Statement* mutate(Expr*) final;
-
-  // Open the for loop.
-  Statement* mutate(ForLoop*) final;
-
-  // Open the for loop.
-  Statement* mutate(IfThenElse*) final;
-
-  // Remake operations with TensorIndex
-  Statement* mutate(UnaryOp*) final;
-  Statement* mutate(BinaryOp*) final;
-  Statement* mutate(TernaryOp*) final;
-  Statement* mutate(ReductionOp*) final;
-  Statement* mutate(BroadcastOp*) final;
-
-  // TensorViews are all based on symbolic sizes. When we first initialize them
-  // we don't know if they're inputs or outputs which would mean that they have
-  // runtime shapes. Intermediate tensors (those not going to global memory) do
-  // not have this information. Since we need to have the correct information in
-  // the kernel being fetched for shapes, we want to replace input and output
-  // tensors to reference the runtime structure containing sizes.
-  void replaceSizes();
-  void fixComputeAt(Fusion* fusion);
-
+class TORCH_CUDA_API GPULower {
  public:
   // Init printer on ostream
   GPULower(Fusion* _fusion) : fusion_(_fusion) {}
@@ -60,6 +20,9 @@
   std::ostream& printKernel(
       std::ostream& _os,
       const std::string& kernel_name = "CUDAGeneratedKernel");
+
+ private:
+  Fusion* const fusion_ = nullptr;
 };
 
 } // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp
new file mode 100644
index 0000000..d506d64
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp
@@ -0,0 +1,225 @@
+#include <torch/csrc/jit/codegen/cuda/index_compute.h>
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
+
+#include <torch/csrc/jit/codegen/cuda/lower_index.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+void IndexLowering::pushBack(Expr* expr) {
+  if (active_scope == nullptr)
+    lowered_exprs.push_back(expr);
+  else
+    scope_utils::pushBack(active_scope, expr);
+}
+
+Statement* IndexLowering::mutate(Expr* expr) {
+  Statement* mutated_stmt = OptOutMutator::mutate(expr);
+  TORCH_INTERNAL_ASSERT(
+      mutated_stmt->isExpr(),
+      "Tried to generate a kernel but hit a non expression during lowering: ",
+      mutated_stmt);
+  return mutated_stmt;
+}
+
+Statement* IndexLowering::mutate(IfThenElse* ite) {
+  Expr* prev_scope = active_scope;
+  active_scope = ite;
+  std::vector<Expr*> mutated_exprs;
+  bool is_mutated = false;
+  for (auto expr : ite->body().exprs()) {
+    Statement* mutated_stmt = mutate(expr);
+    Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
+    mutated_exprs.push_back(mutated_expr);
+    is_mutated = is_mutated | (mutated_expr != expr);
+  }
+
+  std::vector<Expr*> mutated_else_exprs;
+  for (auto expr : ite->elseBody().exprs()) {
+    Statement* mutated_stmt = mutate(expr);
+    Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
+    mutated_else_exprs.push_back(mutated_expr);
+    is_mutated = is_mutated | (mutated_expr != expr);
+  }
+
+  if (is_mutated) {
+    ite->body().clear();
+    for (auto expr : mutated_exprs)
+      ite->body().push_back(expr);
+    ite->elseBody().clear();
+    for (auto expr : mutated_else_exprs)
+      ite->elseBody().push_back(expr);
+  }
+
+  active_scope = prev_scope;
+
+  if (is_mutated) {
+    auto new_ite = new IfThenElse(
+        ite->cond(), mutated_exprs, mutated_else_exprs, ite->parentScope());
+    return new_ite;
+  }
+
+  return ite;
+}
+
+Statement* IndexLowering::mutate(ForLoop* fl) {
+  Expr* prev_scope = active_scope;
+  active_scope = fl;
+  std::vector<Expr*> mutated_exprs;
+  bool is_mutated = false;
+  for (auto expr : fl->body().exprs()) {
+    Statement* mutated_stmt = mutate(expr);
+    Expr* mutated_expr = ir_utils::asExpr(mutated_stmt);
+    mutated_exprs.push_back(mutated_expr);
+    is_mutated = is_mutated | (mutated_expr != expr);
+  }
+
+  active_scope = prev_scope;
+  if (is_mutated) {
+    auto newFL = new ForLoop(
+        fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope());
+    return newFL;
+  }
+
+  return fl;
+}
+
+Statement* IndexLowering::mutate(UnaryOp* uop) {
+  if (!ir_utils::isTVOp(uop))
+    return OptOutMutator::mutate(uop);
+
+  TensorIndex* out = Index::getConsumerIndex(
+      ir_utils::asTV(uop->out()), scope_utils::getLoops(active_scope));
+  Val* in = uop->in();
+  if (ir_utils::isTV(in))
+    in = Index::getProducerIndex(
+        ir_utils::asTV(in),
+        ir_utils::asTV(uop->out()),
+        scope_utils::getLoops(active_scope));
+  Expr* new_op = new UnaryOp(uop->getUnaryOpType(), out, in);
+
+  return new_op;
+}
+
+Statement* IndexLowering::mutate(BinaryOp* bop) {
+  if (!ir_utils::isTVOp(bop))
+    return OptOutMutator::mutate(bop);
+
+  TensorIndex* out = Index::getConsumerIndex(
+      ir_utils::asTV(bop->out()), scope_utils::getLoops(active_scope));
+
+  Val* lhs = bop->lhs();
+  Val* rhs = bop->rhs();
+
+  if (ir_utils::isTV(lhs))
+    lhs = Index::getProducerIndex(
+        ir_utils::asTV(lhs),
+        ir_utils::asTV(bop->out()),
+        scope_utils::getLoops(active_scope));
+
+  if (ir_utils::isTV(rhs))
+    rhs = Index::getProducerIndex(
+        ir_utils::asTV(rhs),
+        ir_utils::asTV(bop->out()),
+        scope_utils::getLoops(active_scope));
+
+  Expr* new_op = new BinaryOp(bop->getBinaryOpType(), out, lhs, rhs);
+
+  return new_op;
+}
+
+Statement* IndexLowering::mutate(TernaryOp* top) {
+  if (!ir_utils::isTVOp(top))
+    return OptOutMutator::mutate(top);
+
+  TensorIndex* out = Index::getConsumerIndex(
+      ir_utils::asTV(top->out()), scope_utils::getLoops(active_scope));
+  Val* in1 = top->in1();
+  Val* in2 = top->in2();
+  Val* in3 = top->in3();
+
+  if (ir_utils::isTV(in1))
+    in1 = Index::getProducerIndex(
+        ir_utils::asTV(in1),
+        ir_utils::asTV(top->out()),
+        scope_utils::getLoops(active_scope));
+
+  if (ir_utils::isTV(in2))
+    in2 = Index::getProducerIndex(
+        ir_utils::asTV(in2),
+        ir_utils::asTV(top->out()),
+        scope_utils::getLoops(active_scope));
+
+  if (ir_utils::isTV(in3))
+    in3 = Index::getProducerIndex(
+        ir_utils::asTV(in3),
+        ir_utils::asTV(top->out()),
+        scope_utils::getLoops(active_scope));
+
+  Expr* new_op = new TernaryOp(top->getTernaryOpType(), out, in1, in2, in3);
+
+  return new_op;
+}
+
+Statement* IndexLowering::mutate(ReductionOp* rop) {
+  TORCH_INTERNAL_ASSERT(
+      ir_utils::isTVOp(rop),
+      "Cannot have a reduction operation on something other than a tensor view.");
+  auto loops = scope_utils::getLoops(active_scope);
+
+  bool is_private_reduce =
+      std::none_of(loops.begin(), loops.end(), [](ForLoop* fl) {
+        return fl->iter_domain()->isThread() &&
+            fl->iter_domain()->isReduction();
+      });
+
+  TensorIndex* out = Index::getConsumerIndex(ir_utils::asTV(rop->out()), loops);
+
+  Val* in = rop->in();
+  if (ir_utils::isTV(in))
+    in = Index::getProducerIndex(
+        ir_utils::asTV(in),
+        ir_utils::asTV(rop->out()),
+        scope_utils::getLoops(active_scope));
+
+  if (!is_private_reduce)
+    return new ReductionOp(rop->getReductionOpType(), rop->init(), out, in);
+
+  Expr* new_op = new BinaryOp(rop->getReductionOpType(), out, out, in);
+
+  return new_op;
+}
+
+Statement* IndexLowering::mutate(BroadcastOp* bop) {
+  if (!ir_utils::isTVOp(bop))
+    return OptOutMutator::mutate(bop);
+
+  TensorIndex* out = Index::getConsumerIndex(
+      ir_utils::asTV(bop->out()), scope_utils::getLoops(active_scope));
+  Val* in = bop->in();
+  if (ir_utils::isTV(in))
+    in = Index::getProducerIndex(
+        ir_utils::asTV(in),
+        ir_utils::asTV(bop->out()),
+        scope_utils::getLoops(active_scope));
+  Expr* new_op = new BroadcastOp(out, in);
+
+  return new_op;
+}
+
+void IndexLowering::generate(const std::vector<Expr*>& exprs) {
+  // Run through loop nests and further lower the expressions
+  for (auto* expr : exprs) {
+    Statement* mutated_stmt = mutate(expr);
+    TORCH_INTERNAL_ASSERT(
+        mutated_stmt->isExpr(),
+        "Tried to generate a kernel but hit a non expression during lowering: ",
+        mutated_stmt);
+    lowered_exprs.push_back(static_cast<Expr*>(mutated_stmt));
+  }
+}
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_index.h b/torch/csrc/jit/codegen/cuda/lower_index.h
new file mode 100644
index 0000000..564cc49
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_index.h
@@ -0,0 +1,52 @@
+#pragma once
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+
+#include <vector>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+class TORCH_CUDA_API IndexLowering : public OptOutMutator {
+ private:
+  std::vector<Expr*> lowered_exprs;
+  Expr* active_scope = nullptr;
+
+  // Wrap pushBack in lower_utils if active_scope is null we want it to go
+  // straight to lower_exprs
+  void pushBack(Expr*);
+
+  // Custom dispatch for Expr, want to find out of it's a TV op
+  Statement* mutate(Expr*) final;
+
+  // Open the for loop.
+  Statement* mutate(ForLoop*) final;
+
+  // Open the for loop.
+  Statement* mutate(IfThenElse*) final;
+
+  // Remake operations with TensorIndex
+  Statement* mutate(UnaryOp*) final;
+  Statement* mutate(BinaryOp*) final;
+  Statement* mutate(TernaryOp*) final;
+  Statement* mutate(ReductionOp*) final;
+  Statement* mutate(BroadcastOp*) final;
+  void generate(const std::vector<Expr*>& exprs);
+
+ public:
+  static std::vector<Expr*> getIndexedExprs(
+      Fusion* fusion,
+      std::vector<Expr*> incoming_exprs) {
+    FusionGuard fg(fusion);
+    IndexLowering il;
+    il.generate(incoming_exprs);
+    return il.lowered_exprs;
+  }
+};
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp
index acd48b1..93f2613 100644
--- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp
+++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp
@@ -2,304 +2,89 @@
 #include <torch/csrc/jit/codegen/cuda/arith.h>
 #include <torch/csrc/jit/codegen/cuda/lower_utils.h>
 
-#include <torch/csrc/jit/codegen/cuda/index_compute.h>
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
 #include <torch/csrc/jit/codegen/cuda/transform_replay.h>
 
 namespace torch {
 namespace jit {
 namespace fuser {
 
-bool operator==(
-    const std::pair<IterDomain*, TensorView*>& p1,
-    const std::pair<IterDomain*, TensorView*>& p2) {
-  return p1.first->sameAs(p2.first) && p1.second == p2.second;
-}
-
-// all the way in the loop nest, grab predicate
-/*
-for( i : ceil(I/4) ) {
-  for( j : ceil(J/128) ) {
-
-    if( i * 4 + 3 < I && j * 128 + 127 < J ){
-      for( k : 4)
-        for( l : 128 )
-          T0[ ( i * 4 + k ) * J + j * 128 + l ] = …
-    } else {
-      for( k : 4 )
-        for( l : 128 )
-          if( i * 4 + k < I && j * 128 + l < J)
-             T0[ ( i * 4 + k ) * J + j * 128 + l ] = …
-    }
-
-  }
-}
-*/
-
-// Custom dispatch for Expr, want to find out of it's a TV op
-void UnrollPass::handle(Expr* expr) {
-  OptOutDispatch::handle(expr);
-}
-
-namespace {
-Bool* getPredicate(TensorView* tv, std::vector<Val*> inds_) {
-  TORCH_INTERNAL_ASSERT(
-      inds_.size() == tv->nDims() ||
-      inds_.size() == tv->domain()->noReductions().size());
-
-  std::vector<Val*> inds;
-  if (inds_.size() < tv->nDims()) {
-    size_t i_ = 0;
-    for (size_t i = 0; i < tv->nDims() && i_ < inds_.size(); i++) {
-      if (tv->axis(i)->isReduction())
-        inds.push_back(new Int(0));
-      else
-        inds.push_back(inds_[i_++]);
-    }
-  } else {
-    inds = inds_;
-  }
-  if (tv->nDims() > inds.size()) {
-    for (decltype(tv->nDims()) i{0}; i < tv->nDims(); i++) {
-      if (tv->axis(i)->isReduction())
-        inds.insert(inds.begin() + i, new Int(0));
-    }
-  }
-  std::vector<Bool*> all_preds = PredicateCompute::computePredicates(
-      new TensorIndex(tv, IndexCompute::get(tv->domain(), inds)));
-
-  std::vector<Bool*> preds;
-
-  for (Bool* pred : all_preds)
-    if (!(pred->isConst()) || !(pred->isConst() && pred->value().value()))
-      preds.push_back(pred);
-
-  if (preds.size() == 0)
-    return new Bool(true);
-
-  Val* cond = preds[0];
-
-  for (decltype(preds.size()) i{1}; i < preds.size(); i++) {
-    cond = andOp(cond, preds[i]);
-  }
-
-  TORCH_INTERNAL_ASSERT(
-      cond->getValType().value() == ValType::Scalar &&
-          cond->getDataType().value() == DataType::Bool,
-      "Error computing predicate, should be returning a Bool, but returning ",
-      cond->getDataType().value());
-
-  return static_cast<Bool*>(cond);
-}
-} // namespace
-
-// This function is one huge mess that should be refactored.
-// It handles the unrolling and predicate generation
-void UnrollPass::handle(ForLoop* fl) {
-  // Setup for loop scoping
-  for_loops.push_back(fl);
-  bool prev_unroll = within_unroll;
-  within_unroll = ir_utils::isUnrolledFor(fl) || within_unroll;
-
-  for (auto expr : fl->body().exprs()) {
-    OptOutDispatch::handle(expr);
-  }
-
-  TensorView* out = nullptr;
-  bool has_global = false;
-  for (Expr* expr : fl->body().exprs())
-    if (ir_utils::isTVOp(expr)) {
-      // Predicate determining op for unroll
-      out = ir_utils::asTV(expr->output(0));
-      has_global = has_global || out->getMemoryType() == MemoryType::Global;
-      for (auto inp : expr->inputs())
-        if (ir_utils::isTV(inp))
-          has_global = has_global ||
-              ir_utils::asTV(inp)->getMemoryType() == MemoryType::Global;
-    }
-
-  bool has_TV_op = out != nullptr;
-
-  if (within_unroll && has_TV_op && has_global) {
-    // Setup unrolled loop information:
-
-    // Indices used to detect when we can unroll a loop safely
-    // For loops outside the unroll, it's just he index, for loops inside
-    // the unroll, if it's a thread it's the thread index, otherwise it's
-    // the size-1
-    std::vector<Val*> unroll_pred_inds;
-    auto it = for_loops.begin();
-    while (it != for_loops.end()) {
-      if (ir_utils::isUnrolledFor(*it))
-        break;
-      unroll_pred_inds.push_back((*it)->index());
-      it++;
-    }
-
-    TORCH_INTERNAL_ASSERT(
-        it != for_loops.end(),
-        "Error unrolling loops, expected an unrolled loop but wasn't found.");
-
-    // This is the outer most loop that needs to be unrolled
-    ForLoop* first_unroll = *it;
-
-    // Indicies inside the unroll
-    while (it != for_loops.end()) {
-      IterDomain* id = (*it)->iter_domain();
-      if (id->isThread())
-        unroll_pred_inds.push_back((*it)->index());
-      else
-        unroll_pred_inds.push_back(sub(id->extent(), new Int(1)));
-      it++;
-    }
-
-    // Make predicates for the unrolling, and the epilogue
-    Bool* unroll_predicate = getPredicate(out, unroll_pred_inds);
-    // Make the IfThenElse controlling the unrolling
-    IfThenElse* unroll_ite =
-        new IfThenElse(unroll_predicate, {}, {}, first_unroll->parentScope());
-
-    // Get the loop nest for the unrolled path
-    ForLoop* unrolled_loop =
-        scope_utils::cloneLoopNest(first_unroll, unroll_ite);
-    unroll_ite->body().push_back(unrolled_loop);
-
-    // Loop nest for inlined path
-    ForLoop* inlined_loop =
-        scope_utils::cloneLoopNest(first_unroll, unroll_ite);
-    unroll_ite->elseBody().push_back(inlined_loop);
-
-    // Inner most inlined loop
-    Expr* inner_most_inlined_loop =
-        scope_utils::firstInnerMostScope(inlined_loop);
-
-    loop_replacement_map.insert({first_unroll, unroll_ite});
-
-    for (auto expr : fl->body().exprs()) {
-      if (!ir_utils::isTVOp(expr))
-        continue;
-
-      // Setup the expressions that need predicates around them.
-      Bool* inline_predicate = getPredicate(out, ir_utils::indices(for_loops));
-
-      IfThenElse* inline_ite =
-          new IfThenElse(inline_predicate, {expr}, {}, inner_most_inlined_loop);
-      std::unordered_map<Expr*, Expr*> inline_replacement_map;
-      inline_replacement_map.emplace(std::pair<Expr*, Expr*>(expr, inline_ite));
-      scope_utils::replaceExprsInScope(
-          inner_most_inlined_loop, inline_replacement_map);
-
-    } // for expr
-  } else { //  if(!within_unroll)
-    // modify in place, so grab a copy of exprs first.
-    std::vector<Expr*> exprs(
-        fl->body().exprs().begin(), fl->body().exprs().end());
-
-    for (auto expr : exprs) {
-      if (!ir_utils::isTVOp(expr))
-        continue;
-
-      TensorView* out = ir_utils::asTV(ir_utils::asExpr(expr)->outputs()[0]);
-
-      Bool* pred = getPredicate(out, ir_utils::indices(for_loops));
-
-      // If we need a predicate, put expr inside an if then else
-      if (!(pred->isConst()) || !(pred->isConst() && pred->value().value())) {
-        IfThenElse* inline_ite =
-            new IfThenElse(pred, {expr}, {}, for_loops.back());
-        for_loops.back()->body().insert_before(expr, inline_ite);
-        for_loops.back()->body().erase(expr);
-      }
-    }
-  } // else (if(!within_unroll))
-
-  for_loops.pop_back();
-  within_unroll = prev_unroll;
-}
-
-// Generate the loop nest structure and place it in lowered_exprs
-void UnrollPass::computeMap() {
-  FusionGuard fg(fusion_);
-
-  // Run through loop nests and further lower the expressions
-  for (auto* expr : incoming_exprs_) {
-    OptOutDispatch::handle(expr);
-  }
-}
-
-std::vector<Expr*> UnrollPass::runPass(
-    Fusion* fusion,
-    const std::vector<Expr*>& exprs) {
-  FusionGuard fg(fusion);
-  UnrollPass up(fusion, exprs);
-  up.computeMap();
-  std::vector<Expr*> mutated_exprs;
-  for (Expr* expr : exprs) {
-    if (up.loop_replacement_map.find(expr) != up.loop_replacement_map.end()) {
-      mutated_exprs.push_back(up.loop_replacement_map[expr]);
-    } else {
-      if (ir_utils::isScope(expr))
-        scope_utils::replaceExprsInScope(expr, up.loop_replacement_map);
-      mutated_exprs.push_back(expr);
-    }
-  }
-  return mutated_exprs;
-}
-
-void LoopNestGenerator::pushAlloc(TensorView* tv) {
+// Create, place, and return the allocation for tv
+Expr* LoopNestGenerator::pushAlloc(TensorView* tv) {
   TORCH_INTERNAL_ASSERT(
       !(FusionGuard::getCurFusion()->hasInput(tv) ||
         FusionGuard::getCurFusion()->hasOutput(tv)),
       "Tried to allocate an input or output tensor.");
 
-  // Compute at axis can be == tv->nDims() meaning it's inline
-  decltype(tv->nDims()) alloc_pos = 0;
-  // Do we need to close to root and alloc there?
-  bool reset = true;
-  while (alloc_pos <= tv->nDims()) {
+  // First figure out which loop nest this allocation needs to be placed in
+  // Do we need to place the allocation at the root?
+  size_t alloc_pos = 0;
+  // If there's no computeAt, then we want to be allocated at the root
+  while (alloc_pos <= tv->nDims() && tv->hasComputeAt()) {
+    // If we have a computeAt and we reached computeAt pos that's where it  goes
     if (tv->hasComputeAt() && alloc_pos == tv->getThisComputeAtAxis()) {
-      reset = false;
       break;
     }
+
+    // If we found an unroll, we want to place the allocation outside the unroll
     if (alloc_pos < tv->nDims() &&
         tv->getComputeAtAxis(alloc_pos).first->parallel_method() ==
             ParallelType::Unroll) {
-      reset = false;
       break;
     }
     alloc_pos++;
   }
-  alloc_pos = reset ? 0 : alloc_pos;
 
+  // Grab the dimensions the allocation will be based on
   std::vector<Val*> alloc_dims;
   for (auto i = alloc_pos; i < tv->nDims(); i++) {
     IterDomain* dim = tv->getComputeAtAxis(i).first;
-    if (dim->isThreadDim() || dim->isReduction())
+    if (
+        // If shared memory, don't use any IDs bound to a grid dimension
+        (tv->memory_type_ == MemoryType::Shared && dim->isBlockDim()) ||
+        // If local memory, don't use any IDs bound to a grid or block dimension
+        (tv->memory_type_ == MemoryType::Local && dim->isThread()) ||
+        // If we're reducing this dimension, don't use it in the allocation
+        // computation
+        dim->isReduction() ||
+        // If this is a broadcast dimension, don't use it in the allocation
+        // computation
+        dim->isBroadcast())
       continue;
     alloc_dims.push_back(dim->extent());
   }
 
-  Val* size;
+  // Multiply all the dimensions we're going to use for the allocation together
+  // to get the total size
+  Val* size = nullptr;
   if (alloc_dims.size() == 0) {
     size = new Int(1);
   } else {
     size = alloc_dims[0];
-    for (decltype(alloc_dims.size()) i{1}; i < alloc_dims.size(); i++) {
+    for (size_t i = 1; i < alloc_dims.size(); i++) {
       size = mul(size, alloc_dims[i]);
     }
   }
+
+  // Create the allocation node
   Allocate* alloc = new Allocate(tv, size);
 
+  // Place the allocation
   if (alloc_pos == 0) {
+    // If we allocate at the root, insert at the begining of the lowered
+    // expressions
     lowered_exprs.insert(lowered_exprs.begin(), alloc);
   } else if (alloc_pos == for_loops.size()) {
-    // inline
-    scope_utils::pushBack(for_loops[alloc_pos - 1], alloc);
+    // If we allocate inline, push to the back of the last for loop
+    scope_utils::pushBack(for_loops[for_loops.size() - 1], alloc);
   } else {
+    // Otherwise we allocate in some loop nest that is not inline, or root, so
+    // insert right before the loop we're just outside of
     scope_utils::insertBefore(
         for_loops[alloc_pos - 1], for_loops[alloc_pos], alloc);
   }
+
+  return alloc;
 }
 
 void LoopNestGenerator::openFor(std::pair<IterDomain*, TensorView*> id_pair) {
@@ -329,29 +114,39 @@
     scope_utils::pushBack(for_loops.back(), expr);
 }
 
-// Update for loop structure based on this TensorView
-void LoopNestGenerator::initReduction(TensorView* tv, Val* init_val) {
-  // This logic was taken from allocation placement, as we want to initialize
-  // the reduction buffers right after they're allocated. Compute at axis can be
-  // == tv->nDims() meaning it's inline
-  decltype(tv->nDims()) alloc_pos = 0;
-  // Do we need to close to root and alloc there?
-  bool reset = true;
-  while (alloc_pos <= tv->nDims()) {
+// Update for loop structure based on this TensorView, if there's an allocation
+// stmt, send it in so we can make sure that we insert this initialization after
+// it
+void LoopNestGenerator::initReduction(
+    TensorView* tv,
+    Val* init_val,
+    Expr* alloc_expr) {
+  // This logic was taken from pushAlloc, as the initialization loop nest will
+  // go at the same place.
+
+  // First figure out which loop nest this allocation needs to be placed in
+  // Do we need to place the allocation at the root?
+  size_t alloc_pos = 0;
+  // If there's no computeAt, then we want to be allocated at the root
+  while (alloc_pos <= tv->nDims() && tv->hasComputeAt()) {
+    // If we have a computeAt and we reached computeAt pos that's where it  goes
     if (tv->hasComputeAt() && alloc_pos == tv->getThisComputeAtAxis()) {
-      reset = false;
       break;
     }
+
+    // If we found an unroll, we want to place the allocation outside the unroll
     if (alloc_pos < tv->nDims() &&
         tv->getComputeAtAxis(alloc_pos).first->parallel_method() ==
             ParallelType::Unroll) {
-      reset = false;
       break;
     }
     alloc_pos++;
   }
-  alloc_pos = reset ? 0 : alloc_pos;
 
+  // Grab the IDs that will be involved in the initialization, ignore reduction
+  // dimensions. Everything else will be iterated over to cover the entire
+  // buffer. Index compute will ignore [block, grid]Dims depending on buffer
+  // memory location
   std::vector<IterDomain*> ids;
   for (auto i = alloc_pos; i < tv->nDims(); i++) {
     IterDomain* dim = tv->getComputeAtAxis(i).first;
@@ -360,45 +155,91 @@
     ids.push_back(dim);
   }
 
+  // Unsafe clone, as we want an exact replica of tv so we can create a UnaryOp
+  // to set the buffer to the init_val.
   auto clone = tv->unsafeClone();
+  if (thread_predicates_.find(tv) != thread_predicates_.end()) {
+    thread_predicates_[clone] = thread_predicates_[tv];
+  }
+  // The initilization stmt that will be located inside the loop nest (if there
+  // is one)
   auto init_stmt = new UnaryOp(UnaryOpType::Set, clone, init_val);
 
-  Expr* init = nullptr;
+  // Init a pointer that will become the entirety of the initialization
+  Expr* init_loop_nest = nullptr;
+
+  // The for loop that we will place the initialization within (alloc_pos - 1),
+  // if one exists. Once we're done this inner_fl will be the inner most loop
+  // containing the init_stmt
   ForLoop* inner_fl = nullptr;
   if (alloc_pos >= 1)
     inner_fl = for_loops[alloc_pos - 1];
+
+  // Work through the iter domains that we need to initialize on, outside to
+  // inside, to construct the loop nest for the initialization.
   for (auto id : ids) {
     ForLoop* new_fl;
+
     if (id->isThread()) {
+      // If based on a thread, make sure we get the named Int right
       std::stringstream ss;
       ss << id->parallel_method();
       new_fl = new ForLoop(
           new NamedScalar(ss.str(), DataType::Int), id, {}, inner_fl);
     } else {
+      // Otherwise it's just a new int-
       new_fl = new ForLoop(new Int(), id, {}, inner_fl);
     }
 
-    if (init == nullptr) {
-      init = new_fl;
-      inner_fl = new_fl;
+    if (init_loop_nest == nullptr) {
+      // If this is our first generated loop, then it will be our outer most
+      // loop nest
+      init_loop_nest = new_fl;
     } else {
+      // Otherwise place it inside the last generated loop
       inner_fl->body().push_back(new_fl);
-      inner_fl = new_fl;
     }
+    // Increment the inner most for loop
+    inner_fl = new_fl;
   }
-  if (init == nullptr) {
-    init = init_stmt;
+
+  if (init_loop_nest == nullptr) {
+    // If no loops were generated, than our init_stmt is all we need
+    init_loop_nest = init_stmt;
   } else {
+    // If there were for loops generated, place the init_stmt in the inner most
+    // for loop.
     inner_fl->body().push_back(init_stmt);
   }
 
+  // Place the allocation
   if (alloc_pos == 0) {
-    lowered_exprs.insert(lowered_exprs.begin(), init);
+    // If we allocate at the root, look for the provided allocatoin if it
+    // exists, and place after it.
+    if (alloc_expr != nullptr) {
+      bool found = false;
+      for (auto it = lowered_exprs.begin(); it != lowered_exprs.end(); it++) {
+        if ((*it) == alloc_expr) {
+          lowered_exprs.insert(it + 1, init_loop_nest);
+          found = true;
+          break;
+        }
+      }
+      TORCH_INTERNAL_ASSERT(
+          found,
+          "Could not figure out where to initialize the buffer for ",
+          tv);
+    } else {
+      lowered_exprs.insert(lowered_exprs.begin(), init_loop_nest);
+    }
   } else if (alloc_pos == for_loops.size()) {
-    scope_utils::pushBack(for_loops[alloc_pos - 1], init);
+    // If we allocate inline, push to the back of the last for loop
+    scope_utils::pushBack(for_loops[for_loops.size() - 1], init_loop_nest);
   } else {
+    // Otherwise we allocate in some loop nest that is not inline, or root, so
+    // insert right before the loop we're just outside of
     scope_utils::insertBefore(
-        for_loops[alloc_pos - 1], for_loops[alloc_pos], init);
+        for_loops[alloc_pos - 1], for_loops[alloc_pos], init_loop_nest);
   }
 }
 
@@ -437,15 +278,17 @@
     openFor(out->getComputeAtAxis((int)compute_at_scope.size()));
   }
 
+  Expr* alloc_stmt = nullptr;
   //  3) Allocate the output.
   if (!FusionGuard::getCurFusion()->hasInput(out) &&
       !FusionGuard::getCurFusion()->hasOutput(out))
-    pushAlloc(out);
+    alloc_stmt = pushAlloc(out);
 
   //  4) If this is a reduction, initialize the output (open for loops to inner
-  //  most, predicate, initialize, F predicate, close to computeAt)
+  //  most, predicate, initialize, place next after allocation if exists, close
+  //  to computeAt)
   if (out->hasReduction())
-    initReduction(out, static_cast<ReductionOp*>(expr)->init());
+    initReduction(out, static_cast<ReductionOp*>(expr)->init(), alloc_stmt);
 
   //  5) Open to inner most loop
   for (decltype(out->nDims()) i = for_loops.size(); i < out->nDims(); i++)
@@ -472,4 +315,4 @@
 
 } // namespace fuser
 } // namespace jit
-} // namespace torch
\ No newline at end of file
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.h b/torch/csrc/jit/codegen/cuda/lower_loops.h
index 008216e..f6ecb57 100644
--- a/torch/csrc/jit/codegen/cuda/lower_loops.h
+++ b/torch/csrc/jit/codegen/cuda/lower_loops.h
@@ -8,77 +8,81 @@
 namespace jit {
 namespace fuser {
 
-struct UnrollPass : public OptOutDispatch {
+/*
+ * Loop nest generator pass will get IR that looks something like:
+ * T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...* for( i :
+ * I0o{ceil(I0/4)} ) { and will generate the loop nest structure for these exprs
+ * like:
+ *
+ * for( i : I0o{ceil(I0/4)} ) {
+ *   for( j : I1o{ceil(I1/128)} ) {
+ *     for( k : I0i{4} )
+ *       for( l : I1i{128} )
+ *         T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
+ *
+ * It does not generate predicates, but it will generate allocations, and loop
+ * nests to initialize reduction buffers.
+ *
+ */
+class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch {
  private:
-  std::unordered_map<Expr*, Expr*> loop_replacement_map;
-  Fusion* fusion_;
-  const std::vector<Expr*>& incoming_exprs_;
-
-  // Keep all for loops conveniently to make unrolling easier
-  std::vector<ForLoop*> for_loops;
-
-  // keep track if we're within an unrolled loop
-  bool within_unroll = false;
-
-  // Custom dispatch for Expr, want to find out of it's a TV op
-  void handle(Expr*) final;
-
-  // Open the for loop.
-  void handle(ForLoop*) final;
-
-  UnrollPass(Fusion* _fusion, const std::vector<Expr*>& _incoming_exprs)
-      : fusion_(_fusion), incoming_exprs_(_incoming_exprs) {}
-
-  void computeMap();
-
- public:
-  static std::vector<Expr*> runPass(
-      Fusion* fusion,
-      const std::vector<Expr*>& exprs);
-};
-
-struct TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch {
- private:
+  // Lowered exprs to return
   std::vector<Expr*> lowered_exprs;
+  // Fusion pointer for convenience
   Fusion* fusion_;
 
-  // Keep all for loops conveniently to make unrolling easier
+  // Keep all for loops conveniently to make unrolling easier, basically just a
+  // stack of the active for_loops
   std::vector<ForLoop*> for_loops;
-  // computeAT scope is determined by the iterat domain, and the tensor view it
-  // belongs to (the final TensorView when following the computeAt path)
+
+  // Track the active computeAt scope, and what view we're "computeAt-ing" into
   std::vector<std::pair<IterDomain*, TensorView*>> compute_at_scope;
 
-  // Get Register allocation statement for tensorview
-  void pushAlloc(TensorView*);
+  // Predicates from ThreadPredicates that we will extend to reduction buffer
+  // initialization
+  std::unordered_map<const TensorView*, Bool*>& thread_predicates_;
 
-  // Open a new inner most for loop
+  // Create, place, and return the allocation for tv
+  Expr* pushAlloc(TensorView*);
+
+  // Open a new inner most for loop, track which TV it was constructed from
+  // according to the computeAt chain.
   void openFor(std::pair<IterDomain*, TensorView*>);
+
+  // Close the inner most for loop
   void popFor();
 
   // Wrap pushBack in lower_utils if active_scope is null we want it to go
   // straight to lower_exprs
   void pushBack(Expr*);
 
-  // Update for loop structure based on this TensorView
+  // Update for loop structure based on this TensorView, see implementation for
+  // more details
   void updateLoopNest(TensorView*);
 
-  // Update for loop structure based on this TensorView
-  void initReduction(TensorView* tv, Val* init_val);
+  // Initialize a buffer to init_val. If this buffer is in smem or registers,
+  // pass in its allocation statement so we can make sure that we insert this
+  // initialization comes after the allocation.
+  void initReduction(TensorView* tv, Val* init_val, Expr* alloc_expr = nullptr);
 
-  // Check if a TV op, generate for loop nest around it
+  // Check if expr is a TV op and handle accordingly.
   void handle(Expr*) final;
 
-  // Generate the loop nest structure and place it in lowered_exprs
+  // Run the pass and accumulate output in lowered_exprs
   void generate(const std::vector<Expr*>& exprs);
 
-  LoopNestGenerator(Fusion* _fusion) : fusion_(_fusion) {}
+  LoopNestGenerator(
+      Fusion* _fusion,
+      std::unordered_map<const TensorView*, Bool*>& _thread_predicates)
+      : fusion_(_fusion), thread_predicates_(_thread_predicates) {}
 
  public:
   static std::vector<Expr*> getLoopNest(
       Fusion* fusion,
-      std::vector<Expr*> exprs) {
+      std::vector<Expr*> exprs,
+      std::unordered_map<const TensorView*, Bool*>& thread_predicates) {
     FusionGuard fg(fusion);
-    LoopNestGenerator lng(fusion);
+    LoopNestGenerator lng(fusion, thread_predicates);
     lng.generate(exprs);
     return lng.lowered_exprs;
   }
diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp
new file mode 100644
index 0000000..a39a825
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp
@@ -0,0 +1,187 @@
+#include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
+
+#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+const static std::unordered_map<ParallelType, int> pt_to_offset{
+    {ParallelType::BIDx, 0},
+    {ParallelType::BIDy, 1},
+    {ParallelType::BIDz, 2},
+    {ParallelType::TIDx, 3},
+    {ParallelType::TIDy, 4},
+    {ParallelType::TIDz, 5}};
+
+const static std::unordered_map<int, ParallelType> offset_to_pt{
+    {0, ParallelType::BIDx},
+    {1, ParallelType::BIDy},
+    {2, ParallelType::BIDz},
+    {3, ParallelType::TIDx},
+    {4, ParallelType::TIDy},
+    {5, ParallelType::TIDz}};
+
+static constexpr int num_p_type = 6;
+
+namespace {
+
+void flip_true(std::bitset<num_p_type>& bits, const ParallelType p_type) {
+  if (pt_to_offset.find(p_type) == pt_to_offset.end()) {
+    TORCH_INTERNAL_ASSERT(false, "Could not recognize parallel type.");
+  }
+  bits[pt_to_offset.at(p_type)] = true;
+}
+
+Val* threadPredicate(int i) {
+  if (offset_to_pt.find(i) == offset_to_pt.end()) {
+    TORCH_INTERNAL_ASSERT(
+        false,
+        "Invalid int for predicate computation, should be from [0-5], but recieved, ",
+        i,
+        ".");
+  }
+  return eq(
+      new NamedScalar(stringifyThread(offset_to_pt.at(i)), DataType::Int),
+      new Int(0));
+}
+
+Bool* getThreadPredicate(std::bitset<num_p_type> bits) {
+  if (bits.none())
+    return new Bool(true);
+
+  Val* pred = nullptr;
+
+  for (int i = 0; i < num_p_type; i++) {
+    if (bits[i]) {
+      if (pred == nullptr) {
+        pred = threadPredicate(i);
+      } else {
+        pred = andOp(pred, threadPredicate(i));
+      }
+    }
+  }
+
+  // Should never be hit.
+  TORCH_INTERNAL_ASSERT(pred != nullptr);
+
+  TORCH_INTERNAL_ASSERT(
+      pred->getDataType().value() == DataType::Bool,
+      "Tried to return a predicate that is not a bool val.");
+
+  return pred->as<Bool>();
+}
+
+} // namespace
+
+std::bitset<num_p_type> ThreadPredicates::getThreadPredicates(
+    const TensorView* tv) {
+  TORCH_INTERNAL_ASSERT(
+      thread_predicates.find(tv) != thread_predicates.end(),
+      "Invalid predicate initialization, couldn't find ",
+      tv);
+  return thread_predicates[tv];
+}
+
+// Update the reduction_deps bitset based on provided Expr
+void ThreadPredicates::updateBitSet(Expr* expr) {
+  // Which predicates were set for the inputs
+  std::bitset<num_p_type> input_preds;
+
+  // Which dims are reductions in inputs
+  std::bitset<num_p_type> input_reductions;
+
+  // Which dims are bcast in inputs
+  std::bitset<num_p_type> input_bcasts;
+
+  // Run through inputs and update bitsets
+  for (const auto* inp : expr->inputs()) {
+    if (!ir_utils::isTV(inp))
+      continue;
+
+    auto tv_inp = ir_utils::asConstTV(inp);
+    TORCH_INTERNAL_ASSERT(
+        thread_predicates.find(tv_inp) != thread_predicates.end(),
+        "Thread predicate map was not initialized, couldn't find ",
+        inp);
+
+    input_preds |= thread_predicates[tv_inp];
+
+    std::bitset<num_p_type> id_reductions;
+    std::bitset<num_p_type> id_bcasts;
+    std::bitset<num_p_type> id_ptypes;
+
+    for (auto id : tv_inp->domain()->domain()) {
+      if (id->isThread()) {
+        flip_true(id_ptypes, id->parallel_method());
+        if (id->isReduction())
+          flip_true(id_reductions, id->parallel_method());
+        if (id->isBroadcast())
+          flip_true(id_bcasts, id->parallel_method());
+      }
+    }
+
+    // Validate the combination of ptypes, reductions, bcasts
+    for (size_t i = 0; i < num_p_type; i++) {
+      if (input_reductions[i]) {
+        if (id_ptypes[i]) {
+          TORCH_INTERNAL_ASSERT(
+              id_reductions[i],
+              "Mismatched parallelized reductions found on inputs of epxr: ",
+              expr);
+          TORCH_CHECK(
+              !id_bcasts[i],
+              "Invalid broadcast and reduction combination, tried to parallelize both with the same thread dim: ",
+              inp);
+        }
+      }
+    }
+
+    // Accumulate
+    input_reductions |= id_reductions;
+    input_bcasts |= id_bcasts;
+  }
+
+  // Update map for this tv, before accumulating to other inputs
+  // Add any reductions this id has to any input predicates
+  auto output_preds = input_preds | input_reductions;
+
+  // Figure out which dims bcast wants to reset
+  auto bcast_reset_map = output_preds & input_bcasts;
+
+  // Flip it to make a bit mask
+  bcast_reset_map = ~bcast_reset_map;
+
+  // Get rid of any reductions which are bcasted
+  output_preds &= bcast_reset_map;
+
+  // Run through outputs and set bitset predicates
+  for (const auto* out : expr->outputs()) {
+    if (!ir_utils::isTV(out))
+      continue;
+    thread_predicates[ir_utils::asConstTV(out)] = output_preds;
+  }
+}
+ThreadPredicates::ThreadPredicates(Fusion* _fusion) : fusion_(_fusion) {
+  for (auto inp : fusion_->inputs())
+    if (ir_utils::isTV(inp))
+      thread_predicates[ir_utils::asConstTV(inp)] = std::bitset<num_p_type>();
+}
+
+std::unordered_map<const TensorView*, Bool*> ThreadPredicates::compute(
+    Fusion* fusion) {
+  ThreadPredicates tp(fusion);
+  for (auto expr : fusion->exprs(true))
+    tp.updateBitSet(expr);
+  std::unordered_map<const TensorView*, Bool*> preds;
+  for (auto entry : tp.thread_predicates) {
+    preds[entry.first] = getThreadPredicate(entry.second);
+  }
+  return preds;
+}
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h
new file mode 100644
index 0000000..df09b72
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h
@@ -0,0 +1,44 @@
+#pragma once
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+
+#include <bitset>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+class TORCH_CUDA_API ThreadPredicates {
+ private:
+  Fusion* fusion_;
+
+  /*
+   * Map from tensorview to bit set represnting <BIDx, BIDy, BIDz, TIDx, TIDy,
+   * TIDz> If any dependency of TV had a parallelized reduction, we will track
+   * it here. This will be used for predicate generation to prevent
+   * parallelization on that axis. This is important if we have a reduction on
+   * for example TIDx, as the reduced value is only valid on threadIdx.x == 0
+   * therefore if we use that value later in the kernel we have that predicate.
+   * If we follow a reduction parallelized on TIDx with a broadcast on TIDx we
+   * no longer need the predicate and can reset the bit accordingly
+   */
+  std::unordered_map<const TensorView*, std::bitset<6>> thread_predicates;
+
+  // Update the thread_predicates bitset based on provided Expr
+  void updateBitSet(Expr*);
+
+  // Safety wrapper to access thread_predicates
+  std::bitset<6> getThreadPredicates(const TensorView*);
+
+  ThreadPredicates(Fusion* _fusion);
+
+ public:
+  // Computes any thread predicates that need to be applied when computing a
+  // TensorView.
+  static std::unordered_map<const TensorView*, Bool*> compute(Fusion* fusion);
+};
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp
new file mode 100644
index 0000000..91c0d28
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp
@@ -0,0 +1,244 @@
+#include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/index_compute.h>
+#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
+#include <torch/csrc/jit/codegen/cuda/predicate_compute.h>
+
+#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+Bool* UnrollPass::getThreadPredicate(const TensorView* tv) {
+  TORCH_INTERNAL_ASSERT(
+      thread_predicates_.find(tv) != thread_predicates_.end(),
+      "Invalid predicate initialization, couldn't find ",
+      tv);
+  return thread_predicates_[tv];
+}
+
+// Custom dispatch for Expr, want to find out of it's a TV op
+void UnrollPass::handle(Expr* expr) {
+  OptOutDispatch::handle(expr);
+}
+
+namespace {
+Bool* getPredicate(TensorView* tv, std::vector<Val*> inds_, Bool* thread_pred) {
+  TORCH_INTERNAL_ASSERT(
+      inds_.size() == tv->nDims() ||
+      inds_.size() == tv->domain()->noReductions().size());
+
+  // Do we need to adjust for reduction axes?
+  bool reductions = inds_.size() != tv->nDims();
+
+  std::vector<Val*> inds;
+  if (reductions) {
+    for (size_t ind_i = 0, tv_i = 0; tv_i < tv->nDims();) {
+      if (tv->axis(tv_i++)->isReduction()) {
+        inds.push_back(new Int(0));
+      } else {
+        TORCH_INTERNAL_ASSERT(
+            ind_i < inds_.size(), "Ran out of indices to generate predicate.");
+        inds.push_back(inds_[ind_i++]);
+      }
+    }
+  } else {
+    inds = inds_;
+  }
+
+  if (tv->nDims() > inds.size()) {
+    for (decltype(tv->nDims()) i{0}; i < tv->nDims(); i++) {
+      if (tv->axis(i)->isReduction())
+        inds.insert(inds.begin() + i, new Int(0));
+    }
+  }
+  std::vector<Bool*> all_preds = PredicateCompute::computePredicates(
+      new TensorIndex(tv, IndexCompute::get(tv->domain(), inds)));
+
+  all_preds.push_back(thread_pred);
+
+  std::vector<Bool*> preds;
+
+  for (Bool* pred : all_preds)
+    if (!(pred->isConst()) || !(pred->isConst() && pred->value().value()))
+      preds.push_back(pred);
+
+  if (preds.size() == 0)
+    return new Bool(true);
+
+  Val* cond = preds[0];
+
+  for (decltype(preds.size()) i{1}; i < preds.size(); i++) {
+    cond = andOp(cond, preds[i]);
+  }
+
+  TORCH_INTERNAL_ASSERT(
+      cond->getValType().value() == ValType::Scalar &&
+          cond->getDataType().value() == DataType::Bool,
+      "Error computing predicate, should be returning a Bool, but returning ",
+      cond->getDataType().value());
+
+  return static_cast<Bool*>(cond);
+}
+} // namespace
+
+// This function is one huge mess that should be refactored.
+// It handles the unrolling and predicate generation
+void UnrollPass::handle(ForLoop* fl) {
+  // Setup for loop scoping
+  for_loops.push_back(fl);
+  bool prev_unroll = within_unroll;
+  within_unroll = ir_utils::isUnrolledFor(fl) || within_unroll;
+
+  for (auto expr : fl->body().exprs()) {
+    OptOutDispatch::handle(expr);
+  }
+
+  TensorView* out = nullptr;
+  bool has_global = false;
+  for (Expr* expr : fl->body().exprs())
+    if (ir_utils::isTVOp(expr)) {
+      // Predicate determining op for unroll
+      out = ir_utils::asTV(expr->output(0));
+      has_global = has_global || out->getMemoryType() == MemoryType::Global;
+      for (auto inp : expr->inputs())
+        if (ir_utils::isTV(inp))
+          has_global = has_global ||
+              ir_utils::asTV(inp)->getMemoryType() == MemoryType::Global;
+    }
+
+  bool has_TV_op = out != nullptr;
+
+  if (within_unroll && has_TV_op && has_global) {
+    // Setup unrolled loop information:
+
+    // Indices used to detect when we can unroll a loop safely
+    // For loops outside the unroll, it's just he index, for loops inside
+    // the unroll, if it's a thread it's the thread index, otherwise it's
+    // the size-1
+    std::vector<Val*> unroll_pred_inds;
+    auto it = for_loops.begin();
+    while (it != for_loops.end()) {
+      if (ir_utils::isUnrolledFor(*it))
+        break;
+      unroll_pred_inds.push_back((*it)->index());
+      it++;
+    }
+
+    TORCH_INTERNAL_ASSERT(
+        it != for_loops.end(),
+        "Error unrolling loops, expected an unrolled loop but wasn't found.");
+
+    // This is the outer most loop that needs to be unrolled
+    ForLoop* first_unroll = *it;
+
+    // Indicies inside the unroll
+    while (it != for_loops.end()) {
+      IterDomain* id = (*it)->iter_domain();
+      if (id->isThread())
+        unroll_pred_inds.push_back((*it)->index());
+      else
+        unroll_pred_inds.push_back(sub(id->extent(), new Int(1)));
+      it++;
+    }
+
+    // Make predicates for the unrolling, and the epilogue
+    Bool* unroll_predicate =
+        getPredicate(out, unroll_pred_inds, getThreadPredicate(out));
+    // Make the IfThenElse controlling the unrolling
+    IfThenElse* unroll_ite =
+        new IfThenElse(unroll_predicate, {}, {}, first_unroll->parentScope());
+
+    // Get the loop nest for the unrolled path
+    ForLoop* unrolled_loop =
+        scope_utils::cloneLoopNest(first_unroll, unroll_ite);
+    unroll_ite->body().push_back(unrolled_loop);
+
+    // Loop nest for inlined path
+    ForLoop* inlined_loop =
+        scope_utils::cloneLoopNest(first_unroll, unroll_ite);
+    unroll_ite->elseBody().push_back(inlined_loop);
+
+    // Inner most inlined loop
+    Expr* inner_most_inlined_loop =
+        scope_utils::firstInnerMostScope(inlined_loop);
+
+    loop_replacement_map.insert({first_unroll, unroll_ite});
+
+    for (auto expr : fl->body().exprs()) {
+      if (!ir_utils::isTVOp(expr))
+        continue;
+
+      // Setup the expressions that need predicates around them.
+      Bool* inline_predicate = getPredicate(
+          out, ir_utils::indices(for_loops), getThreadPredicate(out));
+
+      IfThenElse* inline_ite =
+          new IfThenElse(inline_predicate, {expr}, {}, inner_most_inlined_loop);
+      std::unordered_map<Expr*, Expr*> inline_replacement_map;
+      inline_replacement_map.emplace(std::pair<Expr*, Expr*>(expr, inline_ite));
+      scope_utils::replaceExprsInScope(
+          inner_most_inlined_loop, inline_replacement_map);
+
+    } // for expr
+  } else { //  if(!within_unroll)
+    // modify in place, so grab a copy of exprs first.
+    const std::vector<Expr*> exprs = fl->body().exprs();
+
+    for (auto expr : exprs) {
+      if (!ir_utils::isTVOp(expr))
+        continue;
+
+      TensorView* out = ir_utils::asTV(ir_utils::asExpr(expr)->outputs()[0]);
+
+      Bool* pred = getPredicate(
+          out, ir_utils::indices(for_loops), getThreadPredicate(out));
+
+      // If we need a predicate, put expr inside an if then else
+      if (!(pred->isConst()) || !(pred->isConst() && pred->value().value())) {
+        IfThenElse* inline_ite =
+            new IfThenElse(pred, {expr}, {}, for_loops.back());
+        for_loops.back()->body().insert_before(expr, inline_ite);
+        for_loops.back()->body().erase(expr);
+      }
+    }
+  } // else (if(!within_unroll))
+
+  for_loops.pop_back();
+  within_unroll = prev_unroll;
+}
+
+// Generate the loop nest structure and place it in lowered_exprs
+void UnrollPass::computeMap() {
+  FusionGuard fg(fusion_);
+
+  // Run through loop nests and further lower the expressions
+  for (auto* expr : incoming_exprs_) {
+    OptOutDispatch::handle(expr);
+  }
+}
+
+std::vector<Expr*> UnrollPass::runPass(
+    Fusion* fusion,
+    const std::vector<Expr*>& exprs,
+    std::unordered_map<const TensorView*, Bool*>& thread_predicates) {
+  FusionGuard fg(fusion);
+  UnrollPass up(fusion, exprs, thread_predicates);
+  up.computeMap();
+  std::vector<Expr*> mutated_exprs;
+  for (Expr* expr : exprs) {
+    if (up.loop_replacement_map.find(expr) != up.loop_replacement_map.end()) {
+      mutated_exprs.push_back(up.loop_replacement_map[expr]);
+    } else {
+      if (ir_utils::isScope(expr))
+        scope_utils::replaceExprsInScope(expr, up.loop_replacement_map);
+      mutated_exprs.push_back(expr);
+    }
+  }
+  return mutated_exprs;
+}
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
\ No newline at end of file
diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.h b/torch/csrc/jit/codegen/cuda/lower_unroll.h
new file mode 100644
index 0000000..d421676
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_unroll.h
@@ -0,0 +1,102 @@
+#pragma once
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+#include <torch/csrc/jit/codegen/cuda/dispatch.h>
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
+
+#include <bitset>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+/*
+ * A bit deceptively: UnrollPass adds all predicates, so it needs to be run even
+ * if we don't unroll any loops.
+ *
+ * Unrolling pass will get IR that looks something like:
+ * for( i : I0o{ceil(I0/4)} ) {
+ *   for( j : I1o{ceil(I1/128)} ) {
+ *     for( k : I0i{4} )
+ *       for( l : I1i{128} )
+ *         T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
+ *
+ * And it will return the following:
+ * for( i : I0o{ceil(I0/4)} ) {
+ *   for( j : I1o{ceil(I1/128)} ) {
+ *
+ *     if( i * 4 + 3 < I && j * 128 + 127 < J ){
+ *       for( k : I0i{4} )
+ *         for( l : I1i{128} )
+ *           T0[ ( i * 4 + k ) * J + j * 128 + l ] = …
+ *     } else {
+ *       for( k : I0i{4} )
+ *         for( l : I1i{128} )
+ *           if( i * 4 + k < I && j * 128 + l < J)
+ *              T0[ ( i * 4 + k ) * J + j * 128 + l ] = …
+ *     }
+ *
+ *   }
+ * }
+ *
+ * As can be seen it generates two sets of loops for I0i{4} and I1i{128}. The
+ * first set is protected by a predicate that makes sure there's a full internal
+ * tile we can iterate over. This way we remove the predicate nested in the
+ * inner most loop. There's of course a second set of loops, which has a
+ * predicate still in the inner most loop, making sure that we cover edges and
+ * corners.
+ */
+
+class TORCH_CUDA_API UnrollPass : public OptOutDispatch {
+ private:
+  // Wrapper to access thread_predicates_
+  Bool* getThreadPredicate(const TensorView*);
+
+  // We will track which loops in the incomming IR will be replaced and by what
+  std::unordered_map<Expr*, Expr*> loop_replacement_map;
+  // Hold on to a reference to the fusion for convenience
+  Fusion* fusion_;
+  // Hold on to the incoming exprs, but don't modify them. We don't set the
+  // Expr* to be const as Exprs' are const by virtue of their interface design
+  const std::vector<Expr*>& incoming_exprs_;
+
+  // Keep all for loops conveniently to make unrolling easier
+  std::vector<ForLoop*> for_loops;
+
+  // Map from TensorView
+  std::unordered_map<const TensorView*, Bool*>& thread_predicates_;
+
+  // keep track if we're within an unrolled loop
+  bool within_unroll = false;
+
+  // Custom dispatch for Expr, want to find out of it's a TV op
+  void handle(Expr*) final;
+
+  // Open the for loop.
+  void handle(ForLoop*) final;
+
+  // Constructor
+  UnrollPass(
+      Fusion* _fusion,
+      const std::vector<Expr*>& _incoming_exprs,
+      std::unordered_map<const TensorView*, Bool*>& _thread_predicates)
+      : fusion_(_fusion),
+        incoming_exprs_(_incoming_exprs),
+        thread_predicates_(_thread_predicates) {}
+
+  // Generate the for Expr replacement map
+  void computeMap();
+
+ public:
+  // Take the incoming fusion and exprs and run loop unrolling, returning the
+  // new IR.
+  static std::vector<Expr*> runPass(
+      Fusion* fusion,
+      const std::vector<Expr*>& exprs,
+      std::unordered_map<const TensorView*, Bool*>& thread_predicates);
+};
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp
index 208916e..75dcf0e 100644
--- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp
+++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp
@@ -9,7 +9,7 @@
 // START SCOPE HELPER SYSTEMS
 namespace {
 
-struct Loops : private OptInDispatch {
+class Loops : private OptInDispatch {
  private:
   std::deque<ForLoop*> loops;
   void handle(ForLoop* fl) final {
@@ -34,7 +34,7 @@
   }
 };
 
-struct forLoopCount : private OptInDispatch {
+class forLoopCount : private OptInDispatch {
  private:
   unsigned int count_ = 0;
 
@@ -60,7 +60,7 @@
   }
 };
 
-struct scopePushBack : private OptInDispatch {
+class scopePushBack : private OptInDispatch {
  private:
   Expr* expr_;
   void handle(ForLoop* fl) final {
@@ -87,7 +87,7 @@
   }
 };
 
-struct scopeInsertBefore : private OptInDispatch {
+class scopeInsertBefore : private OptInDispatch {
  private:
   Expr* ref_;
   Expr* expr_;
@@ -115,7 +115,7 @@
   }
 };
 
-struct parentScope : private OptInDispatch {
+class parentScope : private OptInDispatch {
  private:
   Expr* parent_ = nullptr;
 
@@ -139,7 +139,7 @@
   }
 };
 
-struct scopeClearExprs : private OptInDispatch {
+class scopeClearExprs : private OptInDispatch {
  private:
   void handle(ForLoop* fl) final {
     fl->body().clear();
@@ -169,7 +169,7 @@
       "Assert Scope failed when calling a scope_util function.");
 }
 
-struct CloneLoopNest : public OptOutMutator {
+class CloneLoopNest : public OptOutMutator {
  private:
   Expr* parent_scope_ = nullptr;
   Expr* to_clone_ = nullptr;
@@ -199,50 +199,7 @@
   }
 };
 
-struct ReplaceExprsInScope : public OptOutDispatch {
- private:
-  std::unordered_map<Expr*, Expr*> replacement_map_;
-
-  void handle(Expr* expr) final {
-    OptOutDispatch::handle(expr);
-  }
-
-  void handle(ForLoop* fl) final {
-    for (Expr* expr : fl->body().exprs()) {
-      auto it = replacement_map_.find(expr);
-      if (it == replacement_map_.end()) {
-        handle(expr);
-        continue;
-      }
-      fl->body().insert_before(expr, replacement_map_[expr]);
-      fl->body().erase(expr);
-    }
-  }
-
-  void handle(IfThenElse* ite) final {
-    for (Expr* expr : ite->body().exprs()) {
-      auto it = replacement_map_.find(expr);
-      if (it == replacement_map_.end()) {
-        handle(expr);
-        continue;
-      }
-      ite->body().insert_before(expr, replacement_map_[expr]);
-      ite->body().erase(expr);
-    }
-    for (Expr* expr : ite->elseBody().exprs()) {
-      auto it = replacement_map_.find(expr);
-      if (it == replacement_map_.end()) {
-        handle(expr);
-        continue;
-      }
-      ite->elseBody().insert_before(expr, replacement_map_[expr]);
-      ite->elseBody().erase(expr);
-    }
-  }
-
-  ReplaceExprsInScope(std::unordered_map<Expr*, Expr*> _replacement_map)
-      : replacement_map_(std::move(_replacement_map)) {}
-
+class ReplaceExprsInScope : public OptOutDispatch {
  public:
   static void replace(
       Expr* scope,
@@ -250,9 +207,40 @@
     ReplaceExprsInScope reis(std::move(replacement_map));
     reis.handle(scope);
   }
+
+ private:
+  explicit ReplaceExprsInScope(std::unordered_map<Expr*, Expr*> replacement_map)
+      : replacement_map_(std::move(replacement_map)) {}
+
+  void handleScope(Scope& scope) {
+    for (size_t i = 0; i < scope.size(); ++i) {
+      const auto it = replacement_map_.find(scope[i]);
+      if (it == replacement_map_.end()) {
+        handle(scope[i]);
+        continue;
+      }
+      scope[i] = it->second;
+    }
+  }
+
+  void handle(Expr* expr) final {
+    OptOutDispatch::handle(expr);
+  }
+
+  void handle(ForLoop* fl) final {
+    handleScope(fl->body());
+  }
+
+  void handle(IfThenElse* ite) final {
+    handleScope(ite->body());
+    handleScope(ite->elseBody());
+  }
+
+ private:
+  std::unordered_map<Expr*, Expr*> replacement_map_;
 };
 
-struct FirstInnerMostScope : private OptInDispatch {
+class FirstInnerMostScope : private OptInDispatch {
  private:
   Expr* active_scope = nullptr;
 
@@ -295,6 +283,10 @@
 
     FirstInnerMostScope fims;
     Expr* inner = fims.getInner(scope);
+
+    if (inner == nullptr)
+      return scope;
+
     while (fims.getInner(inner) != nullptr)
       inner = fims.getInner(inner);
     return inner;
@@ -411,13 +403,13 @@
   return ids;
 }
 
-bool isTV(const Val* const val) {
+bool isTV(const Val* val) {
   return val->getValType().value() == ValType::TensorView;
 }
 
 // Check if we're a TensorView op that we can generate code for.
 bool isTVOp(const Expr* expr) {
-  if (expr->nOutputs() == 1 && isTV(expr->output(0)) &&
+  if (expr->outputs().size() == 1 && isTV(expr->output(0)) &&
       (expr->getExprType().value() == ExprType::BinaryOp ||
        expr->getExprType().value() == ExprType::UnaryOp ||
        expr->getExprType().value() == ExprType::TernaryOp ||
@@ -462,7 +454,7 @@
   return static_cast<ForLoop*>(expr);
 }
 
-const TensorView* asConstTV(const Val* const val) {
+const TensorView* asConstTV(const Val* val) {
   TORCH_INTERNAL_ASSERT(isTV(val));
   return static_cast<const TensorView*>(val);
 }
diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp
new file mode 100644
index 0000000..31a2fc2
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp
@@ -0,0 +1,130 @@
+#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
+#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
+#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
+#include <torch/csrc/jit/codegen/cuda/type.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+// Some pre-compilation checks
+static void IrValidate(Fusion* fusion) {
+  fusion->validateInputs();
+  for (Val* val : fusion->vals()) {
+    if (ir_utils::isTV(val)) {
+      TensorView* tv = ir_utils::asTV(val);
+      for (decltype(tv->nDims()) i{0}; i < tv->nDims(); i++) {
+        IterDomain* id = tv->getComputeAtAxis(i).first;
+
+        if (id->isBlockDim()) {
+          TORCH_CHECK(
+              !id->isBroadcast(),
+              "Parallelization across blocks on broadcast axes is not supported, but found on, ",
+              tv,
+              ".");
+        }
+      }
+    }
+  }
+}
+
+// Remove circular computeAt references
+void IrFixComputeAt(Fusion* fusion) {
+  std::vector<Expr*> exprs = fusion->exprs(true);
+  std::set<TensorView*> visited;
+  for (auto it = exprs.rbegin(); it != exprs.rend(); it++) {
+    Expr* expr = *it;
+    if (!ir_utils::isTVOp(expr))
+      continue;
+
+    TensorView* tv = ir_utils::asTV(expr->output(0));
+    TensorView* ctv = tv->getComputeAtView();
+
+    if (ctv != nullptr && visited.find(ctv) == visited.end()) {
+      ctv->setComputeAt(tv, (int)tv->getThisComputeAtAxis());
+      tv->clearComputeAt();
+    }
+    visited.emplace(tv);
+  }
+}
+
+void IrBuildSizesMap(Fusion* fusion) {
+  // Sizes of inputs/outputs -> T.size[...]
+  std::unordered_map<Val*, Val*> size_map;
+
+  // Grab inputs and outputs
+  std::vector<TensorView*> inputs_and_outputs;
+  for (auto val : fusion->inputs()) {
+    if (ir_utils::isTV(val)) {
+      inputs_and_outputs.push_back(val->as<TensorView>());
+    }
+  }
+  for (auto val : fusion->outputs()) {
+    if (ir_utils::isTV(val)) {
+      inputs_and_outputs.push_back(val->as<TensorView>());
+    }
+  }
+
+  // Run through inputs and outputs first. Since we're replacing full
+  // tensorviews their names are going to change. We need  the new referenc
+  // name for the inputs/outputs. This way we won't reference the wrong tensor
+  // view. For example T0 may be translated to T9. We don't want our new
+  // variable to be T0->size[...] we need it to be T9->size[...]
+  //
+  // This could be done in a better way but changing split/merge to be a
+  // TensorDomain focused operation, then we could simply do this process on
+  // domains, instead of tensorviews. This would have the benefit that the
+  // TensorView wouldn't change, so users pointers will remain valid. The other
+  // option which seems less elegant but would also work is build up the domain
+  // on the new tensor, and then simply replace it into the original one.
+  for (TensorView* tv : inputs_and_outputs) {
+    // Replace the domain with one based on Ti.size[j]
+    std::vector<IterDomain*> new_domain_iters;
+    const std::vector<IterDomain*>& root_td = tv->getRootDomain();
+
+    size_t dim = 0;
+    for (auto id : root_td) {
+      // Output sizes could have reduction axes, which isn't what gets output.
+      if (id->isReduction())
+        continue;
+
+      Val* orig_size = id->extent();
+
+      std::stringstream ss;
+      ss << "T" << tv->name() << ".size[" << dim++ << "]";
+      Val* new_size =
+          new NamedScalar(ss.str(), orig_size->getDataType().value());
+      if (!orig_size->sameAs(new_size) ||
+          size_map.find(orig_size) == size_map.end())
+        size_map[orig_size] = new_size;
+    }
+  }
+
+  fusion->setValuesMap(size_map);
+}
+
+void IrAdjustMemoryTypes(Fusion* fusion) {
+  for (auto val : fusion->deterministic_vals()) {
+    if (ir_utils::isTV(val)) {
+      auto tv = val->as<TensorView>();
+      if (fusion->hasInput(tv) || fusion->hasOutput(tv)) {
+        tv->setMemoryType(MemoryType::Global);
+      } else if (tv->getMemoryType() == MemoryType::Global) {
+        tv->setMemoryType(MemoryType::Local);
+      }
+    }
+  }
+}
+
+void PrepareForLowering(Fusion* fusion) {
+  FusionGuard fg(fusion);
+
+  IrFixComputeAt(fusion);
+  IrValidate(fusion);
+  IrBuildSizesMap(fusion);
+  IrAdjustMemoryTypes(fusion);
+}
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.h b/torch/csrc/jit/codegen/cuda/lower_validation.h
new file mode 100644
index 0000000..6990012
--- /dev/null
+++ b/torch/csrc/jit/codegen/cuda/lower_validation.h
@@ -0,0 +1,45 @@
+#pragma once
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+
+/*
+ * Currently this does the following:
+ *
+ * (1) Run a validation pass on the IR making sure there are no mistakes or
+ * unsupported scheduling.
+ *
+ * (2) Creates a mapping for symbolic sizes to named scalars
+ *     i.e. T0[i0] -> T0[T0.size[0]]
+ *
+ * (3) Change computeAt structure to make sure computeAt structure follows the
+ * expression structure.
+ *
+ * (4) Adjust TensorView memory types to make sure they are valid
+ */
+
+void TORCH_CUDA_API PrepareForLowering(Fusion* fusion);
+
+// Compute at can have some circular references. Before we can call any tv
+// with tv->getComputeAtAxis(i) we need to break those circular dependencies.
+void IrFixComputeAt(Fusion* fusion);
+
+// TensorViews are all based on symbolic sizes. When we first initialize them we
+// don't know if they're inputs or outputs which would mean that they have
+// runtime shapes. Intermediate tensors (those not going to global memory) do
+// not have this information. Since we need to have the correct information in
+// the kernel being fetched for shapes, we want to replace input and output
+// tensors to reference the runtime structure containing sizes.
+void IrBuildSizesMap(Fusion* fusion);
+
+// Adjust memory types to make sure they are valid
+void IrAdjustMemoryTypes(Fusion* fusion);
+
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/manager.cpp b/torch/csrc/jit/codegen/cuda/manager.cpp
index 486b0b4..e8b029c 100644
--- a/torch/csrc/jit/codegen/cuda/manager.cpp
+++ b/torch/csrc/jit/codegen/cuda/manager.cpp
@@ -3,7 +3,6 @@
 #include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
 #include <torch/csrc/jit/codegen/cuda/parser.h>
 #include <torch/csrc/jit/codegen/cuda/shape_inference.h>
-#include <torch/csrc/jit/codegen/cuda/tensor_meta.h>
 #include <torch/csrc/jit/codegen/cuda/utils.h>
 #include <torch/csrc/jit/passes/canonicalize.h>
 #include <torch/csrc/jit/passes/shape_analysis.h>
@@ -26,15 +25,6 @@
   return req_ptr;
 }
 
-// TODO: contiguity could be used for better kernel launch config.
-TensorContiguity infer_contiguity_from_tensor_type(
-    const std::shared_ptr<c10::TensorType>& tensor_type) {
-  TORCH_INTERNAL_ASSERT(tensor_type->isComplete());
-  return TensorContiguity(
-      *(tensor_type->sizes().concrete_sizes()),
-      *(tensor_type->strides().concrete_sizes()));
-}
-
 // CudaFusionManager holds compiled `CudaKernel` and handles all interfacing
 // including compilation and execution.
 //
@@ -88,7 +78,8 @@
       int32_t kernel_id,
       std::shared_ptr<Graph>& graph,
       const at::ArrayRef<IValue> inputs,
-      std::vector<at::Tensor> outputs) {
+      const std::vector<at::Tensor>& outputs,
+      const std::vector<int64_t>& broadcasted_shape) {
     std::lock_guard<std::mutex> guard(mutex_);
     TORCH_CHECK(
         kernel_cache_.count(kernel_id) != 0, "kernel id not recognized");
@@ -98,7 +89,7 @@
     if (cuda_kernel) {
       // TODO: update launch config for specific sizes;
       //       maybe we should store it in CudaKernel and compute it later
-      runKernel(*cuda_kernel, inputs, outputs);
+      runKernel(*cuda_kernel, inputs, outputs, broadcasted_shape);
     } else {
       // TODO: this should somehow be done after kernel compilation.
       //       we will want compileKernel to return a heuristic
@@ -127,7 +118,7 @@
       // NVRTC compile kernel
       compileKernel(cuda_kernel.value());
 
-      runKernel(*cuda_kernel, inputs, outputs);
+      runKernel(*cuda_kernel, inputs, outputs, broadcasted_shape);
     }
   }
 
@@ -164,7 +155,7 @@
   fusion_node->i_(attr::cache_id, fusion_cache_id);
 }
 
-void runCudaFusionGroup(const Node* const fusion_node, Stack& stack) {
+void runCudaFusionGroup(const Node* fusion_node, Stack& stack) {
   TORCH_CHECK(
       fusion_node->kind() == prim::CudaFusionGroup,
       "prim::CudaFusionGroup expected");
@@ -179,22 +170,28 @@
   std::shared_ptr<Graph> graph = fusion_node->g(attr::Subgraph)->copy();
 
   auto execute_lambda = [&]() {
-    auto nInputs = graph->inputs().size();
+    const auto nInputs = graph->inputs().size();
     at::ArrayRef<IValue> inputs = last(stack, nInputs);
 
     // shape inference in graph
     // update shape information per the new inputs;
     EraseShapeInformation(graph);
-    for (decltype(nInputs) i = 0; i < nInputs; i++) {
+    for (size_t i = 0; i < nInputs; i++) {
       graph->inputs()[i]->setType(inputs[i].type());
     }
     // shape inference
     ShapeTypePropagate(graph);
 
+    // TODO: temporary WAR that allows us to handle fusion with uniform output
+    // shape and consistent broadcast scheme. The difinition is loose and the
+    // implementation is risky. We'll do this properly when we integrate proper
+    // broadcast support.
+    std::vector<int64_t> broadcasted_shape;
+
     // we need to construct outputs;
     std::vector<at::Tensor> outputs;
-    for (const auto* const output : graph->outputs()) {
-      auto type = output->type()->expect<TensorType>();
+    for (const auto* output : graph->outputs()) {
+      const auto type = output->type()->expect<TensorType>();
       // Expect output to be tensor;
       TORCH_CHECK(
           type && type->isComplete(),
@@ -213,11 +210,34 @@
       const auto sizes = extractSizes(type);
       const auto strides = extractStrides(type);
 
-      auto tensor = at::empty_strided(sizes, strides, options);
+      const auto tensor = at::empty_strided(sizes, strides, options);
       outputs.push_back(tensor);
+
+      // TODO: unsafe broadcast assumption. We assume all output from fusion has
+      //       identical size when broadcasting.
+      if (broadcasted_shape.empty()) {
+        if (!hasReductionNode(graph->block())) {
+          broadcasted_shape = sizes;
+        } else if (isReductionNode(output->node())) {
+          auto i_type =
+              output->node()->inputs()[0]->type()->expect<TensorType>();
+          TORCH_CHECK(
+              i_type && i_type->sizes().isComplete(),
+              "Complete TensorType for output is expected.");
+          broadcasted_shape = extractSizes(i_type);
+        } else {
+          // TODO: this assert is not fool proof. We could have ignored
+          // pre-reduction tensor marked as output after we first encountered
+          // reduction output tensor.
+          TORCH_INTERNAL_ASSERT(
+              false,
+              "pre-reduction tensor output for reduction fusion is nor properly supported yet.");
+        }
+      }
     }
+
     CudaFusionManager::getManager().runFusionNode(
-        kernel_id, graph, inputs, outputs);
+        kernel_id, graph, inputs, outputs, broadcasted_shape);
     drop(stack, inputs.size());
     stack.insert(
         stack.end(),
diff --git a/torch/csrc/jit/codegen/cuda/manager.h b/torch/csrc/jit/codegen/cuda/manager.h
index f5ce2d5..9fb1bd1 100644
--- a/torch/csrc/jit/codegen/cuda/manager.h
+++ b/torch/csrc/jit/codegen/cuda/manager.h
@@ -28,9 +28,7 @@
 // Current protocol is that the function allocates output tensor append them to
 // `stack` after execution.
 // TODO: support shape inferencing. Right now we only handles static shape
-TORCH_CUDA_API void runCudaFusionGroup(
-    const Node* const fusion_node,
-    Stack& stack);
+TORCH_CUDA_API void runCudaFusionGroup(const Node* fusion_node, Stack& stack);
 
 TORCH_CUDA_API void CudaFuseGraph(std::shared_ptr<Graph>& graph);
 
diff --git a/torch/csrc/jit/codegen/cuda/mutator.h b/torch/csrc/jit/codegen/cuda/mutator.h
index c8b4c6f..822fba5 100644
--- a/torch/csrc/jit/codegen/cuda/mutator.h
+++ b/torch/csrc/jit/codegen/cuda/mutator.h
@@ -11,7 +11,7 @@
 namespace jit {
 namespace fuser {
 
-struct Fusion;
+class Fusion;
 
 /*
  * Mutators are the mechanism used to modify IR nodes. Since most nodes are
@@ -25,7 +25,7 @@
 
 // Search through "within" and replace all instances of "instance" with the
 // value "with".
-struct TORCH_CUDA_API ReplaceAll : public OptOutMutator {
+class TORCH_CUDA_API ReplaceAll : public OptOutMutator {
  private:
   // Will look in fusion and if we're replacing an input or output we register
   // those changes
diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp
index a70640b..e266f0a 100644
--- a/torch/csrc/jit/codegen/cuda/parser.cpp
+++ b/torch/csrc/jit/codegen/cuda/parser.cpp
@@ -19,23 +19,74 @@
 namespace fuser {
 namespace cuda {
 
-constexpr auto NUM_UNARY_OPS = 31;
-constexpr auto NUM_BINARY_OPS = 24;
-constexpr auto NUM_BINARY_OPS_WITH_ALPHA = 4;
-constexpr auto NUM_LERP_OPS = 2;
+constexpr auto kNumUnaryOps = 31;
+constexpr auto kNumBinaryOps = 24;
+constexpr auto kNumBinaryOpsWithAlpha = 4;
+constexpr auto kNumLerpOps = 2;
 
 namespace {
 
 typedef Val* CgValue;
 typedef Expr* CgOp;
 
-typedef void (
-    *ParseFuncPtr)(const Node* const, std::unordered_map<size_t, CgValue>&);
+typedef void (*ParseFuncPtr)(const Node*, std::unordered_map<size_t, CgValue>&);
+typedef bool (*MergeQueryFuncPtr)(const Node*);
+
+std::vector<int> reductionAxes(TensorView* tv) {
+  size_t n_dims = tv->nDims();
+  std::vector<int> reduction_axes;
+  for (size_t i = 0; i < n_dims; i++) {
+    if (tv->axis(i)->isReduction()) {
+      reduction_axes.emplace_back(i);
+    }
+  }
+  return reduction_axes;
+}
+
+// coalesces all reduction to the right side and returns total number of
+// reduction axes
+size_t coalescReduction(TensorView* tv) {
+  auto reduction_axes = reductionAxes(tv);
+  size_t n_dims = tv->nDims();
+  std::unordered_map<int, int> coalesc_permute;
+  for (size_t i = 0; i < reduction_axes.size(); i++) {
+    size_t new_pos = i + n_dims - reduction_axes.size();
+    if (new_pos == size_t(reduction_axes[i])) {
+      break;
+    } else {
+      coalesc_permute[reduction_axes[i]] = new_pos;
+    }
+  }
+  if (!coalesc_permute.empty()) {
+    tv->reorder(coalesc_permute);
+  }
+  return reduction_axes.size();
+}
 
 // TODO: add a mutex to make it thread safe.
 class IrParser {
+  class RegistrationEntry {
+   public:
+    RegistrationEntry(ParseFuncPtr parse_f, MergeQueryFuncPtr merge_f = nullptr)
+        : parse_f_(parse_f), merge_f_(merge_f) {}
+
+    void parse(const Node* node, std::unordered_map<size_t, CgValue>& values) {
+      parse_f_(node, values);
+    }
+
+    bool is_compatible(const Node* node) {
+      if (merge_f_ == nullptr) {
+        return true;
+      }
+      return merge_f_(node);
+    }
+
+   private:
+    ParseFuncPtr parse_f_;
+    MergeQueryFuncPtr merge_f_;
+  };
+
  private:
-  static const int nthreads = 128;
   static const int unroll_factor = 4;
 
  public:
@@ -52,18 +103,45 @@
     FusionGuard fg(cuda_kernel_->fusion_.get());
     auto block = graph_->block();
 
-    // in case of broadcast, we don't support explicit broadcast, so we need to
-    // convert/expand all inputs tensors to comply to the broadcasted size.
-    // This supports very limited case, which we try to accomodate in graph
-    // partition, that we only merge nodes with identical output shapes.
-    int broadcast_dim =
-        block->outputs()[0]->type()->cast<TensorType>()->dim().value();
+    // [ Note - broadcast support in integration ]
+    //
+    // in case of broadcast, we don't support explicit broadcast,
+    // 1. for point-wise fusion, so we need to convert/expand all inputs
+    // tensors to comply to the broadcasted size. This supports very limited
+    // case, which we try to accomodate in graph partition, that we only merge
+    // nodes with identical output shapes.
+    // 2. in case of reduction-at-end fusion, right now we only support single
+    // reduction operation in fusion, hence we can use the same logig for PW
+    // fusion and conver/expand all inputs to the input tensor to reduction op.
+
+    // TODO: proper broadcast support in integration
+    int broadcast_dim = -1;
+    // broadcast support hack is disabled to reduction.
+    if (hasReductionNode(graph_->block())) {
+      // reduction-at-end fusion, broadcast all inputs to tensor before
+      // reduction
+      // TODO: Not perfectly safe! We could have intermediate output that is not
+      // part of outputs of reduction operations. But we have similar limitation
+      // for broadcast support in PW fusion. We should properly fix this after
+      // broadcast integration.
+      broadcast_dim = block->outputs()[0]
+                          ->node()
+                          ->inputs()[0]
+                          ->type()
+                          ->cast<TensorType>()
+                          ->dim()
+                          .value();
+    } else {
+      // point-wise fusion, broadcast all inputs to output size.
+      broadcast_dim =
+          block->outputs()[0]->type()->cast<TensorType>()->dim().value();
+    }
 
     // register all inputs;
     // shape propagation during parsing is effctively done in parsing rules, as
     // we only explicitly register inputs in the graph.
     for (auto val : block->inputs()) {
-      TORCH_CHECK(registerValue(val, broadcast_dim));
+      TORCH_INTERNAL_ASSERT(registerValue(val, broadcast_dim));
       cuda_kernel_->fusion_->addInput(value_map_[val->unique()]);
 
       auto opt_dtype = value_map_[val->unique()]->getDataType();
@@ -78,12 +156,17 @@
     // TODO: disable unroll to ensure rand_like generates identical output as
     // with eager mode
     bool disable_unroll = false;
+    bool has_reduction = false;
+    bool fcd_reduction = false;
     // compose nodes in topo order;
     for (const JitOp* node : block->nodes()) {
       processJitNode(node);
       if (node->kind() == aten::rand_like) {
         disable_unroll = true;
       }
+      if (node->kind() == aten::sum) {
+        has_reduction = true;
+      }
     }
 
     // mark output;
@@ -102,51 +185,134 @@
 
       cuda_kernel_->fusion_->addOutput(out);
 
-      // Merge all dimensions because we're only supporting pointwise
-      while (out->nDims() > 1)
-        out->merge(0, 1);
-      // Split into 128 which will be bockDim.x
-      out->split(0, nthreads);
-      // Split by another 4 which will be our unroll factor
-      auto ur_factor = disable_unroll ? 1 : unroll_factor;
-      if (!disable_unroll) {
-        out->split(0, ur_factor);
-        cuda_kernel_->unroll_factor_ = ur_factor;
-      }
-    }
+      // TODO: has_reduction for scheudling should be done on a per output
+      //       tensor basis.
+      if (has_reduction) {
+        // TODO: this scheduling only works for a single reduction operation in
+        //       the fusion, in this case we can coalesc all reduction axes and
+        //       merge them together. (same applies to iteration axes)
+        // TODO: does this work for multiple outputs?
 
-    // Run through outputs, grab all inputs of outputs
-    // squeeze with computeAt to set overall structure.
-    for (auto output : cuda_kernel_->fusion_->outputs()) {
-      if (output->getValType() != ValType::TensorView)
-        continue;
-      TensorView* out_tv = static_cast<TensorView*>(output);
-      for (Val* inp : cuda_kernel_->fusion_->inputsOf(output)) {
-        if (inp->getValType().value() == ValType::TensorView)
-          static_cast<TensorView*>(inp)->computeAt(out_tv, 1);
-      }
-      out_tv->axis(0)->parallelize(ParallelType::BIDx);
-    }
+        // query if fastest changing dimension (FCD) is a reduction
+        fcd_reduction = out->axis((int)out->nDims() - 1)->isReduction();
 
-    // Run through intermediates, unroll, and bind their axes
-    for (auto val : cuda_kernel_->fusion_->vals()) {
-      if (val->getValType().value() != ValType::TensorView)
-        continue;
-      TensorView* tv = static_cast<TensorView*>(val);
+        // TODO: could really use evaluation here. Launch configuration is
+        //       imposed by transformation and the information should be
+        //       embedded in codegen IR.
+        cuda_kernel_->reduction_axes_ = reductionAxes(out);
 
-      // Should be true for all intermediates, but if one isn't hooked
-      // up right, skip it and hope for the best for now
-      if (!disable_unroll && tv->nDims() == 3) {
-        tv->axis(-2)->parallelize(ParallelType::Unroll);
-        tv->axis(-1)->parallelize(ParallelType::TIDx);
+        // We coalesc all reduction axes to the right;
+        size_t num_reduction_axes = coalescReduction(out);
+
+        // Merge all iteration dimensions
+        while (out->nDims() > num_reduction_axes + 1) {
+          out->merge(0, 1);
+        }
+        // Merge all reduction dimensions
+        while (out->nDims() > 2) {
+          out->merge(1, 2);
+        }
+
       } else {
-        if (tv->nDims() == 2)
+        // Merge all dimensions because we're only supporting pointwise
+        while (out->nDims() > 1)
+          out->merge(0, 1);
+        // Split into 128 which will be bockDim.x
+        out->split(0, kPwThreadX);
+        // Split by another 4 which will be our unroll factor
+        auto ur_factor = disable_unroll ? 1 : unroll_factor;
+        if (!disable_unroll) {
+          out->split(0, ur_factor);
+          cuda_kernel_->unroll_factor_ = ur_factor;
+        }
+      }
+    }
+
+    if (has_reduction) {
+      // Run through outputs, grab all inputs of outputs
+      // squeeze with computeAt to set overall structure.
+      for (auto output : cuda_kernel_->fusion_->outputs()) {
+        if (output->getValType() != ValType::TensorView)
+          continue;
+        TensorView* out_tv = static_cast<TensorView*>(output);
+
+        // fcd_reduction could be queried later via
+        // cuda_kernel_->reduction_axes_, which would ensure we have proper
+        // launch configuratoin.
+        TensorView* intermediate;
+        if (fcd_reduction) {
+          out_tv->split(-1, kFcdReductionThreadX);
+          // necessary to avoid dynamic allocation on intermediates;
+          intermediate = out_tv->rFactor({-2});
+        } else {
+          // TODO: we don't need a full warp here, this should be determined by
+          //       element data type
+          out_tv->split(0, kNonFcdReductionThreadX);
+          out_tv->split(
+              -1, kNonFcdReductionThreadY); // necessary to avoid dynamic
+                                            // allocation on intermediates;
+          intermediate = out_tv->rFactor({-2});
+        }
+        for (Val* inp : cuda_kernel_->fusion_->inputsOf(output)) {
+          // scheduling of inputs shouldn't change with different fcd_reduction
+          if (inp->getValType().value() == ValType::TensorView) {
+            static_cast<TensorView*>(inp)->computeAt(intermediate, -1);
+          }
+        }
+        // scheduling of inputs shouldn't change with different fcd_reduction
+        intermediate->computeAt(out_tv, -2);
+        if (fcd_reduction) {
+          out_tv->axis(0)->parallelize(ParallelType::BIDx);
+        } else {
+          out_tv->axis(0)->parallelize(ParallelType::BIDx);
+          out_tv->axis(1)->parallelize(ParallelType::TIDx);
+        }
+      }
+      // Run through all values, unroll, and bind their axes
+      for (auto val : cuda_kernel_->fusion_->vals()) {
+        if (val->getValType().value() != ValType::TensorView)
+          continue;
+        TensorView* tv = static_cast<TensorView*>(val);
+        if (fcd_reduction) {
           tv->axis(-1)->parallelize(ParallelType::TIDx);
+        } else {
+          tv->axis(-1)->parallelize(ParallelType::TIDy);
+        }
+      }
+    } else {
+      // Run through outputs, grab all inputs of outputs
+      // squeeze with computeAt to set overall structure.
+      for (auto output : cuda_kernel_->fusion_->outputs()) {
+        if (output->getValType() != ValType::TensorView)
+          continue;
+        TensorView* out_tv = static_cast<TensorView*>(output);
+        for (Val* inp : cuda_kernel_->fusion_->inputsOf(output)) {
+          if (inp->getValType().value() == ValType::TensorView)
+            static_cast<TensorView*>(inp)->computeAt(out_tv, 1);
+        }
+        out_tv->axis(0)->parallelize(ParallelType::BIDx);
+      }
+
+      // Run through all values, unroll, and bind their axes
+      for (auto val : cuda_kernel_->fusion_->vals()) {
+        if (val->getValType().value() != ValType::TensorView)
+          continue;
+        TensorView* tv = static_cast<TensorView*>(val);
+
+        // Should be true for all intermediates, but if one isn't hooked
+        // up right, skip it and hope for the best for now
+        if (!disable_unroll && tv->nDims() == 3) {
+          tv->axis(-2)->parallelize(ParallelType::Unroll);
+          tv->axis(-1)->parallelize(ParallelType::TIDx);
+        } else {
+          if (tv->nDims() == 2)
+            tv->axis(-1)->parallelize(ParallelType::TIDx);
+        }
       }
     }
   }
 
-  static bool canParseNode(const Node* const node) {
+  static bool canParseNode(const Node* node) {
     if (init_registry_) {
       // TODO: mutex this guy;
       registerJitOperator();
@@ -160,17 +326,39 @@
     }
     for (auto& pair_op_func : iter->second) {
       if (node->matches(pair_op_func.first->schema())) {
-        return true;
+        return pair_op_func.second.is_compatible(node);
       }
     }
     return false;
   }
 
+  static bool isReductionNode(const Node* node) {
+    if (init_registry_) {
+      // TODO: mutex this guy;
+      registerJitOperator();
+      init_registry_ = false;
+    }
+
+    return jit_reduction_op_registry_.count(node->kind());
+  }
+
+  // TODO: is_reduction is too hacky here. we should categorize operation types
+  //       based on their memory accessing pattern, which would affect fusion
+  //       strategy and partition logic.
   static void registerParseRule(
       std::shared_ptr<Operator>& op,
-      ParseFuncPtr fn) {
+      ParseFuncPtr parse_fn,
+      MergeQueryFuncPtr merge_query_fn = nullptr,
+      bool is_reduction = false) {
     jit_operator_registry_[Symbol::fromQualString(op->schema().name())]
-        .emplace_back(std::make_pair(op, fn));
+        .emplace_back(
+            std::piecewise_construct,
+            std::forward_as_tuple(op),
+            std::forward_as_tuple(parse_fn, merge_query_fn));
+    if (is_reduction) {
+      jit_reduction_op_registry_.emplace(
+          Symbol::fromQualString(op->schema().name()));
+    }
   }
 
  private:
@@ -179,7 +367,7 @@
     // This is a one-time look up, our hash registry indexes on the pointer in
     // OperatorRegistry.
 
-    std::array<const char*, NUM_BINARY_OPS_WITH_ALPHA> BinaryOpWithAlpha = {
+    std::array<const char*, kNumBinaryOpsWithAlpha> BinaryOpWithAlpha = {
         "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
         "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
         "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
@@ -188,7 +376,7 @@
       auto ptr_op = getOperatorForLiteral(signature);
       registerParseRule(
           ptr_op,
-          [](const Node* const node,
+          [](const Node* node,
              std::unordered_map<size_t, CgValue>& value_map) -> void {
             using BinaryOpWithAlphaType = Val* (*)(Val*, Val*, Val*);
             static std::unordered_map<
@@ -218,7 +406,7 @@
           });
     }
 
-    std::array<const char*, NUM_BINARY_OPS> BinaryOp = {
+    std::array<const char*, kNumBinaryOps> BinaryOp = {
         "aten::div(Tensor self, Tensor other) -> Tensor",
         "aten::div(Tensor self, Scalar other) -> Tensor",
         "aten::mul(Tensor self, Tensor other) -> Tensor",
@@ -247,7 +435,7 @@
       auto ptr_op = getOperatorForLiteral(signature);
       registerParseRule(
           ptr_op,
-          [](const Node* const node,
+          [](const Node* node,
              std::unordered_map<size_t, CgValue>& value_map) -> void {
             static std::unordered_map<Symbol, BinaryOpType> op_mapping(
                 {{aten::div, BinaryOpType::Div},
@@ -275,7 +463,7 @@
     }
 
     // TODO: cast operations should be merged in.
-    std::array<const char*, NUM_UNARY_OPS> UnaryOp = {
+    std::array<const char*, kNumUnaryOps> UnaryOp = {
         "aten::neg(Tensor self) -> Tensor",
         "aten::abs(Tensor self) -> Tensor",
         "aten::log(Tensor self) -> Tensor",
@@ -312,7 +500,7 @@
       auto ptr_op = getOperatorForLiteral(signature);
       registerParseRule(
           ptr_op,
-          [](const Node* const node,
+          [](const Node* node,
              std::unordered_map<size_t, CgValue>& value_map) -> void {
             static std::unordered_map<Symbol, UnaryOpType> op_mapping({
                 {aten::neg, UnaryOpType::Neg},
@@ -359,7 +547,7 @@
           "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor");
       registerParseRule(
           ptr_op,
-          [](const Node* const node,
+          [](const Node* node,
              std::unordered_map<size_t, CgValue>& value_map) -> void {
             auto operand = value_map[node->inputs()[0]->unique()];
 
@@ -373,7 +561,7 @@
           "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor");
       registerParseRule(
           ptr_op,
-          [](const Node* const node,
+          [](const Node* node,
              std::unordered_map<size_t, CgValue>& value_map) -> void {
             auto operand = value_map[node->inputs()[0]->unique()];
             auto th = value_map[node->inputs()[1]->unique()];
@@ -389,7 +577,7 @@
           "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor");
       registerParseRule(
           ptr_op,
-          [](const Node* const node,
+          [](const Node* node,
              std::unordered_map<size_t, CgValue>& value_map) -> void {
             auto operand = value_map[node->inputs()[0]->unique()];
             // TODO: we need to get a proper lower bound per dtype in operand.
@@ -410,7 +598,7 @@
           "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor");
       registerParseRule(
           ptr_op,
-          [](const Node* const node,
+          [](const Node* node,
              std::unordered_map<size_t, CgValue>& value_map) -> void {
             auto condition = value_map[node->inputs()[0]->unique()];
             auto x = value_map[node->inputs()[1]->unique()];
@@ -422,14 +610,14 @@
     }
 
     {
-      std::array<const char*, NUM_LERP_OPS> LerpOp = {
+      std::array<const char*, kNumLerpOps> LerpOp = {
           "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
           "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor"};
       for (auto signature : LerpOp) {
         auto ptr_op = getOperatorForLiteral(signature);
         registerParseRule(
             ptr_op,
-            [](const Node* const node,
+            [](const Node* node,
                std::unordered_map<size_t, CgValue>& value_map) -> void {
               auto self = value_map[node->inputs()[0]->unique()];
               auto end = value_map[node->inputs()[1]->unique()];
@@ -446,7 +634,7 @@
           "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor");
       registerParseRule(
           ptr_op,
-          [](const Node* const node,
+          [](const Node* node,
              std::unordered_map<size_t, CgValue>& value_map) -> void {
             auto self = value_map[node->inputs()[0]->unique()];
             auto tensor1 = value_map[node->inputs()[1]->unique()];
@@ -457,6 +645,48 @@
             value_map.emplace(node->output()->unique(), out);
           });
     }
+
+    {
+      auto ptr_op = getOperatorForLiteral(
+          "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)");
+      registerParseRule(
+          ptr_op,
+          [](const Node* node,
+             std::unordered_map<size_t, CgValue>& value_map) -> void {
+            auto self = value_map[node->input(0)->unique()];
+            auto dims_list = constant_as<c10::List<int64_t>>(node->input(1));
+            TORCH_INTERNAL_ASSERT(
+                dims_list.has_value(), "requires static reduce axes");
+            auto keepdim = constant_as<bool>(node->input(2));
+            std::vector<int> dims;
+            for (const auto dim : dims_list->vec()) {
+              dims.emplace_back(static_cast<int>(dim));
+            }
+            TORCH_INTERNAL_ASSERT(
+                keepdim.has_value() && !keepdim.value(),
+                "Keep dim in reduction is not a const false");
+            auto out = sum(self->as<TensorView>(), dims);
+            value_map.emplace(node->output()->unique(), out);
+          },
+          [](const Node* node) -> bool {
+            // we don't support cast of output types yet;
+            if (!node->inputs()[3]->type()->isSubtypeOf(
+                    static_cast<c10::TypePtr>(NoneType::get()))) {
+              return false;
+            }
+            // we don't support dynamic reduction axes;
+            if (node->inputs()[1]->node()->kind() != prim::Constant) {
+              return false;
+            }
+            // we don't support keepdim yet;
+            if (node->inputs()[2]->node()->kind() != prim::Constant ||
+                *constant_as<bool>(node->input(2))) {
+              return false;
+            }
+            return true;
+          },
+          true);
+    }
   }
 
   void processJitNode(const JitOp* node) {
@@ -464,22 +694,27 @@
       // partition doesn't take constant node explicitly, but it does and copy
       // constant into subgraph. So we need to register constants in codegen IR;
       for (auto output : node->outputs()) {
-        TORCH_CHECK(registerScalar(output));
+        TORCH_INTERNAL_ASSERT(
+            registerScalar(output),
+            "registration of output failed at index ",
+            output->offset(),
+            " for node ",
+            *node);
       }
     } else {
       auto iter = IrParser::jit_operator_registry_.find(node->kind());
       // make sure we have a parser for the op;
-      TORCH_CHECK(
+      TORCH_INTERNAL_ASSERT(
           iter != IrParser::jit_operator_registry_.end(),
           "CudaFusionGroup Parser doesn't handle operator kind(): ",
           node->kind().toDisplayString());
       for (auto& pair_op_func : iter->second) {
         if (node->matches(pair_op_func.first->schema())) {
-          pair_op_func.second(node, value_map_);
+          pair_op_func.second.parse(node, value_map_);
           return;
         }
       }
-      TORCH_CHECK(
+      TORCH_INTERNAL_ASSERT(
           false,
           "CudaFusionGroup Parser doesn't recognize operator overload:",
           canonicalSchemaString(node->schema()));
@@ -511,9 +746,24 @@
       value_map_.emplace(val->unique(), cg_val);
       return true;
     } else if (val->type()->isSubtypeOf(
+                   static_cast<c10::TypePtr>(BoolType::get()))) {
+      CgValue cg_val;
+      if (auto ival = constant_as<bool>(val)) {
+        cg_val = new Bool(ival.value());
+      } else {
+        cg_val = new Bool();
+      }
+      value_map_.emplace(val->unique(), cg_val);
+      return true;
+    } else if (val->type()->isSubtypeOf(
                    static_cast<c10::TypePtr>(NoneType::get()))) {
       // TODO: should we consider adding support for NoneType;
       return true;
+    } else if (val->type()->cast<ListType>()) {
+      // TODO: we don't support list type in codegen yet;
+      // This is a WAR to allow axes of reduction to be passed as constant list;
+      // We simply ignore conversion if the scalar value is a constant;
+      return toIValue(val).has_value();
     }
     return false;
   }
@@ -524,6 +774,9 @@
       // TODO: make this a static function in Tensor class;
       // create tensor;
       if (broadcast_dim >= 0) {
+        TORCH_INTERNAL_ASSERT(
+            broadcast_dim >= (int)*tensor_type->dim(),
+            "attempt to broadcast a tensor to shrinked dimension is invalid");
         tensor_type = tensor_type->withDim(broadcast_dim);
       }
       // TODO: make this a static function in Tensor class;
@@ -543,20 +796,41 @@
   // parsing rule registry.
   static std::unordered_map<
       Symbol,
-      std::vector<std::pair<std::shared_ptr<Operator>, ParseFuncPtr>>>
+      std::vector<std::pair<std::shared_ptr<Operator>, RegistrationEntry>>>
       jit_operator_registry_;
+  static std::unordered_set<Symbol> jit_reduction_op_registry_;
   static bool init_registry_;
 };
 
 std::unordered_map<
     Symbol,
-    std::vector<std::pair<std::shared_ptr<Operator>, ParseFuncPtr>>>
+    std::vector<
+        std::pair<std::shared_ptr<Operator>, IrParser::RegistrationEntry>>>
     IrParser::jit_operator_registry_;
+std::unordered_set<Symbol> IrParser::jit_reduction_op_registry_;
 bool IrParser::init_registry_ = true;
 
 } // namespace
 
-bool isNodeParsible(const Node* const node) {
+bool hasReductionNode(const Block* block) {
+  for (auto node : block->nodes()) {
+    if (isReductionNode(node)) {
+      return true;
+    }
+    for (auto block : node->blocks()) {
+      if (hasReductionNode(block)) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+bool isReductionNode(const Node* node) {
+  return IrParser::isReductionNode(node);
+}
+
+bool isNodeParsible(const Node* node) {
   return IrParser::canParseNode(node);
 }
 
diff --git a/torch/csrc/jit/codegen/cuda/parser.h b/torch/csrc/jit/codegen/cuda/parser.h
index 53eebef..38c0cb2 100644
--- a/torch/csrc/jit/codegen/cuda/parser.h
+++ b/torch/csrc/jit/codegen/cuda/parser.h
@@ -26,8 +26,17 @@
 namespace fuser {
 namespace cuda {
 
+constexpr int kPwThreadX = 128;
+constexpr int kFcdReductionThreadX = 128;
+constexpr int kNonFcdReductionThreadX = 32;
+constexpr int kNonFcdReductionThreadY = 32;
+
+TORCH_CUDA_API bool hasReductionNode(const Block* block);
+
+TORCH_CUDA_API bool isReductionNode(const Node* node);
+
 // returns whether or not a parsing function exists for the given node type.
-TORCH_CUDA_API bool isNodeParsible(const Node* const node);
+TORCH_CUDA_API bool isNodeParsible(const Node* node);
 
 // lowers PyTorch jit graph to `Fusion`.
 TORCH_CUDA_API void parseJitIR(
diff --git a/torch/csrc/jit/codegen/cuda/partition.cpp b/torch/csrc/jit/codegen/cuda/partition.cpp
index e96d97f..09eb585 100644
--- a/torch/csrc/jit/codegen/cuda/partition.cpp
+++ b/torch/csrc/jit/codegen/cuda/partition.cpp
@@ -13,7 +13,7 @@
 //   1. TensorType
 //   2. on the same device;
 // TODO: update this when codegen can output scalar
-static c10::optional<c10::Device> getDevice(const Value* const value) {
+static c10::optional<c10::Device> getDevice(const Value* value) {
   if (!value->type()->isSubtypeOf(TensorType::get())) {
     // not tensor type, return false as the op is not outputing scalar.
     return c10::nullopt;
@@ -21,7 +21,7 @@
   return value->type()->expect<TensorType>()->device();
 }
 
-static c10::optional<c10::Device> getDevice(const Node* const node) {
+static c10::optional<c10::Device> getDevice(const Node* node) {
   auto outputs = node->outputs();
   for (auto output : outputs) {
     auto device = getDevice(output);
@@ -51,25 +51,39 @@
   return device->is_cuda();
 }
 
-inline bool isFusableNode(const Node* const node) {
+inline bool isFusableNode(const Node* node) {
   // checks if node is compatible with parser:
   // 1. if we have a parsing rule; or 2. if the node is already a fusion group.
   return (isNodeParsible(node) || node->kind() == prim::CudaFusionGroup);
 }
 
+bool hasReductionOperation(const Node* node) {
+  if (isReductionNode(node)) {
+    return true;
+  }
+  if (node->kind() == prim::CudaFusionGroup) {
+    for (auto n : node->g(attr::Subgraph)->nodes()) {
+      if (hasReductionOperation(n)) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
 } // namespace
 
-bool isFusableCudaFusionGroup(const Node* const node) {
+bool isFusableCudaFusionGroup(const Node* node) {
   if (isFusableNode(node)) {
     return isFusableDevice(node);
   }
   return false;
 }
 
-bool isFusableCudaFusionGroup(
-    const Node* const fusion,
-    const Node* const node) {
-  if (isFusableCudaFusionGroup(node)) {
+bool isFusableCudaFusionGroup(const Node* fusion, const Node* node) {
+  // TODO: lift the restriction of not fusing producer containing reduction when
+  //       we have proper scheduling.
+  if (isFusableCudaFusionGroup(node) && !hasReductionOperation(node)) {
     // TODO: ensure legit fusion.
     // issue 0: currently codegen doesn't support broadcasting, except in the
     //          form of stride 0.
diff --git a/torch/csrc/jit/codegen/cuda/partition.h b/torch/csrc/jit/codegen/cuda/partition.h
index e46ad8e..21a44d8 100644
--- a/torch/csrc/jit/codegen/cuda/partition.h
+++ b/torch/csrc/jit/codegen/cuda/partition.h
@@ -19,12 +19,12 @@
 namespace fuser {
 namespace cuda {
 
-TORCH_CUDA_API bool isFusableCudaFusionGroup(const Node* const node);
+TORCH_CUDA_API bool isFusableCudaFusionGroup(const Node* node);
 
 // consider if `node` could be fused into `fusion`
 TORCH_CUDA_API bool isFusableCudaFusionGroup(
-    const Node* const fusion,
-    const Node* const node);
+    const Node* fusion,
+    const Node* node);
 
 } // namespace cuda
 } // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.h b/torch/csrc/jit/codegen/cuda/predicate_compute.h
index 620bdc2..f671d31 100644
--- a/torch/csrc/jit/codegen/cuda/predicate_compute.h
+++ b/torch/csrc/jit/codegen/cuda/predicate_compute.h
@@ -32,7 +32,8 @@
 namespace jit {
 namespace fuser {
 
-struct PredicateCompute {
+class PredicateCompute {
+ public:
   // Return if there are any predicates
   static bool hasPredicates(const TensorIndex*);
 
diff --git a/torch/csrc/jit/codegen/cuda/register_interface.cpp b/torch/csrc/jit/codegen/cuda/register_interface.cpp
index 663518b..cedf44c 100644
--- a/torch/csrc/jit/codegen/cuda/register_interface.cpp
+++ b/torch/csrc/jit/codegen/cuda/register_interface.cpp
@@ -12,7 +12,8 @@
 namespace cuda {
 
 namespace {
-struct RegisterInterface {
+class RegisterInterface {
+ public:
   RegisterInterface() {
     auto ptr = getFuserInterface();
     ptr->fn_compile_n_ = &compileCudaFusionGroup;
diff --git a/torch/csrc/jit/codegen/cuda/shape_inference.cpp b/torch/csrc/jit/codegen/cuda/shape_inference.cpp
index f6f0643..8038b2b 100644
--- a/torch/csrc/jit/codegen/cuda/shape_inference.cpp
+++ b/torch/csrc/jit/codegen/cuda/shape_inference.cpp
@@ -1,5 +1,6 @@
 #include <torch/csrc/jit/codegen/cuda/shape_inference.h>
 #include <c10/core/ScalarType.h>
+#include <torch/csrc/jit/ir/constants.h>
 #include <torch/csrc/jit/runtime/operator.h>
 
 #include <ATen/ExpandUtils.h>
@@ -105,7 +106,7 @@
       // to neither type promoteion nor shape.
       case aten::add:
       case aten::sub: {
-        auto promoted_type = binary_broadcast_type(
+        const auto promoted_type = binary_broadcast_type(
             node->input(0)->type()->cast<TensorType>(),
             node->input(1)->type()->cast<TensorType>());
         node->output()->setType(promoted_type);
@@ -118,7 +119,7 @@
       case aten::ge:
       case aten::ne:
       case aten::eq: {
-        auto promoted_type = binary_broadcast_type(
+        const auto promoted_type = binary_broadcast_type(
             node->input(0)->type()->cast<TensorType>(),
             node->input(1)->type()->cast<TensorType>(),
             at::ScalarType::Bool);
@@ -126,7 +127,7 @@
         break;
       }
       case aten::where: {
-        auto promoted_type = binary_broadcast_type(
+        const auto promoted_type = binary_broadcast_type(
             node->input(1)->type()->cast<TensorType>(),
             node->input(2)->type()->cast<TensorType>());
         node->output()->setType(promoted_type);
@@ -141,8 +142,21 @@
         node->output()->setType(promoted_type);
         break;
       }
+      case aten::sum: {
+        const auto out_type = node->input(0)->type()->cast<TensorType>();
+        const auto dims = constant_as<c10::List<int64_t>>(node->input(1));
+        const auto keepdim = constant_as<bool>(node->input(2));
+        TORCH_CHECK(
+            dims.has_value() && keepdim.has_value(),
+            "Shape inference cannot handle options.");
+        node->output()->setType(
+            unary_reduce_type(out_type, dims->vec(), keepdim.value()));
+        break;
+      }
       default:
-        TORCH_CHECK(false, "shape/type inference failed.");
+        TORCH_CHECK(
+            false,
+            "shape/type inference failed, unrecognized operation encountered.");
         // TODO: generate a proper error log, as this probably means something
         //       went unexpected.
         break;
@@ -154,6 +168,29 @@
   }
 
  protected:
+  TensorTypePtr unary_reduce_type(
+      const TensorTypePtr& op,
+      const std::vector<int64_t>& dims,
+      bool keepdim) {
+    TORCH_CHECK(
+        op->scalarType().has_value() && op->device().has_value() &&
+            op->sizes().isComplete(),
+        "requires complete shape on input");
+    std::vector<int64_t> output_size;
+    std::vector<int64_t> input_size = *op->sizes().concrete_sizes();
+    for (size_t i = 0; i < input_size.size(); i++) {
+      if (std::find(dims.begin(), dims.end(), i) == dims.end()) {
+        output_size.emplace_back(input_size[i]);
+      } else if (keepdim) {
+        // Pushing size 1 here to maintain the reduction dimension because
+        // keepdim is true;
+        output_size.emplace_back(1);
+      }
+    }
+    return TensorType::createContiguous(
+        *op->scalarType(), *op->device(), output_size);
+  }
+
   // TODO: we should comply to codegen type promotion.
   TensorTypePtr binary_broadcast_type(
       TensorTypePtr const& op0,
diff --git a/torch/csrc/jit/codegen/cuda/tensor_meta.cpp b/torch/csrc/jit/codegen/cuda/tensor_meta.cpp
deleted file mode 100644
index 275afc4..0000000
--- a/torch/csrc/jit/codegen/cuda/tensor_meta.cpp
+++ /dev/null
@@ -1,290 +0,0 @@
-#include <torch/csrc/jit/codegen/cuda/tensor_meta.h>
-#include <algorithm>
-#include <numeric>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-
-//#define TC_DEBUG
-
-/*
- * [Note - TensorContiguity implementation]
- *
- * contiguity_
- *   stores contiguity information for each dimension:
- *   number stored should be between 0 to N+2;
- *     0   - this axis requires broadcasting;
- *     X   - where X belongs [-N, -1] U [1, N] means current axis is immediately
- *           outside of axis `abs(X)-1`. If X < 0, The two axis are contiguous
- *           and can be collapsed;
- *     N+1 - the fastest changing dimension, If -N-1, means its contiguous in
- *           storage (stride == 1);
- *     N+2 - Unknown;
- *
- * TODO sorted_axes_ is something hard to maintain in a meaningful way during
- *      merge, maybe we'll drop it if it's not of much help for kernel
- *      generation
- * sorted_axes_
- *   This is a helper field, list of axes sorted by their stride;
- *   Default would be `[0, 1, 2, ..., N-1]`
- *   Given sorted_axes_[i] == X means: the i-th axis should be X (if X belongs
- *   [0, N-1]). If X == -1, that means it's unknown. This could happen when we
- *   merge two TensorContiguity and their order of axes are not consistent.
- *
- * The design of TensorContiguity is to handle two things:
- *   1. Contiguity check - whether or not the contiguity information has
- *      changed. To do this, we can simply compare TensorContiguity::contiguity_
- *      between two instances;
- *   2. Kernel generation
- *      By looking at contiguity_ flag, we can make correct decision like:
- *        a. collpasing dimensions;
- *        b. kernel binding;
- *
- * merging two TensorContiguity would check their contiguity_ flag and mark
- * accordinly.
- *
- * Issues with current implementation [definitely not complete]:
- *   1. stride for size-1 dimension.
- *     Because of PyTorch implementation, stride for size-1 dimension is ill
- *     defined and can't properly express intended layout.
- */
-
-// debug print. remove this guy!
-#ifdef TC_DEBUG
-template <typename T>
-std::ostream& operator<<(std::ostream& os, const std::vector<T>& data) {
-  os << "(";
-  for (auto i = data.begin(); i != data.end(); i++) {
-    os << (*i);
-    os << " ";
-  }
-  return os << ")";
-}
-#endif
-
-TensorContiguity::TensorContiguity(
-    const std::vector<int64_t>& sizes,
-    const std::vector<int64_t>& strides) {
-#ifdef TC_DEBUG
-  std::cout << "==== contiguity ====" << std::endl;
-  std::cout << "sizes: " << sizes << std::endl;
-  std::cout << "strides: " << strides << std::endl;
-#endif
-
-  // assert consistent dimensionality;
-  assert(sizes.size() == strides.size());
-
-  // check for size 0 tensor;
-  // size 0 tensor is not handled yet and we should not treat it as broadcast
-  assert(std::none_of(
-      sizes.begin(), sizes.end(), [](int64_t size) { return size == 0; }));
-
-  int dim = sizes.size();
-  contiguity_.resize(dim);
-  sorted_axes_.resize(dim);
-
-  // TODO:
-  if (dim <= 0) {
-    return;
-  }
-
-  std::iota(sorted_axes_.begin(), sorted_axes_.end(), 0);
-
-  // sort axes per their strides
-  // It's important that we use stable sort here, as higher
-  std::stable_sort(
-      sorted_axes_.begin(),
-      sorted_axes_.end(),
-      [&strides](int64_t s_a, int64_t s_b) {
-        return strides[s_a] > strides[s_b];
-      });
-
-#ifdef TC_DEBUG
-  std::cout << "sorted index: " << sorted_axes_ << std::endl;
-#endif
-
-  // Update contiguity flag all the way until the second to the last;
-  for (int i = 0; i < dim; i++) {
-    // decending strides: strides[axis_p_1] <= stride[axis_p];
-    int axis_p = sorted_axes_[i];
-    int stride_p = strides[axis_p];
-    if (stride_p == 0) {
-      contiguity_[axis_p] = 0; // mark axis_p as broadcast
-    } else {
-      if (i + 1 == dim) {
-        contiguity_[axis_p] = dim + 1;
-        if (stride_p == 1) {
-          contiguity_[axis_p] *= -1; // we mark axis_p as contiguous in memory;
-        }
-        break;
-      }
-      // Check if we should skip the check for collapsing, if:
-      //   1. we are at the fastest changing dimension already.
-      //      (i == dim-1)
-      //   or
-      //   2. the next dimension is a broadcast dimension.
-      //      (strides[sorted_axes_[i+1]] == 0))
-      if ((i == dim - 1) || (strides[sorted_axes_[i + 1]] == 0)) {
-        // axis_p is the fastest changing dimension.
-        //   dim+1 is out of range for next axis, so we would know it's the last
-        //   dimension.
-        contiguity_[axis_p] = dim + 1;
-        if (stride_p == 1) {
-          // we mark axis_p as contiguous in memory by setting it to negative.
-          contiguity_[axis_p] *= -1;
-        }
-      } else {
-        int axis_p_1 = sorted_axes_[i + 1];
-        // mark axis_p_1 as the neighboring axis;
-        // Notice the compensation for 1-based indexing.
-        contiguity_[axis_p] = axis_p_1 + 1;
-
-        // Check if axis_p could collapse down to axis_p_1;
-        // [Note] Do NOT specialize on size-1 dimension! Two issues:
-        //   Although size-1 could collapse with any dimension, that's going to
-        //   be a specialization to the static shape information -> hence not an
-        //   intended protocol.
-        //   size-1 collapsing could also be misleading, as the size-1 axis
-        //   could falsely give the impression that its neighboring axes are
-        //   collapsible, while they are not.
-        //   i.e. size[4, 1, 4]; stride[8, 1, 1];
-        //   both axis 0 and 2 could fuse with axis 1 separately. but we cannot
-        //   fuse them all together.
-        if (stride_p == sizes[axis_p_1] * strides[axis_p_1]) {
-          // negative number to specify it's collapsable.
-          contiguity_[axis_p] *= -1; // mark axis_p as broadcast
-        }
-      }
-    }
-  }
-
-#ifdef TC_DEBUG
-  std::cout << "contiguity flag: " << contiguity_ << std::endl;
-  std::cout << "==== done contiguity ====" << std::endl;
-#endif
-}
-
-bool TensorContiguity::isBroadcastDim(int axis) const {
-  assert(axis >= 0 && axis < rank());
-  return contiguity_[axis] == 0;
-}
-
-std::vector<int> TensorContiguity::getBroadcastDims() const {
-  std::vector<int> ret;
-  for (decltype(contiguity_.size()) i{0}; i < contiguity_.size(); i++) {
-    if (contiguity_[i] == 0) {
-      ret.emplace_back(static_cast<int>(i));
-    }
-  }
-  return ret;
-}
-
-// we are checking if axis can merge to right.
-// axis_right == axis + 1,
-bool TensorContiguity::canCollapseToHigher(int axis) const {
-  // not necessary as to check `assert(axis < rank()-1);` as
-  // canCollapseLowerHigher would assert on that;
-  return canCollapseLowerHigher(axis, axis + 1);
-}
-
-int TensorContiguity::rank() const {
-  return contiguity_.size();
-}
-
-bool TensorContiguity::canCollapseLowerHigher(int lower_axis, int higher_axis)
-    const {
-  // we are checking if axis can merge to right.
-  // we mark contiguity_ as -(target_axis + 1), if it's collapsible;
-  assert(
-      lower_axis >= 0 && lower_axis < rank() && higher_axis >= 0 &&
-      higher_axis < rank());
-  return contiguity_[lower_axis] == -(higher_axis + 1);
-}
-
-int TensorContiguity::getFCD() const {
-  for (decltype(contiguity_.size()) i{0}; i < contiguity_.size(); i++) {
-    if (contiguity_[i] == (-((int)contiguity_.size()) - 1))
-      return i;
-  }
-  return -1;
-}
-
-bool TensorContiguity::isIdentical(const TensorContiguity& tc) const {
-  for (int i = 0; i < rank(); i++) {
-    if (tc.contiguity_[i] != contiguity_[i]) {
-      return false;
-    }
-  }
-  return true;
-}
-
-bool TensorContiguity::isCompatible(const TensorContiguity& tc) const {
-  assert(false); // not yet implemented;
-  return false;
-}
-bool TensorContiguity::hasContiguousFCD() const {
-  for (decltype(contiguity_.size()) i{0}; i < contiguity_.size(); i++) {
-    if (contiguity_[i] == (-((int)contiguity_.size()) - 1))
-      return true;
-  }
-  return false;
-}
-
-int TensorContiguity::getAxisByStride(int order) const {
-  assert(order >= 0 && order < rank());
-  return sorted_axes_[order];
-}
-
-const std::vector<int>& TensorContiguity::getAxesOrderedByStride() const {
-  return sorted_axes_;
-}
-
-const std::vector<int>& TensorContiguity::getContiguityTag() const {
-  return contiguity_;
-}
-
-const std::vector<int>& TensorContiguity::getSortedAxesTag() const {
-  return sorted_axes_;
-}
-
-void TensorContiguity::merge(const TensorContiguity& tc) {
-  // TODO: different rank not supported yet; This could be done if we follow
-  //       numpy broadcasting rule across multiple operands. We simply insert
-  //       dummy dimensions at the left for tc with lower rank()
-  // see [Note - TensorContiguity implementation]
-  int dim = rank();
-  assert(dim == tc.rank());
-
-  for (int i = 0; i < dim; i++) {
-    int cont_flag = tc.contiguity_[i];
-
-    if (cont_flag != contiguity_[i]) {
-      if (cont_flag == -contiguity_[i]) {
-        // If sorting should remain, we preserve the information but only relax
-        // the contiguity information;
-        contiguity_[i] = std::abs(cont_flag);
-      } else {
-        // mark contiguity as unknown otherwise.
-        contiguity_[i] = dim + 2;
-      }
-      cont_flag = contiguity_[i];
-    }
-
-    // TODO: can we update sorted_axes_ information via contiguity flag?
-    if (tc.sorted_axes_[i] != sorted_axes_[i]) {
-      // mark sorted_axes_ as unknown;
-      sorted_axes_[i] = -1;
-    }
-  }
-
-#ifdef TC_DEBUG
-  std::cout << "merging" << std::endl;
-  std::cout << "sorted index: " << sorted_axes_ << std::endl;
-  std::cout << "contiguity flag: " << contiguity_ << std::endl;
-#endif
-}
-
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/tensor_meta.h b/torch/csrc/jit/codegen/cuda/tensor_meta.h
deleted file mode 100644
index 0f85ad2..0000000
--- a/torch/csrc/jit/codegen/cuda/tensor_meta.h
+++ /dev/null
@@ -1,85 +0,0 @@
-#pragma once
-
-#include <c10/core/Device.h>
-#include <c10/core/DeviceType.h>
-#include <c10/util/Exception.h>
-#include <torch/csrc/WindowsTorchApiMacro.h> // TORCH_CUDA_API
-
-#include <torch/csrc/jit/codegen/cuda/utils.h>
-
-#include <cstdint>
-#include <vector>
-
-namespace torch {
-namespace jit {
-namespace fuser {
-/*
- Issues that we are not solving:
-   1. strides for trivial dimensions (size-1 dimension);
-   2. memory overlap / interleave;
- TODO: docs explaining the protocol; what is stored and how does merge work
- */
-struct TORCH_CUDA_API TensorContiguity {
-  TensorContiguity(
-      const std::vector<int64_t>& size,
-      const std::vector<int64_t>& stride);
-
-  // gives broadcast information per axis;
-  bool isBroadcastDim(int axis) const;
-
-  // returns all axes that requires broadcast;
-  std::vector<int> getBroadcastDims() const;
-
-  // gives contiguity information per axis;
-  // This basically calls to canCollapseLowerHigher(axis, axis+1);
-  bool canCollapseToHigher(int axis) const;
-
-  // return the rank of the tensor;
-  int rank() const;
-
-  // check contiguity
-  bool isIdentical(const TensorContiguity& tc) const;
-
-  // TODO: check for compatiblity
-  // I need feedback on this one. What do we check? order + broadcast rule?
-  bool isCompatible(const TensorContiguity& tc) const;
-
-  /*******************************************************************************
-   * Future proof support
-   *   we don't need these yet.
-   * TODO: we probably won't need this until much later, but let's try solve
-   * the problem that doesn't exist yet;
-   ******************************************************************************/
-
-  // [NOTE] the order of the argument matters:
-  // canCollapseLowerHigher(x, y) differs from canCollapseLowerHigher(y, x)
-  bool canCollapseLowerHigher(int lower_axis, int higher_axis) const;
-
-  // FCD: Fast changing dimension, the dimension with smallest stride (>0).
-  //   returns -1 if FCD doesn't exist (e.g. fully broadcast)
-  int getFCD() const;
-  // Check if FCD exist and has stride == 1.
-  bool hasContiguousFCD() const;
-
-  // This is used to support rational binding;
-  // similarly return -1 means it's unknown;
-  int getAxisByStride(int order) const;
-  const std::vector<int>& getAxesOrderedByStride() const;
-
-  // TODO: we should encode this to a single integer with restricted rank.
-  const std::vector<int>& getContiguityTag() const;
-  const std::vector<int>& getSortedAxesTag() const;
-
-  void merge(const TensorContiguity& tc);
-
- protected:
-  // contiguity_  : contiguity and broadcast;
-  std::vector<int> contiguity_;
-
-  // sorted_axes_ : axes ordered by strides (slow dimension to fast dimension).
-  std::vector<int> sorted_axes_;
-};
-
-} // namespace fuser
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp
index 44dd66a..dd8ac6f 100644
--- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp
+++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp
@@ -1,9 +1,14 @@
 #include <torch/csrc/jit/codegen/cuda/arith.h>
+#include <torch/csrc/jit/codegen/cuda/compute_at.h>
 #include <torch/csrc/jit/codegen/cuda/fusion.h>
 #include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
+#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
 #include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
-#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
-#include <torch/csrc/jit/codegen/cuda/mutator.h>
+// #include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
+#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
+
+// Cleanup
+// #include <torch/csrc/jit/codegen/cuda/mutator.h>
 #include <torch/csrc/jit/codegen/cuda/transform_iter.h>
 #include <torch/csrc/jit/codegen/cuda/transform_replay.h>
 
@@ -38,10 +43,26 @@
   this->name_ = fusion_->registerVal(this);
 }
 
+TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner)
+    : Val(src, ir_cloner),
+      domain_(ir_cloner->clone(src->domain_)),
+      compute_at_view_(ir_cloner->clone(src->compute_at_view_)),
+      relative_compute_at_axis_(src->relative_compute_at_axis_),
+      this_compute_at_axis_(src->this_compute_at_axis_),
+      memory_type_(src->memory_type_) {}
+
 bool TensorView::hasReduction() const {
   return domain()->hasReduction();
 }
 
+bool TensorView::hasBlockReduction() const {
+  return domain()->hasBlockReduction();
+}
+
+bool TensorView::hasGridReduction() const {
+  return domain()->hasGridReduction();
+}
+
 bool TensorView::hasBroadcast() const {
   return domain()->hasBroadcast();
 }
@@ -55,6 +76,8 @@
 }
 
 IterDomain* TensorView::axis(int pos) const {
+  TORCH_INTERNAL_ASSERT(
+      nDims() > 0, "Tried to access an axis in a 0-dim TensorView");
   if (pos < 0)
     pos += domain()->nDims();
   TORCH_CHECK(
@@ -98,6 +121,17 @@
       "Invalid computeAt, reduction domain inside computeAt axis.");
 }
 
+void TensorView::setComputeAt(
+    TensorView* computeAtView,
+    int thisPos,
+    int relPos) {
+  compute_at_view_ = computeAtView;
+  relative_compute_at_axis_ = relPos;
+  this_compute_at_axis_ = thisPos;
+  TORCH_INTERNAL_ASSERT(
+      this_compute_at_axis_ <= nDims(), "Manually set an invalid computeAt.");
+}
+
 void TensorView::copyDomain(const TensorDomain* td) {
   std::vector<IterDomain*> idv;
   for (decltype(td->nDims()) i = 0; i < td->nDims(); i++)
@@ -116,7 +150,8 @@
   size_t pos_cav = 0, pos_this = 0;
   while ((int)pos_this < pos) {
     TORCH_INTERNAL_ASSERT(
-        pos_cav < nDims(), "Error computing relative position in computeAt.");
+        pos_cav < compute_at_view_->nDims(),
+        "Error computing relative position in computeAt.");
     if (compute_at_view_->axis(pos_cav)->isBroadcast() &&
         !(axis(pos_this)->isBroadcast())) {
       pos_cav++;
@@ -160,239 +195,27 @@
   this_compute_at_axis_ = pos_this;
 }
 
-// Actually applies transformation
-void TensorView::computeAt_impl(
-    TensorView* consumer,
-    int consumer_compute_at_axis) {
-  // Reset view otherwise will conflict with replay.
-  clearComputeAt();
-  // replay this as consumer / producer as consumer
-  TransformReplay::replayPasC(this, consumer, consumer_compute_at_axis);
-  setComputeAt(consumer, consumer_compute_at_axis);
-}
-
-// Actually applies transformation
-void TensorView::forwardComputeAt_impl(
-    TensorView* producer,
-    int producer_compute_at_axis) {
-  // Reset view otherwise will conflict with replay.
-  producer->clearComputeAt();
-  TransformReplay::replayCasP(this, producer, producer_compute_at_axis);
-  producer->setComputeAt(this, producer_compute_at_axis);
-}
-
-namespace {
-// Wrapper around set_intersection
-template <typename T>
-std::set<T> set_intersection(const std::set<T>& set1, const std::set<T>& set2) {
-  std::set<T> intersection;
-  std::set_intersection(
-      set1.begin(),
-      set1.end(),
-      set2.begin(),
-      set2.end(),
-      std::inserter(intersection, intersection.begin()));
-  return intersection;
-}
-
-// convert an iterable of Val* to be an iterable of TensorView*
-template <typename T1, typename T2>
-T1 tv_iterable(const T2& val_iterable) {
-  T1 tv_iterable = T1();
-  std::transform(
-      val_iterable.begin(),
-      val_iterable.end(),
-      std::back_inserter(tv_iterable),
-      [](Val* v) {
-        TORCH_INTERNAL_ASSERT(
-            v->getValType().value() == ValType::TensorView,
-            "When following the computeAt dependency chain, a non TensorView value was found.");
-        return static_cast<TensorView*>(v);
-      });
-  return tv_iterable;
-}
-} // namespace
-
 TensorView* TensorView::computeAt(TensorView* consumer, int axis) {
-  TORCH_CHECK(
-      this->fusion() == consumer->fusion(),
-      this,
-      " and ",
-      consumer,
-      " are not in the same fusion.");
-
-  FusionGuard fg(this->fusion());
-
+  // Make sure this and consumer are not the same tensor, that's illegal
   TORCH_CHECK(
       !this->sameAs(consumer), "Cannot call this->computeAt(this, ...)");
 
+  // We support negative axes, so increment it by consumer->nDims() + 1 and make
+  // sure the result is within consumer->nDims() + 1. being at consumer->nDims()
+  // means producer will be computed inline with consumer, hence the +1.
   if (axis < 0)
-    // Compute at is a bit strange where size is the maximum acceptable value
-    // instead of size-1
     axis += int(consumer->nDims()) + 1;
-
   TORCH_CHECK(
       axis >= 0 && (unsigned int)axis < consumer->nDims() + 1,
       "Compute at called on an axis outside valid range.");
 
-  // If not direct relationship follow dependency chain from consumer to
-  // producer.
-  auto dep_chains = DependencyCheck::getAllDependencyChains(this, consumer);
-
-  std::deque<Val*> dep_chain;
-  if (!dep_chains.empty())
-    dep_chain = dep_chains.front();
-
-  // Make sure there is a dependency chain, if not it's an invalid computeAt.
-  // We could do indirect computeAts, but it's not supported at this time.
-  TORCH_CHECK(
-      !dep_chain.empty(),
-      "Compute At expects ",
-      this,
-      " is a dependency of ",
-      consumer,
-      ", however it is not.");
-
-  // Validate dependency chain returned as expected
-  TORCH_INTERNAL_ASSERT(
-      dep_chain.back() == consumer && dep_chain[0] == this,
-      "Error computing dependency chain.");
-
-  // Start the replay going from consumer, through the dependency chain to
-  // producer. After this section, producer should look like consumer, and there
-  // should be a computeAt chain going from producer to consumer. Proper
-  // computeAts are setup, though they will be over-written in a later stage.
-  while (dep_chain.size() > 1) {
-    Val* consumer_val = dep_chain.back();
-    dep_chain.pop_back();
-    Val* producer_val = dep_chain.back();
-
-    TORCH_INTERNAL_ASSERT(
-        consumer_val->getValType().value() == ValType::TensorView &&
-            producer_val->getValType().value() == ValType::TensorView,
-        "When following the computeAt dependency chain, a non TensorView value was found.");
-
-    TensorView* running_consumer = static_cast<TensorView*>(consumer_val);
-    TensorView* running_producer = static_cast<TensorView*>(producer_val);
-    // Axis is relative to consumer, however as we propagate computeAt, it may
-    // move. This is why we have TensorView->getThisComputeAtAxis() which
-    // returns where in a TensorView does the computeAt (relative to consumer)
-    // line up. Mismatch is due to broadcast.
-    int compute_at_axis = axis;
-    if (running_consumer != consumer)
-      compute_at_axis = (int)running_consumer->getThisComputeAtAxis();
-    running_producer->computeAt_impl(running_consumer, compute_at_axis);
-  }
-
-  /*
-   * Compute At has now worked from consumer to producer, transforming producer
-   * to match computeAt selected in consumer We now need to work from producer
-   * up to its consumers (including indirect consumption) so their use also
-   * matches. If we can find a TV that contains all uses of producer (common
-   * consumer), we can terminate this propagation there. If not, we need to
-   * propagate all the way to outputs.
-   */
-
-  // Start looking for a common consumer of producer
-
-  // Grab all uses of producer in fusion
-  auto val_all_consumer_chains =
-      DependencyCheck::getAllDependencyChainsTo(this);
-
-  // Convert dep chains to tensor view chains
-  std::deque<std::deque<TensorView*>> all_consumer_chains;
-  for (const auto& val_dep_chain : val_all_consumer_chains)
-    all_consumer_chains.push_back(
-        tv_iterable<std::deque<TensorView*>>(val_dep_chain));
-
-  // Set arith to find a common consumer, start with first use chain of producer
-  std::set<TensorView*> common_consumers(
-      all_consumer_chains.front().begin(), all_consumer_chains.front().end());
-
-  // Run through all use chains of producer, and intersect them
-  for (auto dep_chain : all_consumer_chains)
-    common_consumers = set_intersection(
-        common_consumers,
-        std::set<TensorView*>(dep_chain.begin(), dep_chain.end()));
-
-  // Remove all TVs between producer and consumer as we don't want a common
-  // consumer placed logically before consumer provided in computeAt
-  for (const auto& dep_chain : dep_chains) {
-    auto tv_chain = tv_iterable<std::deque<TensorView*>>(dep_chain);
-    for (auto tv : tv_chain) {
-      if (tv != consumer)
-        common_consumers.erase(tv);
-    }
-  }
-
-  // If there is a common consumer, grab the first one (topologically)
-  TensorView* common_consumer = nullptr;
-  if (!common_consumers.empty()) {
-    for (TensorView* tv : all_consumer_chains.front())
-      if (common_consumers.find(tv) != common_consumers.end()) {
-        common_consumer = tv;
-        break;
-      }
-  }
-
-  // Forward propagate the transformationthrough all use chains until
-  // common_consumer if there is one otherwise until we hit all output TVs
-  std::set<TensorView*> output_set;
-  // computeAt axis in outputs don't necessarily match up, make sure to keep the
-  // relative computeAt position in each output
-  std::vector<std::pair<TensorView*, int>> ordered_outputs;
-  for (auto dep_chain : all_consumer_chains) {
-    // All dep chains start with this.
-    TORCH_INTERNAL_ASSERT(
-        dep_chain.front() == this,
-        "Invalid dependency chain found during computeAt, ",
-        dep_chain.front(),
-        " should be ",
-        this);
-    TORCH_INTERNAL_ASSERT(
-        this->hasComputeAt(),
-        "Error detected during computeAt, ",
-        this,
-        ", should have a computeAt set at this point even though we will over-write it.");
-    int running_producer_compute_at = (int)this->getThisComputeAtAxis();
-    while (dep_chain.size() > 1) {
-      TensorView* running_producer = dep_chain.front();
-      dep_chain.pop_front();
-      TensorView* running_consumer = dep_chain.front();
-
-      if (running_producer == common_consumer)
-        break;
-      // Axis is relative to consumer, and may not necessarily apply to all
-      // intermediate steps. Fortunately producer is guarenteed to have a valid
-      // computeAt set, so we can use the compute at axis relative to producer.
-      running_consumer->forwardComputeAt_impl(
-          running_producer, running_producer_compute_at);
-      running_producer_compute_at =
-          (int)running_producer->getThisComputeAtAxis();
-      int consumer_compute_at =
-          (int)running_producer->getRelativeComputeAtAxis();
-
-      if (dep_chain.size() == 1) { // last one
-        if (output_set.find(running_consumer) == output_set.end()) {
-          output_set.emplace(running_consumer);
-          ordered_outputs.emplace_back(std::pair<TensorView*, int>(
-              running_consumer, consumer_compute_at));
-        }
-      }
-    }
-  }
-
-  if (!ordered_outputs.empty())
-    for (auto it = ordered_outputs.begin(); it + 1 != ordered_outputs.end();
-         it++)
-      (*it).first->computeAt_impl(
-          (*(it + 1)).first,
-          (*(it + 1)).second); // use recorded position, not axis.
+  ComputeAt::run(this, consumer, (unsigned int)axis);
 
   return this;
 }
 
 TensorView* TensorView::split(int axis, unsigned int factor) {
+  TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim TensorView");
   if (axis < 0)
     axis += domain()->nDims();
 
@@ -411,6 +234,7 @@
 
 // Merge "axis" and "axis+1" into 1 dimension
 TensorView* TensorView::merge(int axis_o, int axis_i) {
+  TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim TensorView");
   if (axis_o < 0)
     axis_o += domain()->nDims();
 
@@ -434,19 +258,15 @@
 }
 
 TensorView* TensorView::reorder(const std::unordered_map<int, int>& old2new_) {
+  TORCH_INTERNAL_ASSERT(
+      !(nDims() == 0 && old2new_.size() > 0),
+      "Tried to reorder a 0-dim TensorView");
   domain()->reorder(old2new_);
   return this;
 }
 
-/*
- * Take reduction axes out of this domain, and create a new domain. New domain
- * will be used to create this domain. For example: TV1[I0, I1] = TV0[I0, R0,
- * R1, I1] TV0->rfactor({1}) TV0 is transformed to -> TV0[I0, R1, I1] The
- * TensorView returned is: TV2[I0, R0, I3, I1] The reduction will now beset
- * as: TV1[I0, R1, I1] = TV2[I0, R0, I3, I1] TV0[I0, I1] = TV1[I0, R1, I1]
- */
-
 TensorView* TensorView::rFactor(const std::vector<int>& axes) {
+  TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim TensorView");
   FusionGuard fg(this->fusion());
   Expr* origin_expr = this->fusion()->origin(this);
   TORCH_CHECK(
diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.cpp b/torch/csrc/jit/codegen/cuda/transform_iter.cpp
index 42ecdbe..20cf158 100644
--- a/torch/csrc/jit/codegen/cuda/transform_iter.cpp
+++ b/torch/csrc/jit/codegen/cuda/transform_iter.cpp
@@ -26,7 +26,7 @@
   // going to replay the split on
   auto it = id_map_.find(id_in);
   if (it == id_map_.end()) {
-    if (check_all_ops_run_) {
+    if (error_on_failure_) {
       TORCH_INTERNAL_ASSERT(
           false, "Transform traversal failed, dependencies not met.");
     } else {
@@ -67,18 +67,42 @@
   // we're going to replay the merge on
   auto it_outer = id_map_.find(id_outer);
   auto it_inner = id_map_.find(id_inner);
-  if (it_outer == id_map_.end() || it_inner == id_map_.end()) {
-    if (check_all_ops_run_) {
-      TORCH_INTERNAL_ASSERT(
-          false, "Transform traversal failed, dependencies not met.");
-    } else {
-      return;
+
+  const bool outer_found = it_outer != id_map_.end();
+  const bool outer_bcast = id_outer->isBroadcast();
+  const bool inner_found = it_inner != id_map_.end();
+  const bool inner_bcast = id_inner->isBroadcast();
+
+  // If either are not found
+  if (!outer_found || !inner_found) {
+    // If both aren't found, it's a failure
+    // If outer is found && inner is bcast it is not a failure
+    // If inner is found && outer is bcast it is not a failure
+    if (!(outer_found || inner_found) || (outer_found && !inner_bcast) ||
+        (inner_found && !outer_bcast)) {
+      if (error_on_failure_) {
+        TORCH_INTERNAL_ASSERT(
+            false, "Transform traversal failed, dependencies not met.");
+      } else {
+        return;
+      }
     }
   }
 
+  // If we merge a broadcast dim with a non-broadcast dim, just remap the output
+  // to the non-broadcast dim.
+  if (inner_found && !outer_found && outer_bcast) {
+    id_map_[m->out()] = it_inner->second;
+    return;
+  }
+  if (outer_found && !inner_found && inner_bcast) {
+    id_map_[m->out()] = it_outer->second;
+    return;
+  }
+
   // Grab the IDs we're going to replay this merge on
-  auto id_outer_mapped = (*it_outer).second;
-  auto id_inner_mapped = (*it_inner).second;
+  const auto id_outer_mapped = it_outer->second;
+  const auto id_inner_mapped = it_inner->second;
 
   // Make sure these IDs are leaf IDs (meaning they have no uses we generated)
   TORCH_INTERNAL_ASSERT(
@@ -107,15 +131,15 @@
 ReplayTransformations::ReplayTransformations(
     const std::vector<IterDomain*>& _target_domain,
     std::unordered_map<IterDomain*, IterDomain*> _id_map,
-    bool _check_all_ops_run)
+    bool _error_on_failure)
     : target_domain_(_target_domain),
       id_map_(std::move(_id_map)),
-      check_all_ops_run_(_check_all_ops_run) {
+      error_on_failure_(_error_on_failure) {
   // Make sure id_map has all the inputs needed to replay target_domain
   auto inps = IterVisitor::getInputsTo(
       std::vector<Val*>(target_domain_.begin(), target_domain_.end()));
 
-  if (check_all_ops_run_)
+  if (error_on_failure_)
     std::for_each(inps.begin(), inps.end(), [this](Val* val) {
       TORCH_INTERNAL_ASSERT(
           val->getValType().value() == ValType::IterDomain,
@@ -149,7 +173,7 @@
       target_domain_.begin(), target_domain_.end());
   traverseFrom(traversal_vals[0]->fusion(), traversal_vals);
 
-  if (check_all_ops_run_)
+  if (error_on_failure_)
     TORCH_INTERNAL_ASSERT(
         leaf_ids_.size() >= target_domain_.size(),
         "Transform traversal failed, did not find enough output IterDomains.");
@@ -158,7 +182,7 @@
   for (auto out : target_domain_) {
     auto it_replayed = id_map_.find(out);
     if (it_replayed == id_map_.end()) {
-      if (check_all_ops_run_) {
+      if (error_on_failure_) {
         TORCH_INTERNAL_ASSERT(
             false,
             "Transform traversal failed, could not find expected output.");
@@ -170,7 +194,9 @@
     auto it_leaf = leaf_ids_.find(id_replayed);
     TORCH_INTERNAL_ASSERT(
         it_leaf != leaf_ids_.end(),
-        "Transform Traversal failed, expected matched output to be a leaf of the replay, but was not.");
+        "Transform Traversal failed, expected a replayed dim for ",
+        out,
+        " but one was not created.");
   }
 
   // Populate leaf_vec_ in a deterministic manner. This is deterministic
@@ -275,7 +301,7 @@
       continue;
     }
 
-    if (t_expr->nOutputs() != r_expr->nOutputs()) {
+    if (t_expr->outputs().size() != r_expr->outputs().size()) {
       TORCH_INTERNAL_ASSERT(!has_rfactor, err_str);
       continue;
     }
@@ -318,7 +344,7 @@
     }
 
     // Add outputs to map.
-    for (size_t i = 0; i < t_expr->nOutputs(); i++) {
+    for (size_t i = 0; i < t_expr->outputs().size(); i++) {
       auto t_out = t_expr->output(i);
       auto r_out = r_expr->output(i);
       if (t_out->getValType() == ValType::IterDomain &&
diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.h b/torch/csrc/jit/codegen/cuda/transform_iter.h
index 645d702..d1aaddd 100644
--- a/torch/csrc/jit/codegen/cuda/transform_iter.h
+++ b/torch/csrc/jit/codegen/cuda/transform_iter.h
@@ -24,7 +24,7 @@
 };
 
 // Simply grabs all exprs needed to produce provided outputs.
-struct Exprs : public IterVisitor {
+class Exprs : public IterVisitor {
  private:
   std::vector<Expr*> exprs;
   void handle(Expr* e) override {
@@ -44,14 +44,27 @@
 
 } // namespace
 
-struct TORCH_CUDA_API ReplayTransformations : public IterVisitor {
+// Uses the history of _target_domain, and replays that history using the
+// provided map.
+//
+// target_domain contains the history we want replayed.
+//
+// id_map maps IterDomains in that history to the IterDomains we want it
+// replayed on.
+//
+// error_on_failure = true will cause the replay to error if we can't replay any
+// operation in target_domain's history due to missing IDs in the id_map.
+//
+// If error_on_failure = false, replay will replay everything it can, and ignore
+// operations it can't.
+class TORCH_CUDA_API ReplayTransformations : public IterVisitor {
  protected:
   const std::vector<IterDomain*>& target_domain_;
   std::unordered_map<IterDomain*, IterDomain*> id_map_;
   std::unordered_map<IterDomain*, size_t> leaf_ids_;
   std::vector<IterDomain*> leaf_vec_;
   size_t counter = 0;
-  bool check_all_ops_run_ = true;
+  bool error_on_failure_ = true;
   bool ran_replay = false; // Mark if replay has been run
   using IterVisitor::handle;
 
@@ -60,23 +73,16 @@
 
   // TODO: HANDLE RFACTOR DOMAINS
   // We're going to replay this split operation on the corresponding ID
-  virtual void handle(Split* s) override;
+  void handle(Split* s) override;
 
   // We're going to replay this merge operation on the corresponding IDs
-  virtual void handle(Merge* m) override;
+  void handle(Merge* m) override;
 
  public:
-  // Uses the history of _target_domain, and replays that history using the
-  // provided map target_domain contains the history we want replayed, and
-  // id_map maps IterDomains in that history to the IterDomains we want it
-  // replayed on. check_all_ops_run will cause the replay to error if we can't
-  // play any operation in target_domain's history because the IDs are not in
-  // the id_map. If check_all_ops_run = false, replay will replay everything it
-  // can, and ignore operations it can't.
   ReplayTransformations(
       const std::vector<IterDomain*>& _target_domain,
       std::unordered_map<IterDomain*, IterDomain*> _id_map,
-      bool _check_all_ops_run = true);
+      bool _error_on_failure = true);
 
   // Replays outputs that were generated from ids.first on ids.second
   void runReplay();
@@ -155,7 +161,7 @@
  * to the output of the equivlent expr's outputs in relpay_domain's history.
  */
 
-struct TORCH_CUDA_API BestEffortReplay {
+class TORCH_CUDA_API BestEffortReplay {
  private:
   std::unordered_map<IterDomain*, IterDomain*> id_map_;
   std::unordered_map<IterDomain*, size_t> leaf_ids_;
diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp
index 2da65bd..ccfa63e 100644
--- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp
+++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp
@@ -15,7 +15,7 @@
 
 namespace {
 
-struct ReplaySelf : public ReplayTransformations {
+class ReplaySelf : public ReplayTransformations {
  private:
   // Took a good bit of this from ReplayTransformations::handle(Split...)
   void handle(Split* s) override {
@@ -129,8 +129,8 @@
 
 // Self replay.
 TensorDomain* TransformReplay::fullSelfReplay(
-    TensorDomain* new_self_root,
-    TensorDomain* self) {
+    const TensorDomain* new_self_root,
+    const TensorDomain* self) {
   TORCH_INTERNAL_ASSERT(
       new_self_root->nDims() == self->rootDomain().size(),
       "Invalid number of IterDomains provided.");
@@ -174,15 +174,14 @@
   return new TensorDomain(new_self_root->domain(), new_domain);
 }
 
-// Replay producer as consumer.
 // Producer could have rfactor axes which consumer may want replayed. We can
 // "replay" them as long as it doesn't modify the root rfactor axes. What we
 // really want to do is validate if we replayed these axes to the ones they
 // mapped to in the consumer the operations would all be the same. then we want
 // to start the replay of the producer from the rfactor root axes, not the root.
-TensorDomain* TransformReplay::replayPasC(
-    TensorDomain* producer,
-    TensorDomain* consumer,
+std::pair<TensorDomain*, unsigned int> TransformReplay::replayPasC(
+    const TensorDomain* producer,
+    const TensorDomain* consumer,
     int consumer_compute_at_axis) {
   if (consumer_compute_at_axis < 0)
     consumer_compute_at_axis += (int)consumer->nDims() + 1;
@@ -192,17 +191,9 @@
       "Invalid axis in transform replayPasC.");
 
   // consumer ids we need to match in producer
-  std::vector<IterDomain*> consumer_CA_ids;
-  {
-    int itc = 0;
-    while (itc < consumer_compute_at_axis) {
-      if (consumer->axis(itc)->isBroadcast()) {
-        itc++;
-      } else {
-        consumer_CA_ids.emplace_back(consumer->axis(itc++));
-      }
-    }
-  }
+  std::vector<IterDomain*> consumer_CA_ids(
+      consumer->domain().begin(),
+      consumer->domain().begin() + consumer_compute_at_axis);
 
   // Figure out all inputs required to generate the compute_at dimensions
   std::unordered_set<Val*> consumer_CA_root_ids = IterVisitor::getInputsTo(
@@ -228,7 +219,8 @@
   {
     size_t itc = 0, itp = 0;
     while (itc < consumer_root.size() || itp < producer_root.size()) {
-      if (itc < consumer_root.size() && consumer_root[itc]->isBroadcast()) {
+      if (itc < consumer_root.size() && consumer_root[itc]->isBroadcast() &&
+          (itp >= producer_root.size() || !producer_root[itp]->isBroadcast())) {
         itc++;
         continue;
       }
@@ -273,11 +265,14 @@
   // rest
   for (auto c_id : consumer_CA_ids) {
     auto it = replay_PasC.getReplay().find(c_id);
-    TORCH_INTERNAL_ASSERT(
-        it != replay_PasC.getReplay().end(),
-        "Could not find axis, ",
-        c_id,
-        ", requested in replay.");
+    if (it == replay_PasC.getReplay().end()) {
+      TORCH_INTERNAL_ASSERT(
+          c_id->isBroadcast(),
+          "Could not find axis, ",
+          c_id,
+          ", requested in replay.");
+      continue;
+    }
     if (leaf_ids.find(it->second) != leaf_ids.end())
       leaf_ids.erase(it->second);
   }
@@ -328,15 +323,19 @@
   // Add axes in (1)
   for (auto c_id : consumer_CA_ids) {
     auto it = replay_PasC.getReplay().find(c_id);
-    TORCH_INTERNAL_ASSERT(
-        it != replay_PasC.getReplay().end(),
-        "Could not find axis, ",
-        c_id,
-        ", requested in replay.");
+    if (it == replay_PasC.getReplay().end()) {
+      TORCH_INTERNAL_ASSERT(
+          c_id->isBroadcast(),
+          "Could not find axis, ",
+          c_id,
+          ", requested in replay.");
+      continue;
+    }
     new_IDs.push_back(it->second);
     used_IDs.emplace(it->second);
   }
 
+  unsigned int producer_compute_at_axis = new_IDs.size();
   // Add axes in (2)
   std::unordered_set<IterDomain*> consumer_CA_ids_set(
       consumer_CA_ids.begin(), consumer_CA_ids.end());
@@ -369,16 +368,16 @@
 
   TensorDomain* replayed = new TensorDomain(
       producer->rootDomain(), producer->rfactorDomain(), new_IDs);
-  return replayed;
+  return {replayed, producer_compute_at_axis};
 }
 
-// Replay consumer as producer.
-TensorDomain* TransformReplay::replayCasP(
-    TensorDomain* consumer,
-    TensorDomain* producer,
+std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
+    const TensorDomain* consumer,
+    const TensorDomain* producer,
     int producer_compute_at_axis) {
   if (producer_compute_at_axis < 0)
     producer_compute_at_axis += (int)producer->nDims() + 1;
+
   TORCH_INTERNAL_ASSERT(
       producer_compute_at_axis >= 0 &&
           (unsigned int)producer_compute_at_axis <= producer->nDims(),
@@ -397,21 +396,31 @@
     }
   }
 
-  // Figure out all inputs required to generate the compute_at dimensions
-  std::unordered_set<Val*> producer_CA_root_ids = IterVisitor::getInputsTo(
-      std::vector<Val*>(producer_CA_ids.begin(), producer_CA_ids.end()));
-
   // Map of producer_CA_root_ids to related producer_CA_ids
   id_map replay_root_map;
 
   // Grab root domains of producer and consumer
   std::vector<IterDomain*> consumer_root = consumer->rootDomain();
   std::vector<IterDomain*> producer_root = producer->rootDomain();
+
   // If producer has an rfactor root, that's the one that will match the
   // consumer
   if (producer->hasRFactor())
     producer_root = producer->rfactorDomain();
 
+  // Figure out all inputs required to generate the compute_at dimensions
+  std::unordered_set<Val*> all_CA_id_deps = DependencyCheck::getAllValsBetween(
+      std::unordered_set<Val*>(
+          producer->rootDomain().begin(), producer->rootDomain().end()),
+      std::vector<Val*>(producer_CA_ids.begin(), producer_CA_ids.end()));
+
+  // Figure out which root IDs we need:
+  std::unordered_set<Val*> producer_CA_root_ids;
+  for (Val* val : producer_root) {
+    if (all_CA_id_deps.find(val) != all_CA_id_deps.end())
+      producer_CA_root_ids.emplace(val);
+  }
+
   // Track which root axes in consumer we send to replay
   std::unordered_set<IterDomain*> consumer_roots4replay;
   // Map related axes from producer and consumer roots. Make sure we go to the
@@ -419,7 +428,8 @@
   {
     size_t itc = 0, itp = 0;
     while (itc < consumer_root.size() || itp < producer_root.size()) {
-      if (itc < consumer_root.size() && consumer_root[itc]->isBroadcast()) {
+      if (itc < consumer_root.size() && consumer_root[itc]->isBroadcast() &&
+          (itp > producer_root.size() || !producer_root[itp]->isBroadcast())) {
         itc++;
         continue;
       }
@@ -561,11 +571,11 @@
   TensorDomain* replayed = new TensorDomain(
       consumer->rootDomain(), consumer->rfactorDomain(), new_IDs);
 
-  return replayed;
+  return {replayed, producer_CA_ids.size()};
 }
 
 // replay Producer as Consumer
-TensorView* TransformReplay::replayPasC(
+std::pair<TensorView*, unsigned int> TransformReplay::replayPasC(
     TensorView* producer,
     TensorView* consumer,
     int compute_at_axis) {
@@ -573,26 +583,26 @@
 
   // tensor view. When this happens, just return thet target view.
   if (producer == consumer)
-    return producer;
+    return {producer, 0};
 
-  TensorDomain* td =
+  std::pair<TensorDomain*, unsigned int> replay =
       replayPasC(producer->domain(), consumer->domain(), compute_at_axis);
-  producer->setDomain(td);
-  return producer;
+  producer->setDomain(replay.first);
+  return {producer, replay.second};
 }
 
-TensorView* TransformReplay::replayCasP(
+std::pair<TensorView*, unsigned int> TransformReplay::replayCasP(
     TensorView* consumer,
     TensorView* producer,
     int compute_at_axis) {
   // If this is a reduction operation, we may call transform_replay on the same
   // tensor view. When this happens, just return thet target view.
   if (consumer == producer)
-    return consumer;
-  TensorDomain* td =
+    return {consumer, 0};
+  std::pair<TensorDomain*, unsigned int> replay =
       replayCasP(consumer->domain(), producer->domain(), compute_at_axis);
-  consumer->setDomain(td);
-  return consumer;
+  consumer->setDomain(replay.first);
+  return {consumer, replay.second};
 }
 
 } // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h
index 65118b7..d43d71f 100644
--- a/torch/csrc/jit/codegen/cuda/transform_replay.h
+++ b/torch/csrc/jit/codegen/cuda/transform_replay.h
@@ -116,40 +116,39 @@
  *
  */
 
-struct TensorDomain;
-struct TensorView;
+class TensorDomain;
+class TensorView;
 
-struct TORCH_CUDA_API TransformReplay {
- private:
+class TORCH_CUDA_API TransformReplay {
  public:
-  // Replay producer as consumer.
-  static TensorDomain* replayPasC(
-      TensorDomain* producer,
-      TensorDomain* consumer,
+  // Replay producer as consumer, returns {producer, producer_compute_at_axis}.
+  static std::pair<TensorDomain*, unsigned int> replayPasC(
+      const TensorDomain* producer,
+      const TensorDomain* consumer,
       int consumer_compute_at_axis);
 
-  // Replay producer as consumer.
-  static TensorView* replayPasC(
+  // Replay producer as consumer, returns {producer, producer_compute_at_axis}.
+  static std::pair<TensorView*, unsigned int> replayPasC(
       TensorView* producer,
       TensorView* consumer,
       int consumer_compute_at_axis);
 
-  // Replay producer as consumer.
-  static TensorDomain* replayCasP(
-      TensorDomain* consumer,
-      TensorDomain* producer,
+  // Replay producer as consumer, returns {consumer, consumer_compute_at_axis}.
+  static std::pair<TensorDomain*, unsigned int> replayCasP(
+      const TensorDomain* consumer,
+      const TensorDomain* producer,
       int producer_compute_at_axis);
 
-  // Replay producer as consumer.
-  static TensorView* replayCasP(
+  // Replay producer as consumer, returns {consumer, consumer_compute_at_axis}.
+  static std::pair<TensorView*, unsigned int> replayCasP(
       TensorView* consumer,
       TensorView* producer,
       int producer_compute_at_axis);
 
   // Self replay.
   static TensorDomain* fullSelfReplay(
-      TensorDomain* new_self_root,
-      TensorDomain* self);
+      const TensorDomain* new_self_root,
+      const TensorDomain* self);
 };
 
 } // namespace fuser
diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp
index d4152bc..ac10289 100644
--- a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp
+++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp
@@ -11,7 +11,7 @@
 
 namespace {
 
-struct ReplayRFactor : public ReplayTransformations {
+class ReplayRFactor : public ReplayTransformations {
  private:
   // Took a good bit of this from ReplayTransformations::handle(Split...)
   void handle(Split* s) override {
@@ -258,7 +258,7 @@
             id->extent(),
             id->parallel_method(),
             false,
-            true,
+            false,
             false);
       } else {
         new_root[i] = id->clone();
diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.h b/torch/csrc/jit/codegen/cuda/transform_rfactor.h
index 9fc2e15..6d0977f 100644
--- a/torch/csrc/jit/codegen/cuda/transform_rfactor.h
+++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.h
@@ -14,7 +14,7 @@
 
 // TODO: Only replay dispatch is really borrowed from TransformIter, we should
 // reevaluate the reuse of dispatch for classes that inherit TransformIter.
-struct TORCH_CUDA_API TransformRFactor {
+class TORCH_CUDA_API TransformRFactor {
  public:
   // Create a copy of td, change its history by presrving axes so they appear in
   // the root domain
diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp
index dbdcb33..fee6910 100644
--- a/torch/csrc/jit/codegen/cuda/type.cpp
+++ b/torch/csrc/jit/codegen/cuda/type.cpp
@@ -444,7 +444,7 @@
 TORCH_CUDA_API std::ostream& operator<<(
     std::ostream& out,
     const ParallelType ptype) {
-  return out << parallel_type2string(ptype);
+  return out << stringifyThread(ptype);
 }
 
 TORCH_CUDA_API std::ostream& operator<<(
@@ -471,6 +471,10 @@
   return thread_size2string(ptype);
 }
 
+std::string stringifyThread(const ParallelType ptype) {
+  return parallel_type2string(ptype);
+}
+
 TORCH_CUDA_API c10::optional<std::string> cast_func_str(
     const std::pair<DataType, DataType>& cast) {
   const char* str = supported_casts2string(cast);
@@ -478,6 +482,21 @@
                         : c10::nullopt;
 }
 
+size_t dataTypeSize(DataType type) {
+  switch (type) {
+    case DataType::Bool:
+      return sizeof(bool);
+    case DataType::Float:
+      return 4;
+    case DataType::Half:
+      return 2;
+    case DataType::Int:
+      return 4;
+    default:
+      TORCH_INTERNAL_ASSERT(false, "Size undefined for data type, ", type);
+  }
+}
+
 } // namespace fuser
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h
index 760ff58..afabfa9 100644
--- a/torch/csrc/jit/codegen/cuda/type.h
+++ b/torch/csrc/jit/codegen/cuda/type.h
@@ -26,6 +26,7 @@
 enum class DataType { Bool, Float, Half, Int, Null };
 
 enum class ExprType {
+  Invalid,
   UnaryOp,
   BinaryOp,
   TernaryOp,
@@ -93,7 +94,7 @@
   // TypeAs,
 
   // Logical Ops
-  // Int operations, leave position oif Mod we depend on its location of first
+  // Int operations, leave position of Mod we depend on its location of first
   Mod,
   CeilDiv,
   And,
@@ -136,6 +137,7 @@
 TORCH_CUDA_API std::ostream& operator<<(std::ostream&, const ParallelType);
 
 std::string stringifyThreadSize(const ParallelType);
+std::string stringifyThread(const ParallelType);
 
 TORCH_CUDA_API c10::optional<std::string> inline_op_str(const UnaryOpType);
 TORCH_CUDA_API c10::optional<std::string> inline_op_str(const BinaryOpType);
@@ -143,6 +145,8 @@
 TORCH_CUDA_API c10::optional<std::string> cast_func_str(
     const std::pair<DataType, DataType>&);
 
+size_t dataTypeSize(DataType type);
+
 } // namespace fuser
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/codegen/cuda/utils.cpp b/torch/csrc/jit/codegen/cuda/utils.cpp
index cce431e..6c49b7f 100644
--- a/torch/csrc/jit/codegen/cuda/utils.cpp
+++ b/torch/csrc/jit/codegen/cuda/utils.cpp
@@ -2,8 +2,6 @@
 
 #include <c10/util/Exception.h>
 
-#include <torch/csrc/jit/codegen/cuda/tensor_meta.h>
-
 #include <algorithm>
 #include <ostream>