WIP Jiterator reduction

This PR enables jit-compiled reductions and moves `prod` to be jit-compiled.
Currently, only reductions that can use `func_wrapper` for automatic implementation of `reduce/project/translate_idx` opes are supported, there are a few TODOs for support of more complex reductions such as norms and max, that typically require full-fledged ReduceOps functor. Similarly, only reductions with a single input are supported.
Number of inputs is hardcoded to 1, which is true for our current reductions, but can be relaxed in the future.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74446
Approved by: https://github.com/mruberry
diff --git a/aten/src/ATen/native/cuda/CUDAJitLoops.cuh b/aten/src/ATen/native/cuda/CUDAJitLoops.cuh
index b5b1cd5..61417b9 100644
--- a/aten/src/ATen/native/cuda/CUDAJitLoops.cuh
+++ b/aten/src/ATen/native/cuda/CUDAJitLoops.cuh
@@ -71,7 +71,8 @@
   std::tuple<Args...> extra_args) {
 
   TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
-  const int64_t grid = (N + block_work_size() - 1) / block_work_size();
+  //casting result to int is always safe, intermediate is int64 and won't overflow
+  const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
 
   static std::mutex _jiterator_mutex;
   static std::vector<at::cuda::jit::NvrtcFunction> fns(c10::cuda::device_count());
@@ -115,8 +116,8 @@
     args[i + 7] = extra_args_array[i];
   }
 
-  at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
-  C10_CUDA_KERNEL_LAUNCH_CHECK();
+  at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, {grid, 1, 1},
+  {num_threads(), 1, 1});
 }
 
 template<
@@ -129,7 +130,8 @@
 static inline void launch_jitted_vectorized_kernel(DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data,
 at::opmath_type<f_inputs_type> scalar_val, std::tuple<Args...> extra_args) {
   TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
-  const int64_t grid = (N + block_work_size() - 1) / block_work_size();
+  // N is still int64_t for the computation, but it's always safe to cast result to int
+  const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
   const int vec_size = memory::jitted_can_vectorize_up_to<result_type, f_inputs_type, arity>(data);
 
   // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
@@ -196,8 +198,7 @@
       args[i + 3] = extra_args_array[i];
     }
 
-    at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
-    C10_CUDA_KERNEL_LAUNCH_CHECK();
+    at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, {grid, 1, 1}, {num_threads(), 1, 1});
   } else {
     auto ic = TrivialOffsetCalculator<arity>();
     auto oc = TrivialOffsetCalculator<1>();
@@ -219,7 +220,7 @@
       // since 7 slots are already filled in `args`
       args[i + 7] = extra_args_array[i];
     }
-    at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
+    at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, {grid, 1, 1}, {num_threads(), 1, 1});
     C10_CUDA_KERNEL_LAUNCH_CHECK();
   }
 }
diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh
index 5ee3757..57fa55f 100644
--- a/aten/src/ATen/native/cuda/Reduce.cuh
+++ b/aten/src/ATen/native/cuda/Reduce.cuh
@@ -9,6 +9,7 @@
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/cuda/thread_constants.h>
 #include <ATen/native/cuda/MemoryAccess.cuh>
+#include <ATen/OpMathType.h>
 #include <c10/macros/Macros.h>
 #include <c10/cuda/CUDACachingAllocator.h>
 #include <functional>
@@ -17,6 +18,9 @@
 #include <utility>
 #include <thrust/pair.h>
 
+#include <ATen/native/cuda/jit_utils.h>
+#include <iostream>
+
 namespace at { namespace native {
 
 using at::detail::Array;
@@ -272,6 +276,65 @@
   return func_wrapper_t<scalar_t, func_t> { op };
 }
 
+template <typename scalar_t, typename out_scalar_t=scalar_t>
+struct ReduceJitOp {
+//ReduceJitOp is almost like ReduceOp, but it doesn't have ops functor that specifies reduction operations
+//Maybe we can find a way to unify ReduceOp and ReduceJitOp
+  using InputCalculator = OffsetCalculator<1, uint32_t>;
+  using OutputCalculator = OffsetCalculator<2, uint32_t>;
+  //TODO for now arg_t is always opmath_t of the input, later we'll need to change it
+  using arg_t = at::opmath_type<scalar_t>;
+
+  static constexpr int input_vec_size = ReduceConfig::input_vec_size;
+  //TODO - ReduceJitOp will probably need to be changed for reductions that need full functor,
+  //not just wrapper
+  arg_t ident;
+  ReduceConfig config;
+  InputCalculator input_calc;
+  OutputCalculator output_calc;
+  const void* src;
+  const char* dst[2]; //it accepts at most two destinations
+  // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
+  // output is not permissible
+  void* acc_buf;
+  // cta_buf used for accumulation between blocks during global reduction
+  void* cta_buf;
+  int* semaphores;
+  int64_t base_idx;
+  bool accumulate;
+  bool final_output;
+  int noutputs;
+
+  ReduceJitOp(
+      ReduceConfig config,
+      InputCalculator input_calc,
+      OutputCalculator output_calc,
+      const void* src,
+      char* dst0,
+      optional<char*> dst1,
+      void* acc_buf,
+      void* cta_buf,
+      int* semaphores,
+      arg_t ident,
+      int noutputs,
+      int64_t base_idx)
+      : ident(ident),
+        config(config),
+        input_calc(input_calc),
+        output_calc(output_calc),
+        src(src),
+        acc_buf(acc_buf),
+        cta_buf(cta_buf),
+        semaphores(semaphores),
+        base_idx(base_idx),
+        noutputs(noutputs) {
+    dst[0] = dst0;
+    if (dst1.has_value()) {
+      dst[1] = dst1.value();
+    }
+  }
+};
+
 template <typename scalar_t, typename ops_t, typename index_t, typename out_scalar_t=scalar_t, int vt0=4>
 struct ReduceOp {
   using traits = function_traits<decltype(&ops_t::reduce)>;
@@ -284,8 +347,6 @@
     std::is_convertible<arg_t, out_scalar_t>::value
     && std::is_convertible<out_scalar_t, arg_t>::value;
 
-  static constexpr float acc_buffer_multiplier = (float)sizeof(arg_t) / sizeof(out_scalar_t);
-
   static constexpr int input_vec_size = ReduceConfig::input_vec_size;
 
   ops_t ops;
@@ -837,6 +898,47 @@
   }
 }
 
+template<char const *name, typename scalar_t, typename out_scalar_t,
+int vt0, typename R>
+static void launch_jitted_reduce_kernel(DeviceIndex idx, const ReduceConfig& config,
+R& reduction, const std::string& func) {
+  constexpr int max_threads = mnt_wrapper<scalar_t>::MAX_NUM_THREADS;
+  dim3 block = config.block();
+  dim3 grid = config.grid();
+
+  static std::mutex _jiterator_mutex;
+  static std::vector<std::array<at::cuda::jit::NvrtcFunction, 3>> fns(c10::cuda::device_count());
+  int shared_memory = config.shared_memory_size();
+  at::cuda::jit::NvrtcFunction* fn_ptr;
+  switch(config.output_vec_size) {
+  case 4:
+    fn_ptr = &fns[idx][0];
+    break;
+  case 2:
+    fn_ptr = &fns[idx][1];
+    break;
+  default:
+    fn_ptr = &fns[idx][2];
+  }
+  if (!fn_ptr->function) {
+    std::string f_inputs_type_str = at::cuda::jit::typeName<scalar_t>();
+    std::string accum_type_str = at::cuda::jit::typeName<at::opmath_type<scalar_t>>();
+    std::string result_type_str = at::cuda::jit::typeName<out_scalar_t>();
+    int max_threads_codegen = max_threads/config.output_vec_size;
+    auto code = at::cuda::jit::generate_reduction_code(1, func, name, vt0,
+                                               f_inputs_type_str, accum_type_str, result_type_str,
+                                               true, false, config.output_vec_size, max_threads_codegen);
+
+    *fn_ptr = at::cuda::jit::jit_pwise_function(code, "reduction_"+std::string(name));
+
+  }
+  constexpr int kernel_args = 1;
+  void* args[kernel_args];
+  args[0] = static_cast<void*>(&reduction);
+  at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, block, shared_memory);
+}
+
+
 class AccumulationBuffer {
  public:
   AccumulationBuffer() {}
@@ -874,7 +976,7 @@
 };
 
 template <typename scalar_t>
-int get_output_vec_size(TensorIterator &iter) {
+int get_output_vec_size(const TensorIterator &iter) {
   int vec_size = 4;
   auto update_vec_size = [&vec_size](uint64_t n) {
     while(n % vec_size != 0) {
@@ -898,61 +1000,8 @@
   return vec_size;
 }
 
-template <typename scalar_t, typename out_scalar_t, int vt0=4, typename ops_t, typename ident_t=double>
-inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0,
-                              AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
-  AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1);
-
-  using traits = function_traits<decltype(&ops_t::reduce)>;
-  using arg_t = typename traits::template arg<0>::type;
-  static constexpr bool can_accumulate_in_output =
-    std::is_convertible<arg_t, out_scalar_t>::value;
-
-  bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
-  std::unique_ptr<AccumulationBuffer> owned_buf_ptr;
-
-  // The acc_buf_ptr is a shared pointer. It is create at the first entrance and
-  // reused by all recursive function calls.
-  if (acc_buf_ptr == NULL) {
-    // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
-    // when accumulation in output is not possible.
-    if (!can_accumulate_in_output && !can_use_32bit_indexing) {
-      int64_t output_memory_size = iter.element_size(0);
-      for (int dim = 0; dim < iter.ndim(); dim++) {
-        output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
-      }
-      output_memory_size /= iter.element_size(0); //iter.strides is in bytes
-      owned_buf_ptr.reset(new AccumulationBuffer(sizeof(arg_t),
-                                                 sizeof(out_scalar_t),
-                                                 (char*) iter.data_ptr(0),
-                                                 output_memory_size * sizeof(arg_t)));
-    } else {
-      owned_buf_ptr.reset(new AccumulationBuffer());
-    }
-    acc_buf_ptr = owned_buf_ptr.get();
-  }
-
-  if (!can_use_32bit_indexing) {
-    for (auto& sub_iter : iter.with_32bit_indexing()) {
-      int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
-
-      gpu_reduce_kernel<scalar_t, out_scalar_t, vt0>(sub_iter, ops, ident,
-          acc_buf_ptr, sub_iter_base_idx);
-    }
-    return;
-  }
-
-  const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1);
-  char* out_data = (char*)iter.data_ptr(0);
-  const auto noutputs = iter.noutputs();
-  optional<char*> out_data_extra;
-  if (noutputs > 1) {
-    out_data_extra = (char*)iter.data_ptr(1);
-  } else {
-    out_data_extra = nullopt;
-  }
-  char* acc_data = acc_buf_ptr->get_acc_slice(out_data);
-
+template<typename arg_t, typename scalar_t, int vt0>
+ReduceConfig setReduceConfig(const TensorIterator& iter){
   // Start by assuming that each thread handles a single output and all
   // the inputs for that output.
   int64_t num_outputs = iter.num_output_elements();
@@ -1080,7 +1129,64 @@
       config.input_mult[2] = config.split_input(config.ctas_per_output);
     }
   }
+  return config;
+};
 
