remove fp16 support from cpu linalg functions

fp16 on cpu produces slow and inaccurate results, see #69969

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75647
Approved by: https://github.com/Lezcano, https://github.com/mruberry
diff --git a/aten/src/ATen/native/Blas.cpp b/aten/src/ATen/native/Blas.cpp
index 04a12cb..26c3804 100644
--- a/aten/src/ATen/native/Blas.cpp
+++ b/aten/src/ATen/native/Blas.cpp
@@ -165,7 +165,7 @@
     return r;
   }
 
-  return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "dot", [&] {
+  return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::BFloat16, self.scalar_type(), "dot", [&] {
     Tensor result = at::empty({}, self.options());
     result.fill_(dot_impl<scalar_t>(self.numel(), self.data_ptr<scalar_t>(), self.stride(0), other.data_ptr<scalar_t>(), other.stride(0)));
     return result;
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index e7a6782..9a73f80 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -1223,7 +1223,7 @@
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c.is_conj());
 
   // Apply BLAS routine
-  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16,
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
       result.scalar_type(), "addmm_impl_cpu_",
       [&]{
         at::native::cpublas::gemm(
@@ -1428,20 +1428,6 @@
   // is_bmm_out: true for bmm_out, false for baddbmm_
   // self_or_result is "self" for baddbmm_ and "result" for bmm_out
   Tensor& self_or_result = const_cast<Tensor&>(self_or_result_);
-  CheckedFrom c = (is_bmm_out ? "bmm" : "baddbmm");
-
-  auto checkOnCPU = [](const Tensor& t, CheckedFrom c) {
-    TORCH_CHECK(
-        !t.is_cuda(),
-        "Expect tensor to have CPU backend, but got tensor with ",
-        toString(t.options().backend()),
-        " Backend (while checking arguments for ",
-        c);
-  };
-
-  checkOnCPU(self_or_result, c);
-  checkOnCPU(batch1, c);
-  checkOnCPU(batch2, c);
 
   const auto batch1_sizes = batch1.sizes();
   const auto batch2_sizes = batch2.sizes();
@@ -1478,16 +1464,15 @@
 
   if (contraction_size * res_rows * res_cols < 400) {
     if (is_bmm_out) {
-      AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, batch1.scalar_type(), "bmm", [&] {
+      AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, batch1.scalar_type(), "bmm", [&] {
           baddbmm_cpu_kernel<scalar_t, true>(self_or_result, batch1, batch2, beta, alpha);
         });
     } else {
-      AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, batch1.scalar_type(), "baddbmm", [&] {
+      AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, batch1.scalar_type(), "baddbmm", [&] {
           baddbmm_cpu_kernel<scalar_t, false>(self_or_result, batch1, batch2, beta, alpha);
         });
     }
   } else if (at::hasMKL() && ((
-            self_or_result.scalar_type() != kHalf &&
             self_or_result.scalar_type() != kBFloat16 &&
             at::native::is_floating_point(self_or_result)) ||
             at::native::is_complex(self_or_result))
diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp
index 6c2c977..d14e7cd 100644
--- a/aten/src/ATen/test/basic.cpp
+++ b/aten/src/ATen/test/basic.cpp
@@ -41,7 +41,9 @@
   Tensor b = ones({3, 4}, type);
   ASSERT_EQ_RESOLVED((b + b).sum().item<double>(), 24);
   ASSERT_EQ_RESOLVED(b.numel(), 12);
-  ASSERT_EQ_RESOLVED(b.view(-1).dot(b.view(-1)).item<double>(), 12);
+  if (type.backend() != Backend::CPU || type.scalarType() != kHalf) {
+    ASSERT_EQ_RESOLVED(b.view(-1).dot(b.view(-1)).item<double>(), 12);
+  }
 }
 
 void TestSort(DeprecatedTypeProperties& type) {
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 0fcc300..e5136af 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -5831,7 +5831,7 @@
                         torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
     @dtypesIfCUDA(*floating_and_complex_types_and(
                   *[torch.bfloat16] if TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater) else []))
-    @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
+    @dtypes(*floating_and_complex_types_and(torch.bfloat16))
     @tf32_on_and_off(0.05)
     def test_addmm(self, device, dtype):
         self._test_addmm_impl(torch.addmm, None, device, dtype)
@@ -6043,9 +6043,8 @@
         self.compare_with_numpy(torch_fn, np_fn, sx[0])
 
     @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
-    @skipCUDAIf(torch.version.cuda == "10.1", "flaky on CUDA 10.1")
     @onlyNativeDeviceTypes
-    @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
+    @dtypes(*floating_and_complex_types_and(torch.bfloat16))
     @tf32_on_and_off(0.05)
     def test_bmm(self, device, dtype):
         if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
@@ -6157,7 +6156,7 @@
 
     @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
     @onlyNativeDeviceTypes
-    @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
+    @dtypes(*floating_and_complex_types_and(torch.bfloat16))
     @tf32_on_and_off(0.05)
     def test_addbmm(self, device, dtype):
         if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
@@ -6230,7 +6229,7 @@
 
     @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5})
     @onlyNativeDeviceTypes
-    @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
+    @dtypes(*floating_and_complex_types_and(torch.bfloat16))
     @tf32_on_and_off(0.05)
     def test_baddbmm(self, device, dtype):
         if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater:
diff --git a/test/test_ops.py b/test/test_ops.py
index cbf6629..bd51331 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -110,6 +110,7 @@
                     # NOTE: some ops will fail in forward if their inputs
                     #   require grad but they don't support computing the gradient
                     #   in that type! This is a bug in the op!
+                    print("dtype", dtype, e)
                     unsupported(dtype)
 
                 # Short-circuits testing this dtype -- it doesn't work
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index cf6b5dc..31743ec 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8695,7 +8695,7 @@
            # This addmm OpInfo is for when alpha and beta are not both equal to 1.
            # alpha=beta=1 is tested in the following opinfo, because that special case will
            # trigger addmm being decomposed by a jit pass.
-           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
            dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
            assert_autodiffed=True,
@@ -8707,7 +8707,7 @@
     OpInfo('addmm',
            # When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add.
            variant_test_name='decomposed',
-           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
                                                        *[torch.bfloat16] if(CUDA11OrLater or TEST_WITH_ROCM) else []),
            assert_autodiffed=True,
@@ -8736,7 +8736,7 @@
            ref=lambda M, batch1, batch2, beta=1, alpha=1: np.add(np.multiply(np.asarray(beta, dtype=M.dtype), M),
                                                                  np.multiply(np.asarray(alpha, dtype=batch1.dtype),
                                                                              np.sum(np.matmul(batch1, batch2), axis=0))),
-           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
                                                        *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
            backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
@@ -8759,7 +8759,7 @@
            ),
            sample_inputs_func=sample_inputs_addbmm),
     OpInfo('baddbmm',
-           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
                                            *[torch.bfloat16] if CUDA11OrLater or TEST_WITH_ROCM else []),
            backward_dtypesIfCUDA=floating_types_and(torch.float16,
@@ -8776,7 +8776,7 @@
                    'TestMathBits', 'test_conj_view', device_type='cuda')],
            sample_inputs_func=sample_inputs_baddbmm),
     OpInfo('dot',
-           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
                                                        *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
            assert_autodiffed=True,
@@ -8785,7 +8785,7 @@
            supports_fwgrad_bwgrad=True,
            ),
     OpInfo('vdot',
-           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
                                                        *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
            sample_inputs_func=sample_inputs_dot_vdot,
@@ -8793,7 +8793,7 @@
            supports_fwgrad_bwgrad=True,
            ),
     OpInfo('bmm',
-           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
+           dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
                                                        *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM)else []),
            backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
@@ -9330,7 +9330,7 @@
                                     dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
                    )),
     OpInfo('cov',
-           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfCUDA=all_types_and_complex_and(torch.half,
                                                   *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
            backward_dtypesIfCUDA=all_types_and_complex_and(torch.half, *[torch.bfloat16]
@@ -10312,7 +10312,7 @@
     OpInfo('linalg.multi_dot',
            # Need this lambda because gradcheck does not work with TensorList inputs
            aten_name='linalg_multi_dot',
-           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfCUDA=floating_and_complex_types_and(torch.half,
                                                        *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
            supports_inplace_autograd=False,
@@ -11842,7 +11842,7 @@
            aten_name='linear',
            supports_autograd=True,
            sample_inputs_func=sample_inputs_linear,
-           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
            dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
                                                        if (CUDA11OrLater or TEST_WITH_ROCM) else []),
@@ -12387,7 +12387,7 @@
            supports_fwgrad_bwgrad=True,
            autodiff_nonfusible_nodes=["aten::relu6"]),
     OpInfo('mm',
-           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
                                                        if (CUDA11OrLater or TEST_WITH_ROCM) else []),
            assert_autodiffed=True,
@@ -13431,7 +13431,7 @@
            # we need this lambda because SampleInput expects tensor input as the first argument
            # TODO(@heitorschueroff) update SampleInput to handle such cases
            op=lambda tensors, equation: torch.einsum(equation, tensors),
-           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfCUDA=floating_and_complex_types_and(torch.half,
                                                        *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
            backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16]
@@ -14867,7 +14867,7 @@
            supports_fwgrad_bwgrad=True,
            sample_inputs_func=sample_inputs_kron),
     OpInfo('inner',
-           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
                                                        if (CUDA11OrLater or TEST_WITH_ROCM) else []),
            dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
@@ -14876,7 +14876,7 @@
            sample_inputs_func=sample_inputs_inner,
            ),
     OpInfo('tensordot',
-           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           dtypes=all_types_and_complex_and(torch.bfloat16),
            dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16]
                                                        if (CUDA11OrLater or TEST_WITH_ROCM) else []),
            dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),