Revert "`_foreach_copy` with different src/dst dtypes (#121717)"
This reverts commit da2a9a05127c2b44e447e734d99e727d856cb36f.
Reverted https://github.com/pytorch/pytorch/pull/121717 on behalf of https://github.com/janeyx99 due to Causing IMAs on V100s internally :C ([comment](https://github.com/pytorch/pytorch/pull/121717#issuecomment-2025553295))
diff --git a/aten/src/ATen/native/ForeachUtils.h b/aten/src/ATen/native/ForeachUtils.h
index d7a1449..9c22c35 100644
--- a/aten/src/ATen/native/ForeachUtils.h
+++ b/aten/src/ATen/native/ForeachUtils.h
@@ -102,13 +102,12 @@
// corresponding tensors (aligning in index across the tensorLists) share the
// same device and dtype.
inline bool _check_tensors_share_device_and_dtype(
- ArrayRef<TensorList> tensorLists,
- const bool skip_dtype_check = false) {
+ ArrayRef<TensorList> tensorLists) {
const auto expected_dtype = tensorLists[0][0].dtype();
const auto expected_device = tensorLists[0][0].device();
auto is_tensor_okay = [&](const Tensor& tensor) {
- return (skip_dtype_check || tensor.dtype() == expected_dtype) &&
+ return tensor.dtype() == expected_dtype &&
tensor.device() == expected_device && tensor.layout() == at::kStrided &&
tensor.is_non_overlapping_and_dense();
};
diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu
index 035cb8a..366049a 100644
--- a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu
+++ b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu
@@ -4,7 +4,6 @@
#include <ATen/native/cuda/ForeachFunctors.cuh>
#include <ATen/native/cuda/ForeachMinMaxFunctors.cuh>
#include <functional>
-#include <type_traits>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
@@ -251,152 +250,20 @@
power_functor,
/*division_op*/ true);
-template <typename dst_t, typename src_t = dst_t>
-struct Copy {
- __device__ __forceinline__ dst_t operator()(const src_t& x) {
- return static_cast<dst_t>(x);
+template <typename T>
+struct Identity {
+ __device__ __forceinline__ T operator()(const T& x) {
+ return x;
}
};
-template <typename dst_t>
-struct Copy<dst_t, c10::complex<double>> {
- __device__ __forceinline__ dst_t operator()(const c10::complex<double>& x) {
- if constexpr (!(std::is_same_v<dst_t, c10::complex<double>> ||
- std::is_same_v<dst_t, c10::complex<float>>)) {
- return static_cast<dst_t>(x.real());
- } else {
- return static_cast<dst_t>(x);
- }
- }
-};
-
-template <typename dst_t>
-struct Copy<dst_t, c10::complex<float>> {
- __device__ __forceinline__ dst_t operator()(const c10::complex<float>& x) {
- if constexpr (!(std::is_same_v<dst_t, c10::complex<double>> ||
- std::is_same_v<dst_t, c10::complex<float>>)) {
- return static_cast<dst_t>(x.real());
- } else {
- return static_cast<dst_t>(x);
- }
- }
-};
-
-#define AT_DISPATCH_SOURCE_TYPES(TYPE, NAME, ...) \
- AT_DISPATCH_SWITCH( \
- TYPE, \
- NAME, \
- AT_PRIVATE_CASE_TYPE_USING_HINT( \
- at::ScalarType::Byte, src_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE_USING_HINT( \
- at::ScalarType::Char, src_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE_USING_HINT( \
- at::ScalarType::Long, src_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE_USING_HINT( \
- at::ScalarType::Short, src_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE_USING_HINT( \
- at::ScalarType::Double, src_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE_USING_HINT( \
- at::ScalarType::Float, src_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE_USING_HINT( \
- at::ScalarType::ComplexDouble, \
- src_t, \
- __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE_USING_HINT( \
- at::ScalarType::ComplexFloat, \
- src_t, \
- __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE_USING_HINT( \
- at::ScalarType::Half, \
- src_t, \
- __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE_USING_HINT( \
- at::ScalarType::BFloat16, \
- src_t, \
- __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE_USING_HINT( \
- at::ScalarType::Bool, \
- src_t, \
- __VA_ARGS__))
-
-namespace {
-
-template <
- typename T,
- typename src_t,
- int depth,
- int r_args_depth,
- int res_arg_index>
-struct CopyFunctor {
- static_assert(depth == 2 && r_args_depth == 1 && res_arg_index == 1);
- template <typename Op>
- __device__ __forceinline__ void operator()(
- int chunk_size,
- TensorListMetadata<depth>& tl,
- Op op) {
- const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
- const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
- auto n = tl.numel_for_tensor[tensor_loc];
-
- src_t* src_ptr = (src_t*)tl.addresses[0][tensor_loc];
- src_ptr += chunk_idx * chunk_size;
- T* self_ptr = (T*)tl.addresses[1][tensor_loc];
- self_ptr += chunk_idx * chunk_size;
-
- const bool all_aligned{is_aligned(src_ptr) && is_aligned(self_ptr)};
-
- n -= chunk_idx * chunk_size;
- src_t src_args[kILP];
- T r_args[kILP];
-
- // to make things simple, we put aligned case in a different code path
- if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
- for (int64_t i_start = threadIdx.x;
- i_start * kILP < n && i_start * kILP < chunk_size;
- i_start += blockDim.x) {
- // load
- load_store(src_args, src_ptr, 0, i_start);
-#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
- r_args[ii] = static_cast<T>(op(src_args[ii]));
- }
- // store
- load_store(self_ptr, r_args, i_start, 0);
- }
- } else {
- for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
- i_start += blockDim.x * kILP) {
-#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
- const auto i = i_start + threadIdx.x + ii * blockDim.x;
- src_args[ii] = src_ptr[i];
- }
-#pragma unroll
- for (int ii = 0; ii < kILP; ii++) {
- r_args[ii] = static_cast<T>(op(src_args[ii]));
- }
- store_args(self_ptr, r_args, i_start, chunk_size, n);
- }
- }
- }
-};
-
-} // anonymous namespace
-
void foreach_tensor_copy_list_kernel_cuda_(
TensorList self,
TensorList src,
const bool non_blocking) {
check_foreach_api_restrictions(self, src);
- if (!(_check_tensors_share_device_and_dtype(
- {self, src}, /* skip_dtype_check */ true) &&
- std::all_of(
- src.cbegin(),
- src.cend(),
- [&](const auto& t) -> bool {
- return t.dtype() == src[0].dtype();
- }) &&
- _check_tensors_share_sizes_and_strides({self, src}))) {
+ if (!can_use_fast_route(
+ self, src, /* does_op_promote_integer_inputs_to_float */ false)) {
return at::native::foreach_tensor_copy_list_kernel_slow_(
self, src, non_blocking);
}
@@ -411,38 +278,16 @@
"foreach_tensor_copy",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
- AT_DISPATCH_SOURCE_TYPES(src[0].scalar_type(), "foreach_tensor_copy", [&] {
- if constexpr (std::is_same_v<scalar_t, src_t>) {
- multi_tensor_apply<2>(
- tensor_lists,
- UnaryOpFunctor<
- scalar_t,
- /* depth */ 2,
- /* r_args_depth */ 1,
- /* res_arg_index */ 1>(),
- Copy<opmath_t, opmath_t>());
- } else {
- // Ref:
- // https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301
- if (!self[0].is_complex() && src[0].is_complex()) {
- TORCH_WARN_ONCE(
- "Casting complex values to real discards the imaginary part");
- }
- multi_tensor_apply<2>(
- tensor_lists,
- CopyFunctor<
- scalar_t,
- src_t,
- /* depth */ 2,
- /* r_args_depth */ 1,
- /* res_arg_index */ 1>(),
- Copy<scalar_t, src_t>());
- }
- });
+ multi_tensor_apply<2>(
+ tensor_lists,
+ UnaryOpFunctor<
+ scalar_t,
+ /* depth */ 2,
+ /* r_args_depth */ 1,
+ /* res_arg_index */ 1>(),
+ Identity<opmath_t>());
});
increment_version(self);
}
-#undef AT_DISPATCH_SOURCE_TYPES
-
} // namespace at::native
diff --git a/test/test_foreach.py b/test/test_foreach.py
index 11b3215..6a024e5 100644
--- a/test/test_foreach.py
+++ b/test/test_foreach.py
@@ -838,20 +838,6 @@
copy_(t, s, non_blocking)
self.assertEqual(ref_input, sample.input)
- @onlyCUDA
- @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
- def test_foreach_copy_with_multi_dtypes(self, device, dtype, op):
- # check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_
- foreach_copy_ = ForeachFuncWrapper(op.inplace_variant)
- for sample in op.sample_inputs(device, dtype, noncontiguous=False):
- for src_dtype in floating_types_and(torch.half, torch.bfloat16):
- if src_dtype == dtype:
- continue
- self_tensors = [t.clone() for t in sample.input]
- src_tensors = [t.to(src_dtype) for t in self_tensors]
- out = foreach_copy_((self_tensors, src_tensors), is_cuda=True, expect_fastpath=True)
- self.assertEqual(out, [torch.empty_like(t).copy_(s) for t, s in zip(self_tensors, src_tensors)])
-
# Test reverse-mode & forward-mode AD if supported.
@onlyCUDA
@ops(