add Half support for unary ops on CPU (#98493)

Add Half support for log_sigmoid and some unary ops on CPU, including sinc, acosh, asinh, atanh, digamma, trigamma, rsqrt, acos, asin, atan, ceil, cos, erf, erfc, erfinv, exp, expml, floor, log, log10, log1p, log2, i0, round, sin, sqrt, tan, tanh, trunc, lgamma.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98493
Approved by: https://github.com/jgong5, https://github.com/mingfeima, https://github.com/ngimel
diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h
index dd70708..da9c6b6 100644
--- a/aten/src/ATen/cpu/vec/vec_base.h
+++ b/aten/src/ATen/cpu/vec/vec_base.h
@@ -28,6 +28,7 @@
 #include <ATen/native/Math.h>
 #include <ATen/NumericUtils.h>
 #include <c10/util/C++17.h>
+#include <c10/util/Half.h>
 #include <c10/util/BFloat16.h>
 #include <c10/util/BFloat16-math.h>
 #include <c10/util/copysign.h>
diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp
index f679caf..b5d26a4 100644
--- a/aten/src/ATen/native/cpu/Activation.cpp
+++ b/aten/src/ATen/native/cpu/Activation.cpp
@@ -12,6 +12,7 @@
 #include <ATen/Dispatch.h>
 #include <ATen/OpMathType.h>
 #include <ATen/core/TensorBase.h>
+#include <ATen/cpu/vec/functional.h>
 #include <ATen/cpu/vec/vec.h>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/cpu/Loops.h>
@@ -23,44 +24,9 @@
 
 namespace {
 
-template <typename scalar_t>
-inline void _vec_log_sigmoid(TensorBase &output, TensorBase &buffer, const TensorBase &input) {
-  if (input.scalar_type() == kBFloat16) {
-    using Vec = Vectorized<BFloat16>;
-    BFloat16* output_data = output.data_ptr<BFloat16>();
-    BFloat16* buffer_data = buffer.data_ptr<BFloat16>();
-    BFloat16* input_data = input.data_ptr<BFloat16>();
-    parallel_for(0, input.numel(), 1, [&] (int64_t begin, int64_t end) {
-      int64_t size = end - begin;
-      int64_t d = 0;
-      for (; d < size - (size % Vec::size()); d += Vec::size()) {
-        Vec data_vec = Vec::loadu(input_data + begin+ d);
-        Vectorized<float> data_vec0, data_vec1;
-        std::tie(data_vec0, data_vec1) = convert_bfloat16_float(data_vec);
-        Vectorized<float> min_vec = minimum(data_vec0, Vectorized<float>(float(0)));
-        Vectorized<float> buffer_vec0 = data_vec0.abs().neg().exp();
-        Vectorized<float> output_vec0 = min_vec - buffer_vec0.log1p();
-        min_vec = minimum(data_vec1, Vectorized<float>(float(0)));
-        Vectorized<float> buffer_vec1 = data_vec1.abs().neg().exp();
-        Vectorized<float> output_vec1 = min_vec - buffer_vec1.log1p();
-        convert_float_bfloat16(buffer_vec0, buffer_vec1).store(buffer_data + begin + d);
-        convert_float_bfloat16(output_vec0, output_vec1).store(output_data + begin + d);
-      }
-      if (size - d > 0) {
-        Vec data_vec = Vec::loadu(input_data + begin + d, size - d);
-        Vectorized<float> data_vec0, data_vec1;
-        std::tie(data_vec0, data_vec1) = convert_bfloat16_float(data_vec);
-        Vectorized<float> min_vec = minimum(data_vec0, Vectorized<float>(float(0)));
-        Vectorized<float> buffer_vec0 = data_vec0.abs().neg().exp();
-        Vectorized<float> output_vec0 = min_vec - buffer_vec0.log1p();
-        min_vec = minimum(data_vec1, Vectorized<float>(float(0)));
-        Vectorized<float> buffer_vec1 = data_vec1.abs().neg().exp();
-        Vectorized<float> output_vec1 = min_vec - buffer_vec1.log1p();
-        convert_float_bfloat16(buffer_vec0, buffer_vec1).store(buffer_data + begin + d, size - d);
-        convert_float_bfloat16(output_vec0, output_vec1).store(output_data + begin + d, size - d);
-      }
-    });
-  } else {
+static void log_sigmoid_cpu_kernel(TensorBase &output, TensorBase &buffer, const TensorBase &input) {
+  if (at::isReducedFloatingType(input.scalar_type())) {
+    AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "log_sigmoid_cpu", [&]() {
     using Vec = Vectorized<scalar_t>;
     scalar_t* output_data = output.data_ptr<scalar_t>();
     scalar_t* buffer_data = buffer.data_ptr<scalar_t>();
@@ -70,59 +36,93 @@
       int64_t d = 0;
       for (; d < size - (size % Vec::size()); d += Vec::size()) {
         Vec data_vec = Vec::loadu(input_data + begin+ d);
-        Vec min_vec = vec::minimum(data_vec, Vec(scalar_t(0)));
-        Vec buffer_vec = data_vec.abs().neg().exp();
-        Vec output_vec = min_vec - buffer_vec.log1p();
-        buffer_vec.store(buffer_data + begin + d);
-        output_vec.store(output_data + begin + d);
+        Vectorized<float> data_vec0, data_vec1;
+        std::tie(data_vec0, data_vec1) = convert_to_float<scalar_t>(data_vec);
+        Vectorized<float> min_vec = minimum(data_vec0, Vectorized<float>(float(0)));
+        Vectorized<float> buffer_vec0 = data_vec0.abs().neg().exp();
+        Vectorized<float> output_vec0 = min_vec - buffer_vec0.log1p();
+        min_vec = minimum(data_vec1, Vectorized<float>(float(0)));
+        Vectorized<float> buffer_vec1 = data_vec1.abs().neg().exp();
+        Vectorized<float> output_vec1 = min_vec - buffer_vec1.log1p();
+        convert_from_float<scalar_t>(buffer_vec0, buffer_vec1).store(buffer_data + begin + d);
+        convert_from_float<scalar_t>(output_vec0, output_vec1).store(output_data + begin + d);
       }
       if (size - d > 0) {
         Vec data_vec = Vec::loadu(input_data + begin + d, size - d);
-        Vec min_vec = vec::minimum(data_vec, Vec(scalar_t(0)));
-        Vec buffer_vec = data_vec.abs().neg().exp();
-        Vec output_vec = min_vec - buffer_vec.log1p();
-        buffer_vec.store(buffer_data + begin + d, size - d);
-        output_vec.store(output_data + begin + d, size - d);
+        Vectorized<float> data_vec0, data_vec1;
+        std::tie(data_vec0, data_vec1) = convert_to_float<scalar_t>(data_vec);
+        Vectorized<float> min_vec = minimum(data_vec0, Vectorized<float>(float(0)));
+        Vectorized<float> buffer_vec0 = data_vec0.abs().neg().exp();
+        Vectorized<float> output_vec0 = min_vec - buffer_vec0.log1p();
+        min_vec = minimum(data_vec1, Vectorized<float>(float(0)));
+        Vectorized<float> buffer_vec1 = data_vec1.abs().neg().exp();
+        Vectorized<float> output_vec1 = min_vec - buffer_vec1.log1p();
+        convert_from_float<scalar_t>(buffer_vec0, buffer_vec1).store(buffer_data + begin + d, size - d);
+        convert_from_float<scalar_t>(output_vec0, output_vec1).store(output_data + begin + d, size - d);
       }
     });
+    });
+  } else {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "log_sigmoid_cpu", [&] {
+      using Vec = Vectorized<scalar_t>;
+      scalar_t* output_data = output.data_ptr<scalar_t>();
+      scalar_t* buffer_data = buffer.data_ptr<scalar_t>();
+      scalar_t* input_data = input.data_ptr<scalar_t>();
+      parallel_for(0, input.numel(), 1, [&] (int64_t begin, int64_t end) {
+        int64_t size = end - begin;
+        int64_t d = 0;
+        for (; d < size - (size % Vec::size()); d += Vec::size()) {
+          Vec data_vec = Vec::loadu(input_data + begin+ d);
+          Vec min_vec = vec::minimum(data_vec, Vec(scalar_t(0)));
+          Vec buffer_vec = data_vec.abs().neg().exp();
+          Vec output_vec = min_vec - buffer_vec.log1p();
+          buffer_vec.store(buffer_data + begin + d);
+          output_vec.store(output_data + begin + d);
+        }
+        if (size - d > 0) {
+          Vec data_vec = Vec::loadu(input_data + begin + d, size - d);
+          Vec min_vec = vec::minimum(data_vec, Vec(scalar_t(0)));
+          Vec buffer_vec = data_vec.abs().neg().exp();
+          Vec output_vec = min_vec - buffer_vec.log1p();
+          buffer_vec.store(buffer_data + begin + d, size - d);
+          output_vec.store(output_data + begin + d, size - d);
+        }
+      });
+    });
   }
 }
 