+template <typename scalar_t, typename out_scalar_t, int vt0=4, typename ops_t, typename ident_t=double>
+inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0,
+                              AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
+  AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1);
+
+  using traits = function_traits<decltype(&ops_t::reduce)>;
+  using arg_t = typename traits::template arg<0>::type;
+  static constexpr bool can_accumulate_in_output =
+    std::is_convertible<arg_t, out_scalar_t>::value;
+
+  bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
+  std::unique_ptr<AccumulationBuffer> owned_buf_ptr;
+  // The acc_buf_ptr is a shared pointer. It is create at the first entrance and
+  // reused by all recursive function calls.
+  if (acc_buf_ptr == NULL) {
+    // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
+    // when accumulation in output is not possible.
+    if (!can_accumulate_in_output && !can_use_32bit_indexing) {
+      int64_t output_memory_size = iter.element_size(0);
+      for (int dim = 0; dim < iter.ndim(); dim++) {
+        output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
+      }
+      output_memory_size /= iter.element_size(0); //iter.strides is in bytes
+      owned_buf_ptr.reset(new AccumulationBuffer(sizeof(arg_t),
+                                                 sizeof(out_scalar_t),
+                                                 (char*) iter.data_ptr(0),
+                                                 output_memory_size * sizeof(arg_t)));
+    } else {
+      owned_buf_ptr.reset(new AccumulationBuffer());
+    }
+    acc_buf_ptr = owned_buf_ptr.get();
+  }
+
+  if (!can_use_32bit_indexing) {
+    for (auto& sub_iter : iter.with_32bit_indexing()) {
+      int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
+
+      gpu_reduce_kernel<scalar_t, out_scalar_t, vt0>(sub_iter, ops, ident,
+          acc_buf_ptr, sub_iter_base_idx);
+    }
+    return;
+  }
+
+  const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1);
+  char* out_data = (char*)iter.data_ptr(0);
+  const auto noutputs = iter.noutputs();
+  optional<char*> out_data_extra;
+  if (noutputs > 1) {
+    out_data_extra = (char*)iter.data_ptr(1);
+  } else {
+    out_data_extra = nullopt;
+  }
+  char* acc_data = acc_buf_ptr->get_acc_slice(out_data);
+
+  ReduceConfig config = setReduceConfig<arg_t, scalar_t, vt0>(iter);
   at::DataPtr buffer;
   at::DataPtr semaphores;
   if (config.should_global_reduce()) {
@@ -1115,4 +1221,101 @@
   launch_reduce_kernel<mnt_wrapper<scalar_t>::MAX_NUM_THREADS>(config, reduce);
 }
 
+//TODO this is 100 lines of almost-copy-paste, because we have to have different template args for this function
+//try unifying with gpu_reduce_kernel
+template <char const* name, typename scalar_t, typename out_scalar_t, int vt0=4, typename ident_t=double>
+inline void jitted_gpu_reduce_kernel(TensorIterator& iter, const std::string& func, ident_t ident=0,
+                              AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
+  AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1);
+
+  //TODO - this will be different for more complicated reductions, but for now reductions using
+  //func_wrapper all have arg_t = opmath
+  using arg_t = at::opmath_type<scalar_t>;
+  static constexpr bool can_accumulate_in_output =
+    std::is_convertible<arg_t, out_scalar_t>::value;
+  static_assert(can_accumulate_in_output == true, "unsupported arg_t for jitted reduction");
+
+  bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
+  std::unique_ptr<AccumulationBuffer> owned_buf_ptr;
+
+  // The acc_buf_ptr is a shared pointer. It is create at the first entrance and
+  // reused by all recursive function calls.
+  if (acc_buf_ptr == NULL) {
+    // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
+    // when accumulation in output is not possible.
+    if (!can_accumulate_in_output && !can_use_32bit_indexing) {
+      int64_t output_memory_size = iter.element_size(0);
+      for (int dim = 0; dim < iter.ndim(); dim++) {
+        output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
+      }
+      output_memory_size /= iter.element_size(0); //iter.strides is in bytes
+      owned_buf_ptr.reset(new AccumulationBuffer(sizeof(out_scalar_t), //TODO
+                                                 sizeof(out_scalar_t),
+                                                 (char*) iter.data_ptr(0),
+                                                 output_memory_size * sizeof(out_scalar_t))); //TODO
+    } else {
+      owned_buf_ptr.reset(new AccumulationBuffer());
+    }
+    acc_buf_ptr = owned_buf_ptr.get();
+  }
+
+  if (!can_use_32bit_indexing) {
+    for (auto& sub_iter : iter.with_32bit_indexing()) {
+      int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
+
+      jitted_gpu_reduce_kernel<name, scalar_t, out_scalar_t, vt0>(sub_iter, func, ident,
+          acc_buf_ptr, sub_iter_base_idx);
+    }
+    return;
+  }
+
+  //TODO - for now we support a single input, we may be able to relax this constraint
+  const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1);
+  char* out_data = (char*)iter.data_ptr(0);
+  const auto noutputs = iter.noutputs();
+  optional<char*> out_data_extra;
+  if (noutputs > 1) {
+    out_data_extra = (char*)iter.data_ptr(1);
+  } else {
+    out_data_extra = nullopt;
+  }
+  char* acc_data = acc_buf_ptr->get_acc_slice(out_data);
+
+  ReduceConfig config = setReduceConfig<arg_t, scalar_t, vt0>(iter);
+
+  at::DataPtr buffer;
+  at::DataPtr semaphores;
+  if (config.should_global_reduce()) {
+    auto& allocator = *c10::cuda::CUDACachingAllocator::get();
+    buffer = allocator.allocate(config.global_memory_size());
+    semaphores = allocator.allocate(config.semaphore_size());
+
+    auto stream = at::cuda::getCurrentCUDAStream();
+    AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
+  }
+
+  AT_ASSERT(can_use_32bit_indexing);
+  auto output_calc = make_output_calculator<uint32_t>(iter);
+  auto input_calc = make_input_calculator<uint32_t>(iter);
+  auto reduce = ReduceJitOp<scalar_t, out_scalar_t>(
+      config,
+      input_calc,
+      output_calc,
+      in_data,
+      out_data,
+      out_data_extra,
+      acc_data,
+      buffer.get(),
+      (int*)semaphores.get(),
+      ident,
+      noutputs,
+      base_idx);
+  reduce.accumulate = iter.should_accumulate();
+  reduce.final_output = iter.is_final_output();
+
+  launch_jitted_reduce_kernel<name, scalar_t,
+  out_scalar_t, vt0>(iter.device().index(),
+  config, reduce, func);
+}
+
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu b/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu
index bf81ed5..9faeae9 100644
--- a/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu
+++ b/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu
@@ -5,6 +5,7 @@
 #include <ATen/native/SharedReduceOps.h>
 #include <ATen/Dispatch.h>
 #include <ATen/native/ReduceOps.h>
