Support `dtype` kwarg in `_foreach_norm` (#125665)

Fixes #125040

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125665
Approved by: https://github.com/janeyx99
diff --git a/aten/src/ATen/native/ForeachOpsKernels.cpp b/aten/src/ATen/native/ForeachOpsKernels.cpp
index 790c7a5..34c71a8 100644
--- a/aten/src/ATen/native/ForeachOpsKernels.cpp
+++ b/aten/src/ATen/native/ForeachOpsKernels.cpp
@@ -438,11 +438,12 @@
 
 std::vector<Tensor> foreach_tensor_norm_slow(
     TensorList tensors,
-    const Scalar& ord) {
+    const Scalar& ord,
+    c10::optional<ScalarType> dtype) {
   check_foreach_api_restrictions(tensors);
   std::vector<Tensor> result;
   for (const auto& t : tensors) {
-    result.emplace_back(at::linalg_vector_norm(t, ord));
+    result.emplace_back(at::linalg_vector_norm(t, ord, {}, false, dtype));
   }
   return result;
 }
diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu
index eed9656..885c5d0 100644
--- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu
+++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu
@@ -1,3 +1,6 @@
+#include <c10/core/ScalarType.h>
+#include <c10/util/irange.h>
+#include <limits>
 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 #include <ATen/AccumulateType.h>
 #include <ATen/Dispatch.h>
@@ -44,15 +47,16 @@
 template <
     typename T,
     NormType norm_type,
+    typename out_t,
     int depth = 1,
     int r_args_depth = 1,
     int res_arg_index = 0>
 struct LpNormFunctor {
-  using opmath_t = typename at::opmath_type<T>;
+  using out_opmath_t = typename at::opmath_type<out_t>;
   __device__ __forceinline__ void operator()(
       int chunk_size,
       TensorListMetadata<depth>& tl,
-      opmath_t* output_per_tensor,
+      out_opmath_t* output_per_tensor_ptr,
       const int max_chunks_per_tensor) {
     const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
     const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@@ -62,11 +66,11 @@
     x += chunk_idx * chunk_size;
     n -= chunk_idx * chunk_size;
 
-    __shared__ opmath_t s_vals[512];
-    opmath_t vals[kILP];
+    __shared__ out_opmath_t s_vals[512];
+    out_opmath_t vals[kILP];
     T r_x[kILP];
     for (int64_t i = 0; i < kILP; i++) {
-      vals[i] = opmath_t(0);
+      vals[i] = out_opmath_t(0);
       r_x[i] = T(0);
     }
 
@@ -78,7 +82,7 @@
         load_store(r_x, x, 0, i_start);
 #pragma unroll
         for (int ii = 0; ii < kILP; ii++) {
-          opmath_t next = static_cast<opmath_t>(r_x[ii]);
+          const auto next = static_cast<out_opmath_t>(r_x[ii]);
           if constexpr (norm_type == NormType::LInf) {
             vals[ii] = max_propagate_nan(vals[ii], ::abs(next));
           } else {
@@ -93,7 +97,7 @@
         for (int ii = 0; ii < kILP; ii++) {
           int i = i_start + threadIdx.x + ii * blockDim.x;
           if (i < n && i < chunk_size) {
-            opmath_t next = static_cast<opmath_t>(x[i]);
+            const auto next = static_cast<out_opmath_t>(x[i]);
             if constexpr (norm_type == NormType::LInf) {
               vals[ii] = max_propagate_nan(vals[ii], ::abs(next));
             } else {
@@ -104,7 +108,7 @@
       }
     }
 
-    auto val = opmath_t(0);
+    auto val = out_opmath_t(0);
     for (int i = 0; i < kILP; i++) {
       if constexpr (norm_type == NormType::LInf) {
         val = max_propagate_nan(val, vals[i]);
@@ -117,7 +121,7 @@
         : at::native::cuda_utils::BlockReduceMax(val, s_vals);
 
     if (threadIdx.x == 0) {
-      output_per_tensor
+      output_per_tensor_ptr
           [(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor +
            chunk_idx] = final_val;
     }
@@ -127,16 +131,17 @@
 template <
     typename T,
     NormType norm_type,
-    typename opmath_t = at::opmath_type<T>>
+    typename out_t,
+    typename out_opmath_t = at::opmath_type<out_t>>
 __global__ void lpnorm_cleanup(
-    const opmath_t* output_per_tensor,
+    const out_opmath_t* output_per_tensor,
     TensorListAddresses addr_struct,
     int max_chunks_per_tensor) {
-  __shared__ opmath_t vals[512];
+  __shared__ out_opmath_t vals[512];
 
-  const opmath_t* output_this_tensor =
+  const out_opmath_t* output_this_tensor =
       output_per_tensor + blockIdx.x * max_chunks_per_tensor;
-  opmath_t val = 0;
+  out_opmath_t val = 0;
   for (size_t i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) {
     if constexpr (norm_type == NormType::LInf) {
       val = max_propagate_nan(val, output_this_tensor[i]);
@@ -144,33 +149,85 @@
       val += output_this_tensor[i];
     }
   }
-  opmath_t final_val = norm_type == NormType::L1 || norm_type == NormType::L2
-      ? at::native::cuda_utils::BlockReduceSum<opmath_t>(val, vals)
+  out_opmath_t final_val =
+      norm_type == NormType::L1 || norm_type == NormType::L2
+      ? at::native::cuda_utils::BlockReduceSum<out_opmath_t>(val, vals)
       : at::native::cuda_utils::BlockReduceMax(val, vals);
   if (threadIdx.x == 0) {
-    *(T*)addr_struct.addresses[blockIdx.x] =
+    *(out_t*)addr_struct.addresses[blockIdx.x] =
         norm_type == NormType::L1 || norm_type == NormType::LInf
         ? final_val
         : ::sqrt(final_val);
   }
 }
 
+namespace {
+inline void check_foreach_norm_dtype(
+    optional<ScalarType> opt_dtype,
+    ScalarType self_dtype,
+    const char* const name) {
+  if (opt_dtype.has_value()) {
+    auto dtype = opt_dtype.value();
+    TORCH_CHECK(
+        isFloatingType(dtype) || isComplexType(dtype),
+        name,
+        ": dtype should"
+        " be floating point or complex, but got ",
+        dtype);
+    TORCH_CHECK(
+        isComplexType(self_dtype) == isComplexType(dtype),
+        name,
+        ": dtype should be ",
+        isComplexType(self_dtype) ? "complex" : "real",
+        " for ",
+        isComplexType(self_dtype) ? "complex" : "real",
+        " inputs, but got ",
+        dtype);
+    TORCH_CHECK(
+        promoteTypes(self_dtype, dtype) == dtype,
+        name,
+        ": the dtype of the input ",
+        "(",
+        self_dtype,
+        ") should be convertible ",
+        "without narrowing to the specified dtype (",
+        dtype,
+        ")");
+  }
+}
+} // anonymous namespace
+
+#define AT_DISPATCH_OUT_DTYPES(TYPE, NAME, ...)             \
+  AT_DISPATCH_SWITCH(                                       \
+      TYPE,                                                 \
+      NAME,                                                 \
+      AT_PRIVATE_CASE_TYPE_USING_HINT(                      \
+          at::ScalarType::Double, out_t, __VA_ARGS__)       \
+          AT_PRIVATE_CASE_TYPE_USING_HINT(                  \
+              at::ScalarType::Float, out_t, __VA_ARGS__)    \
+              AT_PRIVATE_CASE_TYPE_USING_HINT(              \
+                  at::ScalarType::Half, out_t, __VA_ARGS__) \
+                  AT_PRIVATE_CASE_TYPE_USING_HINT(          \
+                      at::ScalarType::BFloat16, out_t, __VA_ARGS__))
+
 // note(mkozuki): Why excluding Int and Complex from fast path
 // - Int: at::norm does not support.
 // - Complex: __shfl_down_sync does not support complex and foreach does not
 // support functions whose inputs dtypes and output dtype are different.
 std::vector<Tensor> foreach_tensor_norm_cuda(
     TensorList tensors,
-    const Scalar& ord) {
-  double p;
-  if (ord.isIntegral(false)) {
-    p = ord.to<int64_t>();
-  } else if (ord.isFloatingPoint()) {
-    p = ord.to<double>();
-  } else {
-    TORCH_CHECK(
-        false, "foreach_tensor_norm_cuda expects ord to be integer or float");
-  }
+    const Scalar& ord,
+    c10::optional<ScalarType> dtype) {
+  const auto p = [&]() -> double {
+    if (ord.isIntegral(false)) {
+      return ord.to<int64_t>();
+    } else if (ord.isFloatingPoint()) {
+      return ord.to<double>();
+    } else {
+      TORCH_CHECK(
+          false, "foreach_tensor_norm_cuda expects ord to be integer or float");
+    }
+  }();
   check_foreach_api_restrictions(tensors);
   const bool has_int_or_complex =
       std::any_of(tensors.begin(), tensors.end(), [](const auto& t) {
@@ -181,8 +238,10 @@
   if (!can_use_fast_route(tensors) || has_int_or_complex ||
       !(p == static_cast<double>(1) || p == static_cast<double>(2) ||
         p == std::numeric_limits<double>::infinity())) {
-    return foreach_tensor_norm_slow(tensors, ord);
+    return foreach_tensor_norm_slow(tensors, ord, dtype);
   }
+  check_foreach_norm_dtype(
+      dtype, tensors[0].scalar_type(), "_foreach_tensor_norm_cuda");
 
   const size_t ntensors = tensors.size();
   int max_chunks_per_tensor = -1;
@@ -195,143 +254,101 @@
     }
   }
   const auto options = tensors[0].options();
+  const ScalarType output_dtype =
+      dtype.has_value() ? dtype.value() : tensors[0].scalar_type();
+  const ScalarType output_per_tensor_dtype = toOpMathType(output_dtype);
   auto output_per_tensor = at::zeros(
       {static_cast<int64_t>(ntensors) * max_chunks_per_tensor},
-      options.dtype(toOpMathType(tensors[0].scalar_type())));
+      options.dtype(output_per_tensor_dtype));
 
   std::vector<at::Tensor> vec_res;
   vec_res.reserve(ntensors);
+  const auto res_option = options.dtype(output_dtype);
   for (const auto i : c10::irange(ntensors)) {
-    vec_res.push_back(at::empty({}, options));
+    vec_res.push_back(at::empty({}, res_option));
   }
 
   auto tensor_lists = std::vector<std::vector<Tensor>>{tensors.vec()};
-  if (p == static_cast<double>(1)) {
-    AT_DISPATCH_FLOATING_TYPES_AND2(
-        kHalf,
-        kBFloat16,
-        tensor_lists[0][0].scalar_type(),
-        "foreach_tensor_norm_cuda",
-        [&]() {
-          using opmath_t = typename at::opmath_type<scalar_t>;
-          multi_tensor_apply<1>(
-              tensor_lists,
-              LpNormFunctor<scalar_t, NormType::L1>(),
-              output_per_tensor.mutable_data_ptr<opmath_t>(),
-              max_chunks_per_tensor);
-          C10_CUDA_KERNEL_LAUNCH_CHECK();
-          const at::cuda::OptionalCUDAGuard device_guard(
-              device_of(output_per_tensor));
-          auto stream = at::cuda::getCurrentCUDAStream();
 
-          const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
-          for (const auto i : c10::irange(num_kernels)) {
-            const size_t num_tensors_this_kernel =
-                (i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
-                ? MAX_TENSORS_PER_KERNEL
-                : (ntensors % MAX_TENSORS_PER_KERNEL);
-
-            TensorListAddresses addr_struct;
-            for (const auto j : c10::irange(num_tensors_this_kernel)) {
-              addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
-                                             .mutable_data_ptr<scalar_t>();
-            }
-
-            lpnorm_cleanup<scalar_t, NormType::L1>
-                <<<num_tensors_this_kernel, 512, 0, stream>>>(
-                    output_per_tensor.const_data_ptr<opmath_t>() +
-                        i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
-                    addr_struct,
+  AT_DISPATCH_FLOATING_TYPES_AND2(
+      kHalf,
+      c10::kBFloat16,
+      tensor_lists[0][0].scalar_type(),
+      "foreach_tensor_norm_cuda_scalar_type",
+      [&]() {
+        // using opmath_t = typename at::opmath_type<scalar_t>;
+        AT_DISPATCH_OUT_DTYPES(
+            output_dtype, "foreach_tensor_norm_cuda_out_dtype", [&]() {
+              using out_opmath_t = typename at::opmath_type<out_t>;
+              if (p == static_cast<double>(1)) {
+                multi_tensor_apply<1>(
+                    tensor_lists,
+                    LpNormFunctor<scalar_t, NormType::L1, out_t>(),
+                    output_per_tensor.mutable_data_ptr<out_opmath_t>(),
                     max_chunks_per_tensor);
-            C10_CUDA_KERNEL_LAUNCH_CHECK();
-          }
-        });
-  } else if (p == static_cast<double>(2)) {
-    AT_DISPATCH_FLOATING_TYPES_AND2(
-        kHalf,
-        kBFloat16,
-        tensor_lists[0][0].scalar_type(),
-        "foreach_tensor_norm_cuda",
-        [&]() {
-          using opmath_t = typename at::opmath_type<scalar_t>;
-          multi_tensor_apply<1>(
-              tensor_lists,
-              LpNormFunctor<scalar_t, NormType::L2>(),
-              output_per_tensor.mutable_data_ptr<opmath_t>(),
-              max_chunks_per_tensor);
-          C10_CUDA_KERNEL_LAUNCH_CHECK();
-          const at::cuda::OptionalCUDAGuard device_guard(
-              device_of(output_per_tensor));
-          auto stream = at::cuda::getCurrentCUDAStream();
-
-          const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
-          for (const auto i : c10::irange(num_kernels)) {
-            const size_t num_tensors_this_kernel =
-                (i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
-                ? MAX_TENSORS_PER_KERNEL
-                : (ntensors % MAX_TENSORS_PER_KERNEL);
-
-            TensorListAddresses addr_struct;
-            for (const auto j : c10::irange(num_tensors_this_kernel)) {
-              addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
-                                             .mutable_data_ptr<scalar_t>();
-            }
-
-            lpnorm_cleanup<scalar_t, NormType::L2>
-                <<<num_tensors_this_kernel, 512, 0, stream>>>(
-                    output_per_tensor.const_data_ptr<opmath_t>() +
-                        i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
-                    addr_struct,
+              } else if (p == static_cast<double>(2)) {
+                multi_tensor_apply<1>(
+                    tensor_lists,
+                    LpNormFunctor<scalar_t, NormType::L2, out_t>(),
+                    output_per_tensor.mutable_data_ptr<out_opmath_t>(),
                     max_chunks_per_tensor);
-            C10_CUDA_KERNEL_LAUNCH_CHECK();
-          }
-        });
-  } else if (p == std::numeric_limits<double>::infinity()) {
-    AT_DISPATCH_FLOATING_TYPES_AND2(
-        kHalf,
-        kBFloat16,
-        tensor_lists[0][0].scalar_type(),
-        "foreach_tensor_norm_cuda",
-        [&]() {
-          using opmath_t = typename at::opmath_type<scalar_t>;
-          multi_tensor_apply<1>(
-              tensor_lists,
-              LpNormFunctor<scalar_t, NormType::LInf>(),
-              output_per_tensor.mutable_data_ptr<opmath_t>(),
-              max_chunks_per_tensor);
-          C10_CUDA_KERNEL_LAUNCH_CHECK();
-          const at::cuda::OptionalCUDAGuard device_guard(
-              device_of(output_per_tensor));
-          auto stream = at::cuda::getCurrentCUDAStream();
-
-          const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
-          for (const auto i : c10::irange(num_kernels)) {
-            const size_t num_tensors_this_kernel =
-                (i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
-                ? MAX_TENSORS_PER_KERNEL
-                : (ntensors % MAX_TENSORS_PER_KERNEL);
-
-            TensorListAddresses addr_struct;
-            for (const auto j : c10::irange(num_tensors_this_kernel)) {
-              addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
-                                             .mutable_data_ptr<scalar_t>();
-            }
-
-            lpnorm_cleanup<scalar_t, NormType::LInf>
-                <<<num_tensors_this_kernel, 512, 0, stream>>>(
-                    output_per_tensor.const_data_ptr<opmath_t>() +
-                        i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
-                    addr_struct,
+              } else if (p == std::numeric_limits<double>::infinity()) {
+                multi_tensor_apply<1>(
+                    tensor_lists,
+                    LpNormFunctor<scalar_t, NormType::LInf, out_t>(),
+                    output_per_tensor.mutable_data_ptr<out_opmath_t>(),
                     max_chunks_per_tensor);
-            C10_CUDA_KERNEL_LAUNCH_CHECK();
-          }
-        });
-  } else {
-    TORCH_CHECK(
-        false,
-        "foreach_tensor_norm_cuda fast path got unexpected ord value: ",
-        p);
-  }
+              }
+              C10_CUDA_KERNEL_LAUNCH_CHECK();
+              const at::cuda::OptionalCUDAGuard device_guard(
+                  device_of(output_per_tensor));
+              auto stream = at::cuda::getCurrentCUDAStream();
+
+              const size_t num_kernels =
+                  ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
+              for (const auto i : c10::irange(num_kernels)) {
+                const size_t num_tensors_this_kernel =
+                    (i < num_kernels - 1 ||
+                     ntensors % MAX_TENSORS_PER_KERNEL == 0)
+                    ? MAX_TENSORS_PER_KERNEL
+                    : (ntensors % MAX_TENSORS_PER_KERNEL);
+
+                TensorListAddresses addr_struct;
+                for (const auto j : c10::irange(num_tensors_this_kernel)) {
+                  addr_struct.addresses[j] =
+                      vec_res[i * MAX_TENSORS_PER_KERNEL + j]
+                          .mutable_data_ptr<out_t>();
+                }
+
+                if (p == static_cast<double>(1)) {
+                  lpnorm_cleanup<scalar_t, NormType::L1, out_t>
+                      <<<num_tensors_this_kernel, 512, 0, stream>>>(
+                          output_per_tensor.const_data_ptr<out_opmath_t>() +
+                              i * MAX_TENSORS_PER_KERNEL *
+                                  max_chunks_per_tensor,
+                          addr_struct,
+                          max_chunks_per_tensor);
+                } else if (p == static_cast<double>(2)) {
+                  lpnorm_cleanup<scalar_t, NormType::L2, out_t>
+                      <<<num_tensors_this_kernel, 512, 0, stream>>>(
+                          output_per_tensor.const_data_ptr<out_opmath_t>() +
+                              i * MAX_TENSORS_PER_KERNEL *
+                                  max_chunks_per_tensor,
+                          addr_struct,
+                          max_chunks_per_tensor);
+                } else if (p == std::numeric_limits<double>::infinity()) {
+                  lpnorm_cleanup<scalar_t, NormType::LInf, out_t>
+                      <<<num_tensors_this_kernel, 512, 0, stream>>>(
+                          output_per_tensor.const_data_ptr<out_opmath_t>() +
+                              i * MAX_TENSORS_PER_KERNEL *
+                                  max_chunks_per_tensor,
+                          addr_struct,
+                          max_chunks_per_tensor);
+                }
+                C10_CUDA_KERNEL_LAUNCH_CHECK();
+              }
+            });
+      });
 
   // correctly assign values to only non-empty slots, as the empty slots should
   // get skipped
@@ -343,10 +360,12 @@
       result.emplace_back(vec_res[i]);
       i++;
     } else {
-      result.emplace_back(at::zeros({}, options));
+      result.emplace_back(at::zeros({}, res_option));
     }
   }
   return result;
 }
 
+#undef AT_DISPATCH_OUT_DTYPES
+
 } // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 10d8b1a..6226ca1 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -11134,7 +11134,7 @@
     CUDA: foreach_tensor_neg_cuda_
   autogen: _foreach_neg.out
 
-- func: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[]
+- func: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[]
   device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
   variants: function
   dispatch:
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 1fb92aa..f9e8705 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -3154,6 +3154,6 @@
 
 # note(crcrpar): forward-mode AD is tricky for a simple string replace to handle:
 #   formula.replace("p", "ord") produces `norm_jvord(self_ord, self_t, ord, result)`
-- name: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[]
+- name: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[]
   self: norm_backward(grads[i], self[i], ord, result[i])
   result: norm_jvp(self_p, self_t, ord, result[i])
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 93e45bf..89b452b 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -3177,7 +3177,6 @@
         aten._foreach_log1p,
         aten._foreach_log2,
         aten._foreach_neg,
-        aten._foreach_norm,
         aten._foreach_reciprocal,
         aten._foreach_round,
         aten._foreach_sigmoid,
@@ -3307,6 +3306,30 @@
     return [torch.empty_like(e) for e in exponent]
 
 
+@register_meta([aten._foreach_norm])
+def meta__foreach_norm(self, ord=2, dtype=None):
+    torch._check(
+        isinstance(self, list),
+        lambda: f"self must be a tensor list but got {type(self)}",
+    )
+    torch._check(
+        isinstance(ord, Number),
+        lambda: f"ord must be an integer but got {type(ord)}",
+    )
+    torch._check(
+        dtype is None or isinstance(dtype, torch.dtype),
+        lambda: f"dtype must be either None or torch.dtype but got {type(dtype)}",
+    )
+    return [
+        torch.empty(
+            (),
+            device=t.device,
+            dtype=t.dtype.to_real() if dtype is None else dtype.to_real(),
+        )
+        for t in self
+    ]
+
+
 def _check_foreach_binop_tensor_lists(self, other):
     torch._check(
         isinstance(self, List) and isinstance(other, List),
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index d456bb5..cec308b 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -9373,12 +9373,16 @@
         _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
         _foreach_inputs_kwargs["requires_grad"] = requires_grad
 
-        for num_tensors, ord in product(num_input_tensors, (0, 1, 2, -1, -2, float('inf'), float('-inf'))):
+        for num_tensors, ord, out_dtype in product(
+            num_input_tensors,
+            (0, 1, 2, -1, -2, float('inf'), float('-inf')),
+            (None,) + (torch.complex128,) if dtype in complex_types() else (torch.float64,),
+        ):
             input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs)
             disable_fastpath = True
             if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
                 disable_fastpath = False
-            yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath)
+            yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath, dtype=out_dtype)
 
         # Also test nan propagation with a single tensor, but skip autograd testing
         if not requires_grad:
@@ -9398,8 +9402,6 @@
                 yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath)
 
 
-
-
 class foreach_lerp_sample_func(foreach_inputs_sample_func):
     def _sample_rightmost_arg(self, opinfo, rightmost_arg_type, device, dtype, num_tensors, **_foreach_inputs_kwargs):
         if rightmost_arg_type == ForeachRightmostArgType.TensorList: