Move where cuda implementation to TensorIterator (#32984)

Summary:
`where` is special because the arguments do not have the same type, which does not satisfy the assumption in modern https://github.com/pytorch/pytorch/pull/32383. I migrate it to TensorIterator so that there is something to test that this case is not broken. Currently, this case fallback to using legacy (not vectorized, not unrolled) code. It should be supported in the future when I cleanup `Loops.cuh`.

I also move some sharing part of `CUDALoops.cuh` and `ROCmLoops.cuh` into `Loops.cuh` so that to logic for checking whether `func_t` has the same arg types could be shared.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32984

Differential Revision: D19825127

Pulled By: ngimel

fbshipit-source-id: bbf4682349d96b4480c4d657f3c18a3a67a9bf17
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index 6c2517f..595f8ab 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -6,42 +6,12 @@
 #include <ATen/native/ReduceOpsUtils.h>
 #include <c10/util/Exception.h>
 #include <ATen/native/cpu/TensorCompareKernel.h>
-#include <ATen/native/cpu/Loops.h>
 #include <ATen/NamedTensorUtils.h>
 
-namespace {
-template <typename scalar_t>
-void where_cpu(
-    at::Tensor& ret,
-    const at::Tensor& condition,
-    const at::Tensor& self,
-    const at::Tensor& other) {
-  auto iter = at::TensorIterator();
-  iter.set_check_mem_overlap(true);
-  iter.add_output(ret);
-  iter.add_input(condition);
-  iter.add_input(self);
-  iter.add_input(other);
-  iter.dont_compute_common_dtype();
-  iter.build();
-  if (condition.scalar_type() == at::ScalarType::Byte) {
-    at::native::cpu_kernel(
-      iter,
-      [=](uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
-        return cond_val ? self_val : other_val;
-      });
-  } else {
-    at::native::cpu_kernel(
-      iter,
-      [=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
-        return cond_val ? self_val : other_val;
-      });
-  }
-}
-} // namespace
 
 namespace at { namespace native {
 
+DEFINE_DISPATCH(where_kernel);
 DEFINE_DISPATCH(max_kernel);
 DEFINE_DISPATCH(min_kernel);
 
@@ -148,12 +118,18 @@
   return condition.nonzero_numpy();
 }
 
-Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& other) {
+Tensor _s_where(const Tensor& condition, const Tensor& self, const Tensor& other) {
   TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
   Tensor ret = at::empty(self.sizes(), self.options());
-  AT_DISPATCH_ALL_TYPES_AND_COMPLEX(ret.scalar_type(), "where_cpu", [&] {
-    where_cpu<scalar_t>(ret, condition, self, other);
-  });
+  auto iter = at::TensorIterator();
+  iter.set_check_mem_overlap(true);
+  iter.add_output(ret);
+  iter.add_input(condition);
+  iter.add_input(self);
+  iter.add_input(other);
+  iter.dont_compute_common_dtype();
+  iter.build();
+  where_kernel(iter.device_type(), iter, condition.scalar_type());
   return ret;
 }
 
diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
index fb7d93c..641195a 100644
--- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
+++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
@@ -1,4 +1,5 @@
 #include <ATen/native/cpu/TensorCompareKernel.h>
+#include <ATen/native/cpu/Loops.h>
 
 #include <numeric>
 #include <iterator>
@@ -101,9 +102,28 @@
   });
 }
 
+static void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) {
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "where_cpu", [&] {
+    if (condition_type == at::ScalarType::Byte) {
+      at::native::cpu_kernel(
+        iter,
+        [=](uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
+          return cond_val ? self_val : other_val;
+        });
+    } else {
+      at::native::cpu_kernel(
+        iter,
+        [=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
+          return cond_val ? self_val : other_val;
+        });
+    }
+  });
+}
+
 } // anonymous namespace
 
 REGISTER_DISPATCH(max_kernel, &max_kernel_impl);
 REGISTER_DISPATCH(min_kernel, &min_kernel_impl);
+REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
 
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.h b/aten/src/ATen/native/cpu/TensorCompareKernel.h
index a23792d..d1cb033 100644
--- a/aten/src/ATen/native/cpu/TensorCompareKernel.h
+++ b/aten/src/ATen/native/cpu/TensorCompareKernel.h
@@ -3,6 +3,7 @@
 #include <ATen/ATen.h>
 #include <ATen/native/DispatchStub.h>
 #include <c10/util/Optional.h>
