Support `dtype` kwarg in `_foreach_norm` (#125665)
Fixes #125040
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125665
Approved by: https://github.com/janeyx99
diff --git a/aten/src/ATen/native/ForeachOpsKernels.cpp b/aten/src/ATen/native/ForeachOpsKernels.cpp
index 790c7a5..34c71a8 100644
--- a/aten/src/ATen/native/ForeachOpsKernels.cpp
+++ b/aten/src/ATen/native/ForeachOpsKernels.cpp
@@ -438,11 +438,12 @@
std::vector<Tensor> foreach_tensor_norm_slow(
TensorList tensors,
- const Scalar& ord) {
+ const Scalar& ord,
+ c10::optional<ScalarType> dtype) {
check_foreach_api_restrictions(tensors);
std::vector<Tensor> result;
for (const auto& t : tensors) {
- result.emplace_back(at::linalg_vector_norm(t, ord));
+ result.emplace_back(at::linalg_vector_norm(t, ord, {}, false, dtype));
}
return result;
}
diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu
index eed9656..885c5d0 100644
--- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu
+++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu
@@ -1,3 +1,6 @@
+#include <c10/core/ScalarType.h>
+#include <c10/util/irange.h>
+#include <limits>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
@@ -44,15 +47,16 @@
template <
typename T,
NormType norm_type,
+ typename out_t,
int depth = 1,
int r_args_depth = 1,
int res_arg_index = 0>
struct LpNormFunctor {
- using opmath_t = typename at::opmath_type<T>;
+ using out_opmath_t = typename at::opmath_type<out_t>;
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth>& tl,
- opmath_t* output_per_tensor,
+ out_opmath_t* output_per_tensor_ptr,
const int max_chunks_per_tensor) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@@ -62,11 +66,11 @@
x += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
- __shared__ opmath_t s_vals[512];
- opmath_t vals[kILP];
+ __shared__ out_opmath_t s_vals[512];
+ out_opmath_t vals[kILP];
T r_x[kILP];
for (int64_t i = 0; i < kILP; i++) {
- vals[i] = opmath_t(0);
+ vals[i] = out_opmath_t(0);
r_x[i] = T(0);
}
@@ -78,7 +82,7 @@
load_store(r_x, x, 0, i_start);
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
- opmath_t next = static_cast<opmath_t>(r_x[ii]);
+ const auto next = static_cast<out_opmath_t>(r_x[ii]);
if constexpr (norm_type == NormType::LInf) {
vals[ii] = max_propagate_nan(vals[ii], ::abs(next));
} else {
@@ -93,7 +97,7 @@
for (int ii = 0; ii < kILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
- opmath_t next = static_cast<opmath_t>(x[i]);
+ const auto next = static_cast<out_opmath_t>(x[i]);
if constexpr (norm_type == NormType::LInf) {
vals[ii] = max_propagate_nan(vals[ii], ::abs(next));
} else {
@@ -104,7 +108,7 @@
}
}
- auto val = opmath_t(0);
+ auto val = out_opmath_t(0);
for (int i = 0; i < kILP; i++) {
if constexpr (norm_type == NormType::LInf) {
val = max_propagate_nan(val, vals[i]);
@@ -117,7 +121,7 @@
: at::native::cuda_utils::BlockReduceMax(val, s_vals);
if (threadIdx.x == 0) {
- output_per_tensor
+ output_per_tensor_ptr
[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor +
chunk_idx] = final_val;
}
@@ -127,16 +131,17 @@
template <
typename T,
NormType norm_type,
- typename opmath_t = at::opmath_type<T>>
+ typename out_t,
+ typename out_opmath_t = at::opmath_type<out_t>>
__global__ void lpnorm_cleanup(
- const opmath_t* output_per_tensor,
+ const out_opmath_t* output_per_tensor,
TensorListAddresses addr_struct,
int max_chunks_per_tensor) {
- __shared__ opmath_t vals[512];
+ __shared__ out_opmath_t vals[512];
- const opmath_t* output_this_tensor =
+ const out_opmath_t* output_this_tensor =
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
- opmath_t val = 0;
+ out_opmath_t val = 0;
for (size_t i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) {
if constexpr (norm_type == NormType::LInf) {
val = max_propagate_nan(val, output_this_tensor[i]);
@@ -144,33 +149,85 @@
val += output_this_tensor[i];
}
}
- opmath_t final_val = norm_type == NormType::L1 || norm_type == NormType::L2
- ? at::native::cuda_utils::BlockReduceSum<opmath_t>(val, vals)
+ out_opmath_t final_val =
+ norm_type == NormType::L1 || norm_type == NormType::L2
+ ? at::native::cuda_utils::BlockReduceSum<out_opmath_t>(val, vals)
: at::native::cuda_utils::BlockReduceMax(val, vals);
if (threadIdx.x == 0) {
- *(T*)addr_struct.addresses[blockIdx.x] =
+ *(out_t*)addr_struct.addresses[blockIdx.x] =
norm_type == NormType::L1 || norm_type == NormType::LInf
? final_val
: ::sqrt(final_val);
}
}
+namespace {
+inline void check_foreach_norm_dtype(
+ optional<ScalarType> opt_dtype,
+ ScalarType self_dtype,
+ const char* const name) {
+ if (opt_dtype.has_value()) {
+ auto dtype = opt_dtype.value();
+ TORCH_CHECK(
+ isFloatingType(dtype) || isComplexType(dtype),
+ name,
+ ": dtype should"
+ " be floating point or complex, but got ",
+ dtype);
+ TORCH_CHECK(
+ isComplexType(self_dtype) == isComplexType(dtype),
+ name,
+ ": dtype should be ",
+ isComplexType(self_dtype) ? "complex" : "real",
+ " for ",
+ isComplexType(self_dtype) ? "complex" : "real",
+ " inputs, but got ",
+ dtype);
+ TORCH_CHECK(
+ promoteTypes(self_dtype, dtype) == dtype,
+ name,
+ ": the dtype of the input ",
+ "(",
+ self_dtype,
+ ") should be convertible ",
+ "without narrowing to the specified dtype (",
+ dtype,
+ ")");
+ }
+}
+} // anonymous namespace
+
+#define AT_DISPATCH_OUT_DTYPES(TYPE, NAME, ...) \
+ AT_DISPATCH_SWITCH( \
+ TYPE, \
+ NAME, \
+ AT_PRIVATE_CASE_TYPE_USING_HINT( \
+ at::ScalarType::Double, out_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE_USING_HINT( \
+ at::ScalarType::Float, out_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE_USING_HINT( \
+ at::ScalarType::Half, out_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE_USING_HINT( \
+ at::ScalarType::BFloat16, out_t, __VA_ARGS__))
+
// note(mkozuki): Why excluding Int and Complex from fast path
// - Int: at::norm does not support.
// - Complex: __shfl_down_sync does not support complex and foreach does not
// support functions whose inputs dtypes and output dtype are different.
std::vector<Tensor> foreach_tensor_norm_cuda(
TensorList tensors,
- const Scalar& ord) {
- double p;
- if (ord.isIntegral(false)) {
- p = ord.to<int64_t>();
- } else if (ord.isFloatingPoint()) {
- p = ord.to<double>();
- } else {
- TORCH_CHECK(
- false, "foreach_tensor_norm_cuda expects ord to be integer or float");
- }
+ const Scalar& ord,
+ c10::optional<ScalarType> dtype) {
+ const auto p = [&]() -> double {
+ if (ord.isIntegral(false)) {
+ return ord.to<int64_t>();
+ } else if (ord.isFloatingPoint()) {
+ return ord.to<double>();
+ } else {
+ TORCH_CHECK(
+ false, "foreach_tensor_norm_cuda expects ord to be integer or float");
+ }
+ }();
check_foreach_api_restrictions(tensors);
const bool has_int_or_complex =
std::any_of(tensors.begin(), tensors.end(), [](const auto& t) {
@@ -181,8 +238,10 @@
if (!can_use_fast_route(tensors) || has_int_or_complex ||
!(p == static_cast<double>(1) || p == static_cast<double>(2) ||
p == std::numeric_limits<double>::infinity())) {
- return foreach_tensor_norm_slow(tensors, ord);
+ return foreach_tensor_norm_slow(tensors, ord, dtype);
}
+ check_foreach_norm_dtype(
+ dtype, tensors[0].scalar_type(), "_foreach_tensor_norm_cuda");
const size_t ntensors = tensors.size();
int max_chunks_per_tensor = -1;
@@ -195,143 +254,101 @@
}
}
const auto options = tensors[0].options();
+ const ScalarType output_dtype =
+ dtype.has_value() ? dtype.value() : tensors[0].scalar_type();
+ const ScalarType output_per_tensor_dtype = toOpMathType(output_dtype);
auto output_per_tensor = at::zeros(
{static_cast<int64_t>(ntensors) * max_chunks_per_tensor},
- options.dtype(toOpMathType(tensors[0].scalar_type())));
+ options.dtype(output_per_tensor_dtype));
std::vector<at::Tensor> vec_res;
vec_res.reserve(ntensors);
+ const auto res_option = options.dtype(output_dtype);
for (const auto i : c10::irange(ntensors)) {
- vec_res.push_back(at::empty({}, options));
+ vec_res.push_back(at::empty({}, res_option));
}
auto tensor_lists = std::vector<std::vector<Tensor>>{tensors.vec()};
- if (p == static_cast<double>(1)) {
- AT_DISPATCH_FLOATING_TYPES_AND2(
- kHalf,
- kBFloat16,
- tensor_lists[0][0].scalar_type(),
- "foreach_tensor_norm_cuda",
- [&]() {
- using opmath_t = typename at::opmath_type<scalar_t>;
- multi_tensor_apply<1>(
- tensor_lists,
- LpNormFunctor<scalar_t, NormType::L1>(),
- output_per_tensor.mutable_data_ptr<opmath_t>(),
- max_chunks_per_tensor);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- const at::cuda::OptionalCUDAGuard device_guard(
- device_of(output_per_tensor));
- auto stream = at::cuda::getCurrentCUDAStream();
- const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
- for (const auto i : c10::irange(num_kernels)) {
- const size_t num_tensors_this_kernel =
- (i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
- ? MAX_TENSORS_PER_KERNEL
- : (ntensors % MAX_TENSORS_PER_KERNEL);
-
- TensorListAddresses addr_struct;
- for (const auto j : c10::irange(num_tensors_this_kernel)) {
- addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
- .mutable_data_ptr<scalar_t>();
- }
-
- lpnorm_cleanup<scalar_t, NormType::L1>
- <<<num_tensors_this_kernel, 512, 0, stream>>>(
- output_per_tensor.const_data_ptr<opmath_t>() +
- i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
- addr_struct,
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ kHalf,
+ c10::kBFloat16,
+ tensor_lists[0][0].scalar_type(),
+ "foreach_tensor_norm_cuda_scalar_type",
+ [&]() {
+ // using opmath_t = typename at::opmath_type<scalar_t>;
+ AT_DISPATCH_OUT_DTYPES(
+ output_dtype, "foreach_tensor_norm_cuda_out_dtype", [&]() {
+ using out_opmath_t = typename at::opmath_type<out_t>;
+ if (p == static_cast<double>(1)) {
+ multi_tensor_apply<1>(
+ tensor_lists,
+ LpNormFunctor<scalar_t, NormType::L1, out_t>(),
+ output_per_tensor.mutable_data_ptr<out_opmath_t>(),
max_chunks_per_tensor);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- });
- } else if (p == static_cast<double>(2)) {
- AT_DISPATCH_FLOATING_TYPES_AND2(
- kHalf,
- kBFloat16,
- tensor_lists[0][0].scalar_type(),
- "foreach_tensor_norm_cuda",
- [&]() {
- using opmath_t = typename at::opmath_type<scalar_t>;
- multi_tensor_apply<1>(
- tensor_lists,
- LpNormFunctor<scalar_t, NormType::L2>(),
- output_per_tensor.mutable_data_ptr<opmath_t>(),
- max_chunks_per_tensor);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- const at::cuda::OptionalCUDAGuard device_guard(
- device_of(output_per_tensor));
- auto stream = at::cuda::getCurrentCUDAStream();
-
- const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
- for (const auto i : c10::irange(num_kernels)) {
- const size_t num_tensors_this_kernel =
- (i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
- ? MAX_TENSORS_PER_KERNEL
- : (ntensors % MAX_TENSORS_PER_KERNEL);
-
- TensorListAddresses addr_struct;
- for (const auto j : c10::irange(num_tensors_this_kernel)) {
- addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
- .mutable_data_ptr<scalar_t>();
- }
-
- lpnorm_cleanup<scalar_t, NormType::L2>
- <<<num_tensors_this_kernel, 512, 0, stream>>>(
- output_per_tensor.const_data_ptr<opmath_t>() +
- i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
- addr_struct,
+ } else if (p == static_cast<double>(2)) {
+ multi_tensor_apply<1>(
+ tensor_lists,
+ LpNormFunctor<scalar_t, NormType::L2, out_t>(),
+ output_per_tensor.mutable_data_ptr<out_opmath_t>(),
max_chunks_per_tensor);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- });
- } else if (p == std::numeric_limits<double>::infinity()) {
- AT_DISPATCH_FLOATING_TYPES_AND2(
- kHalf,
- kBFloat16,
- tensor_lists[0][0].scalar_type(),
- "foreach_tensor_norm_cuda",
- [&]() {
- using opmath_t = typename at::opmath_type<scalar_t>;
- multi_tensor_apply<1>(
- tensor_lists,
- LpNormFunctor<scalar_t, NormType::LInf>(),
- output_per_tensor.mutable_data_ptr<opmath_t>(),
- max_chunks_per_tensor);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- const at::cuda::OptionalCUDAGuard device_guard(
- device_of(output_per_tensor));
- auto stream = at::cuda::getCurrentCUDAStream();
-
- const size_t num_kernels = ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
- for (const auto i : c10::irange(num_kernels)) {
- const size_t num_tensors_this_kernel =
- (i < num_kernels - 1 || ntensors % MAX_TENSORS_PER_KERNEL == 0)
- ? MAX_TENSORS_PER_KERNEL
- : (ntensors % MAX_TENSORS_PER_KERNEL);
-
- TensorListAddresses addr_struct;
- for (const auto j : c10::irange(num_tensors_this_kernel)) {
- addr_struct.addresses[j] = vec_res[i * MAX_TENSORS_PER_KERNEL + j]
- .mutable_data_ptr<scalar_t>();
- }
-
- lpnorm_cleanup<scalar_t, NormType::LInf>
- <<<num_tensors_this_kernel, 512, 0, stream>>>(
- output_per_tensor.const_data_ptr<opmath_t>() +
- i * MAX_TENSORS_PER_KERNEL * max_chunks_per_tensor,
- addr_struct,
+ } else if (p == std::numeric_limits<double>::infinity()) {
+ multi_tensor_apply<1>(
+ tensor_lists,
+ LpNormFunctor<scalar_t, NormType::LInf, out_t>(),
+ output_per_tensor.mutable_data_ptr<out_opmath_t>(),
max_chunks_per_tensor);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- });
- } else {
- TORCH_CHECK(
- false,
- "foreach_tensor_norm_cuda fast path got unexpected ord value: ",
- p);
- }
+ }
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ const at::cuda::OptionalCUDAGuard device_guard(
+ device_of(output_per_tensor));
+ auto stream = at::cuda::getCurrentCUDAStream();
+
+ const size_t num_kernels =
+ ceil_div(ntensors, MAX_TENSORS_PER_KERNEL);
+ for (const auto i : c10::irange(num_kernels)) {
+ const size_t num_tensors_this_kernel =
+ (i < num_kernels - 1 ||
+ ntensors % MAX_TENSORS_PER_KERNEL == 0)
+ ? MAX_TENSORS_PER_KERNEL
+ : (ntensors % MAX_TENSORS_PER_KERNEL);
+
+ TensorListAddresses addr_struct;
+ for (const auto j : c10::irange(num_tensors_this_kernel)) {
+ addr_struct.addresses[j] =
+ vec_res[i * MAX_TENSORS_PER_KERNEL + j]
+ .mutable_data_ptr<out_t>();
+ }
+
+ if (p == static_cast<double>(1)) {
+ lpnorm_cleanup<scalar_t, NormType::L1, out_t>
+ <<<num_tensors_this_kernel, 512, 0, stream>>>(
+ output_per_tensor.const_data_ptr<out_opmath_t>() +
+ i * MAX_TENSORS_PER_KERNEL *
+ max_chunks_per_tensor,
+ addr_struct,
+ max_chunks_per_tensor);
+ } else if (p == static_cast<double>(2)) {
+ lpnorm_cleanup<scalar_t, NormType::L2, out_t>
+ <<<num_tensors_this_kernel, 512, 0, stream>>>(
+ output_per_tensor.const_data_ptr<out_opmath_t>() +
+ i * MAX_TENSORS_PER_KERNEL *
+ max_chunks_per_tensor,
+ addr_struct,
+ max_chunks_per_tensor);
+ } else if (p == std::numeric_limits<double>::infinity()) {
+ lpnorm_cleanup<scalar_t, NormType::LInf, out_t>
+ <<<num_tensors_this_kernel, 512, 0, stream>>>(
+ output_per_tensor.const_data_ptr<out_opmath_t>() +
+ i * MAX_TENSORS_PER_KERNEL *
+ max_chunks_per_tensor,
+ addr_struct,
+ max_chunks_per_tensor);
+ }
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ }
+ });
+ });
// correctly assign values to only non-empty slots, as the empty slots should
// get skipped
@@ -343,10 +360,12 @@
result.emplace_back(vec_res[i]);
i++;
} else {
- result.emplace_back(at::zeros({}, options));
+ result.emplace_back(at::zeros({}, res_option));
}
}
return result;
}
+#undef AT_DISPATCH_OUT_DTYPES
+
} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 10d8b1a..6226ca1 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -11134,7 +11134,7 @@
CUDA: foreach_tensor_neg_cuda_
autogen: _foreach_neg.out
-- func: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[]
+- func: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 1fb92aa..f9e8705 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -3154,6 +3154,6 @@
# note(crcrpar): forward-mode AD is tricky for a simple string replace to handle:
# formula.replace("p", "ord") produces `norm_jvord(self_ord, self_t, ord, result)`
-- name: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[]
+- name: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[]
self: norm_backward(grads[i], self[i], ord, result[i])
result: norm_jvp(self_p, self_t, ord, result[i])
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 93e45bf..89b452b 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -3177,7 +3177,6 @@
aten._foreach_log1p,
aten._foreach_log2,
aten._foreach_neg,
- aten._foreach_norm,
aten._foreach_reciprocal,
aten._foreach_round,
aten._foreach_sigmoid,
@@ -3307,6 +3306,30 @@
return [torch.empty_like(e) for e in exponent]
+@register_meta([aten._foreach_norm])
+def meta__foreach_norm(self, ord=2, dtype=None):
+ torch._check(
+ isinstance(self, list),
+ lambda: f"self must be a tensor list but got {type(self)}",
+ )
+ torch._check(
+ isinstance(ord, Number),
+ lambda: f"ord must be an integer but got {type(ord)}",
+ )
+ torch._check(
+ dtype is None or isinstance(dtype, torch.dtype),
+ lambda: f"dtype must be either None or torch.dtype but got {type(dtype)}",
+ )
+ return [
+ torch.empty(
+ (),
+ device=t.device,
+ dtype=t.dtype.to_real() if dtype is None else dtype.to_real(),
+ )
+ for t in self
+ ]
+
+
def _check_foreach_binop_tensor_lists(self, other):
torch._check(
isinstance(self, List) and isinstance(other, List),
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index d456bb5..cec308b 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -9373,12 +9373,16 @@
_foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
_foreach_inputs_kwargs["requires_grad"] = requires_grad
- for num_tensors, ord in product(num_input_tensors, (0, 1, 2, -1, -2, float('inf'), float('-inf'))):
+ for num_tensors, ord, out_dtype in product(
+ num_input_tensors,
+ (0, 1, 2, -1, -2, float('inf'), float('-inf')),
+ (None,) + (torch.complex128,) if dtype in complex_types() else (torch.float64,),
+ ):
input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs)
disable_fastpath = True
if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
disable_fastpath = False
- yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath)
+ yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath, dtype=out_dtype)
# Also test nan propagation with a single tensor, but skip autograd testing
if not requires_grad:
@@ -9398,8 +9402,6 @@
yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath)
-
-
class foreach_lerp_sample_func(foreach_inputs_sample_func):
def _sample_rightmost_arg(self, opinfo, rightmost_arg_type, device, dtype, num_tensors, **_foreach_inputs_kwargs):
if rightmost_arg_type == ForeachRightmostArgType.TensorList: