add Half support for unary ops on CPU (#98493)
Add Half support for log_sigmoid and some unary ops on CPU, including sinc, acosh, asinh, atanh, digamma, trigamma, rsqrt, acos, asin, atan, ceil, cos, erf, erfc, erfinv, exp, expml, floor, log, log10, log1p, log2, i0, round, sin, sqrt, tan, tanh, trunc, lgamma.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98493
Approved by: https://github.com/jgong5, https://github.com/mingfeima, https://github.com/ngimel
diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h
index dd70708..da9c6b6 100644
--- a/aten/src/ATen/cpu/vec/vec_base.h
+++ b/aten/src/ATen/cpu/vec/vec_base.h
@@ -28,6 +28,7 @@
#include <ATen/native/Math.h>
#include <ATen/NumericUtils.h>
#include <c10/util/C++17.h>
+#include <c10/util/Half.h>
#include <c10/util/BFloat16.h>
#include <c10/util/BFloat16-math.h>
#include <c10/util/copysign.h>
diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp
index f679caf..b5d26a4 100644
--- a/aten/src/ATen/native/cpu/Activation.cpp
+++ b/aten/src/ATen/native/cpu/Activation.cpp
@@ -12,6 +12,7 @@
#include <ATen/Dispatch.h>
#include <ATen/OpMathType.h>
#include <ATen/core/TensorBase.h>
+#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
@@ -23,44 +24,9 @@
namespace {
-template <typename scalar_t>
-inline void _vec_log_sigmoid(TensorBase &output, TensorBase &buffer, const TensorBase &input) {
- if (input.scalar_type() == kBFloat16) {
- using Vec = Vectorized<BFloat16>;
- BFloat16* output_data = output.data_ptr<BFloat16>();
- BFloat16* buffer_data = buffer.data_ptr<BFloat16>();
- BFloat16* input_data = input.data_ptr<BFloat16>();
- parallel_for(0, input.numel(), 1, [&] (int64_t begin, int64_t end) {
- int64_t size = end - begin;
- int64_t d = 0;
- for (; d < size - (size % Vec::size()); d += Vec::size()) {
- Vec data_vec = Vec::loadu(input_data + begin+ d);
- Vectorized<float> data_vec0, data_vec1;
- std::tie(data_vec0, data_vec1) = convert_bfloat16_float(data_vec);
- Vectorized<float> min_vec = minimum(data_vec0, Vectorized<float>(float(0)));
- Vectorized<float> buffer_vec0 = data_vec0.abs().neg().exp();
- Vectorized<float> output_vec0 = min_vec - buffer_vec0.log1p();
- min_vec = minimum(data_vec1, Vectorized<float>(float(0)));
- Vectorized<float> buffer_vec1 = data_vec1.abs().neg().exp();
- Vectorized<float> output_vec1 = min_vec - buffer_vec1.log1p();
- convert_float_bfloat16(buffer_vec0, buffer_vec1).store(buffer_data + begin + d);
- convert_float_bfloat16(output_vec0, output_vec1).store(output_data + begin + d);
- }
- if (size - d > 0) {
- Vec data_vec = Vec::loadu(input_data + begin + d, size - d);
- Vectorized<float> data_vec0, data_vec1;
- std::tie(data_vec0, data_vec1) = convert_bfloat16_float(data_vec);
- Vectorized<float> min_vec = minimum(data_vec0, Vectorized<float>(float(0)));
- Vectorized<float> buffer_vec0 = data_vec0.abs().neg().exp();
- Vectorized<float> output_vec0 = min_vec - buffer_vec0.log1p();
- min_vec = minimum(data_vec1, Vectorized<float>(float(0)));
- Vectorized<float> buffer_vec1 = data_vec1.abs().neg().exp();
- Vectorized<float> output_vec1 = min_vec - buffer_vec1.log1p();
- convert_float_bfloat16(buffer_vec0, buffer_vec1).store(buffer_data + begin + d, size - d);
- convert_float_bfloat16(output_vec0, output_vec1).store(output_data + begin + d, size - d);
- }
- });
- } else {
+static void log_sigmoid_cpu_kernel(TensorBase &output, TensorBase &buffer, const TensorBase &input) {
+ if (at::isReducedFloatingType(input.scalar_type())) {
+ AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "log_sigmoid_cpu", [&]() {
using Vec = Vectorized<scalar_t>;
scalar_t* output_data = output.data_ptr<scalar_t>();
scalar_t* buffer_data = buffer.data_ptr<scalar_t>();
@@ -70,59 +36,93 @@
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(input_data + begin+ d);
- Vec min_vec = vec::minimum(data_vec, Vec(scalar_t(0)));
- Vec buffer_vec = data_vec.abs().neg().exp();
- Vec output_vec = min_vec - buffer_vec.log1p();
- buffer_vec.store(buffer_data + begin + d);
- output_vec.store(output_data + begin + d);
+ Vectorized<float> data_vec0, data_vec1;
+ std::tie(data_vec0, data_vec1) = convert_to_float<scalar_t>(data_vec);
+ Vectorized<float> min_vec = minimum(data_vec0, Vectorized<float>(float(0)));
+ Vectorized<float> buffer_vec0 = data_vec0.abs().neg().exp();
+ Vectorized<float> output_vec0 = min_vec - buffer_vec0.log1p();
+ min_vec = minimum(data_vec1, Vectorized<float>(float(0)));
+ Vectorized<float> buffer_vec1 = data_vec1.abs().neg().exp();
+ Vectorized<float> output_vec1 = min_vec - buffer_vec1.log1p();
+ convert_from_float<scalar_t>(buffer_vec0, buffer_vec1).store(buffer_data + begin + d);
+ convert_from_float<scalar_t>(output_vec0, output_vec1).store(output_data + begin + d);
}
if (size - d > 0) {
Vec data_vec = Vec::loadu(input_data + begin + d, size - d);
- Vec min_vec = vec::minimum(data_vec, Vec(scalar_t(0)));
- Vec buffer_vec = data_vec.abs().neg().exp();
- Vec output_vec = min_vec - buffer_vec.log1p();
- buffer_vec.store(buffer_data + begin + d, size - d);
- output_vec.store(output_data + begin + d, size - d);
+ Vectorized<float> data_vec0, data_vec1;
+ std::tie(data_vec0, data_vec1) = convert_to_float<scalar_t>(data_vec);
+ Vectorized<float> min_vec = minimum(data_vec0, Vectorized<float>(float(0)));
+ Vectorized<float> buffer_vec0 = data_vec0.abs().neg().exp();
+ Vectorized<float> output_vec0 = min_vec - buffer_vec0.log1p();
+ min_vec = minimum(data_vec1, Vectorized<float>(float(0)));
+ Vectorized<float> buffer_vec1 = data_vec1.abs().neg().exp();
+ Vectorized<float> output_vec1 = min_vec - buffer_vec1.log1p();
+ convert_from_float<scalar_t>(buffer_vec0, buffer_vec1).store(buffer_data + begin + d, size - d);
+ convert_from_float<scalar_t>(output_vec0, output_vec1).store(output_data + begin + d, size - d);
}
});
+ });
+ } else {
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "log_sigmoid_cpu", [&] {
+ using Vec = Vectorized<scalar_t>;
+ scalar_t* output_data = output.data_ptr<scalar_t>();
+ scalar_t* buffer_data = buffer.data_ptr<scalar_t>();
+ scalar_t* input_data = input.data_ptr<scalar_t>();
+ parallel_for(0, input.numel(), 1, [&] (int64_t begin, int64_t end) {
+ int64_t size = end - begin;
+ int64_t d = 0;
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
+ Vec data_vec = Vec::loadu(input_data + begin+ d);
+ Vec min_vec = vec::minimum(data_vec, Vec(scalar_t(0)));
+ Vec buffer_vec = data_vec.abs().neg().exp();
+ Vec output_vec = min_vec - buffer_vec.log1p();
+ buffer_vec.store(buffer_data + begin + d);
+ output_vec.store(output_data + begin + d);
+ }
+ if (size - d > 0) {
+ Vec data_vec = Vec::loadu(input_data + begin + d, size - d);
+ Vec min_vec = vec::minimum(data_vec, Vec(scalar_t(0)));
+ Vec buffer_vec = data_vec.abs().neg().exp();
+ Vec output_vec = min_vec - buffer_vec.log1p();
+ buffer_vec.store(buffer_data + begin + d, size - d);
+ output_vec.store(output_data + begin + d, size - d);
+ }
+ });
+ });
}
}
-static void log_sigmoid_cpu_kernel(TensorBase &output, TensorBase &buffer, const TensorBase &input) {
- AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, input.scalar_type(), "log_sigmoid_cpu", [&] {
- _vec_log_sigmoid<scalar_t>(output, buffer, input);
- });
-}
-
static void log_sigmoid_backward_cpu_kernel(TensorIterator& iter) {
- if (iter.dtype() == kBFloat16) {
- using Vec = Vectorized<BFloat16>;
- auto zero_val = float(0);
- auto zero_vec = Vectorized<float>(zero_val);
- auto one_val = float(1);
- auto one_vec = Vectorized<float>(one_val);
- cpu_kernel_vec(iter,
- [=](BFloat16 a, BFloat16 b, BFloat16 c) -> BFloat16 {
- auto in_negative = float(a) < float(0);
- auto max_deriv = in_negative ? float(1) : float(0);
- auto sign = in_negative ? float(1) : -float(1);
- return (max_deriv - sign * (float(b) / (float(1) + b))) * float(c);
- },
- [=](Vec a, Vec b, Vec c) -> Vec {
- Vectorized<float> a0, a1, b0, b1, c0, c1;
- std::tie(a0, a1) = convert_bfloat16_float(a);
- std::tie(b0, b1) = convert_bfloat16_float(b);
- std::tie(c0, c1) = convert_bfloat16_float(c);
- auto mask = a0 < zero_vec;
- auto max_deriv_vec = Vectorized<float>::blendv(zero_vec, one_vec, mask);
- auto sign_vec = Vectorized<float>::blendv(one_vec.neg(), one_vec, mask);
- a0 = (max_deriv_vec - sign_vec * (b0 / (one_vec + b0))) * c0;
- mask = a1 < zero_vec;
- max_deriv_vec = Vectorized<float>::blendv(zero_vec, one_vec, mask);
- sign_vec = Vectorized<float>::blendv(one_vec.neg(), one_vec, mask);
- a1 = (max_deriv_vec - sign_vec * (b1 / (one_vec + b1))) * c1;
- return convert_float_bfloat16(a0, a1);
- });
+ if (at::isReducedFloatingType(iter.dtype())) {
+ AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "log_sigmoid_backward_cpu", [&]() {
+ using Vec = Vectorized<scalar_t>;
+ auto zero_val = float(0);
+ auto zero_vec = Vectorized<float>(zero_val);
+ auto one_val = float(1);
+ auto one_vec = Vectorized<float>(one_val);
+ cpu_kernel_vec(iter,
+ [=](scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
+ auto in_negative = float(a) < float(0);
+ auto max_deriv = in_negative ? float(1) : float(0);
+ auto sign = in_negative ? float(1) : -float(1);
+ return (max_deriv - sign * (float(b) / (float(1) + b))) * float(c);
+ },
+ [=](Vec a, Vec b, Vec c) -> Vec {
+ Vectorized<float> a0, a1, b0, b1, c0, c1;
+ std::tie(a0, a1) = convert_to_float<scalar_t>(a);
+ std::tie(b0, b1) = convert_to_float<scalar_t>(b);
+ std::tie(c0, c1) = convert_to_float<scalar_t>(c);
+ auto mask = a0 < zero_vec;
+ auto max_deriv_vec = Vectorized<float>::blendv(zero_vec, one_vec, mask);
+ auto sign_vec = Vectorized<float>::blendv(one_vec.neg(), one_vec, mask);
+ a0 = (max_deriv_vec - sign_vec * (b0 / (one_vec + b0))) * c0;
+ mask = a1 < zero_vec;
+ max_deriv_vec = Vectorized<float>::blendv(zero_vec, one_vec, mask);
+ sign_vec = Vectorized<float>::blendv(one_vec.neg(), one_vec, mask);
+ a1 = (max_deriv_vec - sign_vec * (b1 / (one_vec + b1))) * c1;
+ return convert_from_float<scalar_t>(a0, a1);
+ });
+ });
} else {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "log_sigmoid_backward_cpu", [&]() {
using Vec = Vectorized<scalar_t>;
@@ -151,7 +151,7 @@
TensorIteratorBase& iter,
const Scalar& threshold_scalar,
const Scalar& value_scalar) {
- AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "threshold_cpu", [&] {
+ AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "threshold_cpu", [&] {
using Vec = Vectorized<scalar_t>;
scalar_t threshold = threshold_scalar.to<scalar_t>();
Vec threshold_v = Vec(threshold);
@@ -766,23 +766,25 @@
}
void hardtanh_backward_kernel(TensorIterator& iter, const Scalar& min, const Scalar& max) {
- if (iter.dtype() == kBFloat16) {
- auto min_val = min.to<float>();
- auto max_val = max.to<float>();
- cpu_kernel_vec(
- iter,
- [=](BFloat16 grad_val, BFloat16 self_val) -> BFloat16 {
- return (float(self_val) <= min_val || float(self_val) >= max_val) ? BFloat16(0) : grad_val;
- },
- [=](Vectorized<BFloat16> grad_val, Vectorized<BFloat16> self_val) -> Vectorized<BFloat16> {
- Vectorized<float> grad_val0, grad_val1, self_val0, self_val1;
- std::tie(grad_val0, grad_val1) = convert_bfloat16_float(grad_val);
- std::tie(self_val0, self_val1) = convert_bfloat16_float(self_val);
- return convert_float_bfloat16(
- ((self_val0 > min_val) & (self_val0 < max_val)) & grad_val0,
- ((self_val1 > min_val) & (self_val1 < max_val)) & grad_val1
- );
- });
+ if (at::isReducedFloatingType(iter.dtype())) {
+ AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&]() {
+ auto min_val = min.to<float>();
+ auto max_val = max.to<float>();
+ cpu_kernel_vec(
+ iter,
+ [=](scalar_t grad_val, scalar_t self_val) -> scalar_t {
+ return (float(self_val) <= min_val || float(self_val) >= max_val) ? scalar_t(0) : grad_val;
+ },
+ [=](Vectorized<scalar_t> grad_val, Vectorized<scalar_t> self_val) -> Vectorized<scalar_t> {
+ Vectorized<float> grad_val0, grad_val1, self_val0, self_val1;
+ std::tie(grad_val0, grad_val1) = convert_to_float<scalar_t>(grad_val);
+ std::tie(self_val0, self_val1) = convert_to_float<scalar_t>(self_val);
+ return convert_from_float<scalar_t>(
+ ((self_val0 > min_val) & (self_val0 < max_val)) & grad_val0,
+ ((self_val1 > min_val) & (self_val1 < max_val)) & grad_val1
+ );
+ });
+ });
} else {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&] {
auto min_val = min.to<scalar_t>();
diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
index 748393a..f75aaa8 100644
--- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
@@ -875,23 +875,25 @@
return a * (one_vec - b * b).conj();
});
});
- } else if (iter.dtype() == kBFloat16) {
- auto one_vec = Vectorized<float>(float{1});
- cpu_kernel_vec(
- iter,
- [=](BFloat16 a, BFloat16 b) -> BFloat16 {
- float a0 = float(a);
- float b0 = float(b);
- return a0 * (float{1} - b0 * b0);
- },
- [=](Vectorized<BFloat16> a, Vectorized<BFloat16> b) {
- Vectorized<float> a0, a1, b0, b1;
- std::tie(a0, a1) = convert_bfloat16_float(a);
- std::tie(b0, b1) = convert_bfloat16_float(b);
- a0 = a0 * (one_vec - b0 * b0);
- a1 = a1 * (one_vec - b1 * b1);
- return convert_float_bfloat16(a0, a1);
- });
+ } else if (at::isReducedFloatingType(iter.dtype())) {
+ AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() {
+ auto one_vec = Vectorized<float>(float{1});
+ cpu_kernel_vec(
+ iter,
+ [=](scalar_t a, scalar_t b) -> scalar_t {
+ float a0 = float(a);
+ float b0 = float(b);
+ return a0 * (float{1} - b0 * b0);
+ },
+ [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
+ Vectorized<float> a0, a1, b0, b1;
+ std::tie(a0, a1) = convert_to_float<scalar_t>(a);
+ std::tie(b0, b1) = convert_to_float<scalar_t>(b);
+ a0 = a0 * (one_vec - b0 * b0);
+ a1 = a1 * (one_vec - b1 * b1);
+ return convert_from_float<scalar_t>(a0, a1);
+ });
+ });
} else {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() {
auto one_vec = Vectorized<scalar_t>(scalar_t{1});
diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
index c4e4c07..b6c8965 100644
--- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
+++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
@@ -354,7 +354,7 @@
}
static void clamp_scalar_kernel_impl(TensorIteratorBase& iter, const Scalar& min_, const Scalar& max_) {
- AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.common_dtype(), "clamp_scalar_cpu", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "clamp_scalar_cpu", [&]() {
const auto min = min_.to<scalar_t>();
const auto max = max_.to<scalar_t>();
const Vectorized<scalar_t> min_vec(min);
@@ -384,7 +384,7 @@
}
static void clamp_min_scalar_kernel_impl(TensorIteratorBase& iter, Scalar min_) {
- AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.common_dtype(), "clamp_min_scalar_cpu", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "clamp_min_scalar_cpu", [&]() {
const auto min = min_.to<scalar_t>();
const Vectorized<scalar_t> min_vec(min);
cpu_kernel_vec(iter,
diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
index 31bc288..3165c90 100644
--- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
@@ -337,7 +337,7 @@
}
static void sinc_kernel(TensorIteratorBase& iter) {
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.common_dtype(), "sinc_cpu", [&]() {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "sinc_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t {
@@ -370,7 +370,7 @@
}
static void acosh_kernel(TensorIteratorBase& iter) {
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), "acosh_cpu", [&]() {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "acosh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::acosh(a); });
@@ -378,7 +378,7 @@
}
static void asinh_kernel(TensorIteratorBase& iter) {
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), "asinh_cpu", [&]() {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "asinh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::asinh(a); });
@@ -386,7 +386,7 @@
}
static void atanh_kernel(TensorIteratorBase& iter) {
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), "atanh_cpu", [&]() {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "atanh_cpu", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return std::atanh(a); });
@@ -394,7 +394,7 @@
}
static void digamma_kernel(TensorIteratorBase& iter) {
- AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.common_dtype(), "digamma", [&]() {
+ AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "digamma", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return calc_digamma(a); });
@@ -402,7 +402,7 @@
}
static void trigamma_kernel(TensorIteratorBase& iter) {
- AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "trigamma", [&]() {
+ AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "trigamma", [&]() {
cpu_kernel(
iter,
[=](scalar_t a) -> scalar_t { return trigamma(a); });
@@ -469,7 +469,7 @@
}
void rsqrt_kernel(TensorIteratorBase& iter) {
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.common_dtype(), "rsqrt_cpu", [&] {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "rsqrt_cpu", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
@@ -690,7 +690,7 @@
inline namespace CPU_CAPABILITY { \
void op##_kernel(TensorIteratorBase& iter) { \
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); \
- AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), #op "_vml_cpu", [&]() { \
+ AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), #op "_vml_cpu", [&]() { \
constexpr int64_t grain_size = 2048; \
iter.for_each(IMPLEMENT_ITERATOR_LAMBDA(op), grain_size); \
}); \
@@ -703,7 +703,7 @@
inline namespace CPU_CAPABILITY { \
void op##_kernel(TensorIteratorBase& iter) { \
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); \
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), #op "_vml_cpu", [&]() { \
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), #op "_vml_cpu", [&]() { \
constexpr int64_t grain_size = 2048; \
iter.for_each(IMPLEMENT_ITERATOR_LAMBDA(op), grain_size); \
}); \
diff --git a/c10/util/BFloat16-math.h b/c10/util/BFloat16-math.h
index e990d55..c02472b 100644
--- a/c10/util/BFloat16-math.h
+++ b/c10/util/BFloat16-math.h
@@ -1,6 +1,7 @@
#pragma once
-#include <c10/util/BFloat16-inl.h>
+#include <c10/util/BFloat16.h>
+#include <c10/util/Half.h>
#include <c10/util/math_compat.h>
C10_CLANG_DIAGNOSTIC_PUSH()
@@ -10,95 +11,192 @@
namespace std {
-/// Used by vec256<c10::BFloat16>::map
-inline c10::BFloat16 acos(c10::BFloat16 a) {
+template <typename T>
+struct is_reduced_floating_point
+ : std::integral_constant<
+ bool,
+ std::is_same<T, c10::Half>::value ||
+ std::is_same<T, c10::BFloat16>::value> {};
+
+template <typename T>
+constexpr bool is_reduced_floating_point_v =
+ is_reduced_floating_point<T>::value;
+
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T acos(T a) {
return std::acos(float(a));
}
-inline c10::BFloat16 asin(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T asin(T a) {
return std::asin(float(a));
}
-inline c10::BFloat16 atan(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T atan(T a) {
return std::atan(float(a));
}
-inline c10::BFloat16 erf(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T erf(T a) {
return std::erf(float(a));
}
-inline c10::BFloat16 erfc(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T erfc(T a) {
return std::erfc(float(a));
}
-inline c10::BFloat16 exp(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T exp(T a) {
return std::exp(float(a));
}
-inline c10::BFloat16 expm1(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T expm1(T a) {
return std::expm1(float(a));
}
-inline c10::BFloat16 log(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T log(T a) {
return std::log(float(a));
}
-inline c10::BFloat16 log10(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T log10(T a) {
return std::log10(float(a));
}
-inline c10::BFloat16 log1p(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T log1p(T a) {
return std::log1p(float(a));
}
-inline c10::BFloat16 log2(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T log2(T a) {
return std::log2(float(a));
}
-inline c10::BFloat16 ceil(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T ceil(T a) {
return std::ceil(float(a));
}
-inline c10::BFloat16 cos(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T cos(T a) {
return std::cos(float(a));
}
-inline c10::BFloat16 floor(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T floor(T a) {
return std::floor(float(a));
}
-inline c10::BFloat16 nearbyint(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T nearbyint(T a) {
return std::nearbyint(float(a));
}
-inline c10::BFloat16 sin(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T sin(T a) {
return std::sin(float(a));
}
-inline c10::BFloat16 tan(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T tan(T a) {
return std::tan(float(a));
}
-inline c10::BFloat16 sinh(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T sinh(T a) {
return std::sinh(float(a));
}
-inline c10::BFloat16 cosh(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T cosh(T a) {
return std::cosh(float(a));
}
-inline c10::BFloat16 tanh(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T tanh(T a) {
return std::tanh(float(a));
}
-inline c10::BFloat16 trunc(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T trunc(T a) {
return std::trunc(float(a));
}
-inline c10::BFloat16 lgamma(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T lgamma(T a) {
return std::lgamma(float(a));
}
-inline c10::BFloat16 sqrt(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T sqrt(T a) {
return std::sqrt(float(a));
}
-inline c10::BFloat16 rsqrt(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T rsqrt(T a) {
return 1.0 / std::sqrt(float(a));
}
-inline c10::BFloat16 abs(c10::BFloat16 a) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T abs(T a) {
return std::abs(float(a));
}
#if defined(_MSC_VER) && defined(__CUDACC__)
-inline c10::BFloat16 pow(c10::BFloat16 a, double b) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T pow(T a, double b) {
return std::pow(float(a), float(b));
}
#else
-inline c10::BFloat16 pow(c10::BFloat16 a, double b) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T pow(T a, double b) {
return std::pow(float(a), b);
}
#endif
-inline c10::BFloat16 pow(c10::BFloat16 a, c10::BFloat16 b) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T pow(T a, T b) {
return std::pow(float(a), float(b));
}
-inline c10::BFloat16 fmod(c10::BFloat16 a, c10::BFloat16 b) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T fmod(T a, T b) {
return std::fmod(float(a), float(b));
}
@@ -128,13 +226,14 @@
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
----------------------------------------------------------------------
*/
-C10_HOST_DEVICE inline c10::BFloat16 nextafter(
- c10::BFloat16 from,
- c10::BFloat16 to) {
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+C10_HOST_DEVICE inline T nextafter(T from, T to) {
// Reference:
// https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c
using int_repr_t = uint16_t;
- using float_t = c10::BFloat16;
+ using float_t = T;
constexpr uint8_t bits = 16;
union {
float_t f;
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index 8446d81..d133027 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -189,11 +189,11 @@
"nn.functional.avg_pool2d": {i64},
"nn.functional.adaptive_avg_pool2d": {f16},
"nn.functional.ctc_loss": {f32, f64},
- "nn.functional.gaussian_nll_loss": {f32, f64},
+ "nn.functional.gaussian_nll_loss": {f16, f32, f64},
"nn.functional.local_response_norm": {i64},
"nn.functional.one_hot": {i64},
"nn.functional.rrelu": {f32, f64},
- "nn.functional.triplet_margin_with_distance_loss": {f32, f64, i32, i64},
+ "nn.functional.triplet_margin_with_distance_loss": {f16, f32, f64, i32, i64},
"nonzero": {b8, f16, f32, f64, i32, i64},
"normal": {f16, f32, f64},
("normal", "number_mean"): {f16, f32, f64},
diff --git a/test/test_meta.py b/test/test_meta.py
index 6a4c9da..100fc93 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -616,7 +616,7 @@
torch.mode : {f64, i32, i64, f16, u8, i16, bf16, b8, i8, f32},
torch.multinomial : {f64, bf16, f32},
torch.nn.functional.ctc_loss : {f64, f32},
- torch.nn.functional.gaussian_nll_loss : {f64, bf16, f32},
+ torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
torch.nn.functional.max_pool3d : {f64, f32},
torch.nn.functional.max_pool3d_with_indices : {f64, f32},
torch.nn.functional.max_unpool1d : {f64, f32},
@@ -720,7 +720,6 @@
torch.matrix_exp: {f16}, # aten::linalg_matrix_exp
torch.median: {f16}, # aten::median, aten::median.dim_values
torch.multinomial: {f16}, # aten::multinomial, aten::multinomial.out
- torch.nn.functional.gaussian_nll_loss: {f16}, # aten::_local_scalar_dense
torch.nn.functional.max_pool3d: {bf16, f16}, # aten::max_pool3d_with_indices
torch.nn.functional.max_pool3d_with_indices: {bf16, f16}, # aten::max_pool3d_with_indices
torch.nn.functional.max_unpool1d: {f16}, # aten::max_unpool2d
diff --git a/test/test_mps.py b/test/test_mps.py
index e8c8879..60ef643 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -89,9 +89,6 @@
'floor_divide': [torch.float16, torch.float32],
# derivative for aten::narrow_copy is not implemented on CPU
'narrow_copy': [torch.float16, torch.float32],
- # RuntimeError: "log_vml_cpu" not implemented for 'Half'
- '__rpow__': [torch.float16],
- 'pow': [torch.float16],
# 'bool' object is not iterable
'allclose': [torch.float16, torch.float32],
'equal': [torch.float16, torch.float32],
@@ -119,6 +116,9 @@
# trunc_tensor not working properly for float16
'divtrunc_rounding': [torch.float16],
'fmod': [torch.float16],
+
+ # round not working properly for float16
+ 'round': [torch.float16],
}
MACOS_12_3_XFAILLIST_GRAD = {
@@ -644,6 +644,9 @@
# trunc_tensor not working properly for float16
'divtrunc_rounding': [torch.float16],
'fmod': [torch.float16],
+
+ # round not working properly for float16
+ 'round': [torch.float16],
}
UNDEFINED_XFAILLIST = {
@@ -10364,6 +10367,12 @@
'linalg.vector_norm',
'addr', 'var_mean',
'var_mean_unbiased',
+ 'acosh', 'asinh', 'asin',
+ 'masked.std',
+ 'nn.functional.normalize',
+ 'nn.functional.triplet_margin_loss',
+ 'nn.functional.triplet_margin_with_distance_loss',
+ 'round', 'xlogy',
# for macOS 12
'masked.normalize', 'masked.sum', 'masked.var',
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index f47674c..79e1b30 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -6104,7 +6104,7 @@
variant_test_name=variant_test_name,
domain=domain,
decorators=(precisionOverride({torch.float16: 5e-2}),),
- dtypes=all_types_and(torch.bfloat16),
+ dtypes=all_types_and(torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.float16),
sample_inputs_func=sample_inputs_mvlgamma,
supports_forward_ad=True,
@@ -8945,7 +8945,7 @@
aliases=('arccos', ),
ref=np.arccos,
domain=(-1, 1),
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_forward_ad=True,
@@ -8985,7 +8985,7 @@
aliases=('arccosh', ),
ref=np.arccosh,
domain=(1, None),
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
supports_inplace_autograd=False,
@@ -9590,7 +9590,7 @@
supports_sparse_bsc=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
decorators=[
@@ -9617,7 +9617,7 @@
UnaryUfuncInfo('asinh',
aliases=('arcsinh', ),
ref=np.arcsinh,
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
supports_inplace_autograd=False,
@@ -9649,7 +9649,7 @@
UnaryUfuncInfo('atan',
aliases=('arctan', ),
ref=np.arctan,
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_forward_ad=True,
@@ -9696,7 +9696,7 @@
aliases=('arctanh', ),
ref=np.arctanh,
domain=(-1, 1),
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
supports_inplace_autograd=False,
@@ -9871,8 +9871,7 @@
sample_inputs_func=sample_inputs_cdist),
UnaryUfuncInfo('ceil',
ref=np.ceil,
- dtypes=all_types_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
+ dtypes=all_types_and(torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
@@ -10104,7 +10103,7 @@
supports_out=False),
UnaryUfuncInfo('cos',
ref=np.cos,
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
handles_large_floats=False,
@@ -10338,7 +10337,7 @@
)),
UnaryUfuncInfo('exp',
ref=np_unary_ufunc_integer_promotion_wrapper(np.exp),
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
skips=(
# Reference: https://github.com/pytorch/pytorch/pull/50093#pullrequestreview-561791547
@@ -10569,8 +10568,7 @@
)),
UnaryUfuncInfo('floor',
ref=np.floor,
- dtypes=all_types_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
+ dtypes=all_types_and(torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
@@ -10704,8 +10702,7 @@
aliases=('special.i0',),
decorators=(precisionOverride({torch.bfloat16: 3e-1,
torch.float16: 5e-1}),),
- dtypes=all_types_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+ dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
backward_dtypes=floating_types(),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -10768,8 +10765,7 @@
ref=np.log1p,
aliases=('special.log1p',),
domain=(-1, None),
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
decorators=(precisionOverride({torch.bfloat16: 1e-1}),),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -10922,7 +10918,7 @@
UnaryUfuncInfo('log',
ref=np.log,
domain=(0, None),
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.chalf),
assert_autodiffed=True,
@@ -10940,9 +10936,8 @@
ref=np.log10,
domain=(0, None),
decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
- dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
@@ -10955,8 +10950,7 @@
UnaryUfuncInfo('log2',
ref=np.log2,
domain=(0, None),
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -11629,8 +11623,7 @@
)
),
OpInfo('nn.functional.normalize',
- dtypes=floating_and_complex_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+ dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_normalize,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True),
@@ -11806,8 +11799,7 @@
),
OpInfo('nn.functional.cosine_similarity',
aten_name="cosine_similarity",
- dtypes=floating_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+ dtypes=floating_types_and(torch.half, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -11954,8 +11946,7 @@
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
- dtypes=floating_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+ dtypes=floating_types_and(torch.half, torch.bfloat16),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
sample_inputs_func=sample_inputs_binary_cross_entropy_with_logits,
skips=(
@@ -11977,8 +11968,7 @@
supports_sparse_csc=True,
supports_sparse_bsr=True,
supports_sparse_bsc=True,
- dtypes=all_types_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
+ dtypes=all_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_nn_activation_relu,
supports_out=False,
supports_fwgrad_bwgrad=True,
@@ -12519,8 +12509,7 @@
)),
OpInfo(
"nn.functional.soft_margin_loss",
- dtypes=floating_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
+ dtypes=floating_types_and(torch.half, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
# doesn't support grad on target
@@ -12544,8 +12533,7 @@
supports_out=False),
OpInfo(
"nn.functional.margin_ranking_loss",
- dtypes=all_types_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
+ dtypes=all_types_and(torch.half, torch.bfloat16),
supports_out=False,
sample_inputs_func=sample_inputs_margin_ranking_loss,
error_inputs_func=error_inputs_margin_ranking_loss,
@@ -12587,8 +12575,7 @@
OpInfo(
"nn.functional.multilabel_soft_margin_loss",
supports_out=False,
- dtypes=floating_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+ dtypes=floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_multilabel_soft_margin_loss,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -13190,8 +13177,7 @@
aten_name="log_sigmoid",
aten_backward_name='log_sigmoid_backward',
ref=reference_logsigmoid,
- dtypes=floating_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+ dtypes=floating_types_and(torch.half, torch.bfloat16),
supports_autograd=True,
assert_autodiffed=False,
supports_forward_ad=True,
@@ -13252,8 +13238,7 @@
UnaryUfuncInfo(
'nn.functional.tanhshrink',
ref=lambda x: x - np.tanh(x),
- dtypes=all_types_and_complex_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_autograd=True,
@@ -13285,8 +13270,7 @@
UnaryUfuncInfo(
'nn.functional.threshold',
ref=lambda x, threshold, value: np.where(x <= threshold, value, x).astype(x.dtype),
- dtypes=all_types_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and(torch.half, torch.bfloat16),
inplace_variant=lambda x, threshold, value:
torch.nn.functional.threshold(x, threshold, value, inplace=True),
supports_forward_ad=True,
@@ -13306,8 +13290,7 @@
"nn.functional.triplet_margin_loss",
sample_inputs_func=sample_inputs_triplet_margin_loss,
error_inputs_func=error_inputs_triplet_margin_loss,
- dtypes=all_types_and_complex_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -13316,8 +13299,7 @@
"nn.functional.triplet_margin_with_distance_loss",
sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True),
error_inputs_func=error_inputs_triplet_margin_loss,
- dtypes=all_types_and_complex_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -13585,10 +13567,8 @@
UnaryUfuncInfo('nn.functional.hardtanh',
aten_name="hardtanh",
aten_backward_name='hardtanh_backward',
- dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16),
- backward_dtypes=all_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.float16,
- torch.bfloat16),
+ dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.half, torch.bfloat16),
+ backward_dtypes=all_types_and(torch.half, torch.bfloat16),
backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
assert_autodiffed=True,
sample_inputs_func=sample_inputs_hardtanh,
@@ -13618,10 +13598,8 @@
)),
UnaryUfuncInfo('nn.functional.relu6',
aten_name="relu6",
- dtypes=all_types_and(torch.bfloat16),
- backward_dtypes=floating_types_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
- backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and(torch.half, torch.bfloat16),
+ backward_dtypes=floating_types_and(torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_out=False,
supports_forward_ad=True,
@@ -13801,7 +13779,7 @@
# Due to AVX2 currently not being fully supported for Float16, log_vml_cpu can't be enabled
# for Float16, causing this test to fail. pow's autograd for Float16 is thus currently
# unsupported on CPU.
- backward_dtypes=floating_and_complex_types_and(torch.bfloat16),
+ backward_dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf),
# https://github.com/pytorch/pytorch/issues/80411
gradcheck_fast_mode=True,
@@ -13952,8 +13930,7 @@
UnaryUfuncInfo('round',
ref=np.round,
aliases=('special.round',),
- dtypes=all_types_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
+ dtypes=all_types_and(torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
@@ -13982,8 +13959,7 @@
ref=np.round,
variant_test_name='decimals_0',
aliases=('special.round',),
- dtypes=floating_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
+ dtypes=floating_types_and(torch.half, torch.bfloat16),
sample_kwargs=lambda device, dtype, input: ({'decimals': 0}, {'decimals': 0}),
sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 0}),
supports_forward_ad=True,
@@ -14038,7 +14014,7 @@
supports_sparse_csr=False),
UnaryUfuncInfo('sin',
ref=np.sin,
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
handles_large_floats=False,
@@ -14064,8 +14040,7 @@
UnaryUfuncInfo('sinc',
ref=np_sinc_with_fp16_as_fp32,
aliases=('special.sinc',),
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
handles_large_floats=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -14312,8 +14287,7 @@
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
# Reference: https://github.com/pytorch/pytorch/issues/54774
# "log2" "_vml_cpu" not implemented for Half
- backward_dtypes=all_types_and_complex_and(torch.bfloat16),
- backward_dtypesIfCUDA=all_types_and_complex_and(torch.bfloat16, torch.half),
+ backward_dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -14390,7 +14364,7 @@
supports_autograd=False,),
UnaryUfuncInfo('tan',
ref=np.tan,
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_forward_ad=True,
@@ -14424,7 +14398,7 @@
aten_backward_name='tanh_backward',
aliases=('nn.functional.tanh',),
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
assert_jit_shape_analysis=True,
@@ -14516,8 +14490,7 @@
UnaryUfuncInfo('trunc',
aliases=('fix', ),
ref=np.trunc,
- dtypes=all_types_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and(torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_sparse=True,
@@ -14555,8 +14528,7 @@
UnaryUfuncInfo('expm1',
aliases=('special.expm1', ),
ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1),
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_sparse=True,
@@ -14617,7 +14589,7 @@
UnaryUfuncInfo('rsqrt',
ref=lambda x: np.reciprocal(np.sqrt(x)),
domain=(0, None),
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
decorators=(precisionOverride({torch.half: 5e-2}),),
assert_autodiffed=True,
@@ -14636,7 +14608,7 @@
ref=np.sqrt,
supports_sparse=True,
domain=(0, None),
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_forward_ad=True,
@@ -16456,8 +16428,7 @@
sample_inputs_func=sample_inputs_zero_),
OpInfo('logsumexp',
aliases=('special.logsumexp',),
- dtypes=all_types_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.half),
+ dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -16686,7 +16657,7 @@
ref=scipy.special.digamma if TEST_SCIPY else None,
aliases=('special.psi', 'special.digamma',),
decorators=(precisionOverride({torch.float16: 5e-1}),),
- dtypes=all_types_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True),
@@ -16700,8 +16671,7 @@
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
),
- dtypes=all_types_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+ dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
assert_jit_shape_analysis=True,
supports_sparse=True,
@@ -16716,8 +16686,7 @@
aliases=('special.erfc', ),
decorators=(precisionOverride({torch.float16: 1e-2,
torch.bfloat16: 1e-2}),),
- dtypes=all_types_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+ dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True),
@@ -16727,7 +16696,7 @@
decorators=(precisionOverride({torch.float16: 1e-2,
torch.bfloat16: 1e-2,
torch.float32: 1e-4}),),
- dtypes=all_types_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
supports_sparse_csr=True,
supports_sparse_csc=True,
@@ -16783,7 +16752,7 @@
ref=reference_lgamma if TEST_SCIPY else None,
aliases=('special.gammaln', ),
decorators=(precisionOverride({torch.float16: 7e-1}),),
- dtypes=all_types_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -17825,8 +17794,7 @@
),
OpInfo(
"nn.functional.cosine_embedding_loss",
- dtypes=all_types_and(torch.bfloat16, torch.bool),
- dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+ dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -17853,8 +17821,7 @@
),
OpInfo(
"nn.functional.gaussian_nll_loss",
- dtypes=floating_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+ dtypes=floating_types_and(torch.half, torch.bfloat16),
# Runs very slowly on slow gradcheck - alternatively reduce input sizes
gradcheck_fast_mode=True,
supports_out=False,
@@ -17882,8 +17849,7 @@
),
OpInfo(
"nn.functional.hinge_embedding_loss",
- dtypes=floating_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+ dtypes=floating_types_and(torch.half, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -17920,8 +17886,7 @@
),
OpInfo(
"nn.functional.poisson_nll_loss",
- dtypes=all_types_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and(torch.half, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py
index 5d69358..6fce19a 100644
--- a/torch/testing/_internal/opinfo/definitions/_masked.py
+++ b/torch/testing/_internal/opinfo/definitions/_masked.py
@@ -1012,8 +1012,7 @@
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
promotes_int_to_float=True,
- dtypes=all_types_and_complex_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
DecorateInfo(
@@ -1158,14 +1157,6 @@
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
- # RuntimeError: "clamp_min_cpu" not implemented for 'Half'
- DecorateInfo(
- unittest.expectedFailure,
- "TestMasked",
- "test_reference_masked",
- device_type="cpu",
- dtypes=[torch.half],
- ),
),
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
# Runs very slowly on slow gradcheck - alternatively reduce input sizes
@@ -1204,8 +1195,7 @@
),
ReductionOpInfo(
"masked.logsumexp",
- dtypes=all_types_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and(torch.half, torch.bfloat16),
method_variant=None,
nan_policy="propagate",
supports_out=False,
diff --git a/torch/testing/_internal/opinfo/definitions/special.py b/torch/testing/_internal/opinfo/definitions/special.py
index 9570ce4..0e68df6 100644
--- a/torch/testing/_internal/opinfo/definitions/special.py
+++ b/torch/testing/_internal/opinfo/definitions/special.py
@@ -157,8 +157,7 @@
aten_name="special_ndtr",
decorators=(precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-4}),),
ref=scipy.special.ndtr if TEST_SCIPY else None,
- dtypes=all_types_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.float16),
+ dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(