+#include <ATen/jit_macros.h>
 
 namespace at { namespace native {
 
@@ -26,14 +27,28 @@
   }
 };
 
+const char op_name[] = "prod";
+
 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
 struct prod_functor {
+  #if AT_USE_JITERATOR()
+  void operator()(TensorIterator& iter) {
+    std::string func = jiterator_stringify(
+    arg_t combine(arg_t a, arg_t b) {
+      return a * b;
+    }
+    );
+    jitted_gpu_reduce_kernel<op_name, scalar_t, out_t>(
+        iter, func, 1.);
+  }
+  #else
   void operator()(TensorIterator& iter) {
     gpu_reduce_kernel<scalar_t, out_t>(
         iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
           return a * b;
-        }), 1);
+        }), 1.);
   }
+  #endif
 };
 
 // Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context]
diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp
index d88a39e..14d1a92 100644
--- a/aten/src/ATen/native/cuda/jit_utils.cpp
+++ b/aten/src/ATen/native/cuda/jit_utils.cpp
@@ -10,6 +10,7 @@
 #include <ATen/code_template.h>
 #include <ATen/native/cuda/jit_utils.h>
 #include <ATen/cuda/llvm_jit_strings.h>
+#include <ATen/native/cuda/reduction_template.cuh>
 
 #include <sstream>
 #include <fstream>
@@ -118,6 +119,11 @@
   Array() = default;
   Array(const Array&) = default;
   Array& operator=(const Array&) = default;
+  __device__ Array(T x) {
+    for (int i = 0; i < size; i++) {
+      data[i] = x;
+    }
+  }
   };
 
   ${half_string}
@@ -322,10 +328,7 @@
 
 )ESCAPE";
 
-const std::string jit_code_template = R"ESCAPE(
-
-  ${dynamic_casting_string}
-
+const std::string offset_calc_template = R"ESCAPE(
   template <typename T>
   struct DivMod {
   T div;
@@ -409,6 +412,14 @@
     ${index_type} strides_[25][NARGS];
   };
 
+
+)ESCAPE";
+
+const std::string jit_code_template = R"ESCAPE(
+
+  ${dynamic_casting_string}
+
+
   ${functor}
 
   // TODO: setup grid-stride loop
@@ -769,7 +780,7 @@
                   << ">(out[j], data[0], output_offsets[0]);\n";
     env.s("store_outputs", store_outputs.str());
 
-    static auto cuda_template = at::jit::CodeTemplate(jit_common_types + jit_code_template);
+    static auto cuda_template = at::jit::CodeTemplate(jit_common_types + offset_calc_template + jit_code_template);
     const auto code = cuda_template.format(env);
     return code;
   }
@@ -865,8 +876,69 @@
   }
 
   return _r_mkdir(base+dir);
+
 }
 
+std::string load_code_template(const std::string& path) {
+  std::ifstream ifs{path};
+  std::string s{
+    std::istreambuf_iterator<char>(ifs),
+    std::istreambuf_iterator<char>()};
+  return s;
+}
+
+std::string generate_reduction_code(
+    int nOutputs,
+    const std::string& func,
+    const std::string& name,
+    const int vt0,
+    const std::string& f_inputs_type,
+    const std::string& reduction_accum_type,
+    const std::string& result_type,
+    bool contiguous,
+    bool vectorized,
+    int vec_size,
+    int max_threads_codegen) {
+      at::jit::TemplateEnv env;
+      env.s("index_type", "unsigned int");
+      env.s("scalar_type", f_inputs_type);
+      env.s("result_type", result_type);
+      env.s("reduction_accum_type", reduction_accum_type);
+      env.s("vt0", std::to_string(vt0));
+      env.s("name", name);
+      env.s("max_threads_lb", std::to_string(max_threads_codegen));
+      // reductions don't support dynamic casting, so the only way to get nonstandard types
+      // is through input
+      if (f_inputs_type == "at::Half") {
+        env.s("half_string", jiterator_half_support_literal);
+      } else {
+        env.s("half_string", "");
+      }
+      if (f_inputs_type == "at::BFloat16") {
+        env.s("bfloat16_string", jiterator_bfloat16_support_literal);
+      } else {
+        env.s("bfloat16_string", "");
+      }
+      if (f_inputs_type == "std::complex<float>" ||
+          f_inputs_type == "std::complex<double>" ) {
+        env.s("traits_string", get_traits_string());
+        env.s("complex_body_string", get_complex_body_string());
+        env.s("complex_math_string", get_complex_math_string());
+        env.s("complex", std::to_string(1));
+      } else {
+        env.s("traits_string", "");
+        env.s("complex_body_string", "");
+        env.s("complex_math_string", "");
+        env.s("complex", std::to_string(0));
+      }
+      env.s("cmath_string", get_cmath_string());
+      env.s("functor", func);
+      env.s("output_vec_size", std::to_string(vec_size));
+      static auto cuda_template = at::jit::CodeTemplate(
+        jit_common_types + offset_calc_template + get_reduction_template());
+      const auto code = cuda_template.format(env);
+      return code;
+}
 
 // Acquires (possibly creating) the kernel cache directory
 c10::optional<std::string> get_cache_dir() {
@@ -946,9 +1018,7 @@
 NvrtcFunction jit_pwise_function(
     const std::string& code,
     const std::string& kernel_name) {
-
   initializeCudaContext();
-
   // Acquires CUDA and nvrtc versions and whether we're compiling to ptx or SASS
   const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
   int cuda_major = 0, cuda_minor = 0, nvrtc_major = 0, nvrtc_minor = 0;
@@ -1043,7 +1113,7 @@
     AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcGetProgramLog(program, log.data()));
     std::stringstream cu;
     cu << log.data();
-    throw std::runtime_error(cu.str() + code);
+    throw std::runtime_error(code + cu.str());
   }
 
   size_t ptx_size = 0;