+#include <ATen/native/TensorIterator.h>
 
 namespace at { namespace native {
 
@@ -12,4 +13,7 @@
 DECLARE_DISPATCH(reduce_fn, max_kernel);
 DECLARE_DISPATCH(reduce_fn, min_kernel);
 
+using where_fn = void (*)(TensorIterator &, ScalarType);
+DECLARE_DISPATCH(where_fn, where_kernel);
+
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh
index 0c1c4f4..aa396b2 100644
--- a/aten/src/ATen/native/cuda/CUDALoops.cuh
+++ b/aten/src/ATen/native/cuda/CUDALoops.cuh
@@ -255,7 +255,7 @@
 }
 
 // TODO (@zasdfgbnm): this function assume trivial 1d and no dynamic casting
-template<int nt, int vt, typename func_t, typename array_t>
+template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<detail::has_same_arg_types<func_t>::value, int> = 0>
 static void launch_kernel(int64_t N, const func_t& f, array_t data) {
   TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
   if (N == 0) {
@@ -281,191 +281,9 @@
   AT_CUDA_CHECK(cudaGetLastError());
 }
 
+template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<!detail::has_same_arg_types<func_t>::value, int> = 0>
+static void launch_kernel(int64_t N, const func_t& f, array_t data) {}
+
 } // namespace modern
 
-template<typename func_t, int nargs=function_traits<func_t>::arity>
-struct needs_dynamic_casting {
-  static bool check(TensorIterator& iter) {
-    using traits = function_traits<func_t>;
-    if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
-      return true;
-    }
-    return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
-  }
-};
-
-template<typename func_t>
-struct needs_dynamic_casting<func_t, 0> {
-  static bool check(TensorIterator& iter) {
-    using traits = function_traits<func_t>;
-    return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
-  }
-};
-
-template <typename func_t>
-void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
-  using traits = function_traits<func_t>;
-  using arg0_t = typename traits::result_type;
-  constexpr int ntensors = traits::arity + 1;
-
-  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
-  TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
-
-  at::detail::Array<char*, ntensors> data;
-  for (int i = 0; i < ntensors; i++) {
-    data[i] = (char*)iter.data_ptr(i);
-  }
-
-  at::detail::Array<ScalarType, ntensors> dtypes;
-  for (int i = 0; i < ntensors; i++) {
-    dtypes[i] = iter.tensor(i).scalar_type();
-  }
-
-  int64_t numel = iter.numel();
-  if (iter.is_trivial_1d()) {
-    auto inner_strides = iter.get_inner_strides();
-    at::detail::Array<int, ntensors> strides;
-    for (int i = 0; i < ntensors; i++) {
-      strides[i] = inner_strides[i];
-    }
-
-    if (needs_dynamic_casting<func_t>::check(iter)) {
-      legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
-        void* out = data[0] + strides[0] * idx;
-        arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
-        c10::cast_and_store<arg0_t>(dtypes[0], out, result);
-      });
-    } else if (iter.has_contiguous_first_dim()) {
-      modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
-    } else {
-      legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
-        arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
-        *out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
-      });
-    }
-  } else {
-    auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
-    if (needs_dynamic_casting<func_t>::check(iter)) {
-      legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
-        auto offsets = offset_calc.get(idx);
-        void* out = data[0] + offsets[0];
-        arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
-        c10::cast_and_store<arg0_t>(dtypes[0], out, result);
-      });
-    } else {
-      legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
-        auto offsets = offset_calc.get(idx);
-        arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
-        *out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
-      });
-    }
-  }
-}
-
-template <typename func_t>
-void gpu_kernel(TensorIterator& iter, const func_t& f) {
-  ASSERT_HOST_DEVICE_LAMBDA(func_t);
-
-  for (int arg = 0; arg < iter.ntensors(); arg++) {
-    TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
-  }
-
-  if (iter.numel() == 0) {
-    return;
-  }
-
-  if (!iter.can_use_32bit_indexing()) {
-    for (auto& sub_iter : iter.with_32bit_indexing()) {
-      gpu_kernel(sub_iter, f);
-    }
-    return;
-  }
-
-  gpu_kernel_impl(iter, f);
-}
-
-template <typename func_t>
-void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
-  ASSERT_HOST_DEVICE_LAMBDA(func_t);
-  TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
-
-  using traits = function_traits<func_t>;
-  static_assert(
-      traits::arity == 2,
-      "gpu_kernel_with_scalars only supports two input arguments");
-
-  if (iter.is_cpu_scalar(1)) {
-    using arg1_t = typename traits::template arg<0>::type;
-    using arg2_t = typename traits::template arg<1>::type;
-    auto a = iter.scalar_value<arg1_t>(1);
-    iter.remove_operand(1);
-    gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
-      return f(a, b);
-    });
-  } else if (iter.is_cpu_scalar(2)) {
-    using arg1_t = typename traits::template arg<0>::type;
-    using arg2_t = typename traits::template arg<1>::type;
-    auto b = iter.scalar_value<arg2_t>(2);
-    iter.remove_operand(2);
-    gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
-      return f(a, b);
-    });
-  } else {
-    gpu_kernel(iter, f);
-  }
-}
-
-template <typename func_t>
-void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
-  using traits = function_traits<func_t>;
-  using arg0_t = typename traits::result_type;
-
-
-  // Note:
-  // `gpu_kernel_with_index` was originally implemented in PR #28175 with support
-  // of having an arbitrary number of tensors as arguments. This support was removed
-  // during the process of refactoring Loops.cuh to support vectorized memory access
-  // in PR #32777 (See also issue #31975). The removal of this support is soly because
-  // at that time, there is no operator using that functionality. If you need this
-  // functionality, feel free to add it back.
-  static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
-
-  TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
-
-  char* data = (char*)iter.data_ptr(0);
-
-  int64_t numel = iter.numel();
-  if (iter.is_trivial_1d()) {
-    int stride = iter.get_inner_strides()[0];
-    legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
-      arg0_t* out = (arg0_t*)(data + stride * idx);
-      *out = f(idx);
-    });
-  } else {
-    auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
-    legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
-      auto offsets = offset_calc.get(idx);
-      arg0_t* out = (arg0_t*)(data + offsets[0]);
-      *out = f(idx);
-    });
-  }
-}
-
-template <typename func_t>
-void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
-  ASSERT_HOST_DEVICE_LAMBDA(func_t);
-
-  TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
-
-  if (iter.numel() == 0) {
-    return;
-  }
-
-  // Split will change index, thus is not supported
-  // The caller should handle the split and pass in different func
-  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
-
-  gpu_kernel_with_index_impl(iter, f);
-}
-
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh
index e2271f5..856487a 100644
--- a/aten/src/ATen/native/cuda/Loops.cuh
+++ b/aten/src/ATen/native/cuda/Loops.cuh
@@ -1,3 +1,31 @@
+
+#pragma once
+
+#include <ATen/detail/FunctionTraits.h>
+
+namespace at { namespace native { namespace modern { namespace detail {
+
+template<typename func_t, int remaining=function_traits<func_t>::arity-1>
+struct has_same_arg_types {
+  using traits = function_traits<func_t>;
+  static constexpr bool value = std::is_same<
+      typename traits::template arg<remaining>::type,
+      typename traits::template arg<remaining-1>::type
+    >::value && has_same_arg_types<func_t, remaining-1>::value;
+};
+
+template<typename func_t>
+struct has_same_arg_types<func_t, 0> {
+  static constexpr bool value = true;
+};
+
+template<typename func_t>
+struct has_same_arg_types<func_t, -1> {
+  static constexpr bool value = true;
+};
+
+}}}} // namespace at::native::modern::detail
+
 // Note:
 // CUDA and ROCm get diverged in this PR:
 //   https://github.com/pytorch/pytorch/pull/32383
@@ -9,3 +37,195 @@
 #else
 #include <ATen/native/cuda/ROCmLoops.cuh>
 #endif
+
+namespace at { namespace native {
+
+// `needs_dynamic_casting` compares the types expected by iterator
+// (i.e. dtypes of the operands) with the actual type of the arguments
+// of func_t
+template<typename func_t, int nargs=function_traits<func_t>::arity>
+struct needs_dynamic_casting {
+  static bool check(TensorIterator& iter) {
+    using traits = function_traits<func_t>;
+    if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
+      return true;
+    }
+    return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
+  }
+};
+
+template<typename func_t>
+struct needs_dynamic_casting<func_t, 0> {
+  static bool check(TensorIterator& iter) {
+    using traits = function_traits<func_t>;
+    return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
+  }
+};
+
+template <typename func_t>
+void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
+  using traits = function_traits<func_t>;
+  using arg0_t = typename traits::result_type;
+  constexpr int ntensors = traits::arity + 1;
+
+  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
+  TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
+
+  at::detail::Array<char*, ntensors> data;
+  for (int i = 0; i < ntensors; i++) {
+    data[i] = (char*)iter.data_ptr(i);
+  }
+
+  at::detail::Array<ScalarType, ntensors> dtypes;
+  for (int i = 0; i < ntensors; i++) {
+    dtypes[i] = iter.tensor(i).scalar_type();
+  }
+
+  int64_t numel = iter.numel();
+  if (iter.is_trivial_1d()) {
+    auto inner_strides = iter.get_inner_strides();
+    at::detail::Array<int, ntensors> strides;
+    for (int i = 0; i < ntensors; i++) {
+      strides[i] = inner_strides[i];
+    }
+
+    if (needs_dynamic_casting<func_t>::check(iter)) {
+      legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+        void* out = data[0] + strides[0] * idx;
+        arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
+        c10::cast_and_store<arg0_t>(dtypes[0], out, result);
+      });
+    } else if (iter.has_contiguous_first_dim() && modern::detail::has_same_arg_types<func_t>::value) {
+      modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
+    } else {
+      legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+        arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
+        *out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
+      });
+    }
+  } else {
+    auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
+    if (needs_dynamic_casting<func_t>::check(iter)) {
+      legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+        auto offsets = offset_calc.get(idx);
+        void* out = data[0] + offsets[0];
+        arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
+        c10::cast_and_store<arg0_t>(dtypes[0], out, result);
+      });
+    } else {
+      legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+        auto offsets = offset_calc.get(idx);
+        arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
+        *out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
+      });
+    }
+  }
+}
+
+template <typename func_t>
+void gpu_kernel(TensorIterator& iter, const func_t& f) {
+  ASSERT_HOST_DEVICE_LAMBDA(func_t);
+
+  for (int arg = 0; arg < iter.ntensors(); arg++) {
+    TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
+  }
+
+  if (iter.numel() == 0) {
+    return;
+  }
+
+  if (!iter.can_use_32bit_indexing()) {
+    for (auto& sub_iter : iter.with_32bit_indexing()) {
+      gpu_kernel(sub_iter, f);
+    }
+    return;
+  }
+
+  gpu_kernel_impl(iter, f);
+}
+
+template <typename func_t>
+void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
+  ASSERT_HOST_DEVICE_LAMBDA(func_t);
+  TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
+
+  using traits = function_traits<func_t>;
+  static_assert(
+      traits::arity == 2,
+      "gpu_kernel_with_scalars only supports two input arguments");
+
+  if (iter.is_cpu_scalar(1)) {
+    using arg1_t = typename traits::template arg<0>::type;
+    using arg2_t = typename traits::template arg<1>::type;
+    auto a = iter.scalar_value<arg1_t>(1);
+    iter.remove_operand(1);
+    gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
+      return f(a, b);
+    });
+  } else if (iter.is_cpu_scalar(2)) {
+    using arg1_t = typename traits::template arg<0>::type;
+    using arg2_t = typename traits::template arg<1>::type;
+    auto b = iter.scalar_value<arg2_t>(2);
+    iter.remove_operand(2);
+    gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
+      return f(a, b);
+    });
+  } else {
+    gpu_kernel(iter, f);
+  }
+}
+
+template <typename func_t>
+void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
+  using traits = function_traits<func_t>;
+  using arg0_t = typename traits::result_type;
+
+
+  // Note:
+  // `gpu_kernel_with_index` was originally implemented in PR #28175 with support
+  // of having an arbitrary number of tensors as arguments. This support was removed
+  // during the process of refactoring Loops.cuh to support vectorized memory access
+  // in PR #32777 (See also issue #31975). The removal of this support is soly because
+  // at that time, there is no operator using that functionality. If you need this
+  // functionality, feel free to add it back.
+  static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
+
+  TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
+
+  char* data = (char*)iter.data_ptr(0);
+
+  int64_t numel = iter.numel();
+  if (iter.is_trivial_1d()) {
+    int stride = iter.get_inner_strides()[0];
+    legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
+      arg0_t* out = (arg0_t*)(data + stride * idx);
+      *out = f(idx);
+    });
+  } else {
+    auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
+    legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
+      auto offsets = offset_calc.get(idx);
+      arg0_t* out = (arg0_t*)(data + offsets[0]);
+      *out = f(idx);
+    });
+  }
+}
+
+template <typename func_t>
+void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
+  ASSERT_HOST_DEVICE_LAMBDA(func_t);
+
+  TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
+
+  if (iter.numel() == 0) {
+    return;
+  }
+
+  // Split will change index, thus is not supported
+  // The caller should handle the split and pass in different func
+  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
+
+  gpu_kernel_with_index_impl(iter, f);
+}
+
+}} //namespace at::native
\ No newline at end of file
diff --git a/aten/src/ATen/native/cuda/ROCmLoops.cuh b/aten/src/ATen/native/cuda/ROCmLoops.cuh
index a2a97c3..19fc35a 100644
--- a/aten/src/ATen/native/cuda/ROCmLoops.cuh
+++ b/aten/src/ATen/native/cuda/ROCmLoops.cuh
@@ -240,7 +240,7 @@
 }
 
 // TODO (@zasdfgbnm): this function assume trivial 1d and no dynamic casting