-static void log_sigmoid_cpu_kernel(TensorBase &output, TensorBase &buffer, const TensorBase &input) {
-  AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, input.scalar_type(), "log_sigmoid_cpu", [&] {
-    _vec_log_sigmoid<scalar_t>(output, buffer, input);
-  });
-}
-
 static void log_sigmoid_backward_cpu_kernel(TensorIterator& iter) {
-  if (iter.dtype() == kBFloat16) {
-    using Vec = Vectorized<BFloat16>;
-    auto zero_val = float(0);
-    auto zero_vec = Vectorized<float>(zero_val);
-    auto one_val = float(1);
-    auto one_vec = Vectorized<float>(one_val);
-    cpu_kernel_vec(iter,
-      [=](BFloat16 a, BFloat16 b, BFloat16 c) -> BFloat16 {
-        auto in_negative = float(a) < float(0);
-        auto max_deriv = in_negative ? float(1) : float(0);
-        auto sign = in_negative ? float(1) : -float(1);
-        return (max_deriv - sign * (float(b) / (float(1) + b))) * float(c);
-      },
-      [=](Vec a, Vec b, Vec c) -> Vec {
-        Vectorized<float> a0, a1, b0, b1, c0, c1;
-        std::tie(a0, a1) = convert_bfloat16_float(a);
-        std::tie(b0, b1) = convert_bfloat16_float(b);
-        std::tie(c0, c1) = convert_bfloat16_float(c);
-        auto mask = a0 < zero_vec;
-        auto max_deriv_vec = Vectorized<float>::blendv(zero_vec, one_vec, mask);
-        auto sign_vec = Vectorized<float>::blendv(one_vec.neg(), one_vec, mask);
-        a0 = (max_deriv_vec - sign_vec * (b0 / (one_vec + b0))) * c0;
-        mask = a1 < zero_vec;
-        max_deriv_vec = Vectorized<float>::blendv(zero_vec, one_vec, mask);
-        sign_vec = Vectorized<float>::blendv(one_vec.neg(), one_vec, mask);
-        a1 = (max_deriv_vec - sign_vec * (b1 / (one_vec + b1))) * c1;
-        return convert_float_bfloat16(a0, a1);
-      });
+  if (at::isReducedFloatingType(iter.dtype())) {
+    AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "log_sigmoid_backward_cpu", [&]() {
+      using Vec = Vectorized<scalar_t>;
+      auto zero_val = float(0);
+      auto zero_vec = Vectorized<float>(zero_val);
+      auto one_val = float(1);
+      auto one_vec = Vectorized<float>(one_val);
+      cpu_kernel_vec(iter,
+        [=](scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
+          auto in_negative = float(a) < float(0);
+          auto max_deriv = in_negative ? float(1) : float(0);
+          auto sign = in_negative ? float(1) : -float(1);
+          return (max_deriv - sign * (float(b) / (float(1) + b))) * float(c);
+        },
+        [=](Vec a, Vec b, Vec c) -> Vec {
+          Vectorized<float> a0, a1, b0, b1, c0, c1;
+          std::tie(a0, a1) = convert_to_float<scalar_t>(a);
+          std::tie(b0, b1) = convert_to_float<scalar_t>(b);
+          std::tie(c0, c1) = convert_to_float<scalar_t>(c);
+          auto mask = a0 < zero_vec;
+          auto max_deriv_vec = Vectorized<float>::blendv(zero_vec, one_vec, mask);
+          auto sign_vec = Vectorized<float>::blendv(one_vec.neg(), one_vec, mask);
+          a0 = (max_deriv_vec - sign_vec * (b0 / (one_vec + b0))) * c0;
+          mask = a1 < zero_vec;
+          max_deriv_vec = Vectorized<float>::blendv(zero_vec, one_vec, mask);
+          sign_vec = Vectorized<float>::blendv(one_vec.neg(), one_vec, mask);
+          a1 = (max_deriv_vec - sign_vec * (b1 / (one_vec + b1))) * c1;
+          return convert_from_float<scalar_t>(a0, a1);
+        });
+    });
   } else {
     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "log_sigmoid_backward_cpu", [&]() {
     using Vec = Vectorized<scalar_t>;
@@ -151,7 +151,7 @@
     TensorIteratorBase& iter,
     const Scalar& threshold_scalar,
     const Scalar& value_scalar) {
-  AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "threshold_cpu", [&] {
+  AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "threshold_cpu", [&] {
     using Vec = Vectorized<scalar_t>;
     scalar_t threshold = threshold_scalar.to<scalar_t>();
     Vec threshold_v = Vec(threshold);
@@ -766,23 +766,25 @@
 }
 
 void hardtanh_backward_kernel(TensorIterator& iter, const Scalar& min, const Scalar& max) {
-  if (iter.dtype() == kBFloat16) {
-    auto min_val = min.to<float>();
-    auto max_val = max.to<float>();
-    cpu_kernel_vec(
-        iter,
-        [=](BFloat16 grad_val, BFloat16 self_val) -> BFloat16 {
-          return (float(self_val) <= min_val || float(self_val) >= max_val) ? BFloat16(0) : grad_val;
-        },
-        [=](Vectorized<BFloat16> grad_val, Vectorized<BFloat16> self_val) -> Vectorized<BFloat16> {
-          Vectorized<float> grad_val0, grad_val1, self_val0, self_val1;
-          std::tie(grad_val0, grad_val1) = convert_bfloat16_float(grad_val);
-          std::tie(self_val0, self_val1) = convert_bfloat16_float(self_val);
-          return convert_float_bfloat16(
-            ((self_val0 > min_val) & (self_val0 < max_val)) & grad_val0,
-            ((self_val1 > min_val) & (self_val1 < max_val)) & grad_val1
-          );
-        });
+  if (at::isReducedFloatingType(iter.dtype())) {
+    AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&]() {
+      auto min_val = min.to<float>();
+      auto max_val = max.to<float>();
+      cpu_kernel_vec(
+          iter,
+          [=](scalar_t grad_val, scalar_t self_val) -> scalar_t {
+            return (float(self_val) <= min_val || float(self_val) >= max_val) ? scalar_t(0) : grad_val;
+          },
+          [=](Vectorized<scalar_t> grad_val, Vectorized<scalar_t> self_val) -> Vectorized<scalar_t> {
+            Vectorized<float> grad_val0, grad_val1, self_val0, self_val1;
+            std::tie(grad_val0, grad_val1) = convert_to_float<scalar_t>(grad_val);
+            std::tie(self_val0, self_val1) = convert_to_float<scalar_t>(self_val);
+            return convert_from_float<scalar_t>(
+              ((self_val0 > min_val) & (self_val0 < max_val)) & grad_val0,
+              ((self_val1 > min_val) & (self_val1 < max_val)) & grad_val1
+            );
+          });
+    });
   } else {
     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&] {
     auto min_val = min.to<scalar_t>();
diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
index 748393a..f75aaa8 100644
--- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
@@ -875,23 +875,25 @@
           return a * (one_vec - b * b).conj();
         });
     });
-  } else if (iter.dtype() == kBFloat16) {
-    auto one_vec = Vectorized<float>(float{1});
-    cpu_kernel_vec(
-      iter,
-      [=](BFloat16 a, BFloat16 b) -> BFloat16 {
-        float a0 = float(a);
-        float b0 = float(b);
-        return a0 * (float{1} - b0 * b0);
-      },
-      [=](Vectorized<BFloat16> a, Vectorized<BFloat16> b) {
-        Vectorized<float> a0, a1, b0, b1;
-        std::tie(a0, a1) = convert_bfloat16_float(a);
-        std::tie(b0, b1) = convert_bfloat16_float(b);
-        a0 = a0 * (one_vec - b0 * b0);
-        a1 = a1 * (one_vec - b1 * b1);
-        return convert_float_bfloat16(a0, a1);
-      });
+  } else if (at::isReducedFloatingType(iter.dtype())) {
+    AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() {
+      auto one_vec = Vectorized<float>(float{1});
+      cpu_kernel_vec(
+        iter,
+        [=](scalar_t a, scalar_t b) -> scalar_t {
+          float a0 = float(a);
+          float b0 = float(b);
+          return a0 * (float{1} - b0 * b0);
+        },
+        [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
+          Vectorized<float> a0, a1, b0, b1;
+          std::tie(a0, a1) = convert_to_float<scalar_t>(a);
+          std::tie(b0, b1) = convert_to_float<scalar_t>(b);
+          a0 = a0 * (one_vec - b0 * b0);
+          a1 = a1 * (one_vec - b1 * b1);
+          return convert_from_float<scalar_t>(a0, a1);
+        });
+    });
   } else {
     AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() {
       auto one_vec = Vectorized<scalar_t>(scalar_t{1});
diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
index c4e4c07..b6c8965 100644
--- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
+++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
@@ -354,7 +354,7 @@
 }
 
 static void clamp_scalar_kernel_impl(TensorIteratorBase& iter, const Scalar& min_, const Scalar& max_) {
-  AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.common_dtype(), "clamp_scalar_cpu", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "clamp_scalar_cpu", [&]() {
     const auto min = min_.to<scalar_t>();
     const auto max = max_.to<scalar_t>();
     const Vectorized<scalar_t> min_vec(min);
@@ -384,7 +384,7 @@
 }
 
 static void clamp_min_scalar_kernel_impl(TensorIteratorBase& iter, Scalar min_) {
-  AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.common_dtype(), "clamp_min_scalar_cpu", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "clamp_min_scalar_cpu", [&]() {
     const auto min = min_.to<scalar_t>();
     const Vectorized<scalar_t> min_vec(min);
     cpu_kernel_vec(iter,
diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
index 31bc288..3165c90 100644
--- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
@@ -337,7 +337,7 @@
 }
 
 static void sinc_kernel(TensorIteratorBase& iter) {
-  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.common_dtype(), "sinc_cpu", [&]() {
+  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "sinc_cpu", [&]() {
     cpu_kernel(
         iter,
         [=](scalar_t a) -> scalar_t {
@@ -370,7 +370,7 @@
 }
 
 static void acosh_kernel(TensorIteratorBase& iter) {
-    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), "acosh_cpu", [&]() {
+    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "acosh_cpu", [&]() {
       cpu_kernel(
         iter,
         [=](scalar_t a) -> scalar_t { return std::acosh(a); });
@@ -378,7 +378,7 @@
 }
 
 static void asinh_kernel(TensorIteratorBase& iter) {
-    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), "asinh_cpu", [&]() {
+    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "asinh_cpu", [&]() {
       cpu_kernel(
         iter,
         [=](scalar_t a) -> scalar_t { return std::asinh(a); });
@@ -386,7 +386,7 @@
 }
 
 static void atanh_kernel(TensorIteratorBase& iter) {
-    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), "atanh_cpu", [&]() {
+    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "atanh_cpu", [&]() {
       cpu_kernel(
         iter,
         [=](scalar_t a) -> scalar_t { return std::atanh(a); });
@@ -394,7 +394,7 @@
 }
 
 static void digamma_kernel(TensorIteratorBase& iter) {
-  AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.common_dtype(), "digamma", [&]() {
+  AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "digamma", [&]() {
     cpu_kernel(
         iter,
         [=](scalar_t a) -> scalar_t { return calc_digamma(a); });
@@ -402,7 +402,7 @@
 }
 
 static void trigamma_kernel(TensorIteratorBase& iter) {
-  AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "trigamma", [&]() {
+  AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "trigamma", [&]() {
     cpu_kernel(
         iter,
         [=](scalar_t a) -> scalar_t { return trigamma(a); });
@@ -469,7 +469,7 @@
 }
 
 void rsqrt_kernel(TensorIteratorBase& iter) {
-  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.common_dtype(), "rsqrt_cpu", [&] {
+  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "rsqrt_cpu", [&] {
     cpu_kernel_vec(
         iter,
         [=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
@@ -690,7 +690,7 @@
   inline namespace CPU_CAPABILITY {                                                 \
   void op##_kernel(TensorIteratorBase& iter) {                                      \
     TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);                                    \
-    AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), #op "_vml_cpu", [&]() { \
+    AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), #op "_vml_cpu", [&]() { \
       constexpr int64_t grain_size = 2048;                                          \
       iter.for_each(IMPLEMENT_ITERATOR_LAMBDA(op), grain_size);                     \
     });                                                                             \
@@ -703,7 +703,7 @@
   inline namespace CPU_CAPABILITY {                                                              \
   void op##_kernel(TensorIteratorBase& iter) {                                                   \
     TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);                                                 \
-    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), #op "_vml_cpu", [&]() { \
+    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), #op "_vml_cpu", [&]() { \
         constexpr int64_t grain_size = 2048;                                                     \
         iter.for_each(IMPLEMENT_ITERATOR_LAMBDA(op), grain_size);                                \
     });                                                                                          \
diff --git a/c10/util/BFloat16-math.h b/c10/util/BFloat16-math.h
index e990d55..c02472b 100644
--- a/c10/util/BFloat16-math.h
+++ b/c10/util/BFloat16-math.h
@@ -1,6 +1,7 @@
 #pragma once
 
-#include <c10/util/BFloat16-inl.h>
+#include <c10/util/BFloat16.h>
+#include <c10/util/Half.h>
 #include <c10/util/math_compat.h>
 
 C10_CLANG_DIAGNOSTIC_PUSH()
@@ -10,95 +11,192 @@
 
 namespace std {
 
-/// Used by vec256<c10::BFloat16>::map
-inline c10::BFloat16 acos(c10::BFloat16 a) {
+template <typename T>
+struct is_reduced_floating_point
+    : std::integral_constant<
+          bool,
+          std::is_same<T, c10::Half>::value ||
+              std::is_same<T, c10::BFloat16>::value> {};
+
+template <typename T>
+constexpr bool is_reduced_floating_point_v =
+    is_reduced_floating_point<T>::value;
+
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T acos(T a) {
   return std::acos(float(a));
 }
-inline c10::BFloat16 asin(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T asin(T a) {
   return std::asin(float(a));
 }
-inline c10::BFloat16 atan(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T atan(T a) {
   return std::atan(float(a));
 }
-inline c10::BFloat16 erf(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T erf(T a) {
   return std::erf(float(a));
 }
-inline c10::BFloat16 erfc(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T erfc(T a) {
   return std::erfc(float(a));
 }
-inline c10::BFloat16 exp(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T exp(T a) {
   return std::exp(float(a));
 }
-inline c10::BFloat16 expm1(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T expm1(T a) {
   return std::expm1(float(a));
 }
-inline c10::BFloat16 log(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T log(T a) {
   return std::log(float(a));
 }
-inline c10::BFloat16 log10(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T log10(T a) {
   return std::log10(float(a));
 }
-inline c10::BFloat16 log1p(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T log1p(T a) {
   return std::log1p(float(a));
 }
-inline c10::BFloat16 log2(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T log2(T a) {
   return std::log2(float(a));
 }
-inline c10::BFloat16 ceil(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T ceil(T a) {
   return std::ceil(float(a));
 }
-inline c10::BFloat16 cos(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T cos(T a) {
   return std::cos(float(a));
 }
-inline c10::BFloat16 floor(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T floor(T a) {
   return std::floor(float(a));
 }
-inline c10::BFloat16 nearbyint(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T nearbyint(T a) {
   return std::nearbyint(float(a));
 }
-inline c10::BFloat16 sin(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T sin(T a) {
   return std::sin(float(a));
 }
-inline c10::BFloat16 tan(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T tan(T a) {
   return std::tan(float(a));
 }
-inline c10::BFloat16 sinh(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T sinh(T a) {
   return std::sinh(float(a));
 }
-inline c10::BFloat16 cosh(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T cosh(T a) {
   return std::cosh(float(a));
 }
-inline c10::BFloat16 tanh(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T tanh(T a) {
   return std::tanh(float(a));
 }
-inline c10::BFloat16 trunc(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T trunc(T a) {
   return std::trunc(float(a));
 }
-inline c10::BFloat16 lgamma(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T lgamma(T a) {
   return std::lgamma(float(a));
 }
-inline c10::BFloat16 sqrt(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T sqrt(T a) {
   return std::sqrt(float(a));
 }
-inline c10::BFloat16 rsqrt(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T rsqrt(T a) {
   return 1.0 / std::sqrt(float(a));
 }
-inline c10::BFloat16 abs(c10::BFloat16 a) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T abs(T a) {
   return std::abs(float(a));
 }
 #if defined(_MSC_VER) && defined(__CUDACC__)
-inline c10::BFloat16 pow(c10::BFloat16 a, double b) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T pow(T a, double b) {
   return std::pow(float(a), float(b));
 }
 #else
-inline c10::BFloat16 pow(c10::BFloat16 a, double b) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T pow(T a, double b) {
   return std::pow(float(a), b);
 }
 #endif
-inline c10::BFloat16 pow(c10::BFloat16 a, c10::BFloat16 b) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T pow(T a, T b) {
   return std::pow(float(a), float(b));
 }
-inline c10::BFloat16 fmod(c10::BFloat16 a, c10::BFloat16 b) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T fmod(T a, T b) {
   return std::fmod(float(a), float(b));
 }
 
@@ -128,13 +226,14 @@
   SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
   ----------------------------------------------------------------------
  */
-C10_HOST_DEVICE inline c10::BFloat16 nextafter(
-    c10::BFloat16 from,
-    c10::BFloat16 to) {
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+C10_HOST_DEVICE inline T nextafter(T from, T to) {
   // Reference:
   // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c
   using int_repr_t = uint16_t;
-  using float_t = c10::BFloat16;
+  using float_t = T;
   constexpr uint8_t bits = 16;
   union {
     float_t f;
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index 8446d81..d133027 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -189,11 +189,11 @@
     "nn.functional.avg_pool2d": {i64},
     "nn.functional.adaptive_avg_pool2d": {f16},
     "nn.functional.ctc_loss": {f32, f64},
-    "nn.functional.gaussian_nll_loss": {f32, f64},
+    "nn.functional.gaussian_nll_loss": {f16, f32, f64},
     "nn.functional.local_response_norm": {i64},
     "nn.functional.one_hot": {i64},
     "nn.functional.rrelu": {f32, f64},
-    "nn.functional.triplet_margin_with_distance_loss": {f32, f64, i32, i64},
+    "nn.functional.triplet_margin_with_distance_loss": {f16, f32, f64, i32, i64},
     "nonzero": {b8, f16, f32, f64, i32, i64},
     "normal": {f16, f32, f64},
     ("normal", "number_mean"): {f16, f32, f64},
diff --git a/test/test_meta.py b/test/test_meta.py
index 6a4c9da..100fc93 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -616,7 +616,7 @@
     torch.mode : {f64, i32, i64, f16, u8, i16, bf16, b8, i8, f32},
     torch.multinomial : {f64, bf16, f32},
     torch.nn.functional.ctc_loss : {f64, f32},
-    torch.nn.functional.gaussian_nll_loss : {f64, bf16, f32},
+    torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
     torch.nn.functional.max_pool3d : {f64, f32},
     torch.nn.functional.max_pool3d_with_indices : {f64, f32},
     torch.nn.functional.max_unpool1d : {f64, f32},
@@ -720,7 +720,6 @@
     torch.matrix_exp: {f16},  # aten::linalg_matrix_exp
     torch.median: {f16},  # aten::median, aten::median.dim_values
     torch.multinomial: {f16},  # aten::multinomial, aten::multinomial.out
-    torch.nn.functional.gaussian_nll_loss: {f16},  # aten::_local_scalar_dense
     torch.nn.functional.max_pool3d: {bf16, f16},  # aten::max_pool3d_with_indices
     torch.nn.functional.max_pool3d_with_indices: {bf16, f16},  # aten::max_pool3d_with_indices
     torch.nn.functional.max_unpool1d: {f16},  # aten::max_unpool2d
diff --git a/test/test_mps.py b/test/test_mps.py
index e8c8879..60ef643 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -89,9 +89,6 @@
         'floor_divide': [torch.float16, torch.float32],
         # derivative for aten::narrow_copy is not implemented on CPU
         'narrow_copy': [torch.float16, torch.float32],
-        # RuntimeError: "log_vml_cpu" not implemented for 'Half'
-        '__rpow__': [torch.float16],
-        'pow': [torch.float16],
         # 'bool' object is not iterable
         'allclose': [torch.float16, torch.float32],
         'equal': [torch.float16, torch.float32],
@@ -119,6 +116,9 @@
         # trunc_tensor not working properly for float16
         'divtrunc_rounding': [torch.float16],
         'fmod': [torch.float16],
+
+        # round not working properly for float16
+        'round': [torch.float16],
     }
 
     MACOS_12_3_XFAILLIST_GRAD = {
@@ -644,6 +644,9 @@
         # trunc_tensor not working properly for float16
         'divtrunc_rounding': [torch.float16],
         'fmod': [torch.float16],
+
+        # round not working properly for float16
+        'round': [torch.float16],
     }
 
     UNDEFINED_XFAILLIST = {
@@ -10364,6 +10367,12 @@
         'linalg.vector_norm',
         'addr', 'var_mean',
         'var_mean_unbiased',
+        'acosh', 'asinh', 'asin',
+        'masked.std',
+        'nn.functional.normalize',
+        'nn.functional.triplet_margin_loss',
+        'nn.functional.triplet_margin_with_distance_loss',
+        'round', 'xlogy',
 
         # for macOS 12
         'masked.normalize', 'masked.sum', 'masked.var',
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index f47674c..79e1b30 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -6104,7 +6104,7 @@
                           variant_test_name=variant_test_name,
                           domain=domain,
                           decorators=(precisionOverride({torch.float16: 5e-2}),),
-                          dtypes=all_types_and(torch.bfloat16),
+                          dtypes=all_types_and(torch.half, torch.bfloat16),
                           dtypesIfCUDA=all_types_and(torch.float16),
                           sample_inputs_func=sample_inputs_mvlgamma,
                           supports_forward_ad=True,
@@ -8945,7 +8945,7 @@
                    aliases=('arccos', ),
                    ref=np.arccos,
                    domain=(-1, 1),
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
                    supports_forward_ad=True,
@@ -8985,7 +8985,7 @@
                    aliases=('arccosh', ),
                    ref=np.arccosh,
                    domain=(1, None),
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
                    decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
                    supports_inplace_autograd=False,
@@ -9590,7 +9590,7 @@
                    supports_sparse_bsc=True,
                    supports_forward_ad=True,
                    supports_fwgrad_bwgrad=True,
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
                    decorators=[
@@ -9617,7 +9617,7 @@
     UnaryUfuncInfo('asinh',
                    aliases=('arcsinh', ),
                    ref=np.arcsinh,
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
                    decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
                    supports_inplace_autograd=False,
@@ -9649,7 +9649,7 @@
     UnaryUfuncInfo('atan',
                    aliases=('arctan', ),
                    ref=np.arctan,
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
                    supports_forward_ad=True,
@@ -9696,7 +9696,7 @@
                    aliases=('arctanh', ),
                    ref=np.arctanh,
                    domain=(-1, 1),
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
                    decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
                    supports_inplace_autograd=False,
@@ -9871,8 +9871,7 @@
            sample_inputs_func=sample_inputs_cdist),
     UnaryUfuncInfo('ceil',
                    ref=np.ceil,
-                   dtypes=all_types_and(torch.bfloat16),
-                   dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
+                   dtypes=all_types_and(torch.half, torch.bfloat16),
                    supports_forward_ad=True,
                    supports_fwgrad_bwgrad=True,
                    skips=(
@@ -10104,7 +10103,7 @@
            supports_out=False),
     UnaryUfuncInfo('cos',
                    ref=np.cos,
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
                    handles_large_floats=False,
@@ -10338,7 +10337,7 @@
            )),
     UnaryUfuncInfo('exp',
                    ref=np_unary_ufunc_integer_promotion_wrapper(np.exp),
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
                    skips=(
                        # Reference: https://github.com/pytorch/pytorch/pull/50093#pullrequestreview-561791547
@@ -10569,8 +10568,7 @@
            )),
     UnaryUfuncInfo('floor',
                    ref=np.floor,
-                   dtypes=all_types_and(torch.bfloat16),
-                   dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
+                   dtypes=all_types_and(torch.half, torch.bfloat16),
                    supports_forward_ad=True,
                    supports_fwgrad_bwgrad=True,
                    skips=(
@@ -10704,8 +10702,7 @@
                    aliases=('special.i0',),
                    decorators=(precisionOverride({torch.bfloat16: 3e-1,
                                                   torch.float16: 5e-1}),),
-                   dtypes=all_types_and(torch.bool, torch.bfloat16),
-                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
                    backward_dtypes=floating_types(),
                    supports_forward_ad=True,
                    supports_fwgrad_bwgrad=True,
@@ -10768,8 +10765,7 @@
                    ref=np.log1p,
                    aliases=('special.log1p',),
                    domain=(-1, None),
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
-                   dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    decorators=(precisionOverride({torch.bfloat16: 1e-1}),),
                    supports_forward_ad=True,
                    supports_fwgrad_bwgrad=True,
@@ -10922,7 +10918,7 @@
     UnaryUfuncInfo('log',
                    ref=np.log,
                    domain=(0, None),
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
                    backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.chalf),
                    assert_autodiffed=True,
@@ -10940,9 +10936,8 @@
                    ref=np.log10,
                    domain=(0, None),
                    decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
-                   dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    supports_forward_ad=True,
                    supports_fwgrad_bwgrad=True,
                    skips=(
@@ -10955,8 +10950,7 @@
     UnaryUfuncInfo('log2',
                    ref=np.log2,
                    domain=(0, None),
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
-                   dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
                    supports_forward_ad=True,
                    supports_fwgrad_bwgrad=True,
@@ -11629,8 +11623,7 @@
         )
     ),
     OpInfo('nn.functional.normalize',
-           dtypes=floating_and_complex_types_and(torch.bfloat16),
-           dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
            sample_inputs_func=sample_inputs_normalize,
            supports_forward_ad=True,
            supports_fwgrad_bwgrad=True),
@@ -11806,8 +11799,7 @@
            ),
     OpInfo('nn.functional.cosine_similarity',
            aten_name="cosine_similarity",
-           dtypes=floating_types_and(torch.bfloat16),
-           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
            supports_out=False,
            supports_forward_ad=True,
            supports_fwgrad_bwgrad=True,
@@ -11954,8 +11946,7 @@
         supports_forward_ad=True,
         supports_fwgrad_bwgrad=True,
         supports_out=False,
-        dtypes=floating_types_and(torch.bfloat16),
-        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
         gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
         sample_inputs_func=sample_inputs_binary_cross_entropy_with_logits,
         skips=(
@@ -11977,8 +11968,7 @@
         supports_sparse_csc=True,
         supports_sparse_bsr=True,
         supports_sparse_bsc=True,
-        dtypes=all_types_and(torch.bfloat16),
-        dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
+        dtypes=all_types_and(torch.half, torch.bfloat16),
         sample_inputs_func=sample_inputs_nn_activation_relu,
         supports_out=False,
         supports_fwgrad_bwgrad=True,
@@ -12519,8 +12509,7 @@
            )),
     OpInfo(
         "nn.functional.soft_margin_loss",
-        dtypes=floating_types_and(torch.bfloat16),
-        dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
         supports_out=False,
         supports_forward_ad=True,
         # doesn't support grad on target
@@ -12544,8 +12533,7 @@
            supports_out=False),
     OpInfo(
         "nn.functional.margin_ranking_loss",
-        dtypes=all_types_and(torch.bfloat16),
-        dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
+        dtypes=all_types_and(torch.half, torch.bfloat16),
         supports_out=False,
         sample_inputs_func=sample_inputs_margin_ranking_loss,
         error_inputs_func=error_inputs_margin_ranking_loss,
@@ -12587,8 +12575,7 @@
     OpInfo(
         "nn.functional.multilabel_soft_margin_loss",
         supports_out=False,
-        dtypes=floating_types_and(torch.bfloat16),
-        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
         sample_inputs_func=sample_inputs_multilabel_soft_margin_loss,
         supports_forward_ad=True,
         supports_fwgrad_bwgrad=True,
@@ -13190,8 +13177,7 @@
         aten_name="log_sigmoid",
         aten_backward_name='log_sigmoid_backward',
         ref=reference_logsigmoid,
-        dtypes=floating_types_and(torch.bfloat16),
-        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
         supports_autograd=True,
         assert_autodiffed=False,
         supports_forward_ad=True,
@@ -13252,8 +13238,7 @@
     UnaryUfuncInfo(
         'nn.functional.tanhshrink',
         ref=lambda x: x - np.tanh(x),
-        dtypes=all_types_and_complex_and(torch.bfloat16),
-        dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
         supports_forward_ad=True,
         supports_fwgrad_bwgrad=True,
         supports_autograd=True,
@@ -13285,8 +13270,7 @@
     UnaryUfuncInfo(
         'nn.functional.threshold',
         ref=lambda x, threshold, value: np.where(x <= threshold, value, x).astype(x.dtype),
-        dtypes=all_types_and(torch.bfloat16),
-        dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+        dtypes=all_types_and(torch.half, torch.bfloat16),
         inplace_variant=lambda x, threshold, value:
             torch.nn.functional.threshold(x, threshold, value, inplace=True),
         supports_forward_ad=True,
@@ -13306,8 +13290,7 @@
         "nn.functional.triplet_margin_loss",
         sample_inputs_func=sample_inputs_triplet_margin_loss,
         error_inputs_func=error_inputs_triplet_margin_loss,
-        dtypes=all_types_and_complex_and(torch.bfloat16),
-        dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
         supports_out=False,
         supports_forward_ad=True,
         supports_fwgrad_bwgrad=True,
@@ -13316,8 +13299,7 @@
         "nn.functional.triplet_margin_with_distance_loss",
         sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True),
         error_inputs_func=error_inputs_triplet_margin_loss,
-        dtypes=all_types_and_complex_and(torch.bfloat16),
-        dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
         supports_out=False,
         supports_forward_ad=True,
         supports_fwgrad_bwgrad=True,
@@ -13585,10 +13567,8 @@
     UnaryUfuncInfo('nn.functional.hardtanh',
                    aten_name="hardtanh",
                    aten_backward_name='hardtanh_backward',
-                   dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16),
-                   backward_dtypes=all_types_and(torch.bfloat16),
-                   dtypesIfCUDA=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.float16,
-                                                   torch.bfloat16),
+                   dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.half, torch.bfloat16),
+                   backward_dtypes=all_types_and(torch.half, torch.bfloat16),
                    backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
                    assert_autodiffed=True,
                    sample_inputs_func=sample_inputs_hardtanh,
@@ -13618,10 +13598,8 @@
            )),
     UnaryUfuncInfo('nn.functional.relu6',
                    aten_name="relu6",
-                   dtypes=all_types_and(torch.bfloat16),
-                   backward_dtypes=floating_types_and(torch.bfloat16),
-                   dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
-                   backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+                   dtypes=all_types_and(torch.half, torch.bfloat16),
+                   backward_dtypes=floating_types_and(torch.half, torch.bfloat16),
                    assert_autodiffed=True,
                    supports_out=False,
                    supports_forward_ad=True,
@@ -13801,7 +13779,7 @@
                     # Due to AVX2 currently not being fully supported for Float16, log_vml_cpu can't be enabled
                     # for Float16, causing this test to fail. pow's autograd for Float16 is thus currently
                     # unsupported on CPU.
-                    backward_dtypes=floating_and_complex_types_and(torch.bfloat16),
+                    backward_dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
                     backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf),
                     # https://github.com/pytorch/pytorch/issues/80411
                     gradcheck_fast_mode=True,
@@ -13952,8 +13930,7 @@
     UnaryUfuncInfo('round',
                    ref=np.round,
                    aliases=('special.round',),
-                   dtypes=all_types_and(torch.bfloat16),
-                   dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
+                   dtypes=all_types_and(torch.half, torch.bfloat16),
                    supports_forward_ad=True,
                    supports_fwgrad_bwgrad=True,
                    skips=(
@@ -13982,8 +13959,7 @@
                    ref=np.round,
                    variant_test_name='decimals_0',
                    aliases=('special.round',),
-                   dtypes=floating_types_and(torch.bfloat16),
-                   dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
+                   dtypes=floating_types_and(torch.half, torch.bfloat16),
                    sample_kwargs=lambda device, dtype, input: ({'decimals': 0}, {'decimals': 0}),
                    sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 0}),
                    supports_forward_ad=True,
@@ -14038,7 +14014,7 @@
                    supports_sparse_csr=False),
     UnaryUfuncInfo('sin',
                    ref=np.sin,
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
                    handles_large_floats=False,
@@ -14064,8 +14040,7 @@
     UnaryUfuncInfo('sinc',
                    ref=np_sinc_with_fp16_as_fp32,
                    aliases=('special.sinc',),
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
-                   dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    handles_large_floats=False,
                    supports_forward_ad=True,
                    supports_fwgrad_bwgrad=True,
@@ -14312,8 +14287,7 @@
                     dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
                     # Reference: https://github.com/pytorch/pytorch/issues/54774
                     # "log2" "_vml_cpu" not implemented for Half
-                    backward_dtypes=all_types_and_complex_and(torch.bfloat16),
-                    backward_dtypesIfCUDA=all_types_and_complex_and(torch.bfloat16, torch.half),
+                    backward_dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
                     supports_out=False,
                     supports_forward_ad=True,
                     supports_fwgrad_bwgrad=True,
@@ -14390,7 +14364,7 @@
                    supports_autograd=False,),
     UnaryUfuncInfo('tan',
                    ref=np.tan,
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
                    supports_forward_ad=True,
@@ -14424,7 +14398,7 @@
                    aten_backward_name='tanh_backward',
                    aliases=('nn.functional.tanh',),
                    decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
                    assert_jit_shape_analysis=True,
@@ -14516,8 +14490,7 @@
     UnaryUfuncInfo('trunc',
                    aliases=('fix', ),
                    ref=np.trunc,
-                   dtypes=all_types_and(torch.bfloat16),
-                   dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+                   dtypes=all_types_and(torch.half, torch.bfloat16),
                    supports_forward_ad=True,
                    supports_fwgrad_bwgrad=True,
                    supports_sparse=True,
@@ -14555,8 +14528,7 @@
     UnaryUfuncInfo('expm1',
                    aliases=('special.expm1', ),
                    ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1),
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
-                   dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    supports_forward_ad=True,
                    supports_fwgrad_bwgrad=True,
                    supports_sparse=True,
@@ -14617,7 +14589,7 @@
     UnaryUfuncInfo('rsqrt',
                    ref=lambda x: np.reciprocal(np.sqrt(x)),
                    domain=(0, None),
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
                    decorators=(precisionOverride({torch.half: 5e-2}),),
                    assert_autodiffed=True,
@@ -14636,7 +14608,7 @@
                    ref=np.sqrt,
                    supports_sparse=True,
                    domain=(0, None),
-                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
                    supports_forward_ad=True,
@@ -16456,8 +16428,7 @@
            sample_inputs_func=sample_inputs_zero_),
     OpInfo('logsumexp',
            aliases=('special.logsumexp',),
-           dtypes=all_types_and(torch.bool, torch.bfloat16),
-           dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.half),
+           dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
            assert_autodiffed=True,
            supports_forward_ad=True,
            supports_fwgrad_bwgrad=True,
@@ -16686,7 +16657,7 @@
                    ref=scipy.special.digamma if TEST_SCIPY else None,
                    aliases=('special.psi', 'special.digamma',),
                    decorators=(precisionOverride({torch.float16: 5e-1}),),
-                   dtypes=all_types_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and(torch.bool, torch.half),
                    supports_forward_ad=True,
                    supports_fwgrad_bwgrad=True),
@@ -16700,8 +16671,7 @@
                                     'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
 
                    ),
-                   dtypes=all_types_and(torch.bool, torch.bfloat16),
-                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
                    assert_jit_shape_analysis=True,
                    supports_sparse=True,
@@ -16716,8 +16686,7 @@
                    aliases=('special.erfc', ),
                    decorators=(precisionOverride({torch.float16: 1e-2,
                                                   torch.bfloat16: 1e-2}),),
-                   dtypes=all_types_and(torch.bool, torch.bfloat16),
-                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
                    assert_autodiffed=True,
                    supports_forward_ad=True,
                    supports_fwgrad_bwgrad=True),
@@ -16727,7 +16696,7 @@
                    decorators=(precisionOverride({torch.float16: 1e-2,
                                                   torch.bfloat16: 1e-2,
                                                   torch.float32: 1e-4}),),
-                   dtypes=all_types_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and(torch.bool, torch.half),
                    supports_sparse_csr=True,
                    supports_sparse_csc=True,
@@ -16783,7 +16752,7 @@
                    ref=reference_lgamma if TEST_SCIPY else None,
                    aliases=('special.gammaln', ),
                    decorators=(precisionOverride({torch.float16: 7e-1}),),
-                   dtypes=all_types_and(torch.bool, torch.bfloat16),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
                    dtypesIfCUDA=all_types_and(torch.bool, torch.half),
                    supports_forward_ad=True,
                    supports_fwgrad_bwgrad=True,
@@ -17825,8 +17794,7 @@
     ),
     OpInfo(
         "nn.functional.cosine_embedding_loss",
-        dtypes=all_types_and(torch.bfloat16, torch.bool),
-        dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool),
         supports_out=False,
         supports_forward_ad=True,
         supports_fwgrad_bwgrad=True,
@@ -17853,8 +17821,7 @@
     ),
     OpInfo(
         "nn.functional.gaussian_nll_loss",
-        dtypes=floating_types_and(torch.bfloat16),
-        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
         # Runs very slowly on slow gradcheck - alternatively reduce input sizes
         gradcheck_fast_mode=True,
         supports_out=False,
@@ -17882,8 +17849,7 @@
     ),
     OpInfo(
         "nn.functional.hinge_embedding_loss",
-        dtypes=floating_types_and(torch.bfloat16),
-        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
         supports_out=False,
         supports_forward_ad=True,
         supports_fwgrad_bwgrad=True,
@@ -17920,8 +17886,7 @@
     ),
     OpInfo(
         "nn.functional.poisson_nll_loss",
-        dtypes=all_types_and(torch.bfloat16),
-        dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+        dtypes=all_types_and(torch.half, torch.bfloat16),
         supports_out=False,
         supports_forward_ad=True,
         supports_fwgrad_bwgrad=True,
diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py
index 5d69358..6fce19a 100644
--- a/torch/testing/_internal/opinfo/definitions/_masked.py
+++ b/torch/testing/_internal/opinfo/definitions/_masked.py
@@ -1012,8 +1012,7 @@
         # See https://github.com/pytorch/pytorch/pull/78358
         check_batched_forward_grad=False,
         promotes_int_to_float=True,
-        dtypes=all_types_and_complex_and(torch.bfloat16),
-        dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
         skips=(
             # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
             DecorateInfo(
@@ -1158,14 +1157,6 @@
             DecorateInfo(
                 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
             ),
-            # RuntimeError: "clamp_min_cpu" not implemented for 'Half'
-            DecorateInfo(
-                unittest.expectedFailure,
-                "TestMasked",
-                "test_reference_masked",
-                device_type="cpu",
-                dtypes=[torch.half],
-            ),
         ),
         gradcheck_wrapper=gradcheck_wrapper_masked_operation,
         # Runs very slowly on slow gradcheck - alternatively reduce input sizes
@@ -1204,8 +1195,7 @@
     ),
     ReductionOpInfo(
         "masked.logsumexp",
-        dtypes=all_types_and(torch.bfloat16),
-        dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+        dtypes=all_types_and(torch.half, torch.bfloat16),
         method_variant=None,
         nan_policy="propagate",
         supports_out=False,
diff --git a/torch/testing/_internal/opinfo/definitions/special.py b/torch/testing/_internal/opinfo/definitions/special.py
index 9570ce4..0e68df6 100644
--- a/torch/testing/_internal/opinfo/definitions/special.py
+++ b/torch/testing/_internal/opinfo/definitions/special.py
@@ -157,8 +157,7 @@
         aten_name="special_ndtr",
         decorators=(precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-4}),),
         ref=scipy.special.ndtr if TEST_SCIPY else None,
-        dtypes=all_types_and(torch.bool, torch.bfloat16),
-        dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.float16),
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
         supports_forward_ad=True,
         supports_fwgrad_bwgrad=True,
         skips=(