remove fp16 support from cpu linalg functions
fp16 on cpu produces slow and inaccurate results, see #69969
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75647
Approved by: https://github.com/Lezcano, https://github.com/mruberry
diff --git a/aten/src/ATen/native/Blas.cpp b/aten/src/ATen/native/Blas.cpp
index 04a12cb..26c3804 100644
--- a/aten/src/ATen/native/Blas.cpp
+++ b/aten/src/ATen/native/Blas.cpp
@@ -165,7 +165,7 @@
return r;
}
- return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "dot", [&] {
+ return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::BFloat16, self.scalar_type(), "dot", [&] {
Tensor result = at::empty({}, self.options());
result.fill_(dot_impl<scalar_t>(self.numel(), self.data_ptr<scalar_t>(), self.stride(0), other.data_ptr<scalar_t>(), other.stride(0)));
return result;
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index e7a6782..9a73f80 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -1223,7 +1223,7 @@
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c.is_conj());
// Apply BLAS routine
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
result.scalar_type(), "addmm_impl_cpu_",
[&]{
at::native::cpublas::gemm(
@@ -1428,20 +1428,6 @@
// is_bmm_out: true for bmm_out, false for baddbmm_
// self_or_result is "self" for baddbmm_ and "result" for bmm_out
Tensor& self_or_result = const_cast<Tensor&>(self_or_result_);
- CheckedFrom c = (is_bmm_out ? "bmm" : "baddbmm");
-
- auto checkOnCPU = [](const Tensor& t, CheckedFrom c) {
- TORCH_CHECK(
- !t.is_cuda(),
- "Expect tensor to have CPU backend, but got tensor with ",
- toString(t.options().backend()),
- " Backend (while checking arguments for ",
- c);
- };
-
- checkOnCPU(self_or_result, c);
- checkOnCPU(batch1, c);
- checkOnCPU(batch2, c);
const auto batch1_sizes = batch1.sizes();
const auto batch2_sizes = batch2.sizes();
@@ -1478,16 +1464,15 @@
if (contraction_size * res_rows * res_cols < 400) {
if (is_bmm_out) {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, batch1.scalar_type(), "bmm", [&] {
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, batch1.scalar_type(), "bmm", [&] {
baddbmm_cpu_kernel<scalar_t, true>(self_or_result, batch1, batch2, beta, alpha);
});
} else {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, batch1.scalar_type(), "baddbmm", [&] {
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, batch1.scalar_type(), "baddbmm", [&] {
baddbmm_cpu_kernel<scalar_t, false>(self_or_result, batch1, batch2, beta, alpha);
});
}
} else if (at::hasMKL() && ((
- self_or_result.scalar_type() != kHalf &&
self_or_result.scalar_type() != kBFloat16 &&
at::native::is_floating_point(self_or_result)) ||
at::native::is_complex(self_or_result))
diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp
index 6c2c977..d14e7cd 100644
--- a/aten/src/ATen/test/basic.cpp
+++ b/aten/src/ATen/test/basic.cpp
@@ -41,7 +41,9 @@
Tensor b = ones({3, 4}, type);
ASSERT_EQ_RESOLVED((b + b).sum().item<double>(), 24);
ASSERT_EQ_RESOLVED(b.numel(), 12);
- ASSERT_EQ_RESOLVED(b.view(-1).dot(b.view(-1)).item<double>(), 12);
+ if (type.backend() != Backend::CPU || type.scalarType() != kHalf) {
+ ASSERT_EQ_RESOLVED(b.view(-1).dot(b.view(-1)).item<double>(), 12);
+ }
}
void TestSort(DeprecatedTypeProperties& type) {
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 0fcc300..e5136af 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -5831,7 +5831,7 @@
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*floating_and_complex_types_and(
*[torch.bfloat16] if TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater) else []))
- @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
+ @dtypes(*floating_and_complex_types_and(torch.bfloat16))
@tf32_on_and_off(0.05)
def test_addmm(self, device, dtype):
self._test_addmm_impl(torch.addmm, None, device, dtype)
@@ -6043,9 +6043,8 @@
self.compare_with_numpy(torch_fn, np_fn, sx[0])
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
- @skipCUDAIf(torch.version.cuda == "10.1", "flaky on CUDA 10.1")
@onlyNativeDeviceTypes
- @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
+ @dtypes(*floating_and_complex_types_and(torch.bfloat16))
@tf32_on_and_off(0.05)
def test_bmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
@@ -6157,7 +6156,7 @@
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
@onlyNativeDeviceTypes
- @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
+ @dtypes(*floating_and_complex_types_and(torch.bfloat16))
@tf32_on_and_off(0.05)
def test_addbmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
@@ -6230,7 +6229,7 @@
@precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5})
@onlyNativeDeviceTypes
- @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
+ @dtypes(*floating_and_complex_types_and(torch.bfloat16))
@tf32_on_and_off(0.05)
def test_baddbmm(self, device, dtype):
if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
diff --git a/test/test_ops.py b/test/test_ops.py
index cbf6629..bd51331 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -110,6 +110,7 @@
# NOTE: some ops will fail in forward if their inputs
# require grad but they don't support computing the gradient
# in that type! This is a bug in the op!
+ print("dtype", dtype, e)
unsupported(dtype)
# Short-circuits testing this dtype -- it doesn't work
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index cf6b5dc..31743ec 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8695,7 +8695,7 @@
# This addmm OpInfo is for when alpha and beta are not both equal to 1.
# alpha=beta=1 is tested in the following opinfo, because that special case will
# trigger addmm being decomposed by a jit pass.
- dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
assert_autodiffed=True,
@@ -8707,7 +8707,7 @@
OpInfo('addmm',
# When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add.
variant_test_name='decomposed',
- dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
*[torch.bfloat16] if(CUDA11OrLater or TEST_WITH_ROCM) else []),
assert_autodiffed=True,
@@ -8736,7 +8736,7 @@
ref=lambda M, batch1, batch2, beta=1, alpha=1: np.add(np.multiply(np.asarray(beta, dtype=M.dtype), M),
np.multiply(np.asarray(alpha, dtype=batch1.dtype),
np.sum(np.matmul(batch1, batch2), axis=0))),
- dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
@@ -8759,7 +8759,7 @@
),
sample_inputs_func=sample_inputs_addbmm),
OpInfo('baddbmm',
- dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
*[torch.bfloat16] if CUDA11OrLater or TEST_WITH_ROCM else []),
backward_dtypesIfCUDA=floating_types_and(torch.float16,
@@ -8776,7 +8776,7 @@
'TestMathBits', 'test_conj_view', device_type='cuda')],
sample_inputs_func=sample_inputs_baddbmm),
OpInfo('dot',
- dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
assert_autodiffed=True,
@@ -8785,7 +8785,7 @@
supports_fwgrad_bwgrad=True,
),
OpInfo('vdot',
- dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
sample_inputs_func=sample_inputs_dot_vdot,
@@ -8793,7 +8793,7 @@
supports_fwgrad_bwgrad=True,
),
OpInfo('bmm',
- dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
+ dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM)else []),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
@@ -9330,7 +9330,7 @@
dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
)),
OpInfo('cov',
- dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.half,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
backward_dtypesIfCUDA=all_types_and_complex_and(torch.half, *[torch.bfloat16]
@@ -10312,7 +10312,7 @@
OpInfo('linalg.multi_dot',
# Need this lambda because gradcheck does not work with TensorList inputs
aten_name='linalg_multi_dot',
- dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
supports_inplace_autograd=False,
@@ -11842,7 +11842,7 @@
aten_name='linear',
supports_autograd=True,
sample_inputs_func=sample_inputs_linear,
- dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
if (CUDA11OrLater or TEST_WITH_ROCM) else []),
@@ -12387,7 +12387,7 @@
supports_fwgrad_bwgrad=True,
autodiff_nonfusible_nodes=["aten::relu6"]),
OpInfo('mm',
- dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
if (CUDA11OrLater or TEST_WITH_ROCM) else []),
assert_autodiffed=True,
@@ -13431,7 +13431,7 @@
# we need this lambda because SampleInput expects tensor input as the first argument
# TODO(@heitorschueroff) update SampleInput to handle such cases
op=lambda tensors, equation: torch.einsum(equation, tensors),
- dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16]
@@ -14867,7 +14867,7 @@
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_kron),
OpInfo('inner',
- dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
if (CUDA11OrLater or TEST_WITH_ROCM) else []),
dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
@@ -14876,7 +14876,7 @@
sample_inputs_func=sample_inputs_inner,
),
OpInfo('tensordot',
- dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
if (CUDA11OrLater or TEST_WITH_ROCM) else []),
dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),