@@ -1109,24 +1179,26 @@
 void launch_jitted_pwise_function(
     NvrtcFunction function,
     void* args[],
-    const int nBlocks,
-    const int kBlockSize) {
+    const dim3 nBlocks,
+    const dim3 kBlockSize,
+    const int smem) {
   initializeCudaContext();
   const auto& nvrtc = at::globalContext().getNVRTC();
   // Launches kernel on current stream
   auto stream = at::cuda::getCurrentCUDAStream();
   AT_CUDA_DRIVER_CHECK(nvrtc.cuLaunchKernel(
     function.function,
-    nBlocks,
-    1,
-    1,
-    kBlockSize,
-    1,
-    1,
-    0,
+    nBlocks.x,
+    nBlocks.y,
+    nBlocks.z,
+    kBlockSize.x,
+    kBlockSize.y,
+    kBlockSize.z,
+    smem,
     stream,
     args,
     nullptr));
 }
 
+
 }}} // at::cuda::jit
diff --git a/aten/src/ATen/native/cuda/jit_utils.h b/aten/src/ATen/native/cuda/jit_utils.h
index 1f0f9c4..1ff6de7 100644
--- a/aten/src/ATen/native/cuda/jit_utils.h
+++ b/aten/src/ATen/native/cuda/jit_utils.h
@@ -32,6 +32,19 @@
     bool vectorized=false,
     int vec_size=0);
 
+std::string generate_reduction_code(
+    int nOutputs,
+    const std::string& func,
+    const std::string& name,
+    const int vt0,
+    const std::string& f_inputs_type,
+    const std::string& reduction_accum_type,
+    const std::string& result_type,
+    bool contiguous,
+    bool vectorized,
+    int vec_size,
+    int max_threads_codegen);
+
 NvrtcFunction jit_pwise_function(
     const std::string& code,
     const std::string& kernel_name);