-template<int nt, int vt, typename func_t, typename array_t>
+template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<detail::has_same_arg_types<func_t>::value, int> = 0>
 static void launch_kernel(int64_t N, const func_t& f, array_t data) {
   TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
   if (N == 0) {
@@ -253,191 +253,9 @@
   AT_CUDA_CHECK(cudaGetLastError());
 }
 
+template<int nt, int vt, typename func_t, typename array_t, std::enable_if_t<!detail::has_same_arg_types<func_t>::value, int> = 0>
+static void launch_kernel(int64_t N, const func_t& f, array_t data) {}
+
 } // namespace modern
 
-template<typename func_t, int nargs=function_traits<func_t>::arity>
-struct needs_dynamic_casting {
-  static bool check(TensorIterator& iter) {
-    using traits = function_traits<func_t>;
-    if (iter.dtype(nargs) != c10::impl::CPPTypeToScalarType<typename traits::template arg<nargs - 1>::type>::value) {
-      return true;
-    }
-    return needs_dynamic_casting<func_t, nargs - 1>::check(iter);
-  }
-};
-
-template<typename func_t>
-struct needs_dynamic_casting<func_t, 0> {
-  static bool check(TensorIterator& iter) {
-    using traits = function_traits<func_t>;
-    return iter.dtype(0) != c10::impl::CPPTypeToScalarType<typename traits::result_type>::value;
-  }
-};
-
-template <typename func_t>
-void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
-  using traits = function_traits<func_t>;
-  using arg0_t = typename traits::result_type;
-  constexpr int ntensors = traits::arity + 1;
-
-  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
-  TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
-
-  at::detail::Array<char*, ntensors> data;
-  for (int i = 0; i < ntensors; i++) {
-    data[i] = (char*)iter.data_ptr(i);
-  }
-
-  at::detail::Array<ScalarType, ntensors> dtypes;
-  for (int i = 0; i < ntensors; i++) {
-    dtypes[i] = iter.tensor(i).scalar_type();
-  }
-
-  int64_t numel = iter.numel();
-  if (iter.is_trivial_1d()) {
-    auto inner_strides = iter.get_inner_strides();
-    at::detail::Array<int, ntensors> strides;
-    for (int i = 0; i < ntensors; i++) {
-      strides[i] = inner_strides[i];
-    }
-
-    if (needs_dynamic_casting<func_t>::check(iter)) {
-      legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
-        void* out = data[0] + strides[0] * idx;
-        arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
-        c10::cast_and_store<arg0_t>(dtypes[0], out, result);
-      });
-    } else if (iter.has_contiguous_first_dim()) {
-      modern::launch_kernel<C10_WARP_SIZE * 2, 4>(numel, f, data);
-    } else {
-      legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
-        arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
-        *out = legacy::invoke(f, &data.data[1], &strides.data[1], idx);
-      });
-    }
-  } else {
-    auto offset_calc = legacy::make_offset_calculator<traits::arity + 1>(iter);
-    if (needs_dynamic_casting<func_t>::check(iter)) {
-      legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
-        auto offsets = offset_calc.get(idx);
-        void* out = data[0] + offsets[0];
-        arg0_t result = legacy::invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
-        c10::cast_and_store<arg0_t>(dtypes[0], out, result);
-      });
-    } else {
-      legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
-        auto offsets = offset_calc.get(idx);
-        arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
-        *out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
-      });
-    }
-  }
-}
-
-template <typename func_t>
-void gpu_kernel(TensorIterator& iter, const func_t& f) {
-  ASSERT_HOST_DEVICE_LAMBDA(func_t);
-
-  for (int arg = 0; arg < iter.ntensors(); arg++) {
-    TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
-  }
-
-  if (iter.numel() == 0) {
-    return;
-  }
-
-  if (!iter.can_use_32bit_indexing()) {
-    for (auto& sub_iter : iter.with_32bit_indexing()) {
-      gpu_kernel(sub_iter, f);
-    }
-    return;
-  }
-
-  gpu_kernel_impl(iter, f);
-}
-
-template <typename func_t>
-void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
-  ASSERT_HOST_DEVICE_LAMBDA(func_t);
-  TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
-
-  using traits = function_traits<func_t>;
-  static_assert(
-      traits::arity == 2,
-      "gpu_kernel_with_scalars only supports two input arguments");
-
-  if (iter.is_cpu_scalar(1)) {
-    using arg1_t = typename traits::template arg<0>::type;
-    using arg2_t = typename traits::template arg<1>::type;
-    auto a = iter.scalar_value<arg1_t>(1);
-    iter.remove_operand(1);
-    gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
-      return f(a, b);
-    });
-  } else if (iter.is_cpu_scalar(2)) {
-    using arg1_t = typename traits::template arg<0>::type;
-    using arg2_t = typename traits::template arg<1>::type;
-    auto b = iter.scalar_value<arg2_t>(2);
-    iter.remove_operand(2);
-    gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
-      return f(a, b);
-    });
-  } else {
-    gpu_kernel(iter, f);
-  }
-}
-
-template <typename func_t>
-void gpu_kernel_with_index_impl(TensorIterator& iter, const func_t& f) {
-  using traits = function_traits<func_t>;
-  using arg0_t = typename traits::result_type;
-
-
-  // Note:
-  // `gpu_kernel_with_index` was originally implemented in PR #28175 with support
-  // of having an arbitrary number of tensors as arguments. This support was removed
-  // during the process of refactoring Loops.cuh to support vectorized memory access
-  // in PR #32777 (See also issue #31975). The removal of this support is soly because
-  // at that time, there is no operator using that functionality. If you need this
-  // functionality, feel free to add it back.
-  static_assert(traits::arity == 1, "Functor for gpu_kernel_with_index can only have one argument which is the index");
-
-  TORCH_INTERNAL_ASSERT(iter.ntensors() == 1);
-
-  char* data = (char*)iter.data_ptr(0);
-
-  int64_t numel = iter.numel();
-  if (iter.is_trivial_1d()) {
-    int stride = iter.get_inner_strides()[0];
-    legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
-      arg0_t* out = (arg0_t*)(data + stride * idx);
-      *out = f(idx);
-    });
-  } else {
-    auto offset_calc = legacy::make_offset_calculator<traits::arity>(iter);
-    legacy::launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
-      auto offsets = offset_calc.get(idx);
-      arg0_t* out = (arg0_t*)(data + offsets[0]);
-      *out = f(idx);
-    });
-  }
-}
-
-template <typename func_t>
-void gpu_kernel_with_index(TensorIterator& iter, const func_t& f) {
-  ASSERT_HOST_DEVICE_LAMBDA(func_t);
-
-  TORCH_INTERNAL_ASSERT(iter.device(0).is_cuda(), "gpu_kernel_with_index only support cuda tensor.");
-
-  if (iter.numel() == 0) {
-    return;
-  }
-
-  // Split will change index, thus is not supported
-  // The caller should handle the split and pass in different func
-  TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing(), "gpu_kernel_with_index only support 32-bit indexing.");
-
-  gpu_kernel_with_index_impl(iter, f);
-}
-
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu
index e1c3e73..0dbedcf 100644
--- a/aten/src/ATen/native/cuda/TensorCompare.cu
+++ b/aten/src/ATen/native/cuda/TensorCompare.cu
@@ -1,56 +1,38 @@
 #include <ATen/NativeFunctions.h>
 #include <ATen/Dispatch.h>
