Revert D19825127: [pytorch][PR] Move where cuda implementation to TensorIterator

Test Plan: revert-hammer

Differential Revision:
D19825127

Original commit changeset: bbf4682349d9

fbshipit-source-id: 0c439b8c9a00a5aa46fd196396cf7cc83cddb1b4
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index 595f8ab..6c2517f 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -6,12 +6,42 @@
 #include <ATen/native/ReduceOpsUtils.h>
 #include <c10/util/Exception.h>
 #include <ATen/native/cpu/TensorCompareKernel.h>
+#include <ATen/native/cpu/Loops.h>
 #include <ATen/NamedTensorUtils.h>
 
+namespace {
+template <typename scalar_t>
+void where_cpu(
+    at::Tensor& ret,
+    const at::Tensor& condition,
+    const at::Tensor& self,
+    const at::Tensor& other) {
+  auto iter = at::TensorIterator();
+  iter.set_check_mem_overlap(true);
+  iter.add_output(ret);
+  iter.add_input(condition);
+  iter.add_input(self);
+  iter.add_input(other);
+  iter.dont_compute_common_dtype();
+  iter.build();
+  if (condition.scalar_type() == at::ScalarType::Byte) {
+    at::native::cpu_kernel(
+      iter,
+      [=](uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
+        return cond_val ? self_val : other_val;
+      });
+  } else {
+    at::native::cpu_kernel(
+      iter,
+      [=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
+        return cond_val ? self_val : other_val;
+      });
+  }
+}
+} // namespace
 
 namespace at { namespace native {
 
-DEFINE_DISPATCH(where_kernel);
 DEFINE_DISPATCH(max_kernel);
 DEFINE_DISPATCH(min_kernel);
 
@@ -118,18 +148,12 @@
   return condition.nonzero_numpy();
 }
 
-Tensor _s_where(const Tensor& condition, const Tensor& self, const Tensor& other) {
+Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& other) {
   TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
   Tensor ret = at::empty(self.sizes(), self.options());
-  auto iter = at::TensorIterator();
-  iter.set_check_mem_overlap(true);
-  iter.add_output(ret);
-  iter.add_input(condition);
-  iter.add_input(self);
-  iter.add_input(other);
-  iter.dont_compute_common_dtype();
-  iter.build();
-  where_kernel(iter.device_type(), iter, condition.scalar_type());
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX(ret.scalar_type(), "where_cpu", [&] {
+    where_cpu<scalar_t>(ret, condition, self, other);
+  });
   return ret;
 }
 
diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
index 641195a..fb7d93c 100644
--- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
+++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
@@ -1,5 +1,4 @@
 #include <ATen/native/cpu/TensorCompareKernel.h>
-#include <ATen/native/cpu/Loops.h>
 
 #include <numeric>
 #include <iterator>
@@ -102,28 +101,9 @@
   });
 }
 
-static void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) {
-  AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "where_cpu", [&] {
-    if (condition_type == at::ScalarType::Byte) {
-      at::native::cpu_kernel(
-        iter,
-        [=](uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
-          return cond_val ? self_val : other_val;
-        });
-    } else {
-      at::native::cpu_kernel(
-        iter,
-        [=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
-          return cond_val ? self_val : other_val;
-        });
-    }
-  });
-}
-
 } // anonymous namespace
 
 REGISTER_DISPATCH(max_kernel, &max_kernel_impl);
 REGISTER_DISPATCH(min_kernel, &min_kernel_impl);
-REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
 
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.h b/aten/src/ATen/native/cpu/TensorCompareKernel.h
index d1cb033..a23792d 100644
--- a/aten/src/ATen/native/cpu/TensorCompareKernel.h
+++ b/aten/src/ATen/native/cpu/TensorCompareKernel.h
@@ -3,7 +3,6 @@
 #include <ATen/ATen.h>
 #include <ATen/native/DispatchStub.h>
 #include <c10/util/Optional.h>