@@ -39,8 +52,9 @@
 void launch_jitted_pwise_function(
     NvrtcFunction function,
     void* args[],
-    const int nBlocks,
-    const int kBlockSize);
+    const dim3 nBlocks,
+    const dim3 kBlockSize,
+    const int smem=0);
 
 template <typename T>
 struct delayed_false : std::false_type {
diff --git a/aten/src/ATen/native/cuda/reduction_template.cuh b/aten/src/ATen/native/cuda/reduction_template.cuh
new file mode 100644
index 0000000..4d9d559
--- /dev/null
+++ b/aten/src/ATen/native/cuda/reduction_template.cuh
@@ -0,0 +1,664 @@
+namespace at {
+namespace cuda {
+//windows doesn't like large string literals, so split in two
+const std::string reduction_template_0 = R"ESCAPE(
+  #define C10_HOST_DEVICE __host__ __device__
+  #define C10_DEVICE __device__
+
+  template <typename T>
+  __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
+  {
+    return __shfl_down_sync(mask, value, delta, width);
+  }
+
+
+  #if ${complex}
+  template <typename T>
+  __device__ __forceinline__ std::complex<T> WARP_SHFL_DOWN(std::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
+  {
+    return std::complex<T>(
+        __shfl_down_sync(mask, value.real(), delta, width),
+        __shfl_down_sync(mask, value.imag(), delta, width));
+  }
+  #endif
+
+  // aligned vector generates vectorized load/store on CUDA
+  template<typename scalar_t, int vec_size>
+  struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
+    scalar_t val[vec_size];
+  };
+
+
+  C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) {
+    // get GCD of num and denom using Euclid's algorithm.
+    // Can replace this with std::gcd if we ever support c++17.
+    size_t a = denominator;
+    size_t b = numerator;
+    while (b != 0) {
+        a %= b;
+        // swap(a,b)
+        size_t tmp = a;
+        a = b;
+        b = tmp;
+    }
+
+    // a is now the GCD
+    numerator /= a;
+    denominator /= a;
+  }
+
+
+
+
+  struct ReduceConfig {
+  //has to match host-side ReduceConfig in the eager code
+  static constexpr int BLOCK_X = 0;
+  static constexpr int BLOCK_Y = 1;
+  static constexpr int CTA = 2;
+
+  static constexpr int input_vec_size = 4;
+  int element_size_bytes;
+  int num_inputs;
+  int num_outputs;
+  int step_input = 1;
+  int step_output = 1;
+  int ctas_per_output = 1;
+  int input_mult[3] = {0, 0, 0};
+  int output_mult[2] = {0, 0};
+
+  int block_width;
+  int block_height;
+  int num_threads;
+
+  bool vectorize_input = false;
+  int output_vec_size = 1;
+
+  C10_HOST_DEVICE bool should_block_x_reduce() const {
+    return input_mult[BLOCK_X] != 0;
+  }
+
+  C10_HOST_DEVICE bool should_block_y_reduce() const {
+    return input_mult[BLOCK_Y] != 0;
+  }
+
+  C10_HOST_DEVICE bool should_global_reduce() const {
+    return input_mult[CTA] != 0;
+  }
+
+  C10_DEVICE bool should_store(int output_idx) const {
+    return output_idx < num_outputs &&
+      (!should_block_x_reduce() || threadIdx.x == 0) &&
+      (!should_block_y_reduce() || threadIdx.y == 0);
+  }
+
+  C10_DEVICE bool should_reduce_tail() const {
+    return (!should_block_y_reduce() || threadIdx.y == 0) &&
+      (!should_global_reduce() || blockIdx.y == 0);
+  }
+
+  C10_HOST_DEVICE int input_idx() const {
+    int lane = threadIdx.x;
+    int warp = threadIdx.y;
+    int cta2 = blockIdx.y;
+    return (lane * input_mult[BLOCK_X] +
+            warp * input_mult[BLOCK_Y] +
+            cta2 * input_mult[CTA]);
+  }
+
+  template <int output_vec_size>
+  C10_HOST_DEVICE int output_idx() const {
+    int lane = threadIdx.x;
+    int warp = threadIdx.y;
+    int cta1 = blockIdx.x;
+    return (lane * output_mult[BLOCK_X] +
+            warp * output_mult[BLOCK_Y] +
+            cta1 * step_output) * output_vec_size;
+  }
+
+  C10_DEVICE int shared_memory_offset(int offset) const {
+    return threadIdx.x + (threadIdx.y + offset) * blockDim.x;
+  }
+
+  C10_DEVICE int staging_memory_offset(int cta2) const {
+    int offset = cta2 + blockIdx.x * gridDim.y;
+    if (!should_block_x_reduce()) {
+      offset = threadIdx.x + offset * blockDim.x;
+    }
+    return offset;
+  }
+
+
+  };
+
+
+//TODO this will need to be different for more generic reduction functions
+namespace reducer {
+
+  using scalar_t = ${scalar_type};
+  using arg_t = ${reduction_accum_type};
+  using out_scalar_t = ${result_type};
+
+
+  inline __device__ ${functor}
+
+  inline __device__ out_scalar_t project(arg_t arg) {
+    return (out_scalar_t) arg;
+  }
+
+  inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) {
+    return WARP_SHFL_DOWN(arg, offset);
+  }
+
+  inline __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) {
+    return acc;
+  }
+
+  // wrap a normal reduction that ignores the index
+  inline __device__ arg_t reduce(arg_t acc, arg_t val, int64_t idx) {
+     return combine(acc, val);
+  }
+}
+
+
+struct ReduceJitOp {
+  using scalar_t = ${scalar_type};
+  using arg_t = ${reduction_accum_type};
+  using out_scalar_t = ${result_type};
+
+  using InputCalculator = OffsetCalculator<1>;
+  using OutputCalculator = OffsetCalculator<2>;
+
+//   static constexpr bool can_accumulate_in_output =
+//     std::is_convertible<arg_t, out_scalar_t>::value
+//     && std::is_convertible<out_scalar_t, arg_t>::value;
+
+  static constexpr int input_vec_size = ReduceConfig::input_vec_size;
+
+  arg_t ident;
+  ReduceConfig config;
+  InputCalculator input_calc;
+  OutputCalculator output_calc;
+  const void* src;
+  const char* dst[2]; //it accepts at most two destinations
+  // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
+  // output is not permissible
+  void* acc_buf;
+  // cta_buf used for accumulation between blocks during global reduction
+  void* cta_buf;
+  int* semaphores;
+  int64_t base_idx;
+  bool accumulate;
+  bool final_output;
+  int noutputs;
+
+
+  C10_DEVICE void run() const {
+    extern __shared__ char shared_memory[];
+    uint32_t output_idx = config.output_idx<${output_vec_size}>();
+    uint32_t input_idx = config.input_idx();
+    auto base_offsets1 = output_calc.get(output_idx)[1];
+
+    using arg_vec_t = Array<arg_t, ${output_vec_size}>;
+    arg_vec_t value;
+
+    if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
+      const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1);
+
+      value = thread_reduce<${output_vec_size}>(input_slice);
+    }
+
+    if (config.should_block_y_reduce()) {
+      value = block_y_reduce<${output_vec_size}>(value, shared_memory);
+    }
+    if (config.should_block_x_reduce()) {
+      value = block_x_reduce<${output_vec_size}>(value, shared_memory);
+    }
+
+    using out_ptr_vec_t = Array<out_scalar_t*, ${output_vec_size}>;
+    using offset_vec_t = Array<uint32_t, ${output_vec_size}>;
+    offset_vec_t base_offsets;
+    out_ptr_vec_t out;
+
+    #pragma unroll
+    for (int i = 0; i < ${output_vec_size}; i++) {
+      base_offsets[i] = output_calc.get(output_idx + i)[0];
+      out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
+    }
+
+    arg_vec_t* acc = nullptr;
+    if (acc_buf != nullptr) {
+      size_t numerator = sizeof(arg_t);
+      size_t denominator = sizeof(out_scalar_t);
+      reduce_fraction(numerator, denominator);
+      acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
+    }
+
+    if (config.should_global_reduce()) {
+      value = global_reduce<${output_vec_size}>(value, acc, shared_memory);
+    } else if (config.should_store(output_idx)) {
+      if (accumulate) {
+        #pragma unroll
+        for (int i = 0; i < ${output_vec_size}; i++) {
+          value[i] = reducer::translate_idx(value[i], base_idx);
+        }
+      }
+
+      if (acc == nullptr) {
+        if (accumulate) {
+          value = accumulate_in_output<${output_vec_size}>(out, value);
+        }
+        if (final_output) {
+          set_results_to_output<${output_vec_size}>(value, base_offsets);
+        } else {
+          #pragma unroll
+          for (int i = 0; i < ${output_vec_size}; i++) {
+            *(out[i]) = get_accumulated_output(out[i], value[i]);
+          }
+        }
+      } else {
+        if (accumulate) {
+          #pragma unroll
+          for (int i = 0; i < ${output_vec_size}; i++) {
+            value[i] = reducer::combine((*acc)[i], value[i]);
+          }
+        }
+        if (final_output) {
+          set_results_to_output<${output_vec_size}>(value, base_offsets);
+        } else {
+          *acc = value;
+        }
+      }
+    }
+  }
+
+  template <int output_vec_size>
+  C10_DEVICE Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
+    if (config.vectorize_input) {
+      assert(output_vec_size == 1);
+      // reduce at the header of input_slice where memory is not aligned,
+      // so that thread_reduce will have an aligned memory to work on.
+      return {input_vectorized_thread_reduce_impl(data)};
+    } else {
+      uint32_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);
+      bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
+      if (is_contiguous) {
+        return thread_reduce_impl<output_vec_size>(data, [](uint32_t idx) { return idx; });
+      } else if (input_calc.dims == 1) {
+        return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return idx * element_stride; });
+      } else {
+        return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
+      }
+    }
+  }
+
+  C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const {
+    uint32_t end = config.num_inputs;
+
+    // Handle the head of input slice where data is not aligned
+    arg_t value = ident;
+    constexpr int align_bytes = alignof(aligned_vector<scalar_t, input_vec_size>);
+    constexpr int align_elements = align_bytes / sizeof(scalar_t);
+    int shift = ((int64_t)data) % align_bytes / sizeof(scalar_t);
+    if (shift > 0) {
+      data -= shift;
+      end += shift;
+      if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){
+        value = reducer::reduce(value, data[threadIdx.x], threadIdx.x - shift);
+      }
+      end -= align_elements;
+      data += align_elements;
+      shift = align_elements - shift;
+    }
+
+    // Do the vectorized reduction
+    using load_t = aligned_vector<scalar_t, input_vec_size>;
+
+    uint32_t idx = config.input_idx();
+    const uint32_t stride = config.step_input;
+
+    // Multiple accumulators to remove dependency between unrolled loops.
+    arg_t value_list[input_vec_size];
+    value_list[0] = value;
+
+    #pragma unroll
+    for (int i = 1; i < input_vec_size; i++) {
+      value_list[i] = ident;
+    }
+
+    scalar_t values[input_vec_size];
+
+    load_t *values_vector = reinterpret_cast<load_t*>(&values[0]);
+
+    while (idx * input_vec_size + input_vec_size - 1 < end) {
+      *values_vector = reinterpret_cast<const load_t*>(data)[idx];
+      #pragma unroll
+      for (uint32_t i = 0; i < input_vec_size; i++) {
+        value_list[i] = reducer::reduce(value_list[i], values[i], shift + idx * input_vec_size + i);
+      }
+      idx += stride;
+    }
+
+    // tail
+    uint32_t tail_start = end - end % input_vec_size;
+    if (config.should_reduce_tail()) {
+      int idx = tail_start + threadIdx.x;
+      if (idx < end) {
+        value_list[0] = reducer::reduce(value_list[0], data[idx], idx + shift);
+      }
+    }
+
+    // combine accumulators
+    #pragma unroll
+    for (int i = 1; i < input_vec_size; i++) {
+      value_list[0] = reducer::combine(value_list[0], value_list[i]);
+    }
+    return value_list[0];
+  }
+
+  template <int output_vec_size, typename offset_calc_t>
+  C10_DEVICE Array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
+    uint32_t idx = config.input_idx();
+    const uint32_t end = config.num_inputs;
+    const uint32_t stride = config.step_input;
+    const int vt0=${vt0};
+
+    using arg_vec_t = Array<arg_t, output_vec_size>;
+    using load_t = aligned_vector<scalar_t, output_vec_size>;
+    const load_t* data = reinterpret_cast<const load_t*>(data_);
+
+    // Multiple accumulators to remove dependency between unrolled loops.
+    arg_vec_t value_list[vt0];
+
+    #pragma unroll
+    for (int i = 0; i < vt0; i++) {
+      #pragma unroll
+      for (int j = 0; j < output_vec_size; j++) {
+        value_list[i][j] = ident;
+      }
+    }
+
+    load_t values[vt0];
+
+    while (idx + (vt0 - 1) * stride < end) {
+      #pragma unroll
+      for (uint32_t i = 0; i < vt0; i++) {
+        values[i] = data[calc(idx + i * stride) / output_vec_size];
+      }
+      #pragma unroll
+      for (uint32_t i = 0; i < vt0; i++) {
+        #pragma unroll
+        for (uint32_t j = 0; j < output_vec_size; j++) {
+          value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx + i * stride);
+        }
+      }
+      idx += stride * vt0;
+    }
+
+    // tail
+    int idx_ = idx;
+    #pragma unroll
+    for (uint32_t i = 0; i < vt0; i++) {
+      if (idx >= end) {
+        break;
+      }
+      values[i] = data[calc(idx) / output_vec_size];
+      idx += stride;
+    }
+    idx = idx_;
+    #pragma unroll
+    for (uint32_t i = 0; i < vt0; i++) {
+      if (idx >= end) {
+        break;
+      }
+      #pragma unroll
+      for (uint32_t j = 0; j < output_vec_size; j++) {
+        value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx);
+      }
+      idx += stride;
+    }
+
+    // combine accumulators
+    #pragma unroll
+    for (int i = 1; i < vt0; i++) {
+      #pragma unroll
+      for (uint32_t j = 0; j < output_vec_size; j++) {
+        value_list[0][j] = reducer::combine(value_list[0][j], value_list[i][j]);
+      }
+    }
+    return value_list[0];
+  }
+  template <int output_vec_size>
+  C10_DEVICE Array<arg_t, output_vec_size> block_x_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
+    using args_vec_t = Array<arg_t, output_vec_size>;
+    int dim_x = blockDim.x;
+    args_vec_t* shared = (args_vec_t*)shared_memory;
+    if (dim_x > warpSize) {
+      int address_base = threadIdx.x + threadIdx.y*blockDim.x;
+      shared[address_base] = value;
+      for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
+        __syncthreads();
+        if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
+          args_vec_t other = shared[address_base + offset];
+          #pragma unroll
+          for (int i = 0; i < output_vec_size; i++) {
+            value[i] = reducer::combine(value[i], other[i]);
+          }
+          shared[address_base] = value;
+        }
+      }
+      dim_x = warpSize;
+    }
+
+    __syncthreads();
+
+    for (int offset = 1; offset < dim_x; offset <<= 1) {
+      #pragma unroll
+      for (int i = 0; i < output_vec_size; i++) {
+        arg_t other = reducer::warp_shfl_down(value[i], offset);
+        value[i] = reducer::combine(value[i], other);
+      }
+    }
+    return value;
+  }
+
+  template <int output_vec_size>
+  C10_DEVICE Array<arg_t, output_vec_size> block_y_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
+    using args_vec_t = Array<arg_t, output_vec_size>;
+    args_vec_t* shared = (args_vec_t*)shared_memory;
+    shared[config.shared_memory_offset(0)] = value;
+    for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
+      __syncthreads();
+      if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
+        args_vec_t other = shared[config.shared_memory_offset(offset)];
+        #pragma unroll
+        for (int i = 0; i < output_vec_size; i++) {
+          value[i] = reducer::combine(value[i], other[i]);
+        }
+        shared[config.shared_memory_offset(0)] = value;
+      }
+    }
+    return value;
+  }
+  )ESCAPE";
+
+  const std::string reduction_template_1 = R"ESCAPE(
+
+  C10_DEVICE bool mark_block_finished() const {
+    __shared__ bool is_last_block_done_shared;
+
+    __syncthreads();
+    if (threadIdx.x == 0 && threadIdx.y == 0) {
+      int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
+      is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
+    }
+
+    __syncthreads();
+
+    return is_last_block_done_shared;
+  }
+
+  template <int output_vec_size>
+  C10_DEVICE Array<arg_t, output_vec_size> accumulate_in_output(
+    Array<out_scalar_t*, output_vec_size> out,
+    Array<arg_t, output_vec_size> value
+  ) const {
+    Array<arg_t, output_vec_size> ret;
+    #pragma unroll
+    for (int i = 0; i < output_vec_size; i++) {
+      ret[i] = reducer::combine(*(out[i]), value[i]);
+    }
+    return ret;
+  }
+
+
+  C10_DEVICE out_scalar_t get_accumulated_output(
+    out_scalar_t* out, arg_t value
+  ) const {
+    assert(!final_output);
+    return (out_scalar_t)value;
+  }
+
+  template<class T>
+  C10_DEVICE void set_results(const T x, const uint32_t base_offset) const {
+    assert(noutputs == 1);
+    auto res = (out_scalar_t*)((char*)dst[0] + base_offset);
+    *res = x;
+  }
+
+//TODO - multi-output reduction - we won't be able to use thrust::pair
+//just explicitly specify typed output reads/writes
+//Currently implemented for max of two outputs
+//   template<class T1, class T2>
+//   C10_DEVICE void set_results(const thrust::pair<T1, T2> x, const index_t base_offset) const {
+//     if (noutputs >= 1) {
+//       auto res0 = (T1*)((char*)dst[0] + base_offset);
+//       *res0 = x.first;
+//     }
+//     if (noutputs >= 2) {
+//       // base offset is computed assuming element size being sizeof(T1), so we need to make a
+//       // correction to obtain the correct base offset
+//       auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2));
+//       *res1 = x.second;
+//     }
+//   }
+
+  template <int output_vec_size>
+  C10_DEVICE void set_results_to_output(Array<arg_t, output_vec_size> value, Array<uint32_t, output_vec_size> base_offset) const {
+    assert(final_output);
+    #pragma unroll
+    for (int i = 0; i < output_vec_size; i++) {
+      set_results(reducer::project(value[i]), base_offset[i]);
+    }
+  }
+
+  template <int output_vec_size>
+  C10_DEVICE Array<arg_t, output_vec_size> global_reduce(Array<arg_t, output_vec_size> value, Array<arg_t, output_vec_size> *acc, char* shared_memory) const {
+    using arg_vec_t = Array<arg_t, output_vec_size>;
+    using out_ptr_vec_t = Array<out_scalar_t*, output_vec_size>;
+    using offset_vec_t = Array<uint32_t, output_vec_size>;
+
+    arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
+    uint32_t output_idx = config.output_idx<output_vec_size>();
+    offset_vec_t base_offsets;
+    out_ptr_vec_t out;
+
+    #pragma unroll
+    for (int i = 0; i < output_vec_size; i++) {
+      base_offsets[i] = output_calc.get(output_idx + i)[0];
+      out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
+    }
+
+    bool should_store = config.should_store(output_idx);
+    if (should_store) {
+      uint32_t offset = config.staging_memory_offset(blockIdx.y);
+      reduce_buffer[offset] = value;
+    }
+
+    __threadfence(); // make sure writes are globally visible
+    __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done
+    bool is_last_block_done = mark_block_finished();
+
+    if (is_last_block_done) {
+      value = ident;
+      if (config.should_block_x_reduce()) {
+        uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
+        uint32_t step = blockDim.x * blockDim.y;
+        for (; input_offset < config.ctas_per_output; input_offset += step) {
+          uint32_t idx = config.staging_memory_offset(input_offset);
+          arg_vec_t next = reduce_buffer[idx];
+          #pragma unroll
+          for (int i = 0; i < output_vec_size; i++) {
+            value[i] = reducer::combine(value[i], next[i]);
+          }
+        }
+      } else {
+        uint32_t input_offset = threadIdx.y;
+        uint32_t step = blockDim.y;
+        for (; input_offset < config.ctas_per_output; input_offset += step) {
+          uint32_t idx = config.staging_memory_offset(input_offset);
+          arg_vec_t next = reduce_buffer[idx];
+          #pragma unroll
+          for (int i = 0; i < output_vec_size; i++) {
+            value[i] = reducer::combine(value[i], next[i]);
+          }
+        }
+      }
+      value = block_y_reduce(value, shared_memory);
+      if (config.should_block_x_reduce()) {
+        value = block_x_reduce<output_vec_size>(value, shared_memory);
+      }
+      if (should_store) {
+        if (accumulate) {
+          #pragma unroll
+          for (int i = 0; i < output_vec_size; i++) {
+            value[i] = reducer::translate_idx(value[i], base_idx);
+          }
+        }
+
+        if (acc == nullptr) {
+          if (accumulate) {
+            value = accumulate_in_output<output_vec_size>(out, value);
+          }
+          if (final_output) {
+            set_results_to_output<output_vec_size>(value, base_offsets);
+          } else {
+            #pragma unroll
+            for (int i = 0; i < output_vec_size; i++) {
+              *(out[i]) = get_accumulated_output(out[i], value[i]);
+            }
+          }
+        } else {
+          if (accumulate) {
+            #pragma unroll
+            for (int i = 0; i < output_vec_size; i++) {
+              value[i] = reducer::combine((*acc)[i], value[i]);
+            }
+          }
+          if (final_output) {
+            set_results_to_output<output_vec_size>(value, base_offsets);
+          } else {
+            *acc = value;
+          }
+        }
+      }
+    }
+
+    return value;
+  }
+};
+
+extern "C"
+__launch_bounds__(${max_threads_lb}, 4)
+__global__ void reduction_${name}_kernel(ReduceJitOp r){
+  r.run();
+}
+)ESCAPE";
+
+const std::string reduction_template = reduction_template_0 + reduction_template_1;
+
+
+const std::string &get_reduction_template() {
+  return reduction_template;
+}
+
+}}