-
+#include <ATen/native/DispatchStub.h>
+#include <ATen/native/cuda/Loops.cuh>
 #include <ATen/cuda/CUDAApplyUtils.cuh>
 
-namespace {
-template <typename scalar_t>
-void where_cuda(
-    at::Tensor& ret,
-    const at::Tensor& condition,
-    const at::Tensor& self,
-    const at::Tensor& other) {
-  if (condition.scalar_type() == at::ScalarType::Byte) {
-    // Yes this name is repetitive, but the CPU version is called
-    // CPU_tensor_apply4 and we don't have a CPU namespace or directory.
-    at::cuda::CUDA_tensor_apply4<scalar_t, uint8_t, scalar_t, scalar_t>(
-        ret,
-        condition,
-        self,
-        other,
-        [] __device__(
-            scalar_t & ret_val,
-            const uint8_t& cond_val,
-            const scalar_t& self_val,
-            const scalar_t& other_val) {
-          ret_val = cond_val ? self_val : other_val;
-        });
-   } else {
-     at::cuda::CUDA_tensor_apply4<scalar_t, bool, scalar_t, scalar_t>(
-         ret,
-         condition,
-         self,
-         other,
-         [] __device__(
-             scalar_t & ret_val,
-             const bool& cond_val,
-             const scalar_t& self_val,
-             const scalar_t& other_val) {
-           ret_val = cond_val ? self_val : other_val;
-         });
-   }
-}
-} // namespace
 
 namespace at { namespace native {
-Tensor _s_where_cuda(
-    const Tensor& condition,
-    const Tensor& self,
-    const Tensor& other) {
-  Tensor ret = at::empty(self.sizes(), self.options());
-  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, ret.scalar_type(), "where_cuda", [&] {
-    where_cuda<scalar_t>(ret, condition, self, other);
+
+using where_fn = void (*)(TensorIterator &, ScalarType);
+DECLARE_DISPATCH(where_fn, where_kernel);
+
+namespace {
+
+void where_kernel_impl(TensorIterator &iter, ScalarType condition_type) {
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.dtype(), "where_cuda", [&] {
+    if (condition_type == at::ScalarType::Byte) {
+      gpu_kernel(
+        iter,
+        [=] GPU_LAMBDA (uint8_t cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
+          return cond_val ? self_val : other_val;
+        });
+    } else {
+      gpu_kernel(
+        iter,
+        [=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
+          return cond_val ? self_val : other_val;
+        });
+    }
   });
-  return ret;
 }
+
+} // anonymous namespace
+
+
+REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
+
 }} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 88123d3..e2f685d 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2942,9 +2942,6 @@
 - func: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor
   use_c10_dispatcher: full
   variants: function
-  dispatch:
-    CPU: _s_where_cpu
-    CUDA: _s_where_cuda
 
 - func: norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor
   variants: function
diff --git a/aten/src/ATen/test/cuda_vectorized_test.cu b/aten/src/ATen/test/cuda_vectorized_test.cu
index 96b70d9..ac57be0 100644
--- a/aten/src/ATen/test/cuda_vectorized_test.cu
+++ b/aten/src/ATen/test/cuda_vectorized_test.cu
@@ -1,6 +1,7 @@
 #include <gtest/gtest.h>
 #include <ATen/ATen.h>
 #include <ATen/native/cuda/MemoryAccess.cuh>
+#include <ATen/native/cuda/Loops.cuh>
 #include <ATen/cuda/CUDAContext.h>
 
 using namespace at::native::memory;
@@ -21,6 +22,21 @@
   }
 }
 
+TEST(TestLoops, HasSameArgTypes) {
+  // This is a compile-time unit test. If this file compiles without error,
+  // then the test passes and during runtime, we just need to return.
+  using namespace at::native::modern::detail;
+  using func1_t = int (*)(float, float);
+  using func2_t = int (*)(bool, float, float);
+  using func3_t = int (*)(float);
+  using func4_t = int (*)();
+  static_assert(has_same_arg_types<func1_t>::value, "func1_t has the same argument types");
+  static_assert(!has_same_arg_types<func2_t>::value, "func2_t does not have the same argument types");
+  static_assert(has_same_arg_types<func3_t>::value, "func3_t has the same argument types");
+  static_assert(has_same_arg_types<func4_t>::value, "func4_t has the same argument types");
+  return;
+}
+
 TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
   char *ptr = reinterpret_cast<char *>(buffer1);