-#include <ATen/native/TensorIterator.h>
 
 namespace at { namespace native {
 
@@ -13,7 +12,4 @@
 DECLARE_DISPATCH(reduce_fn, max_kernel);
 DECLARE_DISPATCH(reduce_fn, min_kernel);
 
-using where_fn = void (*)(TensorIterator &, ScalarType);
-DECLARE_DISPATCH(where_fn, where_kernel);
-
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh
index aa396b2..0c1c4f4 100644
--- a/aten/src/ATen/native/cuda/CUDALoops.cuh
+++ b/aten/src/ATen/native/cuda/CUDALoops.cuh
@@ -255,7 +255,7 @@
 }
 
 // TODO (@zasdfgbnm): this function assume trivial 1d and no dynamic casting
-template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<detail::has_same_arg_types<func_t>::value, int> = 0>
+template<int nt, int vt, typename func_t, typename array_t>
 static void launch_kernel(int64_t N, const func_t& f, array_t data) {
   TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
   if (N == 0) {
@@ -281,9 +281,191 @@
   AT_CUDA_CHECK(cudaGetLastError());
 }
 
-template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<!detail::has_same_arg_types<func_t>::value, int> = 0>
-static void launch_kernel(int64_t N, const func_t& f, array_t data) {}
-
 } // namespace modern
 
+template<typename func_t, int nargs=function_traits<func_t>::arity>
+struct needs_dynamic_casting {
+  static bool check(TensorIterator& iter) {
+    using traits = function_traits<func_t>;
+    if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
+      return true;
+    }
+    return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
+  }
+};
+
+template<typename func_t>
+struct needs_dynamic_casting<func_t, 0> {
+  static bool check(TensorIterator& iter) {
+    using traits = function_traits<func_t>;
+    return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
+  }
+};
+
+template <typename func_t>
+void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
+  using traits = function_traits<func_t>;
+  using arg0_t = typename traits::result_type;
+  constexpr int ntensors = traits::arity + 1;
+
+  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
+  TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
+
+  at::detail::Array<char*, ntensors> data;
+  for (int i = 0; i < ntensors; i++) {
+    data[i] = (char*)iter.data_ptr(i);
+  }
+
+  at::detail::Array<ScalarType, ntensors> dtypes;
+  for (int i = 0; i < ntensors; i++) {
+    dtypes[i] = iter.tensor(i).scalar_type();
+  }
+
+  int64_t numel = iter.numel();
+  if (iter.is_trivial_1d()) {
+    auto inner_strides = iter.get_inner_strides();
+    at::detail::Array<int, ntensors> strides;
+    for (int i = 0; i < ntensors; i++) {
+      strides[i] = inner_strides[i];
+    }
+
+    if (needs_dynamic_casting<func_t>::check(iter)) {
+      legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+        void* out = data[0] + strides[0] * idx;
+        arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
+        c10::cast_and_store<arg0_t>(dtypes[0], out, result);
+      });
+    } else if (iter.has_contiguous_first_dim()) {
+      modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
+    } else {
+      legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+        arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
+        *out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
+      });
+    }
+  } else {
+    auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
+    if (needs_dynamic_casting<func_t>::check(iter)) {
+      legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+        auto offsets = offset_calc.get(idx);
+        void* out = data[0] + offsets[0];
+        arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
+        c10::cast_and_store<arg0_t>(dtypes[0], out, result);
+      });
+    } else {
+      legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+        auto offsets = offset_calc.get(idx);
+        arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
+        *out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
+      });
+    }
+  }
+}
+
+template <typename func_t>
+void gpu_kernel(TensorIterator& iter, const func_t& f) {
+  ASSERT_HOST_DEVICE_LAMBDA(func_t);
+
+  for (int arg = 0; arg < iter.ntensors(); arg++) {
+    TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
+  }
+
+  if (iter.numel() == 0) {
+    return;
+  }
+
+  if (!iter.can_use_32bit_indexing()) {
+    for (auto& sub_iter : iter.with_32bit_indexing()) {
+      gpu_kernel(sub_iter, f);
+    }
+    return;
+  }
+
+  gpu_kernel_impl(iter, f);
+}
+
+template <typename func_t>
+void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
+  ASSERT_HOST_DEVICE_LAMBDA(func_t);
+  TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
+
+  using traits = function_traits<func_t>;
+  static_assert(
+      traits::arity == 2,
+      "gpu_kernel_with_scalars only supports two input arguments");
+
+  if (iter.is_cpu_scalar(1)) {
+    using arg1_t = typename traits::template arg<0>::type;
+    using arg2_t = typename traits::template arg<1>::type;
+    auto a = iter.scalar_value<arg1_t>(1);
+    iter.remove_operand(1);
+    gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
+      return f(a, b);
+    });
+  } else if (iter.is_cpu_scalar(2)) {
+    using arg1_t = typename traits::template arg<0>::type;
+    using arg2_t = typename traits::template arg<1>::type;
+    auto b = iter.scalar_value<arg2_t>(2);
+    iter.remove_operand(2);
+    gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
+      return f(a, b);
+    });
+  } else {
+    gpu_kernel(iter, f);
+  }
+}
+
+template <typename func_t>
+void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
+  using traits = function_traits<func_t>;
+  using arg0_t = typename traits::result_type;
+
+
+  // Note:
+  // `gpu_kernel_with_index` was originally implemented in PR #28175 with support
+  // of having an arbitrary number of tensors as arguments. This support was removed
+  // during the process of refactoring Loops.cuh to support vectorized memory access
+  // in PR #32777 (See also issue #31975). The removal of this support is soly because
+  // at that time, there is no operator using that functionality. If you need this
+  // functionality, feel free to add it back.
+  static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
+
+  TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
+
+  char* data = (char*)iter.data_ptr(0);
+
+  int64_t numel = iter.numel();
+  if (iter.is_trivial_1d()) {
+    int stride = iter.get_inner_strides()[0];
+    legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+      arg0_t* out = (arg0_t*)(data + stride * idx);
+      *out = f(idx);
+    });
+  } else {
+    auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
+    legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+      auto offsets = offset_calc.get(idx);
+      arg0_t* out = (arg0_t*)(data + offsets[0]);
+      *out = f(idx);
+    });
+  }
+}
+
+template <typename func_t>
+void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
+  ASSERT_HOST_DEVICE_LAMBDA(func_t);
+
+  TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
+
+  if (iter.numel() == 0) {
+    return;
+  }
+
+  // Split will change index, thus is not supported
+  // The caller should handle the split and pass in different func
+  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
+
+  gpu_kernel_with_index_impl(iter, f);
+}
+
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh
index 856487a..e2271f5 100644
--- a/aten/src/ATen/native/cuda/Loops.cuh
+++ b/aten/src/ATen/native/cuda/Loops.cuh
@@ -1,31 +1,3 @@
-
-#pragma once
-
-#include <ATen/detail/FunctionTraits.h>
-
-namespace at { namespace native { namespace modern { namespace detail {
-
-template<typename func_t, int remaining=function_traits<func_t>::arity-1>
-struct has_same_arg_types {
-  using traits = function_traits<func_t>;
-  static constexpr bool value = std::is_same<
-      typename traits::template arg<remaining>::type,
-      typename traits::template arg<remaining-1>::type
-    >::value && has_same_arg_types<func_t, remaining-1>::value;
-};
-
-template<typename func_t>
-struct has_same_arg_types<func_t, 0> {
-  static constexpr bool value = true;
-};
-
-template<typename func_t>
-struct has_same_arg_types<func_t, -1> {
-  static constexpr bool value = true;
-};
-
-}}}} // namespace at::native::modern::detail
-
 // Note:
 // CUDA and ROCm get diverged in this PR:
 //   https://github.com/pytorch/pytorch/pull/32383
