Add acc_gpu_kernel_with_scalars and port add to use it (#63884)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63884
See https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302
for explanation of what's going on here.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: ngimel
Differential Revision: D30545296
Pulled By: ezyang
fbshipit-source-id: f0da52153ae63599fe1d57e90e73f50ca2116939
diff --git a/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu b/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu
index a07fd66..b1c76e1 100644
--- a/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu
+++ b/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu
@@ -10,53 +10,20 @@
namespace at { namespace native {
-template<typename scalar_t, typename accscalar_t>
+template <typename T>
struct AddFunctor {
- AddFunctor(accscalar_t a): alpha(a) {}
- __device__ __forceinline__ scalar_t operator() (const scalar_t a, const scalar_t b) const {
- return a + alpha * b;
+ AddFunctor(T alpha) : alpha_(alpha) {}
+ T alpha_;
+ __device__ __forceinline__ T operator()(T a, T b) const __ubsan_ignore_undefined__ {
+ return a + b * alpha_;
}
- private:
- accscalar_t alpha;
-};
-
-template<typename scalar_t, typename accscalar_t, int SCALAR_ARG>
-struct AddScalarFunctor {
- static_assert(SCALAR_ARG == 1 || SCALAR_ARG == 2, "SCALAR_ARG must be either 1 or 2");
- AddScalarFunctor(accscalar_t alpha, accscalar_t b): alpha(alpha), b(b) {}
- __device__ __forceinline__ scalar_t operator() (const scalar_t a) const {
- return static_cast<scalar_t>(SCALAR_ARG == 1 ? b + alpha * a : a + alpha * b);
- }
- private:
- accscalar_t alpha;
- accscalar_t b;
};
void add_kernel_cuda(TensorIteratorBase& iter, const Scalar& alpha_scalar) {
- if (!isIntegralType(iter.common_dtype(), /* includeBool */ true) && (iter.is_cpu_scalar(1) || iter.is_cpu_scalar(2))) {
- // if common dtype is half the scalar constant can overflow in half precision, and yet the result can
- // still be representable in the half dtype. Cast scalar to acc_type to have better accuracy.
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() {
- using accscalar_t = at::acc_type<scalar_t, true>;
- int scalar_arg = iter.is_cpu_scalar(1) ? 1 : 2;
- auto b = iter.scalar_value<accscalar_t>(scalar_arg);
- iter.remove_operand(scalar_arg);
- const cuda::OptionalCUDAGuard device_guard(device_of(iter.tensor(1)));
- if (scalar_arg == 1) {
- AddScalarFunctor<scalar_t, decltype(b), 1> f(alpha_scalar.to<accscalar_t>(), b);
- gpu_kernel(iter, f);
- } else {
- AddScalarFunctor<scalar_t, decltype(b), 2> f(alpha_scalar.to<accscalar_t>(), b);
- gpu_kernel(iter, f);
- }
- });
- } else {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() {
- using accscalar_t = at::acc_type<scalar_t, true>;
- AddFunctor<scalar_t, accscalar_t> f(alpha_scalar.to<accscalar_t>());
- gpu_kernel_with_scalars(iter, f);
- });
- }
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() {
+ using opmath_t = at::opmath_type<scalar_t>;
+ opmath_gpu_kernel_with_scalars<scalar_t>(iter, AddFunctor<opmath_t>(alpha_scalar.to<opmath_t>()));
+ });
}
static void sub_kernel_cuda(TensorIteratorBase& iter, const Scalar& alpha_scalar) {
diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh
index fde8e86..8849293 100644
--- a/aten/src/ATen/native/cuda/Loops.cuh
+++ b/aten/src/ATen/native/cuda/Loops.cuh
@@ -5,6 +5,7 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TensorIteratorDynamicCasting.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
+#include <ATen/OpMathType.h>
#include <thrust/tuple.h>
@@ -111,49 +112,64 @@
gpu_kernel_impl(iter, f);
}
-template<typename func_t>
+template<typename arg1_t, typename arg2_t, typename return_t, typename func_t>
struct AUnaryFunctor {
using traits = function_traits<func_t>;
- using arg1_t = typename traits::template arg<0>::type;
- using arg2_t = typename traits::template arg<1>::type;
- using return_t = typename traits::result_type;
+ using opmath_arg1_t = typename traits::template arg<0>::type;
__device__ return_t operator()(arg2_t b) const {
return f(a, b);
}
- AUnaryFunctor(func_t f_, arg1_t a_): f(f_), a(a_) {}
+ // NB: scalar is stored in higher precision!
+ AUnaryFunctor(func_t f_, opmath_arg1_t a_): f(f_), a(a_) {}
private:
func_t f;
- arg1_t a;
+ opmath_arg1_t a;
};
-template<typename func_t>
+template<typename arg1_t, typename arg2_t, typename return_t, typename func_t>
struct BUnaryFunctor {
using traits = function_traits<func_t>;
- using arg1_t = typename traits::template arg<0>::type;
- using arg2_t = typename traits::template arg<1>::type;
- using return_t = typename traits::result_type;
+ using opmath_arg2_t = typename traits::template arg<1>::type;
__device__ return_t operator()(arg1_t a) const {
return f(a, b);
}
- BUnaryFunctor(func_t f_, arg2_t b_): f(f_), b(b_) {}
+ // NB: scalar is stored in higher precision!
+ BUnaryFunctor(func_t f_, opmath_arg2_t b_): f(f_), b(b_) {}
private:
func_t f;
- arg2_t b;
+ opmath_arg2_t b;
};
-template <typename func_t>
-void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
+// Though seemingly noop, this inserts casts from arg1_t to func_t's type
+// (which may be higher precision), as well as casts to return_t
+template <typename arg1_t, typename arg2_t, typename return_t, typename func_t>
+struct BinaryFunctor {
+ __device__ return_t operator()(arg1_t a, arg2_t b) const {
+ return f(a, b);
+ }
+ BinaryFunctor(func_t f_): f(f_) {}
+ private:
+ func_t f;
+};
+
+// Unlike gpu_kernel_with_scalars, this allows you to pass a func_t which
+// accepts inputs at higher precision (typically opmath_t), but then
+// ensure that we load from memory at the correct precision (scalar_t)
+// to avoid expensive loads. For the whole sordid story see
+// https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302
+template <typename arg1_t, typename arg2_t = arg1_t, typename return_t = arg1_t, typename func_t>
+void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
using traits = function_traits<func_t>;
+ using opmath_arg1_t = typename traits::template arg<0>::type;
+ using opmath_arg2_t = typename traits::template arg<1>::type;
static_assert(
traits::arity == 2,
"gpu_kernel_with_scalars only supports two input arguments");
- using arg1_t = typename traits::template arg<0>::type;
- using arg2_t = typename traits::template arg<1>::type;
if (iter.is_cpu_scalar(1)) {
- AUnaryFunctor<func_t> af(f, iter.scalar_value<arg1_t>(1));
+ AUnaryFunctor<arg1_t, arg2_t, return_t, func_t> af(f, iter.scalar_value<opmath_arg1_t>(1));
iter.remove_operand(1);
// TODO: When all kernels that use gpu_kernel_with_scalars are
// ported to structured, this device guard can be deleted. This
@@ -163,14 +179,28 @@
const OptionalDeviceGuard device_guard(device_of(iter.tensor(1)));
gpu_kernel(iter, af);
} else if (iter.is_cpu_scalar(2)) {
- BUnaryFunctor<func_t> bf(f, iter.scalar_value<arg2_t>(2));
+ BUnaryFunctor<arg1_t, arg2_t, return_t, func_t> bf(f, iter.scalar_value<opmath_arg2_t>(2));
iter.remove_operand(2);
gpu_kernel(iter, bf);
} else {
- gpu_kernel(iter, f);
+ gpu_kernel(iter, BinaryFunctor<arg1_t, arg2_t, return_t, func_t>(f));
}
}
+// Legacy variant that assumes that func_t has the correct types
+// that we expect to load from memory
+template <typename func_t>
+void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
+ using traits = function_traits<func_t>;
+ static_assert(
+ traits::arity == 2,
+ "gpu_kernel_with_scalars only supports two input arguments");
+ using arg1_t = typename traits::template arg<0>::type;
+ using arg2_t = typename traits::template arg<1>::type;
+ using return_t = typename traits::result_type;
+ opmath_gpu_kernel_with_scalars<arg1_t, arg2_t, return_t, func_t>(iter, f);
+}
+
namespace { // functions for `gpu_kernel_multiple_outputs`.
// check the return type is `thrust::tuple`, not `std::tuple`.