Revert D19825127: [pytorch][PR] Move where cuda implementation to TensorIterator
Test Plan: revert-hammer
Differential Revision:
D19825127
Original commit changeset: bbf4682349d9
fbshipit-source-id: 0c439b8c9a00a5aa46fd196396cf7cc83cddb1b4
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index 595f8ab..6c2517f 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -6,12 +6,42 @@
#include <ATen/native/ReduceOpsUtils.h>
#include <c10/util/Exception.h>
#include <ATen/native/cpu/TensorCompareKernel.h>
+#include <ATen/native/cpu/Loops.h>
#include <ATen/NamedTensorUtils.h>
+namespace {
+template <typename scalar_t>
+void where_cpu(
+ at::Tensor& ret,
+ const at::Tensor& condition,
+ const at::Tensor& self,
+ const at::Tensor& other) {
+ auto iter = at::TensorIterator();
+ iter.set_check_mem_overlap(true);
+ iter.add_output(ret);
+ iter.add_input(condition);
+ iter.add_input(self);
+ iter.add_input(other);
+ iter.dont_compute_common_dtype();
+ iter.build();
+ if (condition.scalar_type() == at::ScalarType::Byte) {
+ at::native::cpu_kernel(
+ iter,
+ [=](uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
+ return cond_val ? self_val : other_val;
+ });
+ } else {
+ at::native::cpu_kernel(
+ iter,
+ [=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
+ return cond_val ? self_val : other_val;
+ });
+ }
+}
+} // namespace
namespace at { namespace native {
-DEFINE_DISPATCH(where_kernel);
DEFINE_DISPATCH(max_kernel);
DEFINE_DISPATCH(min_kernel);
@@ -118,18 +148,12 @@
return condition.nonzero_numpy();
}
-Tensor _s_where(const Tensor& condition, const Tensor& self, const Tensor& other) {
+Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& other) {
TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
Tensor ret = at::empty(self.sizes(), self.options());
- auto iter = at::TensorIterator();
- iter.set_check_mem_overlap(true);
- iter.add_output(ret);
- iter.add_input(condition);
- iter.add_input(self);
- iter.add_input(other);
- iter.dont_compute_common_dtype();
- iter.build();
- where_kernel(iter.device_type(), iter, condition.scalar_type());
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX(ret.scalar_type(), "where_cpu", [&] {
+ where_cpu<scalar_t>(ret, condition, self, other);
+ });
return ret;
}
diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
index 641195a..fb7d93c 100644
--- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
+++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
@@ -1,5 +1,4 @@
#include <ATen/native/cpu/TensorCompareKernel.h>
-#include <ATen/native/cpu/Loops.h>
#include <numeric>
#include <iterator>
@@ -102,28 +101,9 @@
});
}
-static void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "where_cpu", [&] {
- if (condition_type == at::ScalarType::Byte) {
- at::native::cpu_kernel(
- iter,
- [=](uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
- return cond_val ? self_val : other_val;
- });
- } else {
- at::native::cpu_kernel(
- iter,
- [=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
- return cond_val ? self_val : other_val;
- });
- }
- });
-}
-
} // anonymous namespace
REGISTER_DISPATCH(max_kernel, &max_kernel_impl);
REGISTER_DISPATCH(min_kernel, &min_kernel_impl);
-REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
}} // namespace at::native
diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.h b/aten/src/ATen/native/cpu/TensorCompareKernel.h
index d1cb033..a23792d 100644
--- a/aten/src/ATen/native/cpu/TensorCompareKernel.h
+++ b/aten/src/ATen/native/cpu/TensorCompareKernel.h
@@ -3,7 +3,6 @@
#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>
#include <c10/util/Optional.h>
-#include <ATen/native/TensorIterator.h>
namespace at { namespace native {
@@ -13,7 +12,4 @@
DECLARE_DISPATCH(reduce_fn, max_kernel);
DECLARE_DISPATCH(reduce_fn, min_kernel);
-using where_fn = void (*)(TensorIterator &, ScalarType);
-DECLARE_DISPATCH(where_fn, where_kernel);
-
}} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh
index aa396b2..0c1c4f4 100644
--- a/aten/src/ATen/native/cuda/CUDALoops.cuh
+++ b/aten/src/ATen/native/cuda/CUDALoops.cuh
@@ -255,7 +255,7 @@
}
// TODO (@zasdfgbnm): this function assume trivial 1d and no dynamic casting
-template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<detail::has_same_arg_types<func_t>::value, int> = 0>
+template<int nt, int vt, typename func_t, typename array_t>
static void launch_kernel(int64_t N, const func_t& f, array_t data) {
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
if (N == 0) {
@@ -281,9 +281,191 @@
AT_CUDA_CHECK(cudaGetLastError());
}
-template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<!detail::has_same_arg_types<func_t>::value, int> = 0>
-static void launch_kernel(int64_t N, const func_t& f, array_t data) {}
-
} // namespace modern
+template<typename func_t, int nargs=function_traits<func_t>::arity>
+struct needs_dynamic_casting {
+ static bool check(TensorIterator& iter) {
+ using traits = function_traits<func_t>;
+ if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
+ return true;
+ }
+ return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
+ }
+};
+
+template<typename func_t>
+struct needs_dynamic_casting<func_t, 0> {
+ static bool check(TensorIterator& iter) {
+ using traits = function_traits<func_t>;
+ return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
+ }
+};
+
+template <typename func_t>
+void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
+ using traits = function_traits<func_t>;
+ using arg0_t = typename traits::result_type;
+ constexpr int ntensors = traits::arity + 1;
+
+ TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
+ TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
+
+ at::detail::Array<char*, ntensors> data;
+ for (int i = 0; i < ntensors; i++) {
+ data[i] = (char*)iter.data_ptr(i);
+ }
+
+ at::detail::Array<ScalarType, ntensors> dtypes;
+ for (int i = 0; i < ntensors; i++) {
+ dtypes[i] = iter.tensor(i).scalar_type();
+ }
+
+ int64_t numel = iter.numel();
+ if (iter.is_trivial_1d()) {
+ auto inner_strides = iter.get_inner_strides();
+ at::detail::Array<int, ntensors> strides;
+ for (int i = 0; i < ntensors; i++) {
+ strides[i] = inner_strides[i];
+ }
+
+ if (needs_dynamic_casting<func_t>::check(iter)) {
+ legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+ void* out = data[0] + strides[0] * idx;
+ arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
+ c10::cast_and_store<arg0_t>(dtypes[0], out, result);
+ });
+ } else if (iter.has_contiguous_first_dim()) {
+ modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
+ } else {
+ legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+ arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
+ *out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
+ });
+ }
+ } else {
+ auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
+ if (needs_dynamic_casting<func_t>::check(iter)) {
+ legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+ auto offsets = offset_calc.get(idx);
+ void* out = data[0] + offsets[0];
+ arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
+ c10::cast_and_store<arg0_t>(dtypes[0], out, result);
+ });
+ } else {
+ legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+ auto offsets = offset_calc.get(idx);
+ arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
+ *out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
+ });
+ }
+ }
+}
+
+template <typename func_t>
+void gpu_kernel(TensorIterator& iter, const func_t& f) {
+ ASSERT_HOST_DEVICE_LAMBDA(func_t);
+
+ for (int arg = 0; arg < iter.ntensors(); arg++) {
+ TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
+ }
+
+ if (iter.numel() == 0) {
+ return;
+ }
+
+ if (!iter.can_use_32bit_indexing()) {
+ for (auto& sub_iter : iter.with_32bit_indexing()) {
+ gpu_kernel(sub_iter, f);
+ }
+ return;
+ }
+
+ gpu_kernel_impl(iter, f);
+}
+
+template <typename func_t>
+void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
+ ASSERT_HOST_DEVICE_LAMBDA(func_t);
+ TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
+
+ using traits = function_traits<func_t>;
+ static_assert(
+ traits::arity == 2,
+ "gpu_kernel_with_scalars only supports two input arguments");
+
+ if (iter.is_cpu_scalar(1)) {
+ using arg1_t = typename traits::template arg<0>::type;
+ using arg2_t = typename traits::template arg<1>::type;
+ auto a = iter.scalar_value<arg1_t>(1);
+ iter.remove_operand(1);
+ gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
+ return f(a, b);
+ });
+ } else if (iter.is_cpu_scalar(2)) {
+ using arg1_t = typename traits::template arg<0>::type;
+ using arg2_t = typename traits::template arg<1>::type;
+ auto b = iter.scalar_value<arg2_t>(2);
+ iter.remove_operand(2);
+ gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
+ return f(a, b);
+ });
+ } else {
+ gpu_kernel(iter, f);
+ }
+}
+
+template <typename func_t>
+void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
+ using traits = function_traits<func_t>;
+ using arg0_t = typename traits::result_type;
+
+
+ // Note:
+ // `gpu_kernel_with_index` was originally implemented in PR #28175 with support
+ // of having an arbitrary number of tensors as arguments. This support was removed
+ // during the process of refactoring Loops.cuh to support vectorized memory access
+ // in PR #32777 (See also issue #31975). The removal of this support is soly because
+ // at that time, there is no operator using that functionality. If you need this
+ // functionality, feel free to add it back.
+ static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
+
+ TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
+
+ char* data = (char*)iter.data_ptr(0);
+
+ int64_t numel = iter.numel();
+ if (iter.is_trivial_1d()) {
+ int stride = iter.get_inner_strides()[0];
+ legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+ arg0_t* out = (arg0_t*)(data + stride * idx);
+ *out = f(idx);
+ });
+ } else {
+ auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
+ legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+ auto offsets = offset_calc.get(idx);
+ arg0_t* out = (arg0_t*)(data + offsets[0]);
+ *out = f(idx);
+ });
+ }
+}
+
+template <typename func_t>
+void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
+ ASSERT_HOST_DEVICE_LAMBDA(func_t);
+
+ TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
+
+ if (iter.numel() == 0) {
+ return;
+ }
+
+ // Split will change index, thus is not supported
+ // The caller should handle the split and pass in different func
+ TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
+
+ gpu_kernel_with_index_impl(iter, f);
+}
+
}} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh
index 856487a..e2271f5 100644
--- a/aten/src/ATen/native/cuda/Loops.cuh
+++ b/aten/src/ATen/native/cuda/Loops.cuh
@@ -1,31 +1,3 @@
-
-#pragma once
-
-#include <ATen/detail/FunctionTraits.h>
-
-namespace at { namespace native { namespace modern { namespace detail {
-
-template<typename func_t, int remaining=function_traits<func_t>::arity-1>
-struct has_same_arg_types {
- using traits = function_traits<func_t>;
- static constexpr bool value = std::is_same<
- typename traits::template arg<remaining>::type,
- typename traits::template arg<remaining-1>::type
- >::value && has_same_arg_types<func_t, remaining-1>::value;
-};
-
-template<typename func_t>
-struct has_same_arg_types<func_t, 0> {
- static constexpr bool value = true;
-};
-
-template<typename func_t>
-struct has_same_arg_types<func_t, -1> {
- static constexpr bool value = true;
-};
-
-}}}} // namespace at::native::modern::detail
-
// Note:
// CUDA and ROCm get diverged in this PR:
// https://github.com/pytorch/pytorch/pull/32383
@@ -37,195 +9,3 @@
#else
#include <ATen/native/cuda/ROCmLoops.cuh>
#endif
-
-namespace at { namespace native {
-
-// `needs_dynamic_casting` compares the types expected by iterator
-// (i.e. dtypes of the operands) with the actual type of the arguments
-// of func_t
-template<typename func_t, int nargs=function_traits<func_t>::arity>
-struct needs_dynamic_casting {
- static bool check(TensorIterator& iter) {
- using traits = function_traits<func_t>;
- if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
- return true;
- }
- return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
- }
-};
-
-template<typename func_t>
-struct needs_dynamic_casting<func_t, 0> {
- static bool check(TensorIterator& iter) {
- using traits = function_traits<func_t>;
- return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
- }
-};
-
-template <typename func_t>
-void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
- using traits = function_traits<func_t>;
- using arg0_t = typename traits::result_type;
- constexpr int ntensors = traits::arity + 1;
-
- TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
- TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
-
- at::detail::Array<char*, ntensors> data;
- for (int i = 0; i < ntensors; i++) {
- data[i] = (char*)iter.data_ptr(i);
- }
-
- at::detail::Array<ScalarType, ntensors> dtypes;
- for (int i = 0; i < ntensors; i++) {
- dtypes[i] = iter.tensor(i).scalar_type();
- }
-
- int64_t numel = iter.numel();
- if (iter.is_trivial_1d()) {
- auto inner_strides = iter.get_inner_strides();
- at::detail::Array<int, ntensors> strides;
- for (int i = 0; i < ntensors; i++) {
- strides[i] = inner_strides[i];
- }
-
- if (needs_dynamic_casting<func_t>::check(iter)) {
- legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
- void* out = data[0] + strides[0] * idx;
- arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
- c10::cast_and_store<arg0_t>(dtypes[0], out, result);
- });
- } else if (iter.has_contiguous_first_dim() && modern::detail::has_same_arg_types<func_t>::value) {
- modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
- } else {
- legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
- arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
- *out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
- });
- }
- } else {
- auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
- if (needs_dynamic_casting<func_t>::check(iter)) {
- legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
- auto offsets = offset_calc.get(idx);
- void* out = data[0] + offsets[0];
- arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
- c10::cast_and_store<arg0_t>(dtypes[0], out, result);
- });
- } else {
- legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
- auto offsets = offset_calc.get(idx);
- arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
- *out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
- });
- }
- }
-}
-
-template <typename func_t>
-void gpu_kernel(TensorIterator& iter, const func_t& f) {
- ASSERT_HOST_DEVICE_LAMBDA(func_t);
-
- for (int arg = 0; arg < iter.ntensors(); arg++) {
- TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
- }
-
- if (iter.numel() == 0) {
- return;
- }
-
- if (!iter.can_use_32bit_indexing()) {
- for (auto& sub_iter : iter.with_32bit_indexing()) {
- gpu_kernel(sub_iter, f);
- }
- return;
- }
-
- gpu_kernel_impl(iter, f);
-}
-
-template <typename func_t>
-void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
- ASSERT_HOST_DEVICE_LAMBDA(func_t);
- TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
-
- using traits = function_traits<func_t>;
- static_assert(
- traits::arity == 2,
- "gpu_kernel_with_scalars only supports two input arguments");
-
- if (iter.is_cpu_scalar(1)) {
- using arg1_t = typename traits::template arg<0>::type;
- using arg2_t = typename traits::template arg<1>::type;
- auto a = iter.scalar_value<arg1_t>(1);
- iter.remove_operand(1);
- gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
- return f(a, b);
- });
- } else if (iter.is_cpu_scalar(2)) {
- using arg1_t = typename traits::template arg<0>::type;
- using arg2_t = typename traits::template arg<1>::type;
- auto b = iter.scalar_value<arg2_t>(2);
- iter.remove_operand(2);
- gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
- return f(a, b);
- });
- } else {
- gpu_kernel(iter, f);
- }
-}
-
-template <typename func_t>
-void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
- using traits = function_traits<func_t>;
- using arg0_t = typename traits::result_type;
-
-
- // Note:
- // `gpu_kernel_with_index` was originally implemented in PR #28175 with support
- // of having an arbitrary number of tensors as arguments. This support was removed
- // during the process of refactoring Loops.cuh to support vectorized memory access
- // in PR #32777 (See also issue #31975). The removal of this support is soly because
- // at that time, there is no operator using that functionality. If you need this
- // functionality, feel free to add it back.
- static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
-
- TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
-
- char* data = (char*)iter.data_ptr(0);
-
- int64_t numel = iter.numel();
- if (iter.is_trivial_1d()) {
- int stride = iter.get_inner_strides()[0];
- legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
- arg0_t* out = (arg0_t*)(data + stride * idx);
- *out = f(idx);
- });
- } else {
- auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
- legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
- auto offsets = offset_calc.get(idx);
- arg0_t* out = (arg0_t*)(data + offsets[0]);
- *out = f(idx);
- });
- }
-}
-
-template <typename func_t>
-void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
- ASSERT_HOST_DEVICE_LAMBDA(func_t);
-
- TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
-
- if (iter.numel() == 0) {
- return;
- }
-
- // Split will change index, thus is not supported
- // The caller should handle the split and pass in different func
- TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
-
- gpu_kernel_with_index_impl(iter, f);
-}
-
-}} //namespace at::native
\ No newline at end of file
diff --git a/aten/src/ATen/native/cuda/ROCmLoops.cuh b/aten/src/ATen/native/cuda/ROCmLoops.cuh
index 19fc35a..a2a97c3 100644
--- a/aten/src/ATen/native/cuda/ROCmLoops.cuh
+++ b/aten/src/ATen/native/cuda/ROCmLoops.cuh
@@ -240,7 +240,7 @@
}
// TODO (@zasdfgbnm): this function assume trivial 1d and no dynamic casting
-template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<detail::has_same_arg_types<func_t>::value, int> = 0>
+template<int nt, int vt, typename func_t, typename array_t>
static void launch_kernel(int64_t N, const func_t& f, array_t data) {
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
if (N == 0) {
@@ -253,9 +253,191 @@
AT_CUDA_CHECK(cudaGetLastError());
}
-template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<!detail::has_same_arg_types<func_t>::value, int> = 0>
-static void launch_kernel(int64_t N, const func_t& f, array_t data) {}
-
} // namespace modern
+template<typename func_t, int nargs=function_traits<func_t>::arity>
+struct needs_dynamic_casting {
+ static bool check(TensorIterator& iter) {
+ using traits = function_traits<func_t>;
+ if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
+ return true;
+ }
+ return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
+ }
+};
+
+template<typename func_t>
+struct needs_dynamic_casting<func_t, 0> {
+ static bool check(TensorIterator& iter) {
+ using traits = function_traits<func_t>;
+ return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
+ }
+};
+
+template <typename func_t>
+void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
+ using traits = function_traits<func_t>;
+ using arg0_t = typename traits::result_type;
+ constexpr int ntensors = traits::arity + 1;
+
+ TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
+ TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
+
+ at::detail::Array<char*, ntensors> data;
+ for (int i = 0; i < ntensors; i++) {
+ data[i] = (char*)iter.data_ptr(i);
+ }
+
+ at::detail::Array<ScalarType, ntensors> dtypes;
+ for (int i = 0; i < ntensors; i++) {
+ dtypes[i] = iter.tensor(i).scalar_type();
+ }
+
+ int64_t numel = iter.numel();
+ if (iter.is_trivial_1d()) {
+ auto inner_strides = iter.get_inner_strides();
+ at::detail::Array<int, ntensors> strides;
+ for (int i = 0; i < ntensors; i++) {
+ strides[i] = inner_strides[i];
+ }
+
+ if (needs_dynamic_casting<func_t>::check(iter)) {
+ legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+ void* out = data[0] + strides[0] * idx;
+ arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
+ c10::cast_and_store<arg0_t>(dtypes[0], out, result);
+ });
+ } else if (iter.has_contiguous_first_dim()) {
+ modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
+ } else {
+ legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+ arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
+ *out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
+ });
+ }
+ } else {
+ auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
+ if (needs_dynamic_casting<func_t>::check(iter)) {
+ legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+ auto offsets = offset_calc.get(idx);
+ void* out = data[0] + offsets[0];
+ arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
+ c10::cast_and_store<arg0_t>(dtypes[0], out, result);
+ });
+ } else {
+ legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+ auto offsets = offset_calc.get(idx);
+ arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
+ *out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
+ });
+ }
+ }
+}
+
+template <typename func_t>
+void gpu_kernel(TensorIterator& iter, const func_t& f) {
+ ASSERT_HOST_DEVICE_LAMBDA(func_t);
+
+ for (int arg = 0; arg < iter.ntensors(); arg++) {
+ TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
+ }
+
+ if (iter.numel() == 0) {
+ return;
+ }
+
+ if (!iter.can_use_32bit_indexing()) {
+ for (auto& sub_iter : iter.with_32bit_indexing()) {
+ gpu_kernel(sub_iter, f);
+ }
+ return;
+ }
+
+ gpu_kernel_impl(iter, f);
+}
+
+template <typename func_t>
+void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
+ ASSERT_HOST_DEVICE_LAMBDA(func_t);
+ TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
+
+ using traits = function_traits<func_t>;
+ static_assert(
+ traits::arity == 2,
+ "gpu_kernel_with_scalars only supports two input arguments");
+
+ if (iter.is_cpu_scalar(1)) {
+ using arg1_t = typename traits::template arg<0>::type;
+ using arg2_t = typename traits::template arg<1>::type;
+ auto a = iter.scalar_value<arg1_t>(1);
+ iter.remove_operand(1);
+ gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
+ return f(a, b);
+ });
+ } else if (iter.is_cpu_scalar(2)) {
+ using arg1_t = typename traits::template arg<0>::type;
+ using arg2_t = typename traits::template arg<1>::type;
+ auto b = iter.scalar_value<arg2_t>(2);
+ iter.remove_operand(2);
+ gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
+ return f(a, b);
+ });
+ } else {
+ gpu_kernel(iter, f);
+ }
+}
+
+template <typename func_t>
+void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
+ using traits = function_traits<func_t>;
+ using arg0_t = typename traits::result_type;
+
+
+ // Note:
+ // `gpu_kernel_with_index` was originally implemented in PR #28175 with support
+ // of having an arbitrary number of tensors as arguments. This support was removed
+ // during the process of refactoring Loops.cuh to support vectorized memory access
+ // in PR #32777 (See also issue #31975). The removal of this support is soly because
+ // at that time, there is no operator using that functionality. If you need this
+ // functionality, feel free to add it back.
+ static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
+
+ TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
+
+ char* data = (char*)iter.data_ptr(0);
+
+ int64_t numel = iter.numel();
+ if (iter.is_trivial_1d()) {
+ int stride = iter.get_inner_strides()[0];
+ legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+ arg0_t* out = (arg0_t*)(data + stride * idx);
+ *out = f(idx);
+ });
+ } else {
+ auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
+ legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+ auto offsets = offset_calc.get(idx);
+ arg0_t* out = (arg0_t*)(data + offsets[0]);
+ *out = f(idx);
+ });
+ }
+}
+
+template <typename func_t>
+void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
+ ASSERT_HOST_DEVICE_LAMBDA(func_t);
+
+ TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
+
+ if (iter.numel() == 0) {
+ return;
+ }
+
+ // Split will change index, thus is not supported
+ // The caller should handle the split and pass in different func
+ TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
+
+ gpu_kernel_with_index_impl(iter, f);
+}
+
}} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu
index 0dbedcf..e1c3e73 100644
--- a/aten/src/ATen/native/cuda/TensorCompare.cu
+++ b/aten/src/ATen/native/cuda/TensorCompare.cu
@@ -1,38 +1,56 @@
#include <ATen/NativeFunctions.h>
#include <ATen/Dispatch.h>
-#include <ATen/native/DispatchStub.h>
-#include <ATen/native/cuda/Loops.cuh>
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
+namespace {
+template <typename scalar_t>
+void where_cuda(
+ at::Tensor& ret,
+ const at::Tensor& condition,
+ const at::Tensor& self,
+ const at::Tensor& other) {
+ if (condition.scalar_type() == at::ScalarType::Byte) {
+ // Yes this name is repetitive, but the CPU version is called
+ // CPU_tensor_apply4 and we don't have a CPU namespace or directory.
+ at::cuda::CUDA_tensor_apply4<scalar_t, uint8_t, scalar_t, scalar_t>(
+ ret,
+ condition,
+ self,
+ other,
+ [] __device__(
+ scalar_t & ret_val,
+ const uint8_t& cond_val,
+ const scalar_t& self_val,
+ const scalar_t& other_val) {
+ ret_val = cond_val ? self_val : other_val;
+ });
+ } else {
+ at::cuda::CUDA_tensor_apply4<scalar_t, bool, scalar_t, scalar_t>(
+ ret,
+ condition,
+ self,
+ other,
+ [] __device__(
+ scalar_t & ret_val,
+ const bool& cond_val,
+ const scalar_t& self_val,
+ const scalar_t& other_val) {
+ ret_val = cond_val ? self_val : other_val;
+ });
+ }
+}
+} // namespace
namespace at { namespace native {
-
-using where_fn = void (*)(TensorIterator &, ScalarType);
-DECLARE_DISPATCH(where_fn, where_kernel);
-
-namespace {
-
-void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.dtype(), "where_cuda", [&] {
- if (condition_type == at::ScalarType::Byte) {
- gpu_kernel(
- iter,
- [=] GPU_LAMBDA (uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
- return cond_val ? self_val : other_val;
- });
- } else {
- gpu_kernel(
- iter,
- [=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
- return cond_val ? self_val : other_val;
- });
- }
+Tensor _s_where_cuda(
+ const Tensor& condition,
+ const Tensor& self,
+ const Tensor& other) {
+ Tensor ret = at::empty(self.sizes(), self.options());
+ AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, ret.scalar_type(), "where_cuda", [&] {
+ where_cuda<scalar_t>(ret, condition, self, other);
});
+ return ret;
}
-
-} // anonymous namespace
-
-
-REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
-
}} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 98c061b..d33a9d0 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2942,6 +2942,9 @@
- func: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: function
+ dispatch:
+ CPU: _s_where_cpu
+ CUDA: _s_where_cuda
- func: norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor
variants: function
diff --git a/aten/src/ATen/test/cuda_vectorized_test.cu b/aten/src/ATen/test/cuda_vectorized_test.cu
index ac57be0..96b70d9 100644
--- a/aten/src/ATen/test/cuda_vectorized_test.cu
+++ b/aten/src/ATen/test/cuda_vectorized_test.cu
@@ -1,7 +1,6 @@
#include <gtest/gtest.h>
#include <ATen/ATen.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
-#include <ATen/native/cuda/Loops.cuh>
#include <ATen/cuda/CUDAContext.h>
using namespace at::native::memory;
@@ -22,21 +21,6 @@
}
}
-TEST(TestLoops, HasSameArgTypes) {
- // This is a compile-time unit test. If this file compiles without error,
- // then the test passes and during runtime, we just need to return.
- using namespace at::native::modern::detail;
- using func1_t = int (*)(float, float);
- using func2_t = int (*)(bool, float, float);
- using func3_t = int (*)(float);
- using func4_t = int (*)();
- static_assert(has_same_arg_types<func1_t>::value, "func1_t has the same argument types");
- static_assert(!has_same_arg_types<func2_t>::value, "func2_t does not have the same argument types");
- static_assert(has_same_arg_types<func3_t>::value, "func3_t has the same argument types");
- static_assert(has_same_arg_types<func4_t>::value, "func4_t has the same argument types");
- return;
-}
-
TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
char *ptr = reinterpret_cast<char *>(buffer1);