@@ -37,195 +9,3 @@
 #else
 #include <ATen/native/cuda/ROCmLoops.cuh>
 #endif
-
-namespace at { namespace native {
-
-// `needs_dynamic_casting` compares the types expected by iterator
-// (i.e. dtypes of the operands) with the actual type of the arguments
-// of func_t
-template<typename func_t, int nargs=function_traits<func_t>::arity>
-struct needs_dynamic_casting {
-  static bool check(TensorIterator& iter) {
-    using traits = function_traits<func_t>;
-    if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
-      return true;
-    }
-    return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
-  }
-};
-
-template<typename func_t>
-struct needs_dynamic_casting<func_t, 0> {
-  static bool check(TensorIterator& iter) {
-    using traits = function_traits<func_t>;
-    return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
-  }
-};
-
-template <typename func_t>
-void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
-  using traits = function_traits<func_t>;
-  using arg0_t = typename traits::result_type;
-  constexpr int ntensors = traits::arity + 1;
-
-  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
-  TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
-
-  at::detail::Array<char*, ntensors> data;
-  for (int i = 0; i < ntensors; i++) {
-    data[i] = (char*)iter.data_ptr(i);
-  }
-
-  at::detail::Array<ScalarType, ntensors> dtypes;
-  for (int i = 0; i < ntensors; i++) {
-    dtypes[i] = iter.tensor(i).scalar_type();
-  }
-
-  int64_t numel = iter.numel();
-  if (iter.is_trivial_1d()) {
-    auto inner_strides = iter.get_inner_strides();
-    at::detail::Array<int, ntensors> strides;
-    for (int i = 0; i < ntensors; i++) {
-      strides[i] = inner_strides[i];
-    }
-
-    if (needs_dynamic_casting<func_t>::check(iter)) {
-      legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
-        void* out = data[0] + strides[0] * idx;
-        arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
-        c10::cast_and_store<arg0_t>(dtypes[0], out, result);
-      });
-    } else if (iter.has_contiguous_first_dim() && modern::detail::has_same_arg_types<func_t>::value) {
-      modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
-    } else {
-      legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
-        arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
-        *out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
-      });
-    }
-  } else {
-    auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
-    if (needs_dynamic_casting<func_t>::check(iter)) {
-      legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
-        auto offsets = offset_calc.get(idx);
-        void* out = data[0] + offsets[0];
-        arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
-        c10::cast_and_store<arg0_t>(dtypes[0], out, result);
-      });
-    } else {
-      legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
-        auto offsets = offset_calc.get(idx);
-        arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
-        *out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
-      });
-    }
-  }
-}
-
-template <typename func_t>
-void gpu_kernel(TensorIterator& iter, const func_t& f) {
-  ASSERT_HOST_DEVICE_LAMBDA(func_t);
-
-  for (int arg = 0; arg < iter.ntensors(); arg++) {
-    TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
-  }
-
-  if (iter.numel() == 0) {
-    return;
-  }
-
-  if (!iter.can_use_32bit_indexing()) {
-    for (auto& sub_iter : iter.with_32bit_indexing()) {
-      gpu_kernel(sub_iter, f);
-    }
-    return;
-  }
-
-  gpu_kernel_impl(iter, f);
-}
-
-template <typename func_t>
-void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
-  ASSERT_HOST_DEVICE_LAMBDA(func_t);
-  TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
-
-  using traits = function_traits<func_t>;
-  static_assert(
-      traits::arity == 2,
-      "gpu_kernel_with_scalars only supports two input arguments");
-
-  if (iter.is_cpu_scalar(1)) {
-    using arg1_t = typename traits::template arg<0>::type;
-    using arg2_t = typename traits::template arg<1>::type;
-    auto a = iter.scalar_value<arg1_t>(1);
-    iter.remove_operand(1);
-    gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
-      return f(a, b);
-    });
-  } else if (iter.is_cpu_scalar(2)) {
-    using arg1_t = typename traits::template arg<0>::type;
-    using arg2_t = typename traits::template arg<1>::type;
-    auto b = iter.scalar_value<arg2_t>(2);
-    iter.remove_operand(2);
-    gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
-      return f(a, b);
-    });
-  } else {
-    gpu_kernel(iter, f);
-  }
-}
-
-template <typename func_t>
-void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
-  using traits = function_traits<func_t>;
-  using arg0_t = typename traits::result_type;
-
-
-  // Note:
-  // `gpu_kernel_with_index` was originally implemented in PR #28175 with support
-  // of having an arbitrary number of tensors as arguments. This support was removed
-  // during the process of refactoring Loops.cuh to support vectorized memory access
-  // in PR #32777 (See also issue #31975). The removal of this support is soly because
-  // at that time, there is no operator using that functionality. If you need this
-  // functionality, feel free to add it back.
-  static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
-
-  TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
-
-  char* data = (char*)iter.data_ptr(0);
-
-  int64_t numel = iter.numel();
-  if (iter.is_trivial_1d()) {
-    int stride = iter.get_inner_strides()[0];
-    legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
-      arg0_t* out = (arg0_t*)(data + stride * idx);
-      *out = f(idx);
-    });
-  } else {
-    auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
-    legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
-      auto offsets = offset_calc.get(idx);
-      arg0_t* out = (arg0_t*)(data + offsets[0]);
-      *out = f(idx);
-    });
-  }
-}
-
-template <typename func_t>
-void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
-  ASSERT_HOST_DEVICE_LAMBDA(func_t);
-
-  TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
-
-  if (iter.numel() == 0) {
-    return;
-  }
-
-  // Split will change index, thus is not supported
-  // The caller should handle the split and pass in different func
-  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
-
-  gpu_kernel_with_index_impl(iter, f);
-}
-
-}} //namespace at::native
\ No newline at end of file
diff --git a/aten/src/ATen/native/cuda/ROCmLoops.cuh b/aten/src/ATen/native/cuda/ROCmLoops.cuh
index 19fc35a..a2a97c3 100644
--- a/aten/src/ATen/native/cuda/ROCmLoops.cuh
+++ b/aten/src/ATen/native/cuda/ROCmLoops.cuh
@@ -240,7 +240,7 @@
 }
 
 // TODO (@zasdfgbnm): this function assume trivial 1d and no dynamic casting
-template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<detail::has_same_arg_types<func_t>::value, int> = 0>
+template<int nt, int vt, typename func_t, typename array_t>
 static void launch_kernel(int64_t N, const func_t& f, array_t data) {
   TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
   if (N == 0) {
@@ -253,9 +253,191 @@
   AT_CUDA_CHECK(cudaGetLastError());
 }
 
-template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<!detail::has_same_arg_types<func_t>::value, int> = 0>
-static void launch_kernel(int64_t N, const func_t& f, array_t data) {}
-
 } // namespace modern
 
+template<typename func_t, int nargs=function_traits<func_t>::arity>
+struct needs_dynamic_casting {
+  static bool check(TensorIterator& iter) {
+    using traits = function_traits<func_t>;
+    if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
+      return true;
+    }
+    return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
+  }
+};
+
+template<typename func_t>
+struct needs_dynamic_casting<func_t, 0> {
+  static bool check(TensorIterator& iter) {
+    using traits = function_traits<func_t>;
+    return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
+  }
+};
+
+template <typename func_t>
+void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
+  using traits = function_traits<func_t>;
+  using arg0_t = typename traits::result_type;
+  constexpr int ntensors = traits::arity + 1;
+
+  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
+  TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
+
+  at::detail::Array<char*, ntensors> data;
+  for (int i = 0; i < ntensors; i++) {
+    data[i] = (char*)iter.data_ptr(i);
+  }
+
+  at::detail::Array<ScalarType, ntensors> dtypes;
+  for (int i = 0; i < ntensors; i++) {
+    dtypes[i] = iter.tensor(i).scalar_type();
+  }
+
+  int64_t numel = iter.numel();
+  if (iter.is_trivial_1d()) {
+    auto inner_strides = iter.get_inner_strides();
+    at::detail::Array<int, ntensors> strides;
+    for (int i = 0; i < ntensors; i++) {
+      strides[i] = inner_strides[i];
+    }
+
+    if (needs_dynamic_casting<func_t>::check(iter)) {
+      legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+        void* out = data[0] + strides[0] * idx;
+        arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
+        c10::cast_and_store<arg0_t>(dtypes[0], out, result);
+      });
+    } else if (iter.has_contiguous_first_dim()) {
+      modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
+    } else {
+      legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+        arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
+        *out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
+      });
+    }
+  } else {
+    auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
+    if (needs_dynamic_casting<func_t>::check(iter)) {
+      legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+        auto offsets = offset_calc.get(idx);
+        void* out = data[0] + offsets[0];
+        arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
+        c10::cast_and_store<arg0_t>(dtypes[0], out, result);
+      });
+    } else {
+      legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+        auto offsets = offset_calc.get(idx);
+        arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
+        *out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
+      });
+    }
+  }
+}
+
+template <typename func_t>
+void gpu_kernel(TensorIterator& iter, const func_t& f) {
+  ASSERT_HOST_DEVICE_LAMBDA(func_t);
+
+  for (int arg = 0; arg < iter.ntensors(); arg++) {
+    TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
+  }
+
+  if (iter.numel() == 0) {
+    return;
+  }
+
+  if (!iter.can_use_32bit_indexing()) {
+    for (auto& sub_iter : iter.with_32bit_indexing()) {
+      gpu_kernel(sub_iter, f);
+    }
+    return;
+  }
+
+  gpu_kernel_impl(iter, f);
+}
+
+template <typename func_t>
+void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
+  ASSERT_HOST_DEVICE_LAMBDA(func_t);
+  TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
+
+  using traits = function_traits<func_t>;
+  static_assert(
+      traits::arity == 2,
+      "gpu_kernel_with_scalars only supports two input arguments");
+
+  if (iter.is_cpu_scalar(1)) {
+    using arg1_t = typename traits::template arg<0>::type;
+    using arg2_t = typename traits::template arg<1>::type;
+    auto a = iter.scalar_value<arg1_t>(1);
+    iter.remove_operand(1);
+    gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
+      return f(a, b);
+    });
+  } else if (iter.is_cpu_scalar(2)) {
+    using arg1_t = typename traits::template arg<0>::type;
+    using arg2_t = typename traits::template arg<1>::type;
+    auto b = iter.scalar_value<arg2_t>(2);
+    iter.remove_operand(2);
+    gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
+      return f(a, b);
+    });
+  } else {
+    gpu_kernel(iter, f);
+  }
+}
+
+template <typename func_t>
+void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
+  using traits = function_traits<func_t>;
+  using arg0_t = typename traits::result_type;
+
+
+  // Note:
+  // `gpu_kernel_with_index` was originally implemented in PR #28175 with support
+  // of having an arbitrary number of tensors as arguments. This support was removed
+  // during the process of refactoring Loops.cuh to support vectorized memory access
+  // in PR #32777 (See also issue #31975). The removal of this support is soly because
+  // at that time, there is no operator using that functionality. If you need this
+  // functionality, feel free to add it back.
+  static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
+
+  TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
+
+  char* data = (char*)iter.data_ptr(0);
+
+  int64_t numel = iter.numel();
+  if (iter.is_trivial_1d()) {
+    int stride = iter.get_inner_strides()[0];
+    legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+      arg0_t* out = (arg0_t*)(data + stride * idx);
+      *out = f(idx);
+    });
+  } else {
+    auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
+    legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+      auto offsets = offset_calc.get(idx);
+      arg0_t* out = (arg0_t*)(data + offsets[0]);
+      *out = f(idx);
+    });
+  }
+}
+
+template <typename func_t>
+void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
+  ASSERT_HOST_DEVICE_LAMBDA(func_t);
+
+  TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
+
+  if (iter.numel() == 0) {
+    return;
+  }
+
+  // Split will change index, thus is not supported
+  // The caller should handle the split and pass in different func
+  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
+
+  gpu_kernel_with_index_impl(iter, f);
+}
+
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu
index 0dbedcf..e1c3e73 100644
--- a/aten/src/ATen/native/cuda/TensorCompare.cu
+++ b/aten/src/ATen/native/cuda/TensorCompare.cu
@@ -1,38 +1,56 @@
 #include <ATen/NativeFunctions.h>
 #include <ATen/Dispatch.h>
-#include <ATen/native/DispatchStub.h>
-#include <ATen/native/cuda/Loops.cuh>
+
 #include <ATen/cuda/CUDAApplyUtils.cuh>
 
+namespace {
+template <typename scalar_t>
+void where_cuda(
+    at::Tensor& ret,
+    const at::Tensor& condition,
+    const at::Tensor& self,
+    const at::Tensor& other) {
+  if (condition.scalar_type() == at::ScalarType::Byte) {
+    // Yes this name is repetitive, but the CPU version is called
+    // CPU_tensor_apply4 and we don't have a CPU namespace or directory.
+    at::cuda::CUDA_tensor_apply4<scalar_t, uint8_t, scalar_t, scalar_t>(
+        ret,
+        condition,
+        self,
+        other,
+        [] __device__(
+            scalar_t & ret_val,
+            const uint8_t& cond_val,
+            const scalar_t& self_val,
+            const scalar_t& other_val) {
+          ret_val = cond_val ? self_val : other_val;
+        });
+   } else {
+     at::cuda::CUDA_tensor_apply4<scalar_t, bool, scalar_t, scalar_t>(
+         ret,
+         condition,
+         self,
+         other,
+         [] __device__(
+             scalar_t & ret_val,
+             const bool& cond_val,
+             const scalar_t& self_val,
+             const scalar_t& other_val) {
+           ret_val = cond_val ? self_val : other_val;
+         });
+   }
+}
+} // namespace
 
 namespace at { namespace native {
-
-using where_fn = void (*)(TensorIterator &, ScalarType);
-DECLARE_DISPATCH(where_fn, where_kernel);
-
-namespace {
-
-void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) {
-  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.dtype(), "where_cuda", [&] {
-    if (condition_type == at::ScalarType::Byte) {
-      gpu_kernel(
-        iter,
-        [=] GPU_LAMBDA (uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
-          return cond_val ? self_val : other_val;
-        });
-    } else {
-      gpu_kernel(
-        iter,
-        [=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
-          return cond_val ? self_val : other_val;
-        });
-    }
+Tensor _s_where_cuda(
+    const Tensor& condition,
+    const Tensor& self,
+    const Tensor& other) {
+  Tensor ret = at::empty(self.sizes(), self.options());
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, ret.scalar_type(), "where_cuda", [&] {
+    where_cuda<scalar_t>(ret, condition, self, other);
   });
+  return ret;
 }
-
-} // anonymous namespace
-
-
-REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
-
 }} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 98c061b..d33a9d0 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2942,6 +2942,9 @@
 - func: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor
   use_c10_dispatcher: full
   variants: function
