WIP Jiterator reduction
This PR enables jit-compiled reductions and moves `prod` to be jit-compiled.
Currently, only reductions that can use `func_wrapper` for automatic implementation of `reduce/project/translate_idx` opes are supported, there are a few TODOs for support of more complex reductions such as norms and max, that typically require full-fledged ReduceOps functor. Similarly, only reductions with a single input are supported.
Number of inputs is hardcoded to 1, which is true for our current reductions, but can be relaxed in the future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74446
Approved by: https://github.com/mruberry
diff --git a/aten/src/ATen/native/cuda/CUDAJitLoops.cuh b/aten/src/ATen/native/cuda/CUDAJitLoops.cuh
index b5b1cd5..61417b9 100644
--- a/aten/src/ATen/native/cuda/CUDAJitLoops.cuh
+++ b/aten/src/ATen/native/cuda/CUDAJitLoops.cuh
@@ -71,7 +71,8 @@
std::tuple<Args...> extra_args) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
- const int64_t grid = (N + block_work_size() - 1) / block_work_size();
+ //casting result to int is always safe, intermediate is int64 and won't overflow
+ const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
static std::mutex _jiterator_mutex;
static std::vector<at::cuda::jit::NvrtcFunction> fns(c10::cuda::device_count());
@@ -115,8 +116,8 @@
args[i + 7] = extra_args_array[i];
}
- at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
- C10_CUDA_KERNEL_LAUNCH_CHECK();
+ at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, {grid, 1, 1},
+ {num_threads(), 1, 1});
}
template<
@@ -129,7 +130,8 @@
static inline void launch_jitted_vectorized_kernel(DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data,
at::opmath_type<f_inputs_type> scalar_val, std::tuple<Args...> extra_args) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
- const int64_t grid = (N + block_work_size() - 1) / block_work_size();
+ // N is still int64_t for the computation, but it's always safe to cast result to int
+ const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
const int vec_size = memory::jitted_can_vectorize_up_to<result_type, f_inputs_type, arity>(data);
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
@@ -196,8 +198,7 @@
args[i + 3] = extra_args_array[i];
}
- at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
- C10_CUDA_KERNEL_LAUNCH_CHECK();
+ at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, {grid, 1, 1}, {num_threads(), 1, 1});
} else {
auto ic = TrivialOffsetCalculator<arity>();
auto oc = TrivialOffsetCalculator<1>();
@@ -219,7 +220,7 @@
// since 7 slots are already filled in `args`
args[i + 7] = extra_args_array[i];
}
- at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
+ at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, {grid, 1, 1}, {num_threads(), 1, 1});
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh
index 5ee3757..57fa55f 100644
--- a/aten/src/ATen/native/cuda/Reduce.cuh
+++ b/aten/src/ATen/native/cuda/Reduce.cuh
@@ -9,6 +9,7 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/thread_constants.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
+#include <ATen/OpMathType.h>
#include <c10/macros/Macros.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <functional>
@@ -17,6 +18,9 @@
#include <utility>
#include <thrust/pair.h>
+#include <ATen/native/cuda/jit_utils.h>
+#include <iostream>
+
namespace at { namespace native {
using at::detail::Array;
@@ -272,6 +276,65 @@
return func_wrapper_t<scalar_t, func_t> { op };
}
+template <typename scalar_t, typename out_scalar_t=scalar_t>
+struct ReduceJitOp {
+//ReduceJitOp is almost like ReduceOp, but it doesn't have ops functor that specifies reduction operations
+//Maybe we can find a way to unify ReduceOp and ReduceJitOp
+ using InputCalculator = OffsetCalculator<1, uint32_t>;
+ using OutputCalculator = OffsetCalculator<2, uint32_t>;
+ //TODO for now arg_t is always opmath_t of the input, later we'll need to change it
+ using arg_t = at::opmath_type<scalar_t>;
+
+ static constexpr int input_vec_size = ReduceConfig::input_vec_size;
+ //TODO - ReduceJitOp will probably need to be changed for reductions that need full functor,
+ //not just wrapper
+ arg_t ident;
+ ReduceConfig config;
+ InputCalculator input_calc;
+ OutputCalculator output_calc;
+ const void* src;
+ const char* dst[2]; //it accepts at most two destinations
+ // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
+ // output is not permissible
+ void* acc_buf;
+ // cta_buf used for accumulation between blocks during global reduction
+ void* cta_buf;
+ int* semaphores;
+ int64_t base_idx;
+ bool accumulate;
+ bool final_output;
+ int noutputs;
+
+ ReduceJitOp(
+ ReduceConfig config,
+ InputCalculator input_calc,
+ OutputCalculator output_calc,
+ const void* src,
+ char* dst0,
+ optional<char*> dst1,
+ void* acc_buf,
+ void* cta_buf,
+ int* semaphores,
+ arg_t ident,
+ int noutputs,
+ int64_t base_idx)
+ : ident(ident),
+ config(config),
+ input_calc(input_calc),
+ output_calc(output_calc),
+ src(src),
+ acc_buf(acc_buf),
+ cta_buf(cta_buf),
+ semaphores(semaphores),
+ base_idx(base_idx),
+ noutputs(noutputs) {
+ dst[0] = dst0;
+ if (dst1.has_value()) {
+ dst[1] = dst1.value();
+ }
+ }
+};
+
template <typename scalar_t, typename ops_t, typename index_t, typename out_scalar_t=scalar_t, int vt0=4>
struct ReduceOp {
using traits = function_traits<decltype(&ops_t::reduce)>;
@@ -284,8 +347,6 @@
std::is_convertible<arg_t, out_scalar_t>::value
&& std::is_convertible<out_scalar_t, arg_t>::value;
- static constexpr float acc_buffer_multiplier = (float)sizeof(arg_t) / sizeof(out_scalar_t);
-
static constexpr int input_vec_size = ReduceConfig::input_vec_size;
ops_t ops;
@@ -837,6 +898,47 @@
}
}
+template<char const *name, typename scalar_t, typename out_scalar_t,
+int vt0, typename R>
+static void launch_jitted_reduce_kernel(DeviceIndex idx, const ReduceConfig& config,
+R& reduction, const std::string& func) {
+ constexpr int max_threads = mnt_wrapper<scalar_t>::MAX_NUM_THREADS;
+ dim3 block = config.block();
+ dim3 grid = config.grid();
+
+ static std::mutex _jiterator_mutex;
+ static std::vector<std::array<at::cuda::jit::NvrtcFunction, 3>> fns(c10::cuda::device_count());
+ int shared_memory = config.shared_memory_size();
+ at::cuda::jit::NvrtcFunction* fn_ptr;
+ switch(config.output_vec_size) {
+ case 4:
+ fn_ptr = &fns[idx][0];
+ break;
+ case 2:
+ fn_ptr = &fns[idx][1];
+ break;
+ default:
+ fn_ptr = &fns[idx][2];
+ }
+ if (!fn_ptr->function) {
+ std::string f_inputs_type_str = at::cuda::jit::typeName<scalar_t>();
+ std::string accum_type_str = at::cuda::jit::typeName<at::opmath_type<scalar_t>>();
+ std::string result_type_str = at::cuda::jit::typeName<out_scalar_t>();
+ int max_threads_codegen = max_threads/config.output_vec_size;
+ auto code = at::cuda::jit::generate_reduction_code(1, func, name, vt0,
+ f_inputs_type_str, accum_type_str, result_type_str,
+ true, false, config.output_vec_size, max_threads_codegen);
+
+ *fn_ptr = at::cuda::jit::jit_pwise_function(code, "reduction_"+std::string(name));
+
+ }
+ constexpr int kernel_args = 1;
+ void* args[kernel_args];
+ args[0] = static_cast<void*>(&reduction);
+ at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, block, shared_memory);
+}
+
+
class AccumulationBuffer {
public:
AccumulationBuffer() {}
@@ -874,7 +976,7 @@
};
template <typename scalar_t>
-int get_output_vec_size(TensorIterator &iter) {
+int get_output_vec_size(const TensorIterator &iter) {
int vec_size = 4;
auto update_vec_size = [&vec_size](uint64_t n) {
while(n % vec_size != 0) {
@@ -898,61 +1000,8 @@
return vec_size;
}
-template <typename scalar_t, typename out_scalar_t, int vt0=4, typename ops_t, typename ident_t=double>
-inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0,
- AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
- AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1);
-
- using traits = function_traits<decltype(&ops_t::reduce)>;
- using arg_t = typename traits::template arg<0>::type;
- static constexpr bool can_accumulate_in_output =
- std::is_convertible<arg_t, out_scalar_t>::value;
-
- bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
- std::unique_ptr<AccumulationBuffer> owned_buf_ptr;
-
- // The acc_buf_ptr is a shared pointer. It is create at the first entrance and
- // reused by all recursive function calls.
- if (acc_buf_ptr == NULL) {
- // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
- // when accumulation in output is not possible.
- if (!can_accumulate_in_output && !can_use_32bit_indexing) {
- int64_t output_memory_size = iter.element_size(0);
- for (int dim = 0; dim < iter.ndim(); dim++) {
- output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
- }
- output_memory_size /= iter.element_size(0); //iter.strides is in bytes
- owned_buf_ptr.reset(new AccumulationBuffer(sizeof(arg_t),
- sizeof(out_scalar_t),
- (char*) iter.data_ptr(0),
- output_memory_size * sizeof(arg_t)));
- } else {
- owned_buf_ptr.reset(new AccumulationBuffer());
- }
- acc_buf_ptr = owned_buf_ptr.get();
- }
-
- if (!can_use_32bit_indexing) {
- for (auto& sub_iter : iter.with_32bit_indexing()) {
- int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
-
- gpu_reduce_kernel<scalar_t, out_scalar_t, vt0>(sub_iter, ops, ident,
- acc_buf_ptr, sub_iter_base_idx);
- }
- return;
- }
-
- const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1);
- char* out_data = (char*)iter.data_ptr(0);
- const auto noutputs = iter.noutputs();
- optional<char*> out_data_extra;
- if (noutputs > 1) {
- out_data_extra = (char*)iter.data_ptr(1);
- } else {
- out_data_extra = nullopt;
- }
- char* acc_data = acc_buf_ptr->get_acc_slice(out_data);
-
+template<typename arg_t, typename scalar_t, int vt0>
+ReduceConfig setReduceConfig(const TensorIterator& iter){
// Start by assuming that each thread handles a single output and all
// the inputs for that output.
int64_t num_outputs = iter.num_output_elements();
@@ -1080,7 +1129,64 @@
config.input_mult[2] = config.split_input(config.ctas_per_output);
}
}
+ return config;
+};
+template <typename scalar_t, typename out_scalar_t, int vt0=4, typename ops_t, typename ident_t=double>
+inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0,
+ AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
+ AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1);
+
+ using traits = function_traits<decltype(&ops_t::reduce)>;
+ using arg_t = typename traits::template arg<0>::type;
+ static constexpr bool can_accumulate_in_output =
+ std::is_convertible<arg_t, out_scalar_t>::value;
+
+ bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
+ std::unique_ptr<AccumulationBuffer> owned_buf_ptr;
+ // The acc_buf_ptr is a shared pointer. It is create at the first entrance and
+ // reused by all recursive function calls.
+ if (acc_buf_ptr == NULL) {
+ // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
+ // when accumulation in output is not possible.
+ if (!can_accumulate_in_output && !can_use_32bit_indexing) {
+ int64_t output_memory_size = iter.element_size(0);
+ for (int dim = 0; dim < iter.ndim(); dim++) {
+ output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
+ }
+ output_memory_size /= iter.element_size(0); //iter.strides is in bytes
+ owned_buf_ptr.reset(new AccumulationBuffer(sizeof(arg_t),
+ sizeof(out_scalar_t),
+ (char*) iter.data_ptr(0),
+ output_memory_size * sizeof(arg_t)));
+ } else {
+ owned_buf_ptr.reset(new AccumulationBuffer());
+ }
+ acc_buf_ptr = owned_buf_ptr.get();
+ }
+
+ if (!can_use_32bit_indexing) {
+ for (auto& sub_iter : iter.with_32bit_indexing()) {
+ int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
+
+ gpu_reduce_kernel<scalar_t, out_scalar_t, vt0>(sub_iter, ops, ident,
+ acc_buf_ptr, sub_iter_base_idx);
+ }
+ return;
+ }
+
+ const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1);
+ char* out_data = (char*)iter.data_ptr(0);
+ const auto noutputs = iter.noutputs();
+ optional<char*> out_data_extra;
+ if (noutputs > 1) {
+ out_data_extra = (char*)iter.data_ptr(1);
+ } else {
+ out_data_extra = nullopt;
+ }
+ char* acc_data = acc_buf_ptr->get_acc_slice(out_data);
+
+ ReduceConfig config = setReduceConfig<arg_t, scalar_t, vt0>(iter);
at::DataPtr buffer;
at::DataPtr semaphores;
if (config.should_global_reduce()) {
@@ -1115,4 +1221,101 @@
launch_reduce_kernel<mnt_wrapper<scalar_t>::MAX_NUM_THREADS>(config, reduce);
}
+//TODO this is 100 lines of almost-copy-paste, because we have to have different template args for this function
+//try unifying with gpu_reduce_kernel
+template <char const* name, typename scalar_t, typename out_scalar_t, int vt0=4, typename ident_t=double>
+inline void jitted_gpu_reduce_kernel(TensorIterator& iter, const std::string& func, ident_t ident=0,
+ AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
+ AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1);
+
+ //TODO - this will be different for more complicated reductions, but for now reductions using
+ //func_wrapper all have arg_t = opmath
+ using arg_t = at::opmath_type<scalar_t>;
+ static constexpr bool can_accumulate_in_output =
+ std::is_convertible<arg_t, out_scalar_t>::value;
+ static_assert(can_accumulate_in_output == true, "unsupported arg_t for jitted reduction");
+
+ bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
+ std::unique_ptr<AccumulationBuffer> owned_buf_ptr;
+
+ // The acc_buf_ptr is a shared pointer. It is create at the first entrance and
+ // reused by all recursive function calls.
+ if (acc_buf_ptr == NULL) {
+ // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
+ // when accumulation in output is not possible.
+ if (!can_accumulate_in_output && !can_use_32bit_indexing) {
+ int64_t output_memory_size = iter.element_size(0);
+ for (int dim = 0; dim < iter.ndim(); dim++) {
+ output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
+ }
+ output_memory_size /= iter.element_size(0); //iter.strides is in bytes
+ owned_buf_ptr.reset(new AccumulationBuffer(sizeof(out_scalar_t), //TODO
+ sizeof(out_scalar_t),
+ (char*) iter.data_ptr(0),
+ output_memory_size * sizeof(out_scalar_t))); //TODO
+ } else {
+ owned_buf_ptr.reset(new AccumulationBuffer());
+ }
+ acc_buf_ptr = owned_buf_ptr.get();
+ }
+
+ if (!can_use_32bit_indexing) {
+ for (auto& sub_iter : iter.with_32bit_indexing()) {
+ int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
+
+ jitted_gpu_reduce_kernel<name, scalar_t, out_scalar_t, vt0>(sub_iter, func, ident,
+ acc_buf_ptr, sub_iter_base_idx);
+ }
+ return;
+ }
+
+ //TODO - for now we support a single input, we may be able to relax this constraint
+ const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1);
+ char* out_data = (char*)iter.data_ptr(0);
+ const auto noutputs = iter.noutputs();
+ optional<char*> out_data_extra;
+ if (noutputs > 1) {
+ out_data_extra = (char*)iter.data_ptr(1);
+ } else {
+ out_data_extra = nullopt;
+ }
+ char* acc_data = acc_buf_ptr->get_acc_slice(out_data);
+
+ ReduceConfig config = setReduceConfig<arg_t, scalar_t, vt0>(iter);
+
+ at::DataPtr buffer;
+ at::DataPtr semaphores;
+ if (config.should_global_reduce()) {
+ auto& allocator = *c10::cuda::CUDACachingAllocator::get();
+ buffer = allocator.allocate(config.global_memory_size());
+ semaphores = allocator.allocate(config.semaphore_size());
+
+ auto stream = at::cuda::getCurrentCUDAStream();
+ AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
+ }
+
+ AT_ASSERT(can_use_32bit_indexing);
+ auto output_calc = make_output_calculator<uint32_t>(iter);
+ auto input_calc = make_input_calculator<uint32_t>(iter);
+ auto reduce = ReduceJitOp<scalar_t, out_scalar_t>(
+ config,
+ input_calc,
+ output_calc,
+ in_data,
+ out_data,
+ out_data_extra,
+ acc_data,
+ buffer.get(),
+ (int*)semaphores.get(),
+ ident,
+ noutputs,
+ base_idx);
+ reduce.accumulate = iter.should_accumulate();
+ reduce.final_output = iter.is_final_output();
+
+ launch_jitted_reduce_kernel<name, scalar_t,
+ out_scalar_t, vt0>(iter.device().index(),
+ config, reduce, func);
+}
+
}} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu b/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu
index bf81ed5..9faeae9 100644
--- a/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu
+++ b/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu
@@ -5,6 +5,7 @@
#include <ATen/native/SharedReduceOps.h>
#include <ATen/Dispatch.h>
#include <ATen/native/ReduceOps.h>
+#include <ATen/jit_macros.h>
namespace at { namespace native {
@@ -26,14 +27,28 @@
}
};
+const char op_name[] = "prod";
+
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
struct prod_functor {
+ #if AT_USE_JITERATOR()
+ void operator()(TensorIterator& iter) {
+ std::string func = jiterator_stringify(
+ arg_t combine(arg_t a, arg_t b) {
+ return a * b;
+ }
+ );
+ jitted_gpu_reduce_kernel<op_name, scalar_t, out_t>(
+ iter, func, 1.);
+ }
+ #else
void operator()(TensorIterator& iter) {
gpu_reduce_kernel<scalar_t, out_t>(
iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a * b;
- }), 1);
+ }), 1.);
}
+ #endif
};
// Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context]
diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp
index d88a39e..14d1a92 100644
--- a/aten/src/ATen/native/cuda/jit_utils.cpp
+++ b/aten/src/ATen/native/cuda/jit_utils.cpp
@@ -10,6 +10,7 @@
#include <ATen/code_template.h>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/cuda/llvm_jit_strings.h>
+#include <ATen/native/cuda/reduction_template.cuh>
#include <sstream>
#include <fstream>
@@ -118,6 +119,11 @@
Array() = default;
Array(const Array&) = default;
Array& operator=(const Array&) = default;
+ __device__ Array(T x) {
+ for (int i = 0; i < size; i++) {
+ data[i] = x;
+ }
+ }
};
${half_string}
@@ -322,10 +328,7 @@
)ESCAPE";
-const std::string jit_code_template = R"ESCAPE(
-
- ${dynamic_casting_string}
-
+const std::string offset_calc_template = R"ESCAPE(
template <typename T>
struct DivMod {
T div;
@@ -409,6 +412,14 @@
${index_type} strides_[25][NARGS];
};
+
+)ESCAPE";
+
+const std::string jit_code_template = R"ESCAPE(
+
+ ${dynamic_casting_string}
+
+
${functor}
// TODO: setup grid-stride loop
@@ -769,7 +780,7 @@
<< ">(out[j], data[0], output_offsets[0]);\n";
env.s("store_outputs", store_outputs.str());
- static auto cuda_template = at::jit::CodeTemplate(jit_common_types + jit_code_template);
+ static auto cuda_template = at::jit::CodeTemplate(jit_common_types + offset_calc_template + jit_code_template);
const auto code = cuda_template.format(env);
return code;
}
@@ -865,8 +876,69 @@
}
return _r_mkdir(base+dir);
+
}
+std::string load_code_template(const std::string& path) {
+ std::ifstream ifs{path};
+ std::string s{
+ std::istreambuf_iterator<char>(ifs),
+ std::istreambuf_iterator<char>()};
+ return s;
+}
+
+std::string generate_reduction_code(
+ int nOutputs,
+ const std::string& func,
+ const std::string& name,
+ const int vt0,
+ const std::string& f_inputs_type,
+ const std::string& reduction_accum_type,
+ const std::string& result_type,
+ bool contiguous,
+ bool vectorized,
+ int vec_size,
+ int max_threads_codegen) {
+ at::jit::TemplateEnv env;
+ env.s("index_type", "unsigned int");
+ env.s("scalar_type", f_inputs_type);
+ env.s("result_type", result_type);
+ env.s("reduction_accum_type", reduction_accum_type);
+ env.s("vt0", std::to_string(vt0));
+ env.s("name", name);
+ env.s("max_threads_lb", std::to_string(max_threads_codegen));
+ // reductions don't support dynamic casting, so the only way to get nonstandard types
+ // is through input
+ if (f_inputs_type == "at::Half") {
+ env.s("half_string", jiterator_half_support_literal);
+ } else {
+ env.s("half_string", "");
+ }
+ if (f_inputs_type == "at::BFloat16") {
+ env.s("bfloat16_string", jiterator_bfloat16_support_literal);
+ } else {
+ env.s("bfloat16_string", "");
+ }
+ if (f_inputs_type == "std::complex<float>" ||
+ f_inputs_type == "std::complex<double>" ) {
+ env.s("traits_string", get_traits_string());
+ env.s("complex_body_string", get_complex_body_string());
+ env.s("complex_math_string", get_complex_math_string());
+ env.s("complex", std::to_string(1));
+ } else {
+ env.s("traits_string", "");
+ env.s("complex_body_string", "");
+ env.s("complex_math_string", "");
+ env.s("complex", std::to_string(0));
+ }
+ env.s("cmath_string", get_cmath_string());
+ env.s("functor", func);
+ env.s("output_vec_size", std::to_string(vec_size));
+ static auto cuda_template = at::jit::CodeTemplate(
+ jit_common_types + offset_calc_template + get_reduction_template());
+ const auto code = cuda_template.format(env);
+ return code;
+}
// Acquires (possibly creating) the kernel cache directory
c10::optional<std::string> get_cache_dir() {
@@ -946,9 +1018,7 @@
NvrtcFunction jit_pwise_function(
const std::string& code,
const std::string& kernel_name) {
-
initializeCudaContext();
-
// Acquires CUDA and nvrtc versions and whether we're compiling to ptx or SASS
const cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
int cuda_major = 0, cuda_minor = 0, nvrtc_major = 0, nvrtc_minor = 0;
@@ -1043,7 +1113,7 @@
AT_CUDA_NVRTC_CHECK(nvrtc.nvrtcGetProgramLog(program, log.data()));
std::stringstream cu;
cu << log.data();
- throw std::runtime_error(cu.str() + code);
+ throw std::runtime_error(code + cu.str());
}
size_t ptx_size = 0;
@@ -1109,24 +1179,26 @@
void launch_jitted_pwise_function(
NvrtcFunction function,
void* args[],
- const int nBlocks,
- const int kBlockSize) {
+ const dim3 nBlocks,
+ const dim3 kBlockSize,
+ const int smem) {
initializeCudaContext();
const auto& nvrtc = at::globalContext().getNVRTC();
// Launches kernel on current stream
auto stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_DRIVER_CHECK(nvrtc.cuLaunchKernel(
function.function,
- nBlocks,
- 1,
- 1,
- kBlockSize,
- 1,
- 1,
- 0,
+ nBlocks.x,
+ nBlocks.y,
+ nBlocks.z,
+ kBlockSize.x,
+ kBlockSize.y,
+ kBlockSize.z,
+ smem,
stream,
args,
nullptr));
}
+
}}} // at::cuda::jit
diff --git a/aten/src/ATen/native/cuda/jit_utils.h b/aten/src/ATen/native/cuda/jit_utils.h
index 1f0f9c4..1ff6de7 100644
--- a/aten/src/ATen/native/cuda/jit_utils.h
+++ b/aten/src/ATen/native/cuda/jit_utils.h
@@ -32,6 +32,19 @@
bool vectorized=false,
int vec_size=0);
+std::string generate_reduction_code(
+ int nOutputs,
+ const std::string& func,
+ const std::string& name,
+ const int vt0,
+ const std::string& f_inputs_type,
+ const std::string& reduction_accum_type,
+ const std::string& result_type,
+ bool contiguous,
+ bool vectorized,
+ int vec_size,
+ int max_threads_codegen);
+
NvrtcFunction jit_pwise_function(
const std::string& code,
const std::string& kernel_name);
@@ -39,8 +52,9 @@
void launch_jitted_pwise_function(
NvrtcFunction function,
void* args[],
- const int nBlocks,
- const int kBlockSize);
+ const dim3 nBlocks,
+ const dim3 kBlockSize,
+ const int smem=0);
template <typename T>
struct delayed_false : std::false_type {
diff --git a/aten/src/ATen/native/cuda/reduction_template.cuh b/aten/src/ATen/native/cuda/reduction_template.cuh
new file mode 100644
index 0000000..4d9d559
--- /dev/null
+++ b/aten/src/ATen/native/cuda/reduction_template.cuh
@@ -0,0 +1,664 @@
+namespace at {
+namespace cuda {
+//windows doesn't like large string literals, so split in two
+const std::string reduction_template_0 = R"ESCAPE(
+ #define C10_HOST_DEVICE __host__ __device__
+ #define C10_DEVICE __device__
+
+ template <typename T>
+ __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
+ {
+ return __shfl_down_sync(mask, value, delta, width);
+ }
+
+
+ #if ${complex}
+ template <typename T>
+ __device__ __forceinline__ std::complex<T> WARP_SHFL_DOWN(std::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
+ {
+ return std::complex<T>(
+ __shfl_down_sync(mask, value.real(), delta, width),
+ __shfl_down_sync(mask, value.imag(), delta, width));
+ }
+ #endif
+
+ // aligned vector generates vectorized load/store on CUDA
+ template<typename scalar_t, int vec_size>
+ struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
+ scalar_t val[vec_size];
+ };
+
+
+ C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) {
+ // get GCD of num and denom using Euclid's algorithm.
+ // Can replace this with std::gcd if we ever support c++17.
+ size_t a = denominator;
+ size_t b = numerator;
+ while (b != 0) {
+ a %= b;
+ // swap(a,b)
+ size_t tmp = a;
+ a = b;
+ b = tmp;
+ }
+
+ // a is now the GCD
+ numerator /= a;
+ denominator /= a;
+ }
+
+
+
+
+ struct ReduceConfig {
+ //has to match host-side ReduceConfig in the eager code
+ static constexpr int BLOCK_X = 0;
+ static constexpr int BLOCK_Y = 1;
+ static constexpr int CTA = 2;
+
+ static constexpr int input_vec_size = 4;
+ int element_size_bytes;
+ int num_inputs;
+ int num_outputs;
+ int step_input = 1;
+ int step_output = 1;
+ int ctas_per_output = 1;
+ int input_mult[3] = {0, 0, 0};
+ int output_mult[2] = {0, 0};
+
+ int block_width;
+ int block_height;
+ int num_threads;
+
+ bool vectorize_input = false;
+ int output_vec_size = 1;
+
+ C10_HOST_DEVICE bool should_block_x_reduce() const {
+ return input_mult[BLOCK_X] != 0;
+ }
+
+ C10_HOST_DEVICE bool should_block_y_reduce() const {
+ return input_mult[BLOCK_Y] != 0;
+ }
+
+ C10_HOST_DEVICE bool should_global_reduce() const {
+ return input_mult[CTA] != 0;
+ }
+
+ C10_DEVICE bool should_store(int output_idx) const {
+ return output_idx < num_outputs &&
+ (!should_block_x_reduce() || threadIdx.x == 0) &&
+ (!should_block_y_reduce() || threadIdx.y == 0);
+ }
+
+ C10_DEVICE bool should_reduce_tail() const {
+ return (!should_block_y_reduce() || threadIdx.y == 0) &&
+ (!should_global_reduce() || blockIdx.y == 0);
+ }
+
+ C10_HOST_DEVICE int input_idx() const {
+ int lane = threadIdx.x;
+ int warp = threadIdx.y;
+ int cta2 = blockIdx.y;
+ return (lane * input_mult[BLOCK_X] +
+ warp * input_mult[BLOCK_Y] +
+ cta2 * input_mult[CTA]);
+ }
+
+ template <int output_vec_size>
+ C10_HOST_DEVICE int output_idx() const {
+ int lane = threadIdx.x;
+ int warp = threadIdx.y;
+ int cta1 = blockIdx.x;
+ return (lane * output_mult[BLOCK_X] +
+ warp * output_mult[BLOCK_Y] +
+ cta1 * step_output) * output_vec_size;
+ }
+
+ C10_DEVICE int shared_memory_offset(int offset) const {
+ return threadIdx.x + (threadIdx.y + offset) * blockDim.x;
+ }
+
+ C10_DEVICE int staging_memory_offset(int cta2) const {
+ int offset = cta2 + blockIdx.x * gridDim.y;
+ if (!should_block_x_reduce()) {
+ offset = threadIdx.x + offset * blockDim.x;
+ }
+ return offset;
+ }
+
+
+ };
+
+
+//TODO this will need to be different for more generic reduction functions
+namespace reducer {
+
+ using scalar_t = ${scalar_type};
+ using arg_t = ${reduction_accum_type};
+ using out_scalar_t = ${result_type};
+
+
+ inline __device__ ${functor}
+
+ inline __device__ out_scalar_t project(arg_t arg) {
+ return (out_scalar_t) arg;
+ }
+
+ inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) {
+ return WARP_SHFL_DOWN(arg, offset);
+ }
+
+ inline __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) {
+ return acc;
+ }
+
+ // wrap a normal reduction that ignores the index
+ inline __device__ arg_t reduce(arg_t acc, arg_t val, int64_t idx) {
+ return combine(acc, val);
+ }
+}
+
+
+struct ReduceJitOp {
+ using scalar_t = ${scalar_type};
+ using arg_t = ${reduction_accum_type};
+ using out_scalar_t = ${result_type};
+
+ using InputCalculator = OffsetCalculator<1>;
+ using OutputCalculator = OffsetCalculator<2>;
+
+// static constexpr bool can_accumulate_in_output =
+// std::is_convertible<arg_t, out_scalar_t>::value
+// && std::is_convertible<out_scalar_t, arg_t>::value;
+
+ static constexpr int input_vec_size = ReduceConfig::input_vec_size;
+
+ arg_t ident;
+ ReduceConfig config;
+ InputCalculator input_calc;
+ OutputCalculator output_calc;
+ const void* src;
+ const char* dst[2]; //it accepts at most two destinations
+ // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
+ // output is not permissible
+ void* acc_buf;
+ // cta_buf used for accumulation between blocks during global reduction
+ void* cta_buf;
+ int* semaphores;
+ int64_t base_idx;
+ bool accumulate;
+ bool final_output;
+ int noutputs;
+
+
+ C10_DEVICE void run() const {
+ extern __shared__ char shared_memory[];
+ uint32_t output_idx = config.output_idx<${output_vec_size}>();
+ uint32_t input_idx = config.input_idx();
+ auto base_offsets1 = output_calc.get(output_idx)[1];
+
+ using arg_vec_t = Array<arg_t, ${output_vec_size}>;
+ arg_vec_t value;
+
+ if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
+ const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1);
+
+ value = thread_reduce<${output_vec_size}>(input_slice);
+ }
+
+ if (config.should_block_y_reduce()) {
+ value = block_y_reduce<${output_vec_size}>(value, shared_memory);
+ }
+ if (config.should_block_x_reduce()) {
+ value = block_x_reduce<${output_vec_size}>(value, shared_memory);
+ }
+
+ using out_ptr_vec_t = Array<out_scalar_t*, ${output_vec_size}>;
+ using offset_vec_t = Array<uint32_t, ${output_vec_size}>;
+ offset_vec_t base_offsets;
+ out_ptr_vec_t out;
+
+ #pragma unroll
+ for (int i = 0; i < ${output_vec_size}; i++) {
+ base_offsets[i] = output_calc.get(output_idx + i)[0];
+ out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
+ }
+
+ arg_vec_t* acc = nullptr;
+ if (acc_buf != nullptr) {
+ size_t numerator = sizeof(arg_t);
+ size_t denominator = sizeof(out_scalar_t);
+ reduce_fraction(numerator, denominator);
+ acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
+ }
+
+ if (config.should_global_reduce()) {
+ value = global_reduce<${output_vec_size}>(value, acc, shared_memory);
+ } else if (config.should_store(output_idx)) {
+ if (accumulate) {
+ #pragma unroll
+ for (int i = 0; i < ${output_vec_size}; i++) {
+ value[i] = reducer::translate_idx(value[i], base_idx);
+ }
+ }
+
+ if (acc == nullptr) {
+ if (accumulate) {
+ value = accumulate_in_output<${output_vec_size}>(out, value);
+ }
+ if (final_output) {
+ set_results_to_output<${output_vec_size}>(value, base_offsets);
+ } else {
+ #pragma unroll
+ for (int i = 0; i < ${output_vec_size}; i++) {
+ *(out[i]) = get_accumulated_output(out[i], value[i]);
+ }
+ }
+ } else {
+ if (accumulate) {
+ #pragma unroll
+ for (int i = 0; i < ${output_vec_size}; i++) {
+ value[i] = reducer::combine((*acc)[i], value[i]);
+ }
+ }
+ if (final_output) {
+ set_results_to_output<${output_vec_size}>(value, base_offsets);
+ } else {
+ *acc = value;
+ }
+ }
+ }
+ }
+
+ template <int output_vec_size>
+ C10_DEVICE Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
+ if (config.vectorize_input) {
+ assert(output_vec_size == 1);
+ // reduce at the header of input_slice where memory is not aligned,
+ // so that thread_reduce will have an aligned memory to work on.
+ return {input_vectorized_thread_reduce_impl(data)};
+ } else {
+ uint32_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);
+ bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
+ if (is_contiguous) {
+ return thread_reduce_impl<output_vec_size>(data, [](uint32_t idx) { return idx; });
+ } else if (input_calc.dims == 1) {
+ return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return idx * element_stride; });
+ } else {
+ return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
+ }
+ }
+ }
+
+ C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const {
+ uint32_t end = config.num_inputs;
+
+ // Handle the head of input slice where data is not aligned
+ arg_t value = ident;
+ constexpr int align_bytes = alignof(aligned_vector<scalar_t, input_vec_size>);
+ constexpr int align_elements = align_bytes / sizeof(scalar_t);
+ int shift = ((int64_t)data) % align_bytes / sizeof(scalar_t);
+ if (shift > 0) {
+ data -= shift;
+ end += shift;
+ if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){
+ value = reducer::reduce(value, data[threadIdx.x], threadIdx.x - shift);
+ }
+ end -= align_elements;
+ data += align_elements;
+ shift = align_elements - shift;
+ }
+
+ // Do the vectorized reduction
+ using load_t = aligned_vector<scalar_t, input_vec_size>;
+
+ uint32_t idx = config.input_idx();
+ const uint32_t stride = config.step_input;
+
+ // Multiple accumulators to remove dependency between unrolled loops.
+ arg_t value_list[input_vec_size];
+ value_list[0] = value;
+
+ #pragma unroll
+ for (int i = 1; i < input_vec_size; i++) {
+ value_list[i] = ident;
+ }
+
+ scalar_t values[input_vec_size];
+
+ load_t *values_vector = reinterpret_cast<load_t*>(&values[0]);
+
+ while (idx * input_vec_size + input_vec_size - 1 < end) {
+ *values_vector = reinterpret_cast<const load_t*>(data)[idx];
+ #pragma unroll
+ for (uint32_t i = 0; i < input_vec_size; i++) {
+ value_list[i] = reducer::reduce(value_list[i], values[i], shift + idx * input_vec_size + i);
+ }
+ idx += stride;
+ }
+
+ // tail
+ uint32_t tail_start = end - end % input_vec_size;
+ if (config.should_reduce_tail()) {
+ int idx = tail_start + threadIdx.x;
+ if (idx < end) {
+ value_list[0] = reducer::reduce(value_list[0], data[idx], idx + shift);
+ }
+ }
+
+ // combine accumulators
+ #pragma unroll
+ for (int i = 1; i < input_vec_size; i++) {
+ value_list[0] = reducer::combine(value_list[0], value_list[i]);
+ }
+ return value_list[0];
+ }
+
+ template <int output_vec_size, typename offset_calc_t>
+ C10_DEVICE Array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
+ uint32_t idx = config.input_idx();
+ const uint32_t end = config.num_inputs;
+ const uint32_t stride = config.step_input;
+ const int vt0=${vt0};
+
+ using arg_vec_t = Array<arg_t, output_vec_size>;
+ using load_t = aligned_vector<scalar_t, output_vec_size>;
+ const load_t* data = reinterpret_cast<const load_t*>(data_);
+
+ // Multiple accumulators to remove dependency between unrolled loops.
+ arg_vec_t value_list[vt0];
+
+ #pragma unroll
+ for (int i = 0; i < vt0; i++) {
+ #pragma unroll
+ for (int j = 0; j < output_vec_size; j++) {
+ value_list[i][j] = ident;
+ }
+ }
+
+ load_t values[vt0];
+
+ while (idx + (vt0 - 1) * stride < end) {
+ #pragma unroll
+ for (uint32_t i = 0; i < vt0; i++) {
+ values[i] = data[calc(idx + i * stride) / output_vec_size];
+ }
+ #pragma unroll
+ for (uint32_t i = 0; i < vt0; i++) {
+ #pragma unroll
+ for (uint32_t j = 0; j < output_vec_size; j++) {
+ value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx + i * stride);
+ }
+ }
+ idx += stride * vt0;
+ }
+
+ // tail
+ int idx_ = idx;
+ #pragma unroll
+ for (uint32_t i = 0; i < vt0; i++) {
+ if (idx >= end) {
+ break;
+ }
+ values[i] = data[calc(idx) / output_vec_size];
+ idx += stride;
+ }
+ idx = idx_;
+ #pragma unroll
+ for (uint32_t i = 0; i < vt0; i++) {
+ if (idx >= end) {
+ break;
+ }
+ #pragma unroll
+ for (uint32_t j = 0; j < output_vec_size; j++) {
+ value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx);
+ }
+ idx += stride;
+ }
+
+ // combine accumulators
+ #pragma unroll
+ for (int i = 1; i < vt0; i++) {
+ #pragma unroll
+ for (uint32_t j = 0; j < output_vec_size; j++) {
+ value_list[0][j] = reducer::combine(value_list[0][j], value_list[i][j]);
+ }
+ }
+ return value_list[0];
+ }
+ template <int output_vec_size>
+ C10_DEVICE Array<arg_t, output_vec_size> block_x_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
+ using args_vec_t = Array<arg_t, output_vec_size>;
+ int dim_x = blockDim.x;
+ args_vec_t* shared = (args_vec_t*)shared_memory;
+ if (dim_x > warpSize) {
+ int address_base = threadIdx.x + threadIdx.y*blockDim.x;
+ shared[address_base] = value;
+ for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
+ __syncthreads();
+ if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
+ args_vec_t other = shared[address_base + offset];
+ #pragma unroll
+ for (int i = 0; i < output_vec_size; i++) {
+ value[i] = reducer::combine(value[i], other[i]);
+ }
+ shared[address_base] = value;
+ }
+ }
+ dim_x = warpSize;
+ }
+
+ __syncthreads();
+
+ for (int offset = 1; offset < dim_x; offset <<= 1) {
+ #pragma unroll
+ for (int i = 0; i < output_vec_size; i++) {
+ arg_t other = reducer::warp_shfl_down(value[i], offset);
+ value[i] = reducer::combine(value[i], other);
+ }
+ }
+ return value;
+ }
+
+ template <int output_vec_size>
+ C10_DEVICE Array<arg_t, output_vec_size> block_y_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
+ using args_vec_t = Array<arg_t, output_vec_size>;
+ args_vec_t* shared = (args_vec_t*)shared_memory;
+ shared[config.shared_memory_offset(0)] = value;
+ for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
+ __syncthreads();
+ if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
+ args_vec_t other = shared[config.shared_memory_offset(offset)];
+ #pragma unroll
+ for (int i = 0; i < output_vec_size; i++) {
+ value[i] = reducer::combine(value[i], other[i]);
+ }
+ shared[config.shared_memory_offset(0)] = value;
+ }
+ }
+ return value;
+ }
+ )ESCAPE";
+
+ const std::string reduction_template_1 = R"ESCAPE(
+
+ C10_DEVICE bool mark_block_finished() const {
+ __shared__ bool is_last_block_done_shared;
+
+ __syncthreads();
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
+ int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
+ is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
+ }
+
+ __syncthreads();
+
+ return is_last_block_done_shared;
+ }
+
+ template <int output_vec_size>
+ C10_DEVICE Array<arg_t, output_vec_size> accumulate_in_output(
+ Array<out_scalar_t*, output_vec_size> out,
+ Array<arg_t, output_vec_size> value
+ ) const {
+ Array<arg_t, output_vec_size> ret;
+ #pragma unroll
+ for (int i = 0; i < output_vec_size; i++) {
+ ret[i] = reducer::combine(*(out[i]), value[i]);
+ }
+ return ret;
+ }
+
+
+ C10_DEVICE out_scalar_t get_accumulated_output(
+ out_scalar_t* out, arg_t value
+ ) const {
+ assert(!final_output);
+ return (out_scalar_t)value;
+ }
+
+ template<class T>
+ C10_DEVICE void set_results(const T x, const uint32_t base_offset) const {
+ assert(noutputs == 1);
+ auto res = (out_scalar_t*)((char*)dst[0] + base_offset);
+ *res = x;
+ }
+
+//TODO - multi-output reduction - we won't be able to use thrust::pair
+//just explicitly specify typed output reads/writes
+//Currently implemented for max of two outputs
+// template<class T1, class T2>
+// C10_DEVICE void set_results(const thrust::pair<T1, T2> x, const index_t base_offset) const {
+// if (noutputs >= 1) {
+// auto res0 = (T1*)((char*)dst[0] + base_offset);
+// *res0 = x.first;
+// }
+// if (noutputs >= 2) {
+// // base offset is computed assuming element size being sizeof(T1), so we need to make a
+// // correction to obtain the correct base offset
+// auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2));
+// *res1 = x.second;
+// }
+// }
+
+ template <int output_vec_size>
+ C10_DEVICE void set_results_to_output(Array<arg_t, output_vec_size> value, Array<uint32_t, output_vec_size> base_offset) const {
+ assert(final_output);
+ #pragma unroll
+ for (int i = 0; i < output_vec_size; i++) {
+ set_results(reducer::project(value[i]), base_offset[i]);
+ }
+ }
+
+ template <int output_vec_size>
+ C10_DEVICE Array<arg_t, output_vec_size> global_reduce(Array<arg_t, output_vec_size> value, Array<arg_t, output_vec_size> *acc, char* shared_memory) const {
+ using arg_vec_t = Array<arg_t, output_vec_size>;
+ using out_ptr_vec_t = Array<out_scalar_t*, output_vec_size>;
+ using offset_vec_t = Array<uint32_t, output_vec_size>;
+
+ arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
+ uint32_t output_idx = config.output_idx<output_vec_size>();
+ offset_vec_t base_offsets;
+ out_ptr_vec_t out;
+
+ #pragma unroll
+ for (int i = 0; i < output_vec_size; i++) {
+ base_offsets[i] = output_calc.get(output_idx + i)[0];
+ out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
+ }
+
+ bool should_store = config.should_store(output_idx);
+ if (should_store) {
+ uint32_t offset = config.staging_memory_offset(blockIdx.y);
+ reduce_buffer[offset] = value;
+ }
+
+ __threadfence(); // make sure writes are globally visible
+ __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done
+ bool is_last_block_done = mark_block_finished();
+
+ if (is_last_block_done) {
+ value = ident;
+ if (config.should_block_x_reduce()) {
+ uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
+ uint32_t step = blockDim.x * blockDim.y;
+ for (; input_offset < config.ctas_per_output; input_offset += step) {
+ uint32_t idx = config.staging_memory_offset(input_offset);
+ arg_vec_t next = reduce_buffer[idx];
+ #pragma unroll
+ for (int i = 0; i < output_vec_size; i++) {
+ value[i] = reducer::combine(value[i], next[i]);
+ }
+ }
+ } else {
+ uint32_t input_offset = threadIdx.y;
+ uint32_t step = blockDim.y;
+ for (; input_offset < config.ctas_per_output; input_offset += step) {
+ uint32_t idx = config.staging_memory_offset(input_offset);
+ arg_vec_t next = reduce_buffer[idx];
+ #pragma unroll
+ for (int i = 0; i < output_vec_size; i++) {
+ value[i] = reducer::combine(value[i], next[i]);
+ }
+ }
+ }
+ value = block_y_reduce(value, shared_memory);
+ if (config.should_block_x_reduce()) {
+ value = block_x_reduce<output_vec_size>(value, shared_memory);
+ }
+ if (should_store) {
+ if (accumulate) {
+ #pragma unroll
+ for (int i = 0; i < output_vec_size; i++) {
+ value[i] = reducer::translate_idx(value[i], base_idx);
+ }
+ }
+
+ if (acc == nullptr) {
+ if (accumulate) {
+ value = accumulate_in_output<output_vec_size>(out, value);
+ }
+ if (final_output) {
+ set_results_to_output<output_vec_size>(value, base_offsets);
+ } else {
+ #pragma unroll
+ for (int i = 0; i < output_vec_size; i++) {
+ *(out[i]) = get_accumulated_output(out[i], value[i]);
+ }
+ }
+ } else {
+ if (accumulate) {
+ #pragma unroll
+ for (int i = 0; i < output_vec_size; i++) {
+ value[i] = reducer::combine((*acc)[i], value[i]);
+ }
+ }
+ if (final_output) {
+ set_results_to_output<output_vec_size>(value, base_offsets);
+ } else {
+ *acc = value;
+ }
+ }
+ }
+ }
+
+ return value;
+ }
+};
+
+extern "C"
+__launch_bounds__(${max_threads_lb}, 4)
+__global__ void reduction_${name}_kernel(ReduceJitOp r){
+ r.run();
+}
+)ESCAPE";
+
+const std::string reduction_template = reduction_template_0 + reduction_template_1;
+
+
+const std::string &get_reduction_template() {
+ return reduction_template;
+}
+
+}}