+  dispatch:
+    CPU: _s_where_cpu
+    CUDA: _s_where_cuda
 
 - func: norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor
   variants: function
diff --git a/aten/src/ATen/test/cuda_vectorized_test.cu b/aten/src/ATen/test/cuda_vectorized_test.cu
index ac57be0..96b70d9 100644
--- a/aten/src/ATen/test/cuda_vectorized_test.cu
+++ b/aten/src/ATen/test/cuda_vectorized_test.cu
@@ -1,7 +1,6 @@
 #include <gtest/gtest.h>
 #include <ATen/ATen.h>
 #include <ATen/native/cuda/MemoryAccess.cuh>
-#include <ATen/native/cuda/Loops.cuh>
 #include <ATen/cuda/CUDAContext.h>
 
 using namespace at::native::memory;
@@ -22,21 +21,6 @@
   }
 }
 
-TEST(TestLoops, HasSameArgTypes) {
-  // This is a compile-time unit test. If this file compiles without error,
-  // then the test passes and during runtime, we just need to return.
-  using namespace at::native::modern::detail;
-  using func1_t = int (*)(float, float);
-  using func2_t = int (*)(bool, float, float);
-  using func3_t = int (*)(float);
-  using func4_t = int (*)();
-  static_assert(has_same_arg_types<func1_t>::value, "func1_t has the same argument types");
-  static_assert(!has_same_arg_types<func2_t>::value, "func2_t does not have the same argument types");
-  static_assert(has_same_arg_types<func3_t>::value, "func3_t has the same argument types");
-  static_assert(has_same_arg_types<func4_t>::value, "func4_t has the same argument types");
-  return;
-}
-
 TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
   char *ptr = reinterpret_cast<char *>